1--
2-- Licensed to the Apache Software Foundation (ASF) under one
3-- or more contributor license agreements. See the NOTICE file
4-- distributed with this work for additional information
5-- regarding copyright ownership. The ASF licenses this file
6-- to you under the Apache License, Version 2.0 (the
7-- "License"), you may not use this file except in compliance
8-- with the License. You may obtain a copy of the License at
9--
10--   http://www.apache.org/licenses/LICENSE-2.0
11--
12-- Unless required by applicable law or agreed to in writing,
13-- software distributed under the License is distributed on an
14-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15-- KIND, either express or implied. See the License for the
16-- specific language governing permissions and limitations
17-- under the License.
18--
19
20require 'TProtocol'
21require 'libluabpack'
22require 'libluabitwise'
23
24TJSONProtocol = __TObject.new(TProtocolBase, {
25  __type = 'TJSONProtocol',
26  THRIFT_JSON_PROTOCOL_VERSION = 1,
27  jsonContext = {},
28  jsonContextVal = {first = true, colon = true, ttype = 2, null = true},
29  jsonContextIndex = 1,
30  hasReadByte = ""
31})
32
33TTypeToString = {}
34TTypeToString[TType.BOOL]   = "tf"
35TTypeToString[TType.BYTE]   = "i8"
36TTypeToString[TType.I16]    = "i16"
37TTypeToString[TType.I32]    = "i32"
38TTypeToString[TType.I64]    = "i64"
39TTypeToString[TType.DOUBLE] = "dbl"
40TTypeToString[TType.STRING] = "str"
41TTypeToString[TType.STRUCT] = "rec"
42TTypeToString[TType.LIST]   = "lst"
43TTypeToString[TType.SET]    = "set"
44TTypeToString[TType.MAP]    = "map"
45
46StringToTType = {
47  tf  = TType.BOOL,
48  i8  = TType.BYTE,
49  i16 = TType.I16,
50  i32 = TType.I32,
51  i64 = TType.I64,
52  dbl = TType.DOUBLE,
53  str = TType.STRING,
54  rec = TType.STRUCT,
55  map = TType.MAP,
56  set = TType.SET,
57  lst = TType.LIST
58}
59
60JSONNode = {
61  ObjectBegin = '{',
62  ObjectEnd = '}',
63  ArrayBegin = '[',
64  ArrayEnd = ']',
65  PairSeparator = ':',
66  ElemSeparator = ',',
67  Backslash = '\\',
68  StringDelimiter = '"',
69  ZeroChar = '0',
70  EscapeChar = 'u',
71  Nan = 'NaN',
72  Infinity = 'Infinity',
73  NegativeInfinity = '-Infinity',
74  EscapeChars = "\"\\bfnrt",
75  EscapePrefix = "\\u00"
76}
77
78EscapeCharVals = {
79  '"', '\\', '\b', '\f', '\n', '\r', '\t'
80}
81
82JSONCharTable = {
83  --0   1   2   3   4   5   6   7   8   9   A   B   C   D   E   F
84    0,  0,  0,  0,  0,  0,  0,  0, 98,116,110,  0,102,114,  0,  0,
85    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
86    1,  1,34,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
87}
88
89-- character table string
90local b='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/'
91
92-- encoding
93function base64_encode(data)
94    return ((data:gsub('.', function(x)
95        local r,b='',x:byte()
96        for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end
97        return r;
98    end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x)
99        if (#x < 6) then return '' end
100        local c=0
101        for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end
102        return b:sub(c+1,c+1)
103    end)..({ '', '==', '=' })[#data%3+1])
104end
105
106-- decoding
107function base64_decode(data)
108    data = string.gsub(data, '[^'..b..'=]', '')
109    return (data:gsub('.', function(x)
110        if (x == '=') then return '' end
111        local r,f='',(b:find(x)-1)
112        for i=6,1,-1 do r=r..(f%2^i-f%2^(i-1)>0 and '1' or '0') end
113        return r;
114    end):gsub('%d%d%d?%d?%d?%d?%d?%d?', function(x)
115        if (#x ~= 8) then return '' end
116        local c=0
117        for i=1,8 do c=c+(x:sub(i,i)=='1' and 2^(8-i) or 0) end
118        return string.char(c)
119    end))
120end
121
122function TJSONProtocol:resetContext()
123  self.jsonContext = {}
124  self.jsonContextVal = {first = true, colon = true, ttype = 2, null = true}
125  self.jsonContextIndex = 1
126end
127
128function TJSONProtocol:contextPush(context)
129  self.jsonContextIndex = self.jsonContextIndex + 1
130  self.jsonContext[self.jsonContextIndex] = self.jsonContextVal
131  self.jsonContextVal = context
132end
133
134function TJSONProtocol:contextPop()
135  self.jsonContextVal = self.jsonContext[self.jsonContextIndex]
136  self.jsonContextIndex = self.jsonContextIndex - 1
137end
138
139function TJSONProtocol:escapeNum()
140  if self.jsonContextVal.ttype == 1 then
141    return self.jsonContextVal.colon
142  else
143    return false
144  end
145end
146
147function TJSONProtocol:writeElemSeparator()
148  if self.jsonContextVal.null then
149    return
150  end
151  if self.jsonContextVal.first then
152    self.jsonContextVal.first = false
153  else
154    if self.jsonContextVal.ttype == 1 then
155      if self.jsonContextVal.colon then
156        self.trans:write(JSONNode.PairSeparator)
157        self.jsonContextVal.colon = false
158      else
159        self.trans:write(JSONNode.ElemSeparator)
160        self.jsonContextVal.colon = true
161      end
162    else
163      self.trans:write(JSONNode.ElemSeparator)
164    end
165  end
166end
167
168function TJSONProtocol:hexChar(val)
169  val = libluabitwise.band(val, 0x0f)
170  if val < 10 then
171    return val + 48
172  else
173    return val + 87
174  end
175end
176
177function TJSONProtocol:writeJSONEscapeChar(ch)
178  self.trans:write(JSONNode.EscapePrefix)
179  local outCh = hexChar(libluabitwise.shiftr(ch, 4))
180  local buff = libluabpack.bpack('c', outCh)
181  self.trans:write(buff)
182  outCh = hexChar(ch)
183  buff = libluabpack.bpack('c', outCh)
184  self.trans:write(buff)
185end
186
187function TJSONProtocol:writeJSONChar(byte)
188  ch = string.byte(byte)
189  if ch >= 0x30 then
190    if ch == JSONNode.Backslash then
191      self.trans:write(JSONNode.Backslash)
192      self.trans:write(JSONNode.Backslash)
193    else
194      self.trans:write(byte)
195    end
196  else
197    local outCh = JSONCharTable[ch+1]
198    if outCh == 1 then
199      self.trans:write(byte)
200    elseif outCh > 1 then
201      self.trans:write(JSONNode.Backslash)
202      local buff = libluabpack.bpack('c', outCh)
203      self.trans:write(buff)
204    else
205      self:writeJSONEscapeChar(ch)
206    end
207  end
208end
209
210function TJSONProtocol:writeJSONString(str)
211  self:writeElemSeparator()
212  self.trans:write(JSONNode.StringDelimiter)
213  -- TODO escape special characters
214  local length = string.len(str)
215  local ii = 1
216  while ii <= length do
217    self:writeJSONChar(string.sub(str, ii, ii))
218    ii = ii + 1
219  end
220  self.trans:write(JSONNode.StringDelimiter)
221end
222
223function TJSONProtocol:writeJSONBase64(str)
224  self:writeElemSeparator()
225  self.trans:write(JSONNode.StringDelimiter)
226  local length = string.len(str)
227  local offset = 1
228  while length >= 3 do
229    -- Encode 3 bytes at a time
230    local bytes = base64_encode(string.sub(str, offset, offset+3))
231    self.trans:write(bytes)
232    length = length - 3
233    offset = offset + 3
234  end
235  if length > 0 then
236    local bytes = base64_encode(string.sub(str, offset, offset+length))
237    self.trans:write(bytes)
238  end
239  self.trans:write(JSONNode.StringDelimiter)
240end
241
242function TJSONProtocol:writeJSONInteger(num)
243  self:writeElemSeparator()
244  if self:escapeNum() then
245    self.trans:write(JSONNode.StringDelimiter)
246  end
247  local numstr = "" .. num
248  numstr = string.sub(numstr, string.find(numstr, "^[+-]?%d+"))
249  self.trans:write(numstr)
250  if self:escapeNum() then
251    self.trans:write(JSONNode.StringDelimiter)
252  end
253end
254
255function TJSONProtocol:writeJSONDouble(dub)
256  self:writeElemSeparator()
257  local val = "" .. dub
258  local prefix = string.sub(val, 1, 1)
259  local special = false
260  if prefix == 'N' or prefix == 'n' then
261    val = JSONNode.Nan
262    special = true
263  elseif prefix == 'I' or prefix == 'i' then
264    val = JSONNode.Infinity
265    special = true
266  elseif prefix == '-' then
267    local secondByte = string.sub(val, 2, 2)
268    if secondByte == 'I' or secondByte == 'i' then
269      val = JSONNode.NegativeInfinity
270      special = true
271    end
272  end
273
274  if special or self:escapeNum() then
275    self.trans:write(JSONNode.StringDelimiter)
276  end
277  self.trans:write(val)
278  if special or self:escapeNum() then
279    self.trans:write(JSONNode.StringDelimiter)
280  end
281end
282
283function TJSONProtocol:writeJSONObjectBegin()
284  self:writeElemSeparator()
285  self.trans:write(JSONNode.ObjectBegin)
286  self:contextPush({first = true, colon = true, ttype = 1, null = false})
287end
288
289function TJSONProtocol:writeJSONObjectEnd()
290  self:contextPop()
291  self.trans:write(JSONNode.ObjectEnd)
292end
293
294function TJSONProtocol:writeJSONArrayBegin()
295  self:writeElemSeparator()
296  self.trans:write(JSONNode.ArrayBegin)
297  self:contextPush({first = true, colon = true, ttype = 2, null = false})
298end
299
300function TJSONProtocol:writeJSONArrayEnd()
301  self:contextPop()
302  self.trans:write(JSONNode.ArrayEnd)
303end
304
305function TJSONProtocol:writeMessageBegin(name, ttype, seqid)
306  self:resetContext()
307  self:writeJSONArrayBegin()
308  self:writeJSONInteger(TJSONProtocol.THRIFT_JSON_PROTOCOL_VERSION)
309  self:writeJSONString(name)
310  self:writeJSONInteger(ttype)
311  self:writeJSONInteger(seqid)
312end
313
314function TJSONProtocol:writeMessageEnd()
315  self:writeJSONArrayEnd()
316end
317
318function TJSONProtocol:writeStructBegin(name)
319  self:writeJSONObjectBegin()
320end
321
322function TJSONProtocol:writeStructEnd()
323  self:writeJSONObjectEnd()
324end
325
326function TJSONProtocol:writeFieldBegin(name, ttype, id)
327  self:writeJSONInteger(id)
328  self:writeJSONObjectBegin()
329  self:writeJSONString(TTypeToString[ttype])
330end
331
332function TJSONProtocol:writeFieldEnd()
333  self:writeJSONObjectEnd()
334end
335
336function TJSONProtocol:writeFieldStop()
337end
338
339function TJSONProtocol:writeMapBegin(ktype, vtype, size)
340  self:writeJSONArrayBegin()
341  self:writeJSONString(TTypeToString[ktype])
342  self:writeJSONString(TTypeToString[vtype])
343  self:writeJSONInteger(size)
344  return self:writeJSONObjectBegin()
345end
346
347function TJSONProtocol:writeMapEnd()
348  self:writeJSONObjectEnd()
349  self:writeJSONArrayEnd()
350end
351
352function TJSONProtocol:writeListBegin(etype, size)
353  self:writeJSONArrayBegin()
354  self:writeJSONString(TTypeToString[etype])
355  self:writeJSONInteger(size)
356end
357
358function TJSONProtocol:writeListEnd()
359  self:writeJSONArrayEnd()
360end
361
362function TJSONProtocol:writeSetBegin(etype, size)
363  self:writeJSONArrayBegin()
364  self:writeJSONString(TTypeToString[etype])
365  self:writeJSONInteger(size)
366end
367
368function TJSONProtocol:writeSetEnd()
369  self:writeJSONArrayEnd()
370end
371
372function TJSONProtocol:writeBool(bool)
373  if bool then
374    self:writeJSONInteger(1)
375  else
376    self:writeJSONInteger(0)
377  end
378end
379
380function TJSONProtocol:writeByte(byte)
381  local buff = libluabpack.bpack('c', byte)
382  local val = libluabpack.bunpack('c', buff)
383  self:writeJSONInteger(val)
384end
385
386function TJSONProtocol:writeI16(i16)
387  local buff = libluabpack.bpack('s', i16)
388  local val = libluabpack.bunpack('s', buff)
389  self:writeJSONInteger(val)
390end
391
392function TJSONProtocol:writeI32(i32)
393  local buff = libluabpack.bpack('i', i32)
394  local val = libluabpack.bunpack('i', buff)
395  self:writeJSONInteger(val)
396end
397
398function TJSONProtocol:writeI64(i64)
399  local buff = libluabpack.bpack('l', i64)
400  local val = libluabpack.bunpack('l', buff)
401  self:writeJSONInteger(tostring(val))
402end
403
404function TJSONProtocol:writeDouble(dub)
405  self:writeJSONDouble(string.format("%.16f", dub))
406end
407
408function TJSONProtocol:writeString(str)
409  self:writeJSONString(str)
410end
411
412function TJSONProtocol:writeBinary(str)
413  -- Should be utf-8
414  self:writeJSONBase64(str)
415end
416
417function TJSONProtocol:readJSONSyntaxChar(ch)
418  local ch2 = ""
419  if self.hasReadByte ~= "" then
420    ch2 = self.hasReadByte
421    self.hasReadByte = ""
422  else
423    ch2 = self.trans:readAll(1)
424  end
425  if ch2 ~= ch then
426    terror(TProtocolException:new{message = "Expected ".. ch .. ", got " .. ch2})
427  end
428end
429
430function TJSONProtocol:readElemSeparator()
431  if self.jsonContextVal.null then
432    return
433  end
434  if self.jsonContextVal.first then
435    self.jsonContextVal.first = false
436  else
437    if self.jsonContextVal.ttype == 1 then
438      if self.jsonContextVal.colon then
439        self:readJSONSyntaxChar(JSONNode.PairSeparator)
440        self.jsonContextVal.colon = false
441      else
442        self:readJSONSyntaxChar(JSONNode.ElemSeparator)
443        self.jsonContextVal.colon = true
444      end
445    else
446      self:readJSONSyntaxChar(JSONNode.ElemSeparator)
447    end
448  end
449end
450
451function TJSONProtocol:hexVal(ch)
452  local val = string.byte(ch)
453  if val >= 48 and val <= 57 then
454    return val - 48
455  elseif val >= 97 and val <= 102 then
456    return val - 87
457  else
458    terror(TProtocolException:new{message = "Expected hex val ([0-9a-f]); got " .. ch})
459  end
460end
461
462function TJSONProtocol:readJSONEscapeChar(ch)
463  self:readJSONSyntaxChar(JSONNode.ZeroChar)
464  self:readJSONSyntaxChar(JSONNode.ZeroChar)
465  local b1 = self.trans:readAll(1)
466  local b2 = self.trans:readAll(1)
467  return libluabitwise.shiftl(self:hexVal(b1), 4) + self:hexVal(b2)
468end
469
470
471function TJSONProtocol:readJSONString()
472  self:readElemSeparator()
473  self:readJSONSyntaxChar(JSONNode.StringDelimiter)
474  local result = ""
475  while true do
476    local ch = self.trans:readAll(1)
477    if ch == JSONNode.StringDelimiter then
478      break
479    end
480    if ch == JSONNode.Backslash then
481      ch = self.trans:readAll(1)
482      if ch == JSONNode.EscapeChar then
483        self:readJSONEscapeChar(ch)
484      else
485        local pos, _ = string.find(JSONNode.EscapeChars, ch)
486        if pos == nil then
487          terror(TProtocolException:new{message = "Expected control char, got " .. ch})
488        end
489        ch = EscapeCharVals[pos]
490      end
491    end
492    result = result .. ch
493  end
494  return result
495end
496
497function TJSONProtocol:readJSONBase64()
498  local result = self:readJSONString()
499  local length = string.len(result)
500  local str = ""
501  local offset = 1
502  while length >= 4 do
503    local bytes = string.sub(result, offset, offset+4)
504    str = str .. base64_decode(bytes)
505    offset = offset + 4
506    length = length - 4
507  end
508  if length >= 0 then
509    str = str .. base64_decode(string.sub(result, offset, offset + length))
510  end
511  return str
512end
513
514function TJSONProtocol:readJSONNumericChars()
515  local result = ""
516  while true do
517    local ch = self.trans:readAll(1)
518    if string.find(ch, '[-+0-9.Ee]') then
519      result = result .. ch
520    else
521      self.hasReadByte = ch
522      break
523    end
524  end
525  return result
526end
527
528function TJSONProtocol:readJSONLongInteger()
529  self:readElemSeparator()
530  if self:escapeNum() then
531    self:readJSONSyntaxChar(JSONNode.StringDelimiter)
532  end
533  local result = self:readJSONNumericChars()
534  if self:escapeNum() then
535    self:readJSONSyntaxChar(JSONNode.StringDelimiter)
536  end
537  return result
538end
539
540function TJSONProtocol:readJSONInteger()
541  return tonumber(self:readJSONLongInteger())
542end
543
544function TJSONProtocol:readJSONDouble()
545  self:readElemSeparator()
546  local delimiter = self.trans:readAll(1)
547  local num = 0.0
548  if delimiter == JSONNode.StringDelimiter then
549    local str = self:readJSONString()
550    if str == JSONNode.Nan then
551      num = 1.0
552    elseif str == JSONNode.Infinity then
553      num = math.maxinteger
554    elseif str == JSONNode.NegativeInfinity then
555      num = math.mininteger
556    else
557      num = tonumber(str)
558    end
559  else
560    if self:escapeNum() then
561      self:readJSONSyntaxChar(JSONNode.StringDelimiter)
562    end
563    local result = self:readJSONNumericChars()
564    num = tonumber(delimiter.. result)
565  end
566  return num
567end
568
569function TJSONProtocol:readJSONObjectBegin()
570  self:readElemSeparator()
571  self:readJSONSyntaxChar(JSONNode.ObjectBegin)
572  self:contextPush({first = true, colon = true, ttype = 1, null = false})
573end
574
575function TJSONProtocol:readJSONObjectEnd()
576  self:readJSONSyntaxChar(JSONNode.ObjectEnd)
577  self:contextPop()
578end
579
580function TJSONProtocol:readJSONArrayBegin()
581  self:readElemSeparator()
582  self:readJSONSyntaxChar(JSONNode.ArrayBegin)
583  self:contextPush({first = true, colon = true, ttype = 2, null = false})
584end
585
586function TJSONProtocol:readJSONArrayEnd()
587  self:readJSONSyntaxChar(JSONNode.ArrayEnd)
588  self:contextPop()
589end
590
591function TJSONProtocol:readMessageBegin()
592  self:resetContext()
593  self:readJSONArrayBegin()
594  local version = self:readJSONInteger()
595  if version ~= self.THRIFT_JSON_PROTOCOL_VERSION then
596    terror(TProtocolException:new{message = "Message contained bad version."})
597  end
598  local name = self:readJSONString()
599  local ttype = self:readJSONInteger()
600  local seqid = self:readJSONInteger()
601  return name, ttype, seqid
602end
603
604function TJSONProtocol:readMessageEnd()
605  self:readJSONArrayEnd()
606end
607
608function TJSONProtocol:readStructBegin()
609  self:readJSONObjectBegin()
610  return nil
611end
612
613function TJSONProtocol:readStructEnd()
614  self:readJSONObjectEnd()
615end
616
617function TJSONProtocol:readFieldBegin()
618  local ttype = TType.STOP
619  local id = 0
620  local ch = self.trans:readAll(1)
621  self.hasReadByte = ch
622  if ch ~= JSONNode.ObjectEnd then
623    id = self:readJSONInteger()
624    self:readJSONObjectBegin()
625    local typeName = self:readJSONString()
626    ttype = StringToTType[typeName]
627  end
628  return nil, ttype, id
629end
630
631function TJSONProtocol:readFieldEnd()
632  self:readJSONObjectEnd()
633end
634
635function TJSONProtocol:readMapBegin()
636  self:readJSONArrayBegin()
637  local typeName = self:readJSONString()
638  local ktype = StringToTType[typeName]
639  typeName = self:readJSONString()
640  local vtype = StringToTType[typeName]
641  local size = self:readJSONInteger()
642  self:readJSONObjectBegin()
643  return ktype, vtype, size
644end
645
646function TJSONProtocol:readMapEnd()
647  self:readJSONObjectEnd()
648  self:readJSONArrayEnd()
649end
650
651function TJSONProtocol:readListBegin()
652  self:readJSONArrayBegin()
653  local typeName = self:readJSONString()
654  local etype = StringToTType[typeName]
655  local size = self:readJSONInteger()
656  return etype, size
657end
658
659function TJSONProtocol:readListEnd()
660  return self:readJSONArrayEnd()
661end
662
663function TJSONProtocol:readSetBegin()
664  return self:readListBegin()
665end
666
667function TJSONProtocol:readSetEnd()
668  return self:readJSONArrayEnd()
669end
670
671function TJSONProtocol:readBool()
672  local result = self:readJSONInteger()
673  if result == 1 then
674    return true
675  else
676    return false
677  end
678end
679
680function TJSONProtocol:readByte()
681  local result = self:readJSONInteger()
682  if result >= 256 then
683    terror(TProtocolException:new{message = "UnExpected Byte " .. result})
684  end
685  return result
686end
687
688function TJSONProtocol:readI16()
689  return self:readJSONInteger()
690end
691
692function TJSONProtocol:readI32()
693  return self:readJSONInteger()
694end
695
696function TJSONProtocol:readI64()
697  local long = liblualongnumber.new
698  return long(self:readJSONLongInteger())
699end
700
701function TJSONProtocol:readDouble()
702  return self:readJSONDouble()
703end
704
705function TJSONProtocol:readString()
706  return self:readJSONString()
707end
708
709function TJSONProtocol:readBinary()
710  return self:readJSONBase64()
711end
712
713TJSONProtocolFactory = TProtocolFactory:new{
714  __type = 'TJSONProtocolFactory',
715}
716
717function TJSONProtocolFactory:getProtocol(trans)
718  -- TODO Enforce that this must be a transport class (ie not a bool)
719  if not trans then
720    terror(TProtocolException:new{
721      message = 'Must supply a transport to ' .. ttype(self)
722    })
723  end
724  return TJSONProtocol:new{
725    trans = trans
726  }
727end
728