1package gen
2
3import (
4	"fmt"
5	"strings"
6)
7
8var (
9	identNext   = 0
10	identPrefix = "za"
11)
12
13func resetIdent(prefix string) {
14	identPrefix = prefix
15	identNext = 0
16}
17
18// generate a random identifier name
19func randIdent() string {
20	identNext++
21	return fmt.Sprintf("%s%04d", identPrefix, identNext)
22}
23
24// This code defines the type declaration tree.
25//
26// Consider the following:
27//
28// type Marshaler struct {
29// 	  Thing1 *float64 `msg:"thing1"`
30// 	  Body   []byte   `msg:"body"`
31// }
32//
33// A parser using this generator as a backend
34// should parse the above into:
35//
36// var val Elem = &Ptr{
37// 	name: "z",
38// 	Value: &Struct{
39// 		Name: "Marshaler",
40// 		Fields: []StructField{
41// 			{
42// 				FieldTag: "thing1",
43// 				FieldElem: &Ptr{
44// 					name: "z.Thing1",
45// 					Value: &BaseElem{
46// 						name:    "*z.Thing1",
47// 						Value:   Float64,
48//						Convert: false,
49// 					},
50// 				},
51// 			},
52// 			{
53// 				FieldTag: "body",
54// 				FieldElem: &BaseElem{
55// 					name:    "z.Body",
56// 					Value:   Bytes,
57// 					Convert: false,
58// 				},
59// 			},
60// 		},
61// 	},
62// }
63
64// Base is one of the
65// base types
66type Primitive uint8
67
68// this is effectively the
69// list of currently available
70// ReadXxxx / WriteXxxx methods.
71const (
72	Invalid Primitive = iota
73	Bytes
74	String
75	Float32
76	Float64
77	Complex64
78	Complex128
79	Uint
80	Uint8
81	Uint16
82	Uint32
83	Uint64
84	Byte
85	Int
86	Int8
87	Int16
88	Int32
89	Int64
90	Bool
91	Intf // interface{}
92	Time // time.Time
93	Ext  // extension
94
95	IDENT // IDENT means an unrecognized identifier
96)
97
98// all of the recognized identities
99// that map to primitive types
100var primitives = map[string]Primitive{
101	"[]byte":         Bytes,
102	"string":         String,
103	"float32":        Float32,
104	"float64":        Float64,
105	"complex64":      Complex64,
106	"complex128":     Complex128,
107	"uint":           Uint,
108	"uint8":          Uint8,
109	"uint16":         Uint16,
110	"uint32":         Uint32,
111	"uint64":         Uint64,
112	"byte":           Byte,
113	"rune":           Int32,
114	"int":            Int,
115	"int8":           Int8,
116	"int16":          Int16,
117	"int32":          Int32,
118	"int64":          Int64,
119	"bool":           Bool,
120	"interface{}":    Intf,
121	"time.Time":      Time,
122	"msgp.Extension": Ext,
123}
124
125// types built into the library
126// that satisfy all of the
127// interfaces.
128var builtins = map[string]struct{}{
129	"msgp.Raw":    struct{}{},
130	"msgp.Number": struct{}{},
131}
132
133// common data/methods for every Elem
134type common struct{ vname, alias string }
135
136func (c *common) SetVarname(s string) { c.vname = s }
137func (c *common) Varname() string     { return c.vname }
138func (c *common) Alias(typ string)    { c.alias = typ }
139func (c *common) hidden()             {}
140
141func IsPrintable(e Elem) bool {
142	if be, ok := e.(*BaseElem); ok && !be.Printable() {
143		return false
144	}
145	return true
146}
147
148// Elem is a go type capable of being
149// serialized into MessagePack. It is
150// implemented by *Ptr, *Struct, *Array,
151// *Slice, *Map, and *BaseElem.
152type Elem interface {
153	// SetVarname sets this nodes
154	// variable name and recursively
155	// sets the names of all its children.
156	// In general, this should only be
157	// called on the parent of the tree.
158	SetVarname(s string)
159
160	// Varname returns the variable
161	// name of the element.
162	Varname() string
163
164	// TypeName is the canonical
165	// go type name of the node
166	// e.g. "string", "int", "map[string]float64"
167	// OR the alias name, if it has been set.
168	TypeName() string
169
170	// Alias sets a type (alias) name
171	Alias(typ string)
172
173	// Copy should perform a deep copy of the object
174	Copy() Elem
175
176	// Complexity returns a measure of the
177	// complexity of element (greater than
178	// or equal to 1.)
179	Complexity() int
180
181	// ZeroExpr returns the expression for the correct zero/empty
182	// value.  Can be used for assignment.
183	// Returns "" if zero/empty not supported for this Elem.
184	ZeroExpr() string
185
186	// IfZeroExpr returns the expression to compare to zero/empty
187	// for this type.  It is meant to be used in an if statement
188	// and may include the simple statement form followed by
189	// semicolon and then the expression.
190	// Returns "" if zero/empty not supported for this Elem.
191	IfZeroExpr() string
192
193	hidden()
194}
195
196// Ident returns the *BaseElem that corresponds
197// to the provided identity.
198func Ident(id string) *BaseElem {
199	p, ok := primitives[id]
200	if ok {
201		return &BaseElem{Value: p}
202	}
203	be := &BaseElem{Value: IDENT}
204	be.Alias(id)
205	return be
206}
207
208type Array struct {
209	common
210	Index string // index variable name
211	Size  string // array size
212	Els   Elem   // child
213}
214
215func (a *Array) SetVarname(s string) {
216	a.common.SetVarname(s)
217ridx:
218	a.Index = randIdent()
219
220	// try to avoid using the same
221	// index as a parent slice
222	if strings.Contains(a.Varname(), a.Index) {
223		goto ridx
224	}
225
226	a.Els.SetVarname(fmt.Sprintf("%s[%s]", a.Varname(), a.Index))
227}
228
229func (a *Array) TypeName() string {
230	if a.common.alias != "" {
231		return a.common.alias
232	}
233	a.common.Alias(fmt.Sprintf("[%s]%s", a.Size, a.Els.TypeName()))
234	return a.common.alias
235}
236
237func (a *Array) Copy() Elem {
238	b := *a
239	b.Els = a.Els.Copy()
240	return &b
241}
242
243func (a *Array) Complexity() int { return 1 + a.Els.Complexity() }
244
245// ZeroExpr returns the zero/empty expression or empty string if not supported.  Unsupported for this case.
246func (a *Array) ZeroExpr() string { return "" }
247
248// IfZeroExpr unsupported
249func (a *Array) IfZeroExpr() string { return "" }
250
251// Map is a map[string]Elem
252type Map struct {
253	common
254	Keyidx string // key variable name
255	Validx string // value variable name
256	Value  Elem   // value element
257}
258
259func (m *Map) SetVarname(s string) {
260	m.common.SetVarname(s)
261ridx:
262	m.Keyidx = randIdent()
263	m.Validx = randIdent()
264
265	// just in case
266	if m.Keyidx == m.Validx {
267		goto ridx
268	}
269
270	m.Value.SetVarname(m.Validx)
271}
272
273func (m *Map) TypeName() string {
274	if m.common.alias != "" {
275		return m.common.alias
276	}
277	m.common.Alias("map[string]" + m.Value.TypeName())
278	return m.common.alias
279}
280
281func (m *Map) Copy() Elem {
282	g := *m
283	g.Value = m.Value.Copy()
284	return &g
285}
286
287func (m *Map) Complexity() int { return 2 + m.Value.Complexity() }
288
289// ZeroExpr returns the zero/empty expression or empty string if not supported.  Always "nil" for this case.
290func (m *Map) ZeroExpr() string { return "nil" }
291
292// IfZeroExpr returns the expression to compare to zero/empty.
293func (m *Map) IfZeroExpr() string { return m.Varname() + " == nil" }
294
295type Slice struct {
296	common
297	Index string
298	Els   Elem // The type of each element
299}
300
301func (s *Slice) SetVarname(a string) {
302	s.common.SetVarname(a)
303	s.Index = randIdent()
304	varName := s.Varname()
305	if varName[0] == '*' {
306		// Pointer-to-slice requires parenthesis for slicing.
307		varName = "(" + varName + ")"
308	}
309	s.Els.SetVarname(fmt.Sprintf("%s[%s]", varName, s.Index))
310}
311
312func (s *Slice) TypeName() string {
313	if s.common.alias != "" {
314		return s.common.alias
315	}
316	s.common.Alias("[]" + s.Els.TypeName())
317	return s.common.alias
318}
319
320func (s *Slice) Copy() Elem {
321	z := *s
322	z.Els = s.Els.Copy()
323	return &z
324}
325
326func (s *Slice) Complexity() int {
327	return 1 + s.Els.Complexity()
328}
329
330// ZeroExpr returns the zero/empty expression or empty string if not supported.  Always "nil" for this case.
331func (s *Slice) ZeroExpr() string { return "nil" }
332
333// IfZeroExpr returns the expression to compare to zero/empty.
334func (s *Slice) IfZeroExpr() string { return s.Varname() + " == nil" }
335
336type Ptr struct {
337	common
338	Value Elem
339}
340
341func (s *Ptr) SetVarname(a string) {
342	s.common.SetVarname(a)
343
344	// struct fields are dereferenced
345	// automatically...
346	switch x := s.Value.(type) {
347	case *Struct:
348		// struct fields are automatically dereferenced
349		x.SetVarname(a)
350		return
351
352	case *BaseElem:
353		// identities have pointer receivers
354		if x.Value == IDENT {
355			x.SetVarname(a)
356		} else {
357			x.SetVarname("*" + a)
358		}
359		return
360
361	default:
362		s.Value.SetVarname("*" + a)
363		return
364	}
365}
366
367func (s *Ptr) TypeName() string {
368	if s.common.alias != "" {
369		return s.common.alias
370	}
371	s.common.Alias("*" + s.Value.TypeName())
372	return s.common.alias
373}
374
375func (s *Ptr) Copy() Elem {
376	v := *s
377	v.Value = s.Value.Copy()
378	return &v
379}
380
381func (s *Ptr) Complexity() int { return 1 + s.Value.Complexity() }
382
383func (s *Ptr) Needsinit() bool {
384	if be, ok := s.Value.(*BaseElem); ok && be.needsref {
385		return false
386	}
387	return true
388}
389
390// ZeroExpr returns the zero/empty expression or empty string if not supported.  Always "nil" for this case.
391func (s *Ptr) ZeroExpr() string { return "nil" }
392
393// IfZeroExpr returns the expression to compare to zero/empty.
394func (s *Ptr) IfZeroExpr() string { return s.Varname() + " == nil" }
395
396type Struct struct {
397	common
398	Fields  []StructField // field list
399	AsTuple bool          // write as an array instead of a map
400}
401
402func (s *Struct) TypeName() string {
403	if s.common.alias != "" {
404		return s.common.alias
405	}
406	str := "struct{\n"
407	for i := range s.Fields {
408		str += s.Fields[i].FieldName +
409			" " + s.Fields[i].FieldElem.TypeName() +
410			" " + s.Fields[i].RawTag + ";\n"
411	}
412	str += "}"
413	s.common.Alias(str)
414	return s.common.alias
415}
416
417func (s *Struct) SetVarname(a string) {
418	s.common.SetVarname(a)
419	writeStructFields(s.Fields, a)
420}
421
422func (s *Struct) Copy() Elem {
423	g := *s
424	g.Fields = make([]StructField, len(s.Fields))
425	copy(g.Fields, s.Fields)
426	for i := range s.Fields {
427		g.Fields[i].FieldElem = s.Fields[i].FieldElem.Copy()
428	}
429	return &g
430}
431
432func (s *Struct) Complexity() int {
433	c := 1
434	for i := range s.Fields {
435		c += s.Fields[i].FieldElem.Complexity()
436	}
437	return c
438}
439
440// ZeroExpr returns the zero/empty expression or empty string if not supported.
441func (s *Struct) ZeroExpr() string {
442	if s.alias == "" {
443		return "" // structs with no names not supported (for now)
444	}
445	return "(" + s.TypeName() + "{})"
446}
447
448// IfZeroExpr returns the expression to compare to zero/empty.
449func (s *Struct) IfZeroExpr() string {
450	if s.alias == "" {
451		return "" // structs with no names not supported (for now)
452	}
453	return s.Varname() + " == " + s.ZeroExpr()
454}
455
456// AnyHasTagPart returns true if HasTagPart(p) is true for any field.
457func (s *Struct) AnyHasTagPart(pname string) bool {
458	for _, sf := range s.Fields {
459		if sf.HasTagPart(pname) {
460			return true
461		}
462	}
463	return false
464}
465
466type StructField struct {
467	FieldTag      string   // the string inside the `msg:""` tag up to the first comma
468	FieldTagParts []string // the string inside the `msg:""` tag split by commas
469	RawTag        string   // the full struct tag
470	FieldName     string   // the name of the struct field
471	FieldElem     Elem     // the field type
472}
473
474// HasTagPart returns true if the specified tag part (option) is present.
475func (sf *StructField) HasTagPart(pname string) bool {
476	if len(sf.FieldTagParts) < 2 {
477		return false
478	}
479	for _, p := range sf.FieldTagParts[1:] {
480		if p == pname {
481			return true
482		}
483	}
484	return false
485}
486
487type ShimMode int
488
489const (
490	Cast ShimMode = iota
491	Convert
492)
493
494// BaseElem is an element that
495// can be represented by a primitive
496// MessagePack type.
497type BaseElem struct {
498	common
499	ShimMode     ShimMode  // Method used to shim
500	ShimToBase   string    // shim to base type, or empty
501	ShimFromBase string    // shim from base type, or empty
502	Value        Primitive // Type of element
503	Convert      bool      // should we do an explicit conversion?
504	mustinline   bool      // must inline; not printable
505	needsref     bool      // needs reference for shim
506}
507
508func (s *BaseElem) Printable() bool { return !s.mustinline }
509
510func (s *BaseElem) Alias(typ string) {
511	s.common.Alias(typ)
512	if s.Value != IDENT {
513		s.Convert = true
514	}
515	if strings.Contains(typ, ".") {
516		s.mustinline = true
517	}
518}
519
520func (s *BaseElem) SetVarname(a string) {
521	// extensions whose parents
522	// are not pointers need to
523	// be explicitly referenced
524	if s.Value == Ext || s.needsref {
525		if strings.HasPrefix(a, "*") {
526			s.common.SetVarname(a[1:])
527			return
528		}
529		s.common.SetVarname("&" + a)
530		return
531	}
532
533	s.common.SetVarname(a)
534}
535
536// TypeName returns the syntactically correct Go
537// type name for the base element.
538func (s *BaseElem) TypeName() string {
539	if s.common.alias != "" {
540		return s.common.alias
541	}
542	s.common.Alias(s.BaseType())
543	return s.common.alias
544}
545
546// ToBase, used if Convert==true, is used as tmp = {{ToBase}}({{Varname}})
547func (s *BaseElem) ToBase() string {
548	if s.ShimToBase != "" {
549		return s.ShimToBase
550	}
551	return s.BaseType()
552}
553
554// FromBase, used if Convert==true, is used as {{Varname}} = {{FromBase}}(tmp)
555func (s *BaseElem) FromBase() string {
556	if s.ShimFromBase != "" {
557		return s.ShimFromBase
558	}
559	return s.TypeName()
560}
561
562// BaseName returns the string form of the
563// base type (e.g. Float64, Ident, etc)
564func (s *BaseElem) BaseName() string {
565	// time is a special case;
566	// we strip the package prefix
567	if s.Value == Time {
568		return "Time"
569	}
570	return s.Value.String()
571}
572
573func (s *BaseElem) BaseType() string {
574	switch s.Value {
575	case IDENT:
576		return s.TypeName()
577
578	// exceptions to the naming/capitalization
579	// rule:
580	case Intf:
581		return "interface{}"
582	case Bytes:
583		return "[]byte"
584	case Time:
585		return "time.Time"
586	case Ext:
587		return "msgp.Extension"
588
589	// everything else is base.String() with
590	// the first letter as lowercase
591	default:
592		return strings.ToLower(s.BaseName())
593	}
594}
595
596func (s *BaseElem) Needsref(b bool) {
597	s.needsref = b
598}
599
600func (s *BaseElem) Copy() Elem {
601	g := *s
602	return &g
603}
604
605func (s *BaseElem) Complexity() int {
606	if s.Convert && !s.mustinline {
607		return 2
608	}
609	// we need to return 1 if !printable(),
610	// in order to make sure that stuff gets
611	// inlined appropriately
612	return 1
613}
614
615// Resolved returns whether or not
616// the type of the element is
617// a primitive or a builtin provided
618// by the package.
619func (s *BaseElem) Resolved() bool {
620	if s.Value == IDENT {
621		_, ok := builtins[s.TypeName()]
622		return ok
623	}
624	return true
625}
626
627// ZeroExpr returns the zero/empty expression or empty string if not supported.
628func (s *BaseElem) ZeroExpr() string {
629
630	switch s.Value {
631	case Bytes:
632		return "nil"
633	case String:
634		return "\"\""
635	case Complex64, Complex128:
636		return "complex(0,0)"
637	case Float32,
638		Float64,
639		Uint,
640		Uint8,
641		Uint16,
642		Uint32,
643		Uint64,
644		Byte,
645		Int,
646		Int8,
647		Int16,
648		Int32,
649		Int64:
650		return "0"
651	case Bool:
652		return "false"
653
654	case Time:
655		return "(time.Time{})"
656
657	}
658
659	return ""
660}
661
662// IfZeroExpr returns the expression to compare to zero/empty.
663func (s *BaseElem) IfZeroExpr() string {
664	z := s.ZeroExpr()
665	if z == "" {
666		return ""
667	}
668	return s.Varname() + " == " + z
669}
670
671func (k Primitive) String() string {
672	switch k {
673	case String:
674		return "String"
675	case Bytes:
676		return "Bytes"
677	case Float32:
678		return "Float32"
679	case Float64:
680		return "Float64"
681	case Complex64:
682		return "Complex64"
683	case Complex128:
684		return "Complex128"
685	case Uint:
686		return "Uint"
687	case Uint8:
688		return "Uint8"
689	case Uint16:
690		return "Uint16"
691	case Uint32:
692		return "Uint32"
693	case Uint64:
694		return "Uint64"
695	case Byte:
696		return "Byte"
697	case Int:
698		return "Int"
699	case Int8:
700		return "Int8"
701	case Int16:
702		return "Int16"
703	case Int32:
704		return "Int32"
705	case Int64:
706		return "Int64"
707	case Bool:
708		return "Bool"
709	case Intf:
710		return "Intf"
711	case Time:
712		return "time.Time"
713	case Ext:
714		return "Extension"
715	case IDENT:
716		return "Ident"
717	default:
718		return "INVALID"
719	}
720}
721
722// writeStructFields is a trampoline for writeBase for
723// all of the fields in a struct
724func writeStructFields(s []StructField, name string) {
725	for i := range s {
726		s[i].FieldElem.SetVarname(fmt.Sprintf("%s.%s", name, s[i].FieldName))
727	}
728}
729
730// coerceArraySize ensures we can compare constant array lengths.
731//
732// msgpack array headers are 32 bit unsigned, which is reflected in the
733// ArrayHeader implementation in this library using uint32. On the Go side, we
734// can declare array lengths as any constant integer width, which breaks when
735// attempting a direct comparison to an array header's uint32.
736//
737func coerceArraySize(asz string) string {
738	return fmt.Sprintf("uint32(%s)", asz)
739}
740