1package gen
2
3import (
4	"encoding"
5	"encoding/json"
6	"fmt"
7	"reflect"
8	"strings"
9	"unicode"
10
11	"github.com/mailru/easyjson"
12)
13
14// Target this byte size for initial slice allocation to reduce garbage collection.
15const minSliceBytes = 64
16
17func (g *Generator) getDecoderName(t reflect.Type) string {
18	return g.functionName("decode", t)
19}
20
21var primitiveDecoders = map[reflect.Kind]string{
22	reflect.String:  "in.String()",
23	reflect.Bool:    "in.Bool()",
24	reflect.Int:     "in.Int()",
25	reflect.Int8:    "in.Int8()",
26	reflect.Int16:   "in.Int16()",
27	reflect.Int32:   "in.Int32()",
28	reflect.Int64:   "in.Int64()",
29	reflect.Uint:    "in.Uint()",
30	reflect.Uint8:   "in.Uint8()",
31	reflect.Uint16:  "in.Uint16()",
32	reflect.Uint32:  "in.Uint32()",
33	reflect.Uint64:  "in.Uint64()",
34	reflect.Float32: "in.Float32()",
35	reflect.Float64: "in.Float64()",
36}
37
38var primitiveStringDecoders = map[reflect.Kind]string{
39	reflect.String:  "in.String()",
40	reflect.Int:     "in.IntStr()",
41	reflect.Int8:    "in.Int8Str()",
42	reflect.Int16:   "in.Int16Str()",
43	reflect.Int32:   "in.Int32Str()",
44	reflect.Int64:   "in.Int64Str()",
45	reflect.Uint:    "in.UintStr()",
46	reflect.Uint8:   "in.Uint8Str()",
47	reflect.Uint16:  "in.Uint16Str()",
48	reflect.Uint32:  "in.Uint32Str()",
49	reflect.Uint64:  "in.Uint64Str()",
50	reflect.Uintptr: "in.UintptrStr()",
51	reflect.Float32: "in.Float32Str()",
52	reflect.Float64: "in.Float64Str()",
53}
54
55var customDecoders = map[string]string{
56	"json.Number": "in.JsonNumber()",
57}
58
59// genTypeDecoder generates decoding code for the type t, but uses unmarshaler interface if implemented by t.
60func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, indent int) error {
61	ws := strings.Repeat("  ", indent)
62
63	unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()
64	if reflect.PtrTo(t).Implements(unmarshalerIface) {
65		fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)")
66		return nil
67	}
68
69	unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
70	if reflect.PtrTo(t).Implements(unmarshalerIface) {
71		fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {")
72		fmt.Fprintln(g.out, ws+"  in.AddError( ("+out+").UnmarshalJSON(data) )")
73		fmt.Fprintln(g.out, ws+"}")
74		return nil
75	}
76
77	unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
78	if reflect.PtrTo(t).Implements(unmarshalerIface) {
79		fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {")
80		fmt.Fprintln(g.out, ws+"  in.AddError( ("+out+").UnmarshalText(data) )")
81		fmt.Fprintln(g.out, ws+"}")
82		return nil
83	}
84
85	err := g.genTypeDecoderNoCheck(t, out, tags, indent)
86	return err
87}
88
89// returns true of the type t implements one of the custom unmarshaler interfaces
90func hasCustomUnmarshaler(t reflect.Type) bool {
91	t = reflect.PtrTo(t)
92	return t.Implements(reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()) ||
93		t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) ||
94		t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem())
95}
96
97func hasUnknownsUnmarshaler(t reflect.Type) bool {
98	t = reflect.PtrTo(t)
99	return t.Implements(reflect.TypeOf((*easyjson.UnknownsUnmarshaler)(nil)).Elem())
100}
101
102func hasUnknownsMarshaler(t reflect.Type) bool {
103	t = reflect.PtrTo(t)
104	return t.Implements(reflect.TypeOf((*easyjson.UnknownsMarshaler)(nil)).Elem())
105}
106
107// genTypeDecoderNoCheck generates decoding code for the type t.
108func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags fieldTags, indent int) error {
109	ws := strings.Repeat("  ", indent)
110	// Check whether type is primitive, needs to be done after interface check.
111	if dec := customDecoders[t.String()]; dec != "" {
112		fmt.Fprintln(g.out, ws+out+" = "+dec)
113		return nil
114	} else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString {
115		fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
116		return nil
117	} else if dec := primitiveDecoders[t.Kind()]; dec != "" {
118		fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
119		return nil
120	}
121
122	switch t.Kind() {
123	case reflect.Slice:
124		tmpVar := g.uniqueVarName()
125		elem := t.Elem()
126
127		if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" {
128			fmt.Fprintln(g.out, ws+"if in.IsNull() {")
129			fmt.Fprintln(g.out, ws+"  in.Skip()")
130			fmt.Fprintln(g.out, ws+"  "+out+" = nil")
131			fmt.Fprintln(g.out, ws+"} else {")
132			fmt.Fprintln(g.out, ws+"  "+out+" = in.Bytes()")
133			fmt.Fprintln(g.out, ws+"}")
134
135		} else {
136
137			capacity := minSliceBytes / elem.Size()
138			if capacity == 0 {
139				capacity = 1
140			}
141
142			fmt.Fprintln(g.out, ws+"if in.IsNull() {")
143			fmt.Fprintln(g.out, ws+"  in.Skip()")
144			fmt.Fprintln(g.out, ws+"  "+out+" = nil")
145			fmt.Fprintln(g.out, ws+"} else {")
146			fmt.Fprintln(g.out, ws+"  in.Delim('[')")
147			fmt.Fprintln(g.out, ws+"  if "+out+" == nil {")
148			fmt.Fprintln(g.out, ws+"    if !in.IsDelim(']') {")
149			fmt.Fprintln(g.out, ws+"      "+out+" = make("+g.getType(t)+", 0, "+fmt.Sprint(capacity)+")")
150			fmt.Fprintln(g.out, ws+"    } else {")
151			fmt.Fprintln(g.out, ws+"      "+out+" = "+g.getType(t)+"{}")
152			fmt.Fprintln(g.out, ws+"    }")
153			fmt.Fprintln(g.out, ws+"  } else { ")
154			fmt.Fprintln(g.out, ws+"    "+out+" = ("+out+")[:0]")
155			fmt.Fprintln(g.out, ws+"  }")
156			fmt.Fprintln(g.out, ws+"  for !in.IsDelim(']') {")
157			fmt.Fprintln(g.out, ws+"    var "+tmpVar+" "+g.getType(elem))
158
159			if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil {
160				return err
161			}
162
163			fmt.Fprintln(g.out, ws+"    "+out+" = append("+out+", "+tmpVar+")")
164			fmt.Fprintln(g.out, ws+"    in.WantComma()")
165			fmt.Fprintln(g.out, ws+"  }")
166			fmt.Fprintln(g.out, ws+"  in.Delim(']')")
167			fmt.Fprintln(g.out, ws+"}")
168		}
169
170	case reflect.Array:
171		iterVar := g.uniqueVarName()
172		elem := t.Elem()
173
174		if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" {
175			fmt.Fprintln(g.out, ws+"if in.IsNull() {")
176			fmt.Fprintln(g.out, ws+"  in.Skip()")
177			fmt.Fprintln(g.out, ws+"} else {")
178			fmt.Fprintln(g.out, ws+"  copy("+out+"[:], in.Bytes())")
179			fmt.Fprintln(g.out, ws+"}")
180
181		} else {
182
183			length := t.Len()
184
185			fmt.Fprintln(g.out, ws+"if in.IsNull() {")
186			fmt.Fprintln(g.out, ws+"  in.Skip()")
187			fmt.Fprintln(g.out, ws+"} else {")
188			fmt.Fprintln(g.out, ws+"  in.Delim('[')")
189			fmt.Fprintln(g.out, ws+"  "+iterVar+" := 0")
190			fmt.Fprintln(g.out, ws+"  for !in.IsDelim(']') {")
191			fmt.Fprintln(g.out, ws+"    if "+iterVar+" < "+fmt.Sprint(length)+" {")
192
193			if err := g.genTypeDecoder(elem, "("+out+")["+iterVar+"]", tags, indent+3); err != nil {
194				return err
195			}
196
197			fmt.Fprintln(g.out, ws+"      "+iterVar+"++")
198			fmt.Fprintln(g.out, ws+"    } else {")
199			fmt.Fprintln(g.out, ws+"      in.SkipRecursive()")
200			fmt.Fprintln(g.out, ws+"    }")
201			fmt.Fprintln(g.out, ws+"    in.WantComma()")
202			fmt.Fprintln(g.out, ws+"  }")
203			fmt.Fprintln(g.out, ws+"  in.Delim(']')")
204			fmt.Fprintln(g.out, ws+"}")
205		}
206
207	case reflect.Struct:
208		dec := g.getDecoderName(t)
209		g.addType(t)
210
211		if len(out) > 0 && out[0] == '*' {
212			// NOTE: In order to remove an extra reference to a pointer
213			fmt.Fprintln(g.out, ws+dec+"(in, "+out[1:]+")")
214		} else {
215			fmt.Fprintln(g.out, ws+dec+"(in, &"+out+")")
216		}
217
218	case reflect.Ptr:
219		fmt.Fprintln(g.out, ws+"if in.IsNull() {")
220		fmt.Fprintln(g.out, ws+"  in.Skip()")
221		fmt.Fprintln(g.out, ws+"  "+out+" = nil")
222		fmt.Fprintln(g.out, ws+"} else {")
223		fmt.Fprintln(g.out, ws+"  if "+out+" == nil {")
224		fmt.Fprintln(g.out, ws+"    "+out+" = new("+g.getType(t.Elem())+")")
225		fmt.Fprintln(g.out, ws+"  }")
226
227		if err := g.genTypeDecoder(t.Elem(), "*"+out, tags, indent+1); err != nil {
228			return err
229		}
230
231		fmt.Fprintln(g.out, ws+"}")
232
233	case reflect.Map:
234		key := t.Key()
235		keyDec, ok := primitiveStringDecoders[key.Kind()]
236		if !ok && !hasCustomUnmarshaler(key) {
237			return fmt.Errorf("map type %v not supported: only string and integer keys and types implementing json.Unmarshaler are allowed", key)
238		} // else assume the caller knows what they are doing and that the custom unmarshaler performs the translation from string or integer keys to the key type
239		elem := t.Elem()
240		tmpVar := g.uniqueVarName()
241		keepEmpty := tags.required || tags.noOmitEmpty || (!g.omitEmpty && !tags.omitEmpty)
242
243		fmt.Fprintln(g.out, ws+"if in.IsNull() {")
244		fmt.Fprintln(g.out, ws+"  in.Skip()")
245		fmt.Fprintln(g.out, ws+"} else {")
246		fmt.Fprintln(g.out, ws+"  in.Delim('{')")
247		if !keepEmpty {
248			fmt.Fprintln(g.out, ws+"  if !in.IsDelim('}') {")
249		}
250		fmt.Fprintln(g.out, ws+"  "+out+" = make("+g.getType(t)+")")
251		if !keepEmpty {
252			fmt.Fprintln(g.out, ws+"  } else {")
253			fmt.Fprintln(g.out, ws+"  "+out+" = nil")
254			fmt.Fprintln(g.out, ws+"  }")
255		}
256
257		fmt.Fprintln(g.out, ws+"  for !in.IsDelim('}') {")
258		// NOTE: extra check for TextUnmarshaler. It overrides default methods.
259		if reflect.PtrTo(key).Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) {
260			fmt.Fprintln(g.out, ws+"    var key "+g.getType(key))
261			fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {")
262			fmt.Fprintln(g.out, ws+"  in.AddError(key.UnmarshalText(data) )")
263			fmt.Fprintln(g.out, ws+"}")
264		} else if keyDec != "" {
265			fmt.Fprintln(g.out, ws+"    key := "+g.getType(key)+"("+keyDec+")")
266		} else {
267			fmt.Fprintln(g.out, ws+"    var key "+g.getType(key))
268			if err := g.genTypeDecoder(key, "key", tags, indent+2); err != nil {
269				return err
270			}
271		}
272
273		fmt.Fprintln(g.out, ws+"    in.WantColon()")
274		fmt.Fprintln(g.out, ws+"    var "+tmpVar+" "+g.getType(elem))
275
276		if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil {
277			return err
278		}
279
280		fmt.Fprintln(g.out, ws+"    ("+out+")[key] = "+tmpVar)
281		fmt.Fprintln(g.out, ws+"    in.WantComma()")
282		fmt.Fprintln(g.out, ws+"  }")
283		fmt.Fprintln(g.out, ws+"  in.Delim('}')")
284		fmt.Fprintln(g.out, ws+"}")
285
286	case reflect.Interface:
287		if t.NumMethod() != 0 {
288			return fmt.Errorf("interface type %v not supported: only interface{} is allowed", t)
289		}
290		fmt.Fprintln(g.out, ws+"if m, ok := "+out+".(easyjson.Unmarshaler); ok {")
291		fmt.Fprintln(g.out, ws+"m.UnmarshalEasyJSON(in)")
292		fmt.Fprintln(g.out, ws+"} else if m, ok := "+out+".(json.Unmarshaler); ok {")
293		fmt.Fprintln(g.out, ws+"_ = m.UnmarshalJSON(in.Raw())")
294		fmt.Fprintln(g.out, ws+"} else {")
295		fmt.Fprintln(g.out, ws+"  "+out+" = in.Interface()")
296		fmt.Fprintln(g.out, ws+"}")
297	default:
298		return fmt.Errorf("don't know how to decode %v", t)
299	}
300	return nil
301
302}
303
304func (g *Generator) genStructFieldDecoder(t reflect.Type, f reflect.StructField) error {
305	jsonName := g.fieldNamer.GetJSONFieldName(t, f)
306	tags := parseFieldTags(f)
307
308	if tags.omit {
309		return nil
310	}
311
312	fmt.Fprintf(g.out, "    case %q:\n", jsonName)
313	if err := g.genTypeDecoder(f.Type, "out."+f.Name, tags, 3); err != nil {
314		return err
315	}
316
317	if tags.required {
318		fmt.Fprintf(g.out, "%sSet = true\n", f.Name)
319	}
320
321	return nil
322}
323
324func (g *Generator) genRequiredFieldSet(t reflect.Type, f reflect.StructField) {
325	tags := parseFieldTags(f)
326
327	if !tags.required {
328		return
329	}
330
331	fmt.Fprintf(g.out, "var %sSet bool\n", f.Name)
332}
333
334func (g *Generator) genRequiredFieldCheck(t reflect.Type, f reflect.StructField) {
335	jsonName := g.fieldNamer.GetJSONFieldName(t, f)
336	tags := parseFieldTags(f)
337
338	if !tags.required {
339		return
340	}
341
342	g.imports["fmt"] = "fmt"
343
344	fmt.Fprintf(g.out, "if !%sSet {\n", f.Name)
345	fmt.Fprintf(g.out, "    in.AddError(fmt.Errorf(\"key '%s' is required\"))\n", jsonName)
346	fmt.Fprintf(g.out, "}\n")
347}
348
349func mergeStructFields(fields1, fields2 []reflect.StructField) (fields []reflect.StructField) {
350	used := map[string]bool{}
351	for _, f := range fields2 {
352		used[f.Name] = true
353		fields = append(fields, f)
354	}
355
356	for _, f := range fields1 {
357		if !used[f.Name] {
358			fields = append(fields, f)
359		}
360	}
361	return
362}
363
364func getStructFields(t reflect.Type) ([]reflect.StructField, error) {
365	if t.Kind() != reflect.Struct {
366		return nil, fmt.Errorf("got %v; expected a struct", t)
367	}
368
369	var efields []reflect.StructField
370	for i := 0; i < t.NumField(); i++ {
371		f := t.Field(i)
372		tags := parseFieldTags(f)
373		if !f.Anonymous || tags.name != "" {
374			continue
375		}
376
377		t1 := f.Type
378		if t1.Kind() == reflect.Ptr {
379			t1 = t1.Elem()
380		}
381
382		fs, err := getStructFields(t1)
383		if err != nil {
384			return nil, fmt.Errorf("error processing embedded field: %v", err)
385		}
386		efields = mergeStructFields(efields, fs)
387	}
388
389	var fields []reflect.StructField
390	for i := 0; i < t.NumField(); i++ {
391		f := t.Field(i)
392		tags := parseFieldTags(f)
393		if f.Anonymous && tags.name == "" {
394			continue
395		}
396
397		c := []rune(f.Name)[0]
398		if unicode.IsUpper(c) {
399			fields = append(fields, f)
400		}
401	}
402	return mergeStructFields(efields, fields), nil
403}
404
405func (g *Generator) genDecoder(t reflect.Type) error {
406	switch t.Kind() {
407	case reflect.Slice, reflect.Array, reflect.Map:
408		return g.genSliceArrayDecoder(t)
409	default:
410		return g.genStructDecoder(t)
411	}
412}
413
414func (g *Generator) genSliceArrayDecoder(t reflect.Type) error {
415	switch t.Kind() {
416	case reflect.Slice, reflect.Array, reflect.Map:
417	default:
418		return fmt.Errorf("cannot generate encoder/decoder for %v, not a slice/array/map type", t)
419	}
420
421	fname := g.getDecoderName(t)
422	typ := g.getType(t)
423
424	fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {")
425	fmt.Fprintln(g.out, " isTopLevel := in.IsStart()")
426	err := g.genTypeDecoderNoCheck(t, "*out", fieldTags{}, 1)
427	if err != nil {
428		return err
429	}
430	fmt.Fprintln(g.out, "  if isTopLevel {")
431	fmt.Fprintln(g.out, "    in.Consumed()")
432	fmt.Fprintln(g.out, "  }")
433	fmt.Fprintln(g.out, "}")
434
435	return nil
436}
437
438func (g *Generator) genStructDecoder(t reflect.Type) error {
439	if t.Kind() != reflect.Struct {
440		return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct type", t)
441	}
442
443	fname := g.getDecoderName(t)
444	typ := g.getType(t)
445
446	fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {")
447	fmt.Fprintln(g.out, "  isTopLevel := in.IsStart()")
448	fmt.Fprintln(g.out, "  if in.IsNull() {")
449	fmt.Fprintln(g.out, "    if isTopLevel {")
450	fmt.Fprintln(g.out, "      in.Consumed()")
451	fmt.Fprintln(g.out, "    }")
452	fmt.Fprintln(g.out, "    in.Skip()")
453	fmt.Fprintln(g.out, "    return")
454	fmt.Fprintln(g.out, "  }")
455
456	// Init embedded pointer fields.
457	for i := 0; i < t.NumField(); i++ {
458		f := t.Field(i)
459		if !f.Anonymous || f.Type.Kind() != reflect.Ptr {
460			continue
461		}
462		fmt.Fprintln(g.out, "  out."+f.Name+" = new("+g.getType(f.Type.Elem())+")")
463	}
464
465	fs, err := getStructFields(t)
466	if err != nil {
467		return fmt.Errorf("cannot generate decoder for %v: %v", t, err)
468	}
469
470	for _, f := range fs {
471		g.genRequiredFieldSet(t, f)
472	}
473
474	fmt.Fprintln(g.out, "  in.Delim('{')")
475	fmt.Fprintln(g.out, "  for !in.IsDelim('}') {")
476	fmt.Fprintln(g.out, "    key := in.UnsafeString()")
477	fmt.Fprintln(g.out, "    in.WantColon()")
478	fmt.Fprintln(g.out, "    if in.IsNull() {")
479	fmt.Fprintln(g.out, "       in.Skip()")
480	fmt.Fprintln(g.out, "       in.WantComma()")
481	fmt.Fprintln(g.out, "       continue")
482	fmt.Fprintln(g.out, "    }")
483
484	fmt.Fprintln(g.out, "    switch key {")
485	for _, f := range fs {
486		if err := g.genStructFieldDecoder(t, f); err != nil {
487			return err
488		}
489	}
490
491	fmt.Fprintln(g.out, "    default:")
492	if g.disallowUnknownFields {
493		fmt.Fprintln(g.out, `      in.AddError(&jlexer.LexerError{
494          Offset: in.GetPos(),
495          Reason: "unknown field",
496          Data: key,
497      })`)
498	} else if hasUnknownsUnmarshaler(t) {
499		fmt.Fprintln(g.out, "      out.UnmarshalUnknown(in, key)")
500	} else {
501		fmt.Fprintln(g.out, "      in.SkipRecursive()")
502	}
503	fmt.Fprintln(g.out, "    }")
504	fmt.Fprintln(g.out, "    in.WantComma()")
505	fmt.Fprintln(g.out, "  }")
506	fmt.Fprintln(g.out, "  in.Delim('}')")
507	fmt.Fprintln(g.out, "  if isTopLevel {")
508	fmt.Fprintln(g.out, "    in.Consumed()")
509	fmt.Fprintln(g.out, "  }")
510
511	for _, f := range fs {
512		g.genRequiredFieldCheck(t, f)
513	}
514
515	fmt.Fprintln(g.out, "}")
516
517	return nil
518}
519
520func (g *Generator) genStructUnmarshaler(t reflect.Type) error {
521	switch t.Kind() {
522	case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct:
523	default:
524		return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct/slice/array/map type", t)
525	}
526
527	fname := g.getDecoderName(t)
528	typ := g.getType(t)
529
530	if !g.noStdMarshalers {
531		fmt.Fprintln(g.out, "// UnmarshalJSON supports json.Unmarshaler interface")
532		fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalJSON(data []byte) error {")
533		fmt.Fprintln(g.out, "  r := jlexer.Lexer{Data: data}")
534		fmt.Fprintln(g.out, "  "+fname+"(&r, v)")
535		fmt.Fprintln(g.out, "  return r.Error()")
536		fmt.Fprintln(g.out, "}")
537	}
538
539	fmt.Fprintln(g.out, "// UnmarshalEasyJSON supports easyjson.Unmarshaler interface")
540	fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalEasyJSON(l *jlexer.Lexer) {")
541	fmt.Fprintln(g.out, "  "+fname+"(l, v)")
542	fmt.Fprintln(g.out, "}")
543
544	return nil
545}
546