1package gen
2
3import (
4	"encoding"
5	"encoding/json"
6	"fmt"
7	"reflect"
8	"strings"
9	"unicode"
10
11	"github.com/mailru/easyjson"
12)
13
14// Target this byte size for initial slice allocation to reduce garbage collection.
15const minSliceBytes = 64
16
17func (g *Generator) getDecoderName(t reflect.Type) string {
18	return g.functionName("decode", t)
19}
20
21var primitiveDecoders = map[reflect.Kind]string{
22	reflect.String:  "in.String()",
23	reflect.Bool:    "in.Bool()",
24	reflect.Int:     "in.Int()",
25	reflect.Int8:    "in.Int8()",
26	reflect.Int16:   "in.Int16()",
27	reflect.Int32:   "in.Int32()",
28	reflect.Int64:   "in.Int64()",
29	reflect.Uint:    "in.Uint()",
30	reflect.Uint8:   "in.Uint8()",
31	reflect.Uint16:  "in.Uint16()",
32	reflect.Uint32:  "in.Uint32()",
33	reflect.Uint64:  "in.Uint64()",
34	reflect.Float32: "in.Float32()",
35	reflect.Float64: "in.Float64()",
36}
37
38var primitiveStringDecoders = map[reflect.Kind]string{
39	reflect.String:  "in.String()",
40	reflect.Int:     "in.IntStr()",
41	reflect.Int8:    "in.Int8Str()",
42	reflect.Int16:   "in.Int16Str()",
43	reflect.Int32:   "in.Int32Str()",
44	reflect.Int64:   "in.Int64Str()",
45	reflect.Uint:    "in.UintStr()",
46	reflect.Uint8:   "in.Uint8Str()",
47	reflect.Uint16:  "in.Uint16Str()",
48	reflect.Uint32:  "in.Uint32Str()",
49	reflect.Uint64:  "in.Uint64Str()",
50	reflect.Uintptr: "in.UintptrStr()",
51	reflect.Float32: "in.Float32Str()",
52	reflect.Float64: "in.Float64Str()",
53}
54
55var customDecoders = map[string]string{
56	"json.Number": "in.JsonNumber()",
57}
58
59// genTypeDecoder generates decoding code for the type t, but uses unmarshaler interface if implemented by t.
60func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, indent int) error {
61	ws := strings.Repeat("  ", indent)
62
63	unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()
64	if reflect.PtrTo(t).Implements(unmarshalerIface) {
65		fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)")
66		return nil
67	}
68
69	unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
70	if reflect.PtrTo(t).Implements(unmarshalerIface) {
71		fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {")
72		fmt.Fprintln(g.out, ws+"  in.AddError( ("+out+").UnmarshalJSON(data) )")
73		fmt.Fprintln(g.out, ws+"}")
74		return nil
75	}
76
77	unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
78	if reflect.PtrTo(t).Implements(unmarshalerIface) {
79		fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {")
80		fmt.Fprintln(g.out, ws+"  in.AddError( ("+out+").UnmarshalText(data) )")
81		fmt.Fprintln(g.out, ws+"}")
82		return nil
83	}
84
85	err := g.genTypeDecoderNoCheck(t, out, tags, indent)
86	return err
87}
88
89// returns true of the type t implements one of the custom unmarshaler interfaces
90func hasCustomUnmarshaler(t reflect.Type) bool {
91	t = reflect.PtrTo(t)
92	return t.Implements(reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()) ||
93		t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) ||
94		t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem())
95}
96
97// genTypeDecoderNoCheck generates decoding code for the type t.
98func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags fieldTags, indent int) error {
99	ws := strings.Repeat("  ", indent)
100	// Check whether type is primitive, needs to be done after interface check.
101	if dec := customDecoders[t.String()]; dec != "" {
102		fmt.Fprintln(g.out, ws+out+" = "+dec)
103		return nil
104	} else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString {
105		fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
106		return nil
107	} else if dec := primitiveDecoders[t.Kind()]; dec != "" {
108		fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
109		return nil
110	}
111
112	switch t.Kind() {
113	case reflect.Slice:
114		tmpVar := g.uniqueVarName()
115		elem := t.Elem()
116
117		if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" {
118			fmt.Fprintln(g.out, ws+"if in.IsNull() {")
119			fmt.Fprintln(g.out, ws+"  in.Skip()")
120			fmt.Fprintln(g.out, ws+"  "+out+" = nil")
121			fmt.Fprintln(g.out, ws+"} else {")
122			fmt.Fprintln(g.out, ws+"  "+out+" = in.Bytes()")
123			fmt.Fprintln(g.out, ws+"}")
124
125		} else {
126
127			capacity := minSliceBytes / elem.Size()
128			if capacity == 0 {
129				capacity = 1
130			}
131
132			fmt.Fprintln(g.out, ws+"if in.IsNull() {")
133			fmt.Fprintln(g.out, ws+"  in.Skip()")
134			fmt.Fprintln(g.out, ws+"  "+out+" = nil")
135			fmt.Fprintln(g.out, ws+"} else {")
136			fmt.Fprintln(g.out, ws+"  in.Delim('[')")
137			fmt.Fprintln(g.out, ws+"  if "+out+" == nil {")
138			fmt.Fprintln(g.out, ws+"    if !in.IsDelim(']') {")
139			fmt.Fprintln(g.out, ws+"      "+out+" = make("+g.getType(t)+", 0, "+fmt.Sprint(capacity)+")")
140			fmt.Fprintln(g.out, ws+"    } else {")
141			fmt.Fprintln(g.out, ws+"      "+out+" = "+g.getType(t)+"{}")
142			fmt.Fprintln(g.out, ws+"    }")
143			fmt.Fprintln(g.out, ws+"  } else { ")
144			fmt.Fprintln(g.out, ws+"    "+out+" = ("+out+")[:0]")
145			fmt.Fprintln(g.out, ws+"  }")
146			fmt.Fprintln(g.out, ws+"  for !in.IsDelim(']') {")
147			fmt.Fprintln(g.out, ws+"    var "+tmpVar+" "+g.getType(elem))
148
149			if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil {
150				return err
151			}
152
153			fmt.Fprintln(g.out, ws+"    "+out+" = append("+out+", "+tmpVar+")")
154			fmt.Fprintln(g.out, ws+"    in.WantComma()")
155			fmt.Fprintln(g.out, ws+"  }")
156			fmt.Fprintln(g.out, ws+"  in.Delim(']')")
157			fmt.Fprintln(g.out, ws+"}")
158		}
159
160	case reflect.Array:
161		iterVar := g.uniqueVarName()
162		elem := t.Elem()
163
164		if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" {
165			fmt.Fprintln(g.out, ws+"if in.IsNull() {")
166			fmt.Fprintln(g.out, ws+"  in.Skip()")
167			fmt.Fprintln(g.out, ws+"} else {")
168			fmt.Fprintln(g.out, ws+"  copy("+out+"[:], in.Bytes())")
169			fmt.Fprintln(g.out, ws+"}")
170
171		} else {
172
173			length := t.Len()
174
175			fmt.Fprintln(g.out, ws+"if in.IsNull() {")
176			fmt.Fprintln(g.out, ws+"  in.Skip()")
177			fmt.Fprintln(g.out, ws+"} else {")
178			fmt.Fprintln(g.out, ws+"  in.Delim('[')")
179			fmt.Fprintln(g.out, ws+"  "+iterVar+" := 0")
180			fmt.Fprintln(g.out, ws+"  for !in.IsDelim(']') {")
181			fmt.Fprintln(g.out, ws+"    if "+iterVar+" < "+fmt.Sprint(length)+" {")
182
183			if err := g.genTypeDecoder(elem, "("+out+")["+iterVar+"]", tags, indent+3); err != nil {
184				return err
185			}
186
187			fmt.Fprintln(g.out, ws+"      "+iterVar+"++")
188			fmt.Fprintln(g.out, ws+"    } else {")
189			fmt.Fprintln(g.out, ws+"      in.SkipRecursive()")
190			fmt.Fprintln(g.out, ws+"    }")
191			fmt.Fprintln(g.out, ws+"    in.WantComma()")
192			fmt.Fprintln(g.out, ws+"  }")
193			fmt.Fprintln(g.out, ws+"  in.Delim(']')")
194			fmt.Fprintln(g.out, ws+"}")
195		}
196
197	case reflect.Struct:
198		dec := g.getDecoderName(t)
199		g.addType(t)
200
201		if len(out) > 0 && out[0] == '*' {
202			// NOTE: In order to remove an extra reference to a pointer
203			fmt.Fprintln(g.out, ws+dec+"(in, "+out[1:]+")")
204		} else {
205			fmt.Fprintln(g.out, ws+dec+"(in, &"+out+")")
206		}
207
208	case reflect.Ptr:
209		fmt.Fprintln(g.out, ws+"if in.IsNull() {")
210		fmt.Fprintln(g.out, ws+"  in.Skip()")
211		fmt.Fprintln(g.out, ws+"  "+out+" = nil")
212		fmt.Fprintln(g.out, ws+"} else {")
213		fmt.Fprintln(g.out, ws+"  if "+out+" == nil {")
214		fmt.Fprintln(g.out, ws+"    "+out+" = new("+g.getType(t.Elem())+")")
215		fmt.Fprintln(g.out, ws+"  }")
216
217		if err := g.genTypeDecoder(t.Elem(), "*"+out, tags, indent+1); err != nil {
218			return err
219		}
220
221		fmt.Fprintln(g.out, ws+"}")
222
223	case reflect.Map:
224		key := t.Key()
225		keyDec, ok := primitiveStringDecoders[key.Kind()]
226		if !ok && !hasCustomUnmarshaler(key) {
227			return fmt.Errorf("map type %v not supported: only string and integer keys and types implementing json.Unmarshaler are allowed", key)
228		} // else assume the caller knows what they are doing and that the custom unmarshaler performs the translation from string or integer keys to the key type
229		elem := t.Elem()
230		tmpVar := g.uniqueVarName()
231
232		fmt.Fprintln(g.out, ws+"if in.IsNull() {")
233		fmt.Fprintln(g.out, ws+"  in.Skip()")
234		fmt.Fprintln(g.out, ws+"} else {")
235		fmt.Fprintln(g.out, ws+"  in.Delim('{')")
236		fmt.Fprintln(g.out, ws+"  if !in.IsDelim('}') {")
237		fmt.Fprintln(g.out, ws+"  "+out+" = make("+g.getType(t)+")")
238		fmt.Fprintln(g.out, ws+"  } else {")
239		fmt.Fprintln(g.out, ws+"  "+out+" = nil")
240		fmt.Fprintln(g.out, ws+"  }")
241
242		fmt.Fprintln(g.out, ws+"  for !in.IsDelim('}') {")
243		// NOTE: extra check for TextUnmarshaler. It overrides default methods.
244		if reflect.PtrTo(key).Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) {
245			fmt.Fprintln(g.out, ws+"    var key "+g.getType(key))
246			fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {")
247			fmt.Fprintln(g.out, ws+"  in.AddError(key.UnmarshalText(data) )")
248			fmt.Fprintln(g.out, ws+"}")
249		} else if keyDec != "" {
250			fmt.Fprintln(g.out, ws+"    key := "+g.getType(key)+"("+keyDec+")")
251		} else {
252			fmt.Fprintln(g.out, ws+"    var key "+g.getType(key))
253			if err := g.genTypeDecoder(key, "key", tags, indent+2); err != nil {
254				return err
255			}
256		}
257
258		fmt.Fprintln(g.out, ws+"    in.WantColon()")
259		fmt.Fprintln(g.out, ws+"    var "+tmpVar+" "+g.getType(elem))
260
261		if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil {
262			return err
263		}
264
265		fmt.Fprintln(g.out, ws+"    ("+out+")[key] = "+tmpVar)
266		fmt.Fprintln(g.out, ws+"    in.WantComma()")
267		fmt.Fprintln(g.out, ws+"  }")
268		fmt.Fprintln(g.out, ws+"  in.Delim('}')")
269		fmt.Fprintln(g.out, ws+"}")
270
271	case reflect.Interface:
272		if t.NumMethod() != 0 {
273			return fmt.Errorf("interface type %v not supported: only interface{} is allowed", t)
274		}
275		fmt.Fprintln(g.out, ws+"if m, ok := "+out+".(easyjson.Unmarshaler); ok {")
276		fmt.Fprintln(g.out, ws+"m.UnmarshalEasyJSON(in)")
277		fmt.Fprintln(g.out, ws+"} else if m, ok := "+out+".(json.Unmarshaler); ok {")
278		fmt.Fprintln(g.out, ws+"_ = m.UnmarshalJSON(in.Raw())")
279		fmt.Fprintln(g.out, ws+"} else {")
280		fmt.Fprintln(g.out, ws+"  "+out+" = in.Interface()")
281		fmt.Fprintln(g.out, ws+"}")
282	default:
283		return fmt.Errorf("don't know how to decode %v", t)
284	}
285	return nil
286
287}
288
289func (g *Generator) genStructFieldDecoder(t reflect.Type, f reflect.StructField) error {
290	jsonName := g.fieldNamer.GetJSONFieldName(t, f)
291	tags := parseFieldTags(f)
292
293	if tags.omit {
294		return nil
295	}
296
297	fmt.Fprintf(g.out, "    case %q:\n", jsonName)
298	if err := g.genTypeDecoder(f.Type, "out."+f.Name, tags, 3); err != nil {
299		return err
300	}
301
302	if tags.required {
303		fmt.Fprintf(g.out, "%sSet = true\n", f.Name)
304	}
305
306	return nil
307}
308
309func (g *Generator) genRequiredFieldSet(t reflect.Type, f reflect.StructField) {
310	tags := parseFieldTags(f)
311
312	if !tags.required {
313		return
314	}
315
316	fmt.Fprintf(g.out, "var %sSet bool\n", f.Name)
317}
318
319func (g *Generator) genRequiredFieldCheck(t reflect.Type, f reflect.StructField) {
320	jsonName := g.fieldNamer.GetJSONFieldName(t, f)
321	tags := parseFieldTags(f)
322
323	if !tags.required {
324		return
325	}
326
327	g.imports["fmt"] = "fmt"
328
329	fmt.Fprintf(g.out, "if !%sSet {\n", f.Name)
330	fmt.Fprintf(g.out, "    in.AddError(fmt.Errorf(\"key '%s' is required\"))\n", jsonName)
331	fmt.Fprintf(g.out, "}\n")
332}
333
334func mergeStructFields(fields1, fields2 []reflect.StructField) (fields []reflect.StructField) {
335	used := map[string]bool{}
336	for _, f := range fields2 {
337		used[f.Name] = true
338		fields = append(fields, f)
339	}
340
341	for _, f := range fields1 {
342		if !used[f.Name] {
343			fields = append(fields, f)
344		}
345	}
346	return
347}
348
349func getStructFields(t reflect.Type) ([]reflect.StructField, error) {
350	if t.Kind() != reflect.Struct {
351		return nil, fmt.Errorf("got %v; expected a struct", t)
352	}
353
354	var efields []reflect.StructField
355	for i := 0; i < t.NumField(); i++ {
356		f := t.Field(i)
357		tags := parseFieldTags(f)
358		if !f.Anonymous || tags.name != "" {
359			continue
360		}
361
362		t1 := f.Type
363		if t1.Kind() == reflect.Ptr {
364			t1 = t1.Elem()
365		}
366
367		fs, err := getStructFields(t1)
368		if err != nil {
369			return nil, fmt.Errorf("error processing embedded field: %v", err)
370		}
371		efields = mergeStructFields(efields, fs)
372	}
373
374	var fields []reflect.StructField
375	for i := 0; i < t.NumField(); i++ {
376		f := t.Field(i)
377		tags := parseFieldTags(f)
378		if f.Anonymous && tags.name == "" {
379			continue
380		}
381
382		c := []rune(f.Name)[0]
383		if unicode.IsUpper(c) {
384			fields = append(fields, f)
385		}
386	}
387	return mergeStructFields(efields, fields), nil
388}
389
390func (g *Generator) genDecoder(t reflect.Type) error {
391	switch t.Kind() {
392	case reflect.Slice, reflect.Array, reflect.Map:
393		return g.genSliceArrayDecoder(t)
394	default:
395		return g.genStructDecoder(t)
396	}
397}
398
399func (g *Generator) genSliceArrayDecoder(t reflect.Type) error {
400	switch t.Kind() {
401	case reflect.Slice, reflect.Array, reflect.Map:
402	default:
403		return fmt.Errorf("cannot generate encoder/decoder for %v, not a slice/array/map type", t)
404	}
405
406	fname := g.getDecoderName(t)
407	typ := g.getType(t)
408
409	fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {")
410	fmt.Fprintln(g.out, " isTopLevel := in.IsStart()")
411	err := g.genTypeDecoderNoCheck(t, "*out", fieldTags{}, 1)
412	if err != nil {
413		return err
414	}
415	fmt.Fprintln(g.out, "  if isTopLevel {")
416	fmt.Fprintln(g.out, "    in.Consumed()")
417	fmt.Fprintln(g.out, "  }")
418	fmt.Fprintln(g.out, "}")
419
420	return nil
421}
422
423func (g *Generator) genStructDecoder(t reflect.Type) error {
424	if t.Kind() != reflect.Struct {
425		return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct type", t)
426	}
427
428	fname := g.getDecoderName(t)
429	typ := g.getType(t)
430
431	fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {")
432	fmt.Fprintln(g.out, "  isTopLevel := in.IsStart()")
433	fmt.Fprintln(g.out, "  if in.IsNull() {")
434	fmt.Fprintln(g.out, "    if isTopLevel {")
435	fmt.Fprintln(g.out, "      in.Consumed()")
436	fmt.Fprintln(g.out, "    }")
437	fmt.Fprintln(g.out, "    in.Skip()")
438	fmt.Fprintln(g.out, "    return")
439	fmt.Fprintln(g.out, "  }")
440
441	// Init embedded pointer fields.
442	for i := 0; i < t.NumField(); i++ {
443		f := t.Field(i)
444		if !f.Anonymous || f.Type.Kind() != reflect.Ptr {
445			continue
446		}
447		fmt.Fprintln(g.out, "  out."+f.Name+" = new("+g.getType(f.Type.Elem())+")")
448	}
449
450	fs, err := getStructFields(t)
451	if err != nil {
452		return fmt.Errorf("cannot generate decoder for %v: %v", t, err)
453	}
454
455	for _, f := range fs {
456		g.genRequiredFieldSet(t, f)
457	}
458
459	fmt.Fprintln(g.out, "  in.Delim('{')")
460	fmt.Fprintln(g.out, "  for !in.IsDelim('}') {")
461	fmt.Fprintln(g.out, "    key := in.UnsafeString()")
462	fmt.Fprintln(g.out, "    in.WantColon()")
463	fmt.Fprintln(g.out, "    if in.IsNull() {")
464	fmt.Fprintln(g.out, "       in.Skip()")
465	fmt.Fprintln(g.out, "       in.WantComma()")
466	fmt.Fprintln(g.out, "       continue")
467	fmt.Fprintln(g.out, "    }")
468
469	fmt.Fprintln(g.out, "    switch key {")
470	for _, f := range fs {
471		if err := g.genStructFieldDecoder(t, f); err != nil {
472			return err
473		}
474	}
475
476	fmt.Fprintln(g.out, "    default:")
477	if g.disallowUnknownFields {
478		fmt.Fprintln(g.out, `      in.AddError(&jlexer.LexerError{
479          Offset: in.GetPos(),
480          Reason: "unknown field",
481          Data: key,
482      })`)
483	} else {
484		fmt.Fprintln(g.out, "      in.SkipRecursive()")
485	}
486	fmt.Fprintln(g.out, "    }")
487	fmt.Fprintln(g.out, "    in.WantComma()")
488	fmt.Fprintln(g.out, "  }")
489	fmt.Fprintln(g.out, "  in.Delim('}')")
490	fmt.Fprintln(g.out, "  if isTopLevel {")
491	fmt.Fprintln(g.out, "    in.Consumed()")
492	fmt.Fprintln(g.out, "  }")
493
494	for _, f := range fs {
495		g.genRequiredFieldCheck(t, f)
496	}
497
498	fmt.Fprintln(g.out, "}")
499
500	return nil
501}
502
503func (g *Generator) genStructUnmarshaler(t reflect.Type) error {
504	switch t.Kind() {
505	case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct:
506	default:
507		return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct/slice/array/map type", t)
508	}
509
510	fname := g.getDecoderName(t)
511	typ := g.getType(t)
512
513	if !g.noStdMarshalers {
514		fmt.Fprintln(g.out, "// UnmarshalJSON supports json.Unmarshaler interface")
515		fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalJSON(data []byte) error {")
516		fmt.Fprintln(g.out, "  r := jlexer.Lexer{Data: data}")
517		fmt.Fprintln(g.out, "  "+fname+"(&r, v)")
518		fmt.Fprintln(g.out, "  return r.Error()")
519		fmt.Fprintln(g.out, "}")
520	}
521
522	fmt.Fprintln(g.out, "// UnmarshalEasyJSON supports easyjson.Unmarshaler interface")
523	fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalEasyJSON(l *jlexer.Lexer) {")
524	fmt.Fprintln(g.out, "  "+fname+"(l, v)")
525	fmt.Fprintln(g.out, "}")
526
527	return nil
528}
529