1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package thrift
18
19import (
20	"encoding/base64"
21	"fmt"
22)
23
24const (
25	THRIFT_JSON_PROTOCOL_VERSION = 1
26)
27
28// for references to _ParseContext see simple_json_protocol.go
29
30// JSONProtocol is the Compact JSON protocol implementation for thrift.
31//
32// This protocol produces/consumes a compact JSON output with field numbers as
33// object keys and field values lightly encoded.
34//
35// Example: With the Message definition
36//
37//   struct Message {
38//     1: bool aBool
39//     2: map<string, bool> aBoolStringMap
40//   },
41//
42//   Message(aBool=True, aBoolStringMap={"key1": True, "key2": False})
43//
44// will be encoded as:
45//
46//   {"1":{"tf":1},"2":{"map":["str","tf",2,{"key1": 1,"key2":0}]}}'
47type JSONProtocol struct {
48	*SimpleJSONProtocol
49}
50
51// Constructor
52func NewJSONProtocol(t Transport) *JSONProtocol {
53	v := &JSONProtocol{SimpleJSONProtocol: NewSimpleJSONProtocol(t)}
54	v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
55	v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
56	return v
57}
58
59// Factory
60type JSONProtocolFactory struct{}
61
62func (p *JSONProtocolFactory) GetProtocol(trans Transport) Protocol {
63	return NewJSONProtocol(trans)
64}
65
66func NewJSONProtocolFactory() *JSONProtocolFactory {
67	return &JSONProtocolFactory{}
68}
69
70func (p *JSONProtocol) WriteMessageBegin(name string, typeId MessageType, seqId int32) error {
71	p.resetContextStack() // THRIFT-3735
72	if e := p.OutputListBegin(); e != nil {
73		return e
74	}
75	if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil {
76		return e
77	}
78	if e := p.WriteString(name); e != nil {
79		return e
80	}
81	if e := p.WriteByte(byte(typeId)); e != nil {
82		return e
83	}
84	if e := p.WriteI32(seqId); e != nil {
85		return e
86	}
87	return nil
88}
89
90func (p *JSONProtocol) WriteMessageEnd() error {
91	return p.OutputListEnd()
92}
93
94func (p *JSONProtocol) WriteStructBegin(name string) error {
95	if e := p.OutputObjectBegin(); e != nil {
96		return e
97	}
98	return nil
99}
100
101func (p *JSONProtocol) WriteStructEnd() error {
102	return p.OutputObjectEnd()
103}
104
105func (p *JSONProtocol) WriteFieldBegin(name string, typeId Type, id int16) error {
106	if e := p.WriteI16(id); e != nil {
107		return e
108	}
109	if e := p.OutputObjectBegin(); e != nil {
110		return e
111	}
112	s, e1 := p.TypeIdToString(typeId)
113	if e1 != nil {
114		return e1
115	}
116	if e := p.WriteString(s); e != nil {
117		return e
118	}
119	return nil
120}
121
122func (p *JSONProtocol) WriteFieldEnd() error {
123	return p.OutputObjectEnd()
124}
125
126func (p *JSONProtocol) WriteFieldStop() error { return nil }
127
128func (p *JSONProtocol) WriteMapBegin(keyType Type, valueType Type, size int) error {
129	if e := p.OutputListBegin(); e != nil {
130		return e
131	}
132	s, e1 := p.TypeIdToString(keyType)
133	if e1 != nil {
134		return e1
135	}
136	if e := p.WriteString(s); e != nil {
137		return e
138	}
139	s, e1 = p.TypeIdToString(valueType)
140	if e1 != nil {
141		return e1
142	}
143	if e := p.WriteString(s); e != nil {
144		return e
145	}
146	if e := p.WriteI64(int64(size)); e != nil {
147		return e
148	}
149	return p.OutputObjectBegin()
150}
151
152func (p *JSONProtocol) WriteMapEnd() error {
153	if e := p.OutputObjectEnd(); e != nil {
154		return e
155	}
156	return p.OutputListEnd()
157}
158
159func (p *JSONProtocol) WriteListBegin(elemType Type, size int) error {
160	return p.OutputElemListBegin(elemType, size)
161}
162
163func (p *JSONProtocol) WriteListEnd() error {
164	return p.OutputListEnd()
165}
166
167func (p *JSONProtocol) WriteSetBegin(elemType Type, size int) error {
168	return p.OutputElemListBegin(elemType, size)
169}
170
171func (p *JSONProtocol) WriteSetEnd() error {
172	return p.OutputListEnd()
173}
174
175func (p *JSONProtocol) WriteBool(b bool) error {
176	if b {
177		return p.WriteI32(1)
178	}
179	return p.WriteI32(0)
180}
181
182func (p *JSONProtocol) WriteByte(b byte) error {
183	return p.WriteI32(int32(b))
184}
185
186func (p *JSONProtocol) WriteI16(v int16) error {
187	return p.WriteI32(int32(v))
188}
189
190func (p *JSONProtocol) WriteI32(v int32) error {
191	return p.OutputI64(int64(v))
192}
193
194func (p *JSONProtocol) WriteI64(v int64) error {
195	return p.OutputI64(int64(v))
196}
197
198func (p *JSONProtocol) WriteDouble(v float64) error {
199	return p.OutputF64(v)
200}
201
202func (p *JSONProtocol) WriteFloat(v float32) error {
203	return p.OutputF32(v)
204}
205
206func (p *JSONProtocol) WriteString(v string) error {
207	return p.OutputString(v)
208}
209
210func (p *JSONProtocol) WriteBinary(v []byte) error {
211	// JSON library only takes in a string,
212	// not an arbitrary byte array, to ensure bytes are transmitted
213	// efficiently we must convert this into a valid JSON string
214	// therefore we use base64 encoding to avoid excessive escaping/quoting
215	if e := p.OutputPreValue(); e != nil {
216		return e
217	}
218	if _, e := p.write(JSON_QUOTE_BYTES); e != nil {
219		return NewProtocolException(e)
220	}
221	writer := base64.NewEncoder(base64.StdEncoding, p.writer)
222	if _, e := writer.Write(v); e != nil {
223		p.writer.Reset(p.trans) // THRIFT-3735
224		return NewProtocolException(e)
225	}
226	if e := writer.Close(); e != nil {
227		return NewProtocolException(e)
228	}
229	if _, e := p.write(JSON_QUOTE_BYTES); e != nil {
230		return NewProtocolException(e)
231	}
232	return p.OutputPostValue()
233}
234
235// Reading methods.
236func (p *JSONProtocol) ReadMessageBegin() (name string, typeId MessageType, seqId int32, err error) {
237	p.resetContextStack() // THRIFT-3735
238	if isNull, err := p.ParseListBegin(); isNull || err != nil {
239		return name, typeId, seqId, err
240	}
241	version, err := p.ReadI32()
242	if err != nil {
243		return name, typeId, seqId, err
244	}
245	if version != THRIFT_JSON_PROTOCOL_VERSION {
246		e := fmt.Errorf("Unknown Protocol version %d, expected version %d", version, THRIFT_JSON_PROTOCOL_VERSION)
247		return name, typeId, seqId, NewProtocolExceptionWithType(INVALID_DATA, e)
248
249	}
250	if name, err = p.ReadString(); err != nil {
251		return name, typeId, seqId, err
252	}
253	bTypeId, err := p.ReadByte()
254	typeId = MessageType(bTypeId)
255	if err != nil {
256		return name, typeId, seqId, err
257	}
258	if seqId, err = p.ReadI32(); err != nil {
259		return name, typeId, seqId, err
260	}
261	return name, typeId, seqId, nil
262}
263
264func (p *JSONProtocol) ReadMessageEnd() error {
265	err := p.ParseListEnd()
266	return err
267}
268
269func (p *JSONProtocol) ReadStructBegin() (name string, err error) {
270	_, err = p.ParseObjectStart()
271	return "", err
272}
273
274func (p *JSONProtocol) ReadStructEnd() error {
275	return p.ParseObjectEnd()
276}
277
278func (p *JSONProtocol) ReadFieldBegin() (string, Type, int16, error) {
279	b, _ := p.reader.Peek(1)
280	if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] {
281		return "", STOP, -1, nil
282	}
283	fieldId, err := p.ReadI16()
284	if err != nil {
285		return "", STOP, fieldId, err
286	}
287	if _, err = p.ParseObjectStart(); err != nil {
288		return "", STOP, fieldId, err
289	}
290	sType, err := p.ReadString()
291	if err != nil {
292		return "", STOP, fieldId, err
293	}
294	fType, err := p.StringToTypeId(sType)
295	return "", fType, fieldId, err
296}
297
298func (p *JSONProtocol) ReadFieldEnd() error {
299	return p.ParseObjectEnd()
300}
301
302func (p *JSONProtocol) ReadMapBegin() (keyType Type, valueType Type, size int, e error) {
303	if isNull, e := p.ParseListBegin(); isNull || e != nil {
304		return VOID, VOID, 0, e
305	}
306
307	// read keyType
308	sKeyType, e := p.ReadString()
309	if e != nil {
310		return keyType, valueType, size, e
311	}
312	keyType, e = p.StringToTypeId(sKeyType)
313	if e != nil {
314		return keyType, valueType, size, e
315	}
316
317	// read valueType
318	sValueType, e := p.ReadString()
319	if e != nil {
320		return keyType, valueType, size, e
321	}
322	valueType, e = p.StringToTypeId(sValueType)
323	if e != nil {
324		return keyType, valueType, size, e
325	}
326
327	// read size
328	iSize, e := p.ReadI64()
329	if e != nil {
330		return keyType, valueType, size, e
331	}
332	size = int(iSize)
333
334	_, e = p.ParseObjectStart()
335	return keyType, valueType, size, e
336}
337
338func (p *JSONProtocol) ReadMapEnd() error {
339	e := p.ParseObjectEnd()
340	if e != nil {
341		return e
342	}
343	return p.ParseListEnd()
344}
345
346func (p *JSONProtocol) ReadListBegin() (elemType Type, size int, e error) {
347	return p.ParseElemListBegin()
348}
349
350func (p *JSONProtocol) ReadListEnd() error {
351	return p.ParseListEnd()
352}
353
354func (p *JSONProtocol) ReadSetBegin() (elemType Type, size int, e error) {
355	return p.ParseElemListBegin()
356}
357
358func (p *JSONProtocol) ReadSetEnd() error {
359	return p.ParseListEnd()
360}
361
362func (p *JSONProtocol) ReadBool() (bool, error) {
363	value, err := p.ReadI32()
364	return (value != 0), err
365}
366
367func (p *JSONProtocol) ReadByte() (byte, error) {
368	v, err := p.ReadI64()
369	return byte(v), err
370}
371
372func (p *JSONProtocol) ReadI16() (int16, error) {
373	v, err := p.ReadI64()
374	return int16(v), err
375}
376
377func (p *JSONProtocol) ReadI32() (int32, error) {
378	v, err := p.ReadI64()
379	return int32(v), err
380}
381
382func (p *JSONProtocol) ReadI64() (int64, error) {
383	v, _, err := p.ParseI64()
384	return v, err
385}
386
387func (p *JSONProtocol) ReadDouble() (float64, error) {
388	v, _, err := p.ParseF64()
389	return v, err
390}
391
392func (p *JSONProtocol) ReadFloat() (float32, error) {
393	v, _, err := p.ParseF32()
394	return v, err
395}
396
397func (p *JSONProtocol) ReadString() (string, error) {
398	var v string
399	if err := p.ParsePreValue(); err != nil {
400		return v, err
401	}
402	f, _ := p.reader.Peek(1)
403	if len(f) > 0 && f[0] == JSON_QUOTE {
404		p.reader.ReadByte()
405		value, err := p.ParseStringBody()
406		v = value
407		if err != nil {
408			return v, err
409		}
410	} else if len(f) > 0 && f[0] == JSON_NULL[0] {
411		b := make([]byte, len(JSON_NULL))
412		_, err := p.reader.Read(b)
413		if err != nil {
414			return v, NewProtocolException(err)
415		}
416		if string(b) != string(JSON_NULL) {
417			e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
418			return v, NewProtocolExceptionWithType(INVALID_DATA, e)
419		}
420	} else {
421		e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
422		return v, NewProtocolExceptionWithType(INVALID_DATA, e)
423	}
424	return v, p.ParsePostValue()
425}
426
427func (p *JSONProtocol) ReadBinary() ([]byte, error) {
428	var v []byte
429	if err := p.ParsePreValue(); err != nil {
430		return nil, err
431	}
432	f, _ := p.reader.Peek(1)
433	if len(f) > 0 && f[0] == JSON_QUOTE {
434		p.reader.ReadByte()
435		value, err := p.ParseBase64EncodedBody()
436		v = value
437		if err != nil {
438			return v, err
439		}
440	} else if len(f) > 0 && f[0] == JSON_NULL[0] {
441		b := make([]byte, len(JSON_NULL))
442		_, err := p.reader.Read(b)
443		if err != nil {
444			return v, NewProtocolException(err)
445		}
446		if string(b) != string(JSON_NULL) {
447			e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
448			return v, NewProtocolExceptionWithType(INVALID_DATA, e)
449		}
450	} else {
451		e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
452		return v, NewProtocolExceptionWithType(INVALID_DATA, e)
453	}
454
455	return v, p.ParsePostValue()
456}
457
458func (p *JSONProtocol) Flush() (err error) {
459	err = p.writer.Flush()
460	if err == nil {
461		err = p.trans.Flush()
462	}
463	return NewProtocolException(err)
464}
465
466func (p *JSONProtocol) Skip(fieldType Type) (err error) {
467	return SkipDefaultDepth(p, fieldType)
468}
469
470func (p *JSONProtocol) Transport() Transport {
471	return p.trans
472}
473
474func (p *JSONProtocol) OutputElemListBegin(elemType Type, size int) error {
475	if e := p.OutputListBegin(); e != nil {
476		return e
477	}
478	s, e1 := p.TypeIdToString(elemType)
479	if e1 != nil {
480		return e1
481	}
482	if e := p.WriteString(s); e != nil {
483		return e
484	}
485	if e := p.WriteI64(int64(size)); e != nil {
486		return e
487	}
488	return nil
489}
490
491func (p *JSONProtocol) ParseElemListBegin() (elemType Type, size int, e error) {
492	if isNull, e := p.ParseListBegin(); isNull || e != nil {
493		return VOID, 0, e
494	}
495	sElemType, err := p.ReadString()
496	if err != nil {
497		return VOID, size, err
498	}
499	elemType, err = p.StringToTypeId(sElemType)
500	if err != nil {
501		return elemType, size, err
502	}
503	nSize, err2 := p.ReadI64()
504	size = int(nSize)
505	return elemType, size, err2
506}
507
508func (p *JSONProtocol) readElemListBegin() (elemType Type, size int, e error) {
509	if isNull, e := p.ParseListBegin(); isNull || e != nil {
510		return VOID, 0, e
511	}
512	sElemType, err := p.ReadString()
513	if err != nil {
514		return VOID, size, err
515	}
516	elemType, err = p.StringToTypeId(sElemType)
517	if err != nil {
518		return elemType, size, err
519	}
520	nSize, err2 := p.ReadI64()
521	size = int(nSize)
522	return elemType, size, err2
523}
524
525func (p *JSONProtocol) writeElemListBegin(elemType Type, size int) error {
526	if e := p.OutputListBegin(); e != nil {
527		return e
528	}
529	s, e1 := p.TypeIdToString(elemType)
530	if e1 != nil {
531		return e1
532	}
533	if e := p.OutputString(s); e != nil {
534		return e
535	}
536	if e := p.OutputI64(int64(size)); e != nil {
537		return e
538	}
539	return nil
540}
541
542func (p *JSONProtocol) TypeIdToString(fieldType Type) (string, error) {
543	switch byte(fieldType) {
544	case BOOL:
545		return "tf", nil
546	case BYTE:
547		return "i8", nil
548	case I16:
549		return "i16", nil
550	case I32:
551		return "i32", nil
552	case I64:
553		return "i64", nil
554	case DOUBLE:
555		return "dbl", nil
556	case FLOAT:
557		return "flt", nil
558	case STRING:
559		return "str", nil
560	case STRUCT:
561		return "rec", nil
562	case MAP:
563		return "map", nil
564	case SET:
565		return "set", nil
566	case LIST:
567		return "lst", nil
568	}
569
570	e := fmt.Errorf("Unknown fieldType: %d", int(fieldType))
571	return "", NewProtocolExceptionWithType(INVALID_DATA, e)
572}
573
574func (p *JSONProtocol) StringToTypeId(fieldType string) (Type, error) {
575	switch fieldType {
576	case "tf":
577		return Type(BOOL), nil
578	case "i8":
579		return Type(BYTE), nil
580	case "i16":
581		return Type(I16), nil
582	case "i32":
583		return Type(I32), nil
584	case "i64":
585		return Type(I64), nil
586	case "dbl":
587		return Type(DOUBLE), nil
588	case "flt":
589		return Type(FLOAT), nil
590	case "str":
591		return Type(STRING), nil
592	case "rec":
593		return Type(STRUCT), nil
594	case "map":
595		return Type(MAP), nil
596	case "set":
597		return Type(SET), nil
598	case "lst":
599		return Type(LIST), nil
600	}
601
602	e := fmt.Errorf("Unknown type identifier: %s", fieldType)
603	return Type(STOP), NewProtocolExceptionWithType(INVALID_DATA, e)
604}
605