1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package generators
18
19import (
20	"bytes"
21	"fmt"
22	"io"
23	"path/filepath"
24	"reflect"
25	"strings"
26
27	"k8s.io/gengo/args"
28	"k8s.io/gengo/generator"
29	"k8s.io/gengo/namer"
30	"k8s.io/gengo/types"
31
32	"k8s.io/klog"
33)
34
35// CustomArgs is used tby the go2idl framework to pass args specific to this
36// generator.
37type CustomArgs struct {
38	ExtraPeerDirs []string // Always consider these as last-ditch possibilities for conversions.
39}
40
41// These are the comment tags that carry parameters for defaulter generation.
42const tagName = "k8s:defaulter-gen"
43const intputTagName = "k8s:defaulter-gen-input"
44
45func extractTag(comments []string) []string {
46	return types.ExtractCommentTags("+", comments)[tagName]
47}
48
49func extractInputTag(comments []string) []string {
50	return types.ExtractCommentTags("+", comments)[intputTagName]
51}
52
53func checkTag(comments []string, require ...string) bool {
54	values := types.ExtractCommentTags("+", comments)[tagName]
55	if len(require) == 0 {
56		return len(values) == 1 && values[0] == ""
57	}
58	return reflect.DeepEqual(values, require)
59}
60
61func defaultFnNamer() *namer.NameStrategy {
62	return &namer.NameStrategy{
63		Prefix: "SetDefaults_",
64		Join: func(pre string, in []string, post string) string {
65			return pre + strings.Join(in, "_") + post
66		},
67	}
68}
69
70func objectDefaultFnNamer() *namer.NameStrategy {
71	return &namer.NameStrategy{
72		Prefix: "SetObjectDefaults_",
73		Join: func(pre string, in []string, post string) string {
74			return pre + strings.Join(in, "_") + post
75		},
76	}
77}
78
79// NameSystems returns the name system used by the generators in this package.
80func NameSystems() namer.NameSystems {
81	return namer.NameSystems{
82		"public":          namer.NewPublicNamer(1),
83		"raw":             namer.NewRawNamer("", nil),
84		"defaultfn":       defaultFnNamer(),
85		"objectdefaultfn": objectDefaultFnNamer(),
86	}
87}
88
89// DefaultNameSystem returns the default name system for ordering the types to be
90// processed by the generators in this package.
91func DefaultNameSystem() string {
92	return "public"
93}
94
95// defaults holds the declared defaulting functions for a given type (all defaulting functions
96// are expected to be func(1))
97type defaults struct {
98	// object is the defaulter function for a top level type (typically one with TypeMeta) that
99	// invokes all child defaulters. May be nil if the object defaulter has not yet been generated.
100	object *types.Type
101	// base is a defaulter function defined for a type SetDefaults_Pod which does not invoke all
102	// child defaults - the base defaulter alone is insufficient to default a type
103	base *types.Type
104	// additional is zero or more defaulter functions of the form SetDefaults_Pod_XXXX that can be
105	// included in the Object defaulter.
106	additional []*types.Type
107}
108
109// All of the types in conversions map are of type "DeclarationOf" with
110// the underlying type being "Func".
111type defaulterFuncMap map[*types.Type]defaults
112
113// Returns all manually-defined defaulting functions in the package.
114func getManualDefaultingFunctions(context *generator.Context, pkg *types.Package, manualMap defaulterFuncMap) {
115	buffer := &bytes.Buffer{}
116	sw := generator.NewSnippetWriter(buffer, context, "$", "$")
117
118	for _, f := range pkg.Functions {
119		if f.Underlying == nil || f.Underlying.Kind != types.Func {
120			klog.Errorf("Malformed function: %#v", f)
121			continue
122		}
123		if f.Underlying.Signature == nil {
124			klog.Errorf("Function without signature: %#v", f)
125			continue
126		}
127		signature := f.Underlying.Signature
128		// Check whether the function is defaulting function.
129		// Note that all of them have signature:
130		// object: func SetObjectDefaults_inType(*inType)
131		// base: func SetDefaults_inType(*inType)
132		// additional: func SetDefaults_inType_Qualifier(*inType)
133		if signature.Receiver != nil {
134			continue
135		}
136		if len(signature.Parameters) != 1 {
137			continue
138		}
139		if len(signature.Results) != 0 {
140			continue
141		}
142		inType := signature.Parameters[0]
143		if inType.Kind != types.Pointer {
144			continue
145		}
146		// Check if this is the primary defaulter.
147		args := defaultingArgsFromType(inType.Elem)
148		sw.Do("$.inType|defaultfn$", args)
149		switch {
150		case f.Name.Name == buffer.String():
151			key := inType.Elem
152			// We might scan the same package twice, and that's OK.
153			v, ok := manualMap[key]
154			if ok && v.base != nil && v.base.Name.Package != pkg.Path {
155				panic(fmt.Sprintf("duplicate static defaulter defined: %#v", key))
156			}
157			v.base = f
158			manualMap[key] = v
159			klog.V(6).Infof("found base defaulter function for %s from %s", key.Name, f.Name)
160		// Is one of the additional defaulters - a top level defaulter on a type that is
161		// also invoked.
162		case strings.HasPrefix(f.Name.Name, buffer.String()+"_"):
163			key := inType.Elem
164			v, ok := manualMap[key]
165			if ok {
166				exists := false
167				for _, existing := range v.additional {
168					if existing.Name == f.Name {
169						exists = true
170						break
171					}
172				}
173				if exists {
174					continue
175				}
176			}
177			v.additional = append(v.additional, f)
178			manualMap[key] = v
179			klog.V(6).Infof("found additional defaulter function for %s from %s", key.Name, f.Name)
180		}
181		buffer.Reset()
182		sw.Do("$.inType|objectdefaultfn$", args)
183		if f.Name.Name == buffer.String() {
184			key := inType.Elem
185			// We might scan the same package twice, and that's OK.
186			v, ok := manualMap[key]
187			if ok && v.base != nil && v.base.Name.Package != pkg.Path {
188				panic(fmt.Sprintf("duplicate static defaulter defined: %#v", key))
189			}
190			v.object = f
191			manualMap[key] = v
192			klog.V(6).Infof("found object defaulter function for %s from %s", key.Name, f.Name)
193		}
194		buffer.Reset()
195	}
196}
197
198func Packages(context *generator.Context, arguments *args.GeneratorArgs) generator.Packages {
199	boilerplate, err := arguments.LoadGoBoilerplate()
200	if err != nil {
201		klog.Fatalf("Failed loading boilerplate: %v", err)
202	}
203
204	packages := generator.Packages{}
205	header := append([]byte(fmt.Sprintf("// +build !%s\n\n", arguments.GeneratedBuildTag)), boilerplate...)
206
207	// Accumulate pre-existing default functions.
208	// TODO: This is too ad-hoc.  We need a better way.
209	existingDefaulters := defaulterFuncMap{}
210
211	buffer := &bytes.Buffer{}
212	sw := generator.NewSnippetWriter(buffer, context, "$", "$")
213
214	// We are generating defaults only for packages that are explicitly
215	// passed as InputDir.
216	for _, i := range context.Inputs {
217		klog.V(5).Infof("considering pkg %q", i)
218		pkg := context.Universe[i]
219		if pkg == nil {
220			// If the input had no Go files, for example.
221			continue
222		}
223		// typesPkg is where the types that needs defaulter are defined.
224		// Sometimes it is different from pkg. For example, kubernetes core/v1
225		// types are defined in vendor/k8s.io/api/core/v1, while pkg is at
226		// pkg/api/v1.
227		typesPkg := pkg
228
229		// Add defaulting functions.
230		getManualDefaultingFunctions(context, pkg, existingDefaulters)
231
232		var peerPkgs []string
233		if customArgs, ok := arguments.CustomArgs.(*CustomArgs); ok {
234			for _, pkg := range customArgs.ExtraPeerDirs {
235				if i := strings.Index(pkg, "/vendor/"); i != -1 {
236					pkg = pkg[i+len("/vendor/"):]
237				}
238				peerPkgs = append(peerPkgs, pkg)
239			}
240		}
241		// Make sure our peer-packages are added and fully parsed.
242		for _, pp := range peerPkgs {
243			context.AddDir(pp)
244			getManualDefaultingFunctions(context, context.Universe[pp], existingDefaulters)
245		}
246
247		typesWith := extractTag(pkg.Comments)
248		shouldCreateObjectDefaulterFn := func(t *types.Type) bool {
249			if defaults, ok := existingDefaulters[t]; ok && defaults.object != nil {
250				// A default generator is defined
251				klog.V(5).Infof("  an object defaulter already exists as %s", defaults.base.Name)
252				return false
253			}
254			// opt-out
255			if checkTag(t.SecondClosestCommentLines, "false") {
256				return false
257			}
258			// opt-in
259			if checkTag(t.SecondClosestCommentLines, "true") {
260				return true
261			}
262			// For every k8s:defaulter-gen tag at the package level, interpret the value as a
263			// field name (like TypeMeta, ListMeta, ObjectMeta) and trigger defaulter generation
264			// for any type with any of the matching field names. Provides a more useful package
265			// level defaulting than global (because we only need defaulters on a subset of objects -
266			// usually those with TypeMeta).
267			if t.Kind == types.Struct && len(typesWith) > 0 {
268				for _, field := range t.Members {
269					for _, s := range typesWith {
270						if field.Name == s {
271							return true
272						}
273					}
274				}
275			}
276			return false
277		}
278
279		// if the types are not in the same package where the defaulter functions to be generated
280		inputTags := extractInputTag(pkg.Comments)
281		if len(inputTags) > 1 {
282			panic(fmt.Sprintf("there could only be one input tag, got %#v", inputTags))
283		}
284		if len(inputTags) == 1 {
285			var err error
286			typesPkg, err = context.AddDirectory(filepath.Join(pkg.Path, inputTags[0]))
287			if err != nil {
288				klog.Fatalf("cannot import package %s", inputTags[0])
289			}
290			// update context.Order to the latest context.Universe
291			orderer := namer.Orderer{Namer: namer.NewPublicNamer(1)}
292			context.Order = orderer.OrderUniverse(context.Universe)
293		}
294
295		newDefaulters := defaulterFuncMap{}
296		for _, t := range typesPkg.Types {
297			if !shouldCreateObjectDefaulterFn(t) {
298				continue
299			}
300			if namer.IsPrivateGoName(t.Name.Name) {
301				// We won't be able to convert to a private type.
302				klog.V(5).Infof("  found a type %v, but it is a private name", t)
303				continue
304			}
305
306			// create a synthetic type we can use during generation
307			newDefaulters[t] = defaults{}
308		}
309
310		// only generate defaulters for objects that actually have defined defaulters
311		// prevents empty defaulters from being registered
312		for {
313			promoted := 0
314			for t, d := range newDefaulters {
315				if d.object != nil {
316					continue
317				}
318				if newCallTreeForType(existingDefaulters, newDefaulters).build(t, true) != nil {
319					args := defaultingArgsFromType(t)
320					sw.Do("$.inType|objectdefaultfn$", args)
321					newDefaulters[t] = defaults{
322						object: &types.Type{
323							Name: types.Name{
324								Package: pkg.Path,
325								Name:    buffer.String(),
326							},
327							Kind: types.Func,
328						},
329					}
330					buffer.Reset()
331					promoted++
332				}
333			}
334			if promoted != 0 {
335				continue
336			}
337
338			// prune any types that were not used
339			for t, d := range newDefaulters {
340				if d.object == nil {
341					klog.V(6).Infof("did not generate defaulter for %s because no child defaulters were registered", t.Name)
342					delete(newDefaulters, t)
343				}
344			}
345			break
346		}
347
348		if len(newDefaulters) == 0 {
349			klog.V(5).Infof("no defaulters in package %s", pkg.Name)
350		}
351
352		path := pkg.Path
353		// if the source path is within a /vendor/ directory (for example,
354		// k8s.io/kubernetes/vendor/k8s.io/apimachinery/pkg/apis/meta/v1), allow
355		// generation to output to the proper relative path (under vendor).
356		// Otherwise, the generator will create the file in the wrong location
357		// in the output directory.
358		// TODO: build a more fundamental concept in gengo for dealing with modifications
359		// to vendored packages.
360		if strings.HasPrefix(pkg.SourcePath, arguments.OutputBase) {
361			expandedPath := strings.TrimPrefix(pkg.SourcePath, arguments.OutputBase)
362			if strings.Contains(expandedPath, "/vendor/") {
363				path = expandedPath
364			}
365		}
366
367		packages = append(packages,
368			&generator.DefaultPackage{
369				PackageName: filepath.Base(pkg.Path),
370				PackagePath: path,
371				HeaderText:  header,
372				GeneratorFunc: func(c *generator.Context) (generators []generator.Generator) {
373					return []generator.Generator{
374						NewGenDefaulter(arguments.OutputFileBaseName, typesPkg.Path, pkg.Path, existingDefaulters, newDefaulters, peerPkgs),
375					}
376				},
377				FilterFunc: func(c *generator.Context, t *types.Type) bool {
378					return t.Name.Package == typesPkg.Path
379				},
380			})
381	}
382	return packages
383}
384
385// callTreeForType contains fields necessary to build a tree for types.
386type callTreeForType struct {
387	existingDefaulters     defaulterFuncMap
388	newDefaulters          defaulterFuncMap
389	currentlyBuildingTypes map[*types.Type]bool
390}
391
392func newCallTreeForType(existingDefaulters, newDefaulters defaulterFuncMap) *callTreeForType {
393	return &callTreeForType{
394		existingDefaulters:     existingDefaulters,
395		newDefaulters:          newDefaulters,
396		currentlyBuildingTypes: make(map[*types.Type]bool),
397	}
398}
399
400// build creates a tree of paths to fields (based on how they would be accessed in Go - pointer, elem,
401// slice, or key) and the functions that should be invoked on each field. An in-order traversal of the resulting tree
402// can be used to generate a Go function that invokes each nested function on the appropriate type. The return
403// value may be nil if there are no functions to call on type or the type is a primitive (Defaulters can only be
404// invoked on structs today). When root is true this function will not use a newDefaulter. existingDefaulters should
405// contain all defaulting functions by type defined in code - newDefaulters should contain all object defaulters
406// that could be or will be generated. If newDefaulters has an entry for a type, but the 'object' field is nil,
407// this function skips adding that defaulter - this allows us to avoid generating object defaulter functions for
408// list types that call empty defaulters.
409func (c *callTreeForType) build(t *types.Type, root bool) *callNode {
410	parent := &callNode{}
411
412	if root {
413		// the root node is always a pointer
414		parent.elem = true
415	}
416
417	defaults, _ := c.existingDefaulters[t]
418	newDefaults, generated := c.newDefaulters[t]
419	switch {
420	case !root && generated && newDefaults.object != nil:
421		parent.call = append(parent.call, newDefaults.object)
422		// if we will be generating the defaulter, it by definition is a covering
423		// defaulter, so we halt recursion
424		klog.V(6).Infof("the defaulter %s will be generated as an object defaulter", t.Name)
425		return parent
426
427	case defaults.object != nil:
428		// object defaulters are always covering
429		parent.call = append(parent.call, defaults.object)
430		return parent
431
432	case defaults.base != nil:
433		parent.call = append(parent.call, defaults.base)
434		// if the base function indicates it "covers" (it already includes defaulters)
435		// we can halt recursion
436		if checkTag(defaults.base.CommentLines, "covers") {
437			klog.V(6).Infof("the defaulter %s indicates it covers all sub generators", t.Name)
438			return parent
439		}
440	}
441
442	// base has been added already, now add any additional defaulters defined for this object
443	parent.call = append(parent.call, defaults.additional...)
444
445	// if the type already exists, don't build the tree for it and don't generate anything.
446	// This is used to avoid recursion for nested recursive types.
447	if c.currentlyBuildingTypes[t] {
448		return nil
449	}
450	// if type doesn't exist, mark it as existing
451	c.currentlyBuildingTypes[t] = true
452
453	defer func() {
454		// The type will now acts as a parent, not a nested recursive type.
455		// We can now build the tree for it safely.
456		c.currentlyBuildingTypes[t] = false
457	}()
458
459	switch t.Kind {
460	case types.Pointer:
461		if child := c.build(t.Elem, false); child != nil {
462			child.elem = true
463			parent.children = append(parent.children, *child)
464		}
465	case types.Slice, types.Array:
466		if child := c.build(t.Elem, false); child != nil {
467			child.index = true
468			if t.Elem.Kind == types.Pointer {
469				child.elem = true
470			}
471			parent.children = append(parent.children, *child)
472		}
473	case types.Map:
474		if child := c.build(t.Elem, false); child != nil {
475			child.key = true
476			parent.children = append(parent.children, *child)
477		}
478	case types.Struct:
479		for _, field := range t.Members {
480			name := field.Name
481			if len(name) == 0 {
482				if field.Type.Kind == types.Pointer {
483					name = field.Type.Elem.Name.Name
484				} else {
485					name = field.Type.Name.Name
486				}
487			}
488			if child := c.build(field.Type, false); child != nil {
489				child.field = name
490				parent.children = append(parent.children, *child)
491			}
492		}
493	case types.Alias:
494		if child := c.build(t.Underlying, false); child != nil {
495			parent.children = append(parent.children, *child)
496		}
497	}
498	if len(parent.children) == 0 && len(parent.call) == 0 {
499		//klog.V(6).Infof("decided type %s needs no generation", t.Name)
500		return nil
501	}
502	return parent
503}
504
505const (
506	runtimePackagePath    = "k8s.io/apimachinery/pkg/runtime"
507	conversionPackagePath = "k8s.io/apimachinery/pkg/conversion"
508)
509
510// genDefaulter produces a file with a autogenerated conversions.
511type genDefaulter struct {
512	generator.DefaultGen
513	typesPackage       string
514	outputPackage      string
515	peerPackages       []string
516	newDefaulters      defaulterFuncMap
517	existingDefaulters defaulterFuncMap
518	imports            namer.ImportTracker
519	typesForInit       []*types.Type
520}
521
522func NewGenDefaulter(sanitizedName, typesPackage, outputPackage string, existingDefaulters, newDefaulters defaulterFuncMap, peerPkgs []string) generator.Generator {
523	return &genDefaulter{
524		DefaultGen: generator.DefaultGen{
525			OptionalName: sanitizedName,
526		},
527		typesPackage:       typesPackage,
528		outputPackage:      outputPackage,
529		peerPackages:       peerPkgs,
530		newDefaulters:      newDefaulters,
531		existingDefaulters: existingDefaulters,
532		imports:            generator.NewImportTracker(),
533		typesForInit:       make([]*types.Type, 0),
534	}
535}
536
537func (g *genDefaulter) Namers(c *generator.Context) namer.NameSystems {
538	// Have the raw namer for this file track what it imports.
539	return namer.NameSystems{
540		"raw": namer.NewRawNamer(g.outputPackage, g.imports),
541	}
542}
543
544func (g *genDefaulter) isOtherPackage(pkg string) bool {
545	if pkg == g.outputPackage {
546		return false
547	}
548	if strings.HasSuffix(pkg, `"`+g.outputPackage+`"`) {
549		return false
550	}
551	return true
552}
553
554func (g *genDefaulter) Filter(c *generator.Context, t *types.Type) bool {
555	defaults, ok := g.newDefaulters[t]
556	if !ok || defaults.object == nil {
557		return false
558	}
559	g.typesForInit = append(g.typesForInit, t)
560	return true
561}
562
563func (g *genDefaulter) Imports(c *generator.Context) (imports []string) {
564	var importLines []string
565	for _, singleImport := range g.imports.ImportLines() {
566		if g.isOtherPackage(singleImport) {
567			importLines = append(importLines, singleImport)
568		}
569	}
570	return importLines
571}
572
573func (g *genDefaulter) Init(c *generator.Context, w io.Writer) error {
574	sw := generator.NewSnippetWriter(w, c, "$", "$")
575
576	scheme := c.Universe.Type(types.Name{Package: runtimePackagePath, Name: "Scheme"})
577	schemePtr := &types.Type{
578		Kind: types.Pointer,
579		Elem: scheme,
580	}
581	sw.Do("// RegisterDefaults adds defaulters functions to the given scheme.\n", nil)
582	sw.Do("// Public to allow building arbitrary schemes.\n", nil)
583	sw.Do("// All generated defaulters are covering - they call all nested defaulters.\n", nil)
584	sw.Do("func RegisterDefaults(scheme $.|raw$) error {\n", schemePtr)
585	for _, t := range g.typesForInit {
586		args := defaultingArgsFromType(t)
587		sw.Do("scheme.AddTypeDefaultingFunc(&$.inType|raw${}, func(obj interface{}) { $.inType|objectdefaultfn$(obj.(*$.inType|raw$)) })\n", args)
588	}
589	sw.Do("return nil\n", nil)
590	sw.Do("}\n\n", nil)
591	return sw.Error()
592}
593
594func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Writer) error {
595	if _, ok := g.newDefaulters[t]; !ok {
596		return nil
597	}
598
599	klog.V(5).Infof("generating for type %v", t)
600
601	callTree := newCallTreeForType(g.existingDefaulters, g.newDefaulters).build(t, true)
602	if callTree == nil {
603		klog.V(5).Infof("  no defaulters defined")
604		return nil
605	}
606	i := 0
607	callTree.VisitInOrder(func(ancestors []*callNode, current *callNode) {
608		if len(current.call) == 0 {
609			return
610		}
611		path := callPath(append(ancestors, current))
612		klog.V(5).Infof("  %d: %s", i, path)
613		i++
614	})
615
616	sw := generator.NewSnippetWriter(w, c, "$", "$")
617	g.generateDefaulter(t, callTree, sw)
618	return sw.Error()
619}
620
621func defaultingArgsFromType(inType *types.Type) generator.Args {
622	return generator.Args{
623		"inType": inType,
624	}
625}
626
627func (g *genDefaulter) generateDefaulter(inType *types.Type, callTree *callNode, sw *generator.SnippetWriter) {
628	sw.Do("func $.inType|objectdefaultfn$(in *$.inType|raw$) {\n", defaultingArgsFromType(inType))
629	callTree.WriteMethod("in", 0, nil, sw)
630	sw.Do("}\n\n", nil)
631}
632
633// callNode represents an entry in a tree of Go type accessors - the path from the root to a leaf represents
634// how in Go code an access would be performed. For example, if a defaulting function exists on a container
635// lifecycle hook, to invoke that defaulter correctly would require this Go code:
636//
637//     for i := range pod.Spec.Containers {
638//       o := &pod.Spec.Containers[i]
639//       if o.LifecycleHook != nil {
640//         SetDefaults_LifecycleHook(o.LifecycleHook)
641//       }
642//     }
643//
644// That would be represented by a call tree like:
645//
646//   callNode
647//     field: "Spec"
648//     children:
649//     - field: "Containers"
650//       children:
651//       - index: true
652//         children:
653//         - field: "LifecycleHook"
654//           elem: true
655//           call:
656//           - SetDefaults_LifecycleHook
657//
658// which we can traverse to build that Go struct (you must call the field Spec, then Containers, then range over
659// that field, then check whether the LifecycleHook field is nil, before calling SetDefaults_LifecycleHook on
660// the pointer to that field).
661type callNode struct {
662	// field is the name of the Go member to access
663	field string
664	// key is true if this is a map and we must range over the key and values
665	key bool
666	// index is true if this is a slice and we must range over the slice values
667	index bool
668	// elem is true if the previous elements refer to a pointer (typically just field)
669	elem bool
670
671	// call is all of the functions that must be invoked on this particular node, in order
672	call []*types.Type
673	// children is the child call nodes that must also be traversed
674	children []callNode
675}
676
677// CallNodeVisitorFunc is a function for visiting a call tree. ancestors is the list of all parents
678// of this node to the root of the tree - will be empty at the root.
679type CallNodeVisitorFunc func(ancestors []*callNode, node *callNode)
680
681func (n *callNode) VisitInOrder(fn CallNodeVisitorFunc) {
682	n.visitInOrder(nil, fn)
683}
684
685func (n *callNode) visitInOrder(ancestors []*callNode, fn CallNodeVisitorFunc) {
686	fn(ancestors, n)
687	ancestors = append(ancestors, n)
688	for i := range n.children {
689		n.children[i].visitInOrder(ancestors, fn)
690	}
691}
692
693var (
694	indexVariables = "ijklmnop"
695	localVariables = "abcdefgh"
696)
697
698// varsForDepth creates temporary variables guaranteed to be unique within lexical Go scopes
699// of this depth in a function. It uses canonical Go loop variables for the first 7 levels
700// and then resorts to uglier prefixes.
701func varsForDepth(depth int) (index, local string) {
702	if depth > len(indexVariables) {
703		index = fmt.Sprintf("i%d", depth)
704	} else {
705		index = indexVariables[depth : depth+1]
706	}
707	if depth > len(localVariables) {
708		local = fmt.Sprintf("local%d", depth)
709	} else {
710		local = localVariables[depth : depth+1]
711	}
712	return
713}
714
715// writeCalls generates a list of function calls based on the calls field for the provided variable
716// name and pointer.
717func (n *callNode) writeCalls(varName string, isVarPointer bool, sw *generator.SnippetWriter) {
718	accessor := varName
719	if !isVarPointer {
720		accessor = "&" + accessor
721	}
722	for _, fn := range n.call {
723		sw.Do("$.fn|raw$($.var$)\n", generator.Args{
724			"fn":  fn,
725			"var": accessor,
726		})
727	}
728}
729
730// WriteMethod performs an in-order traversal of the calltree, generating loops and if blocks as necessary
731// to correctly turn the call tree into a method body that invokes all calls on all child nodes of the call tree.
732// Depth is used to generate local variables at the proper depth.
733func (n *callNode) WriteMethod(varName string, depth int, ancestors []*callNode, sw *generator.SnippetWriter) {
734	// if len(n.call) > 0 {
735	// 	sw.Do(fmt.Sprintf("// %s\n", callPath(append(ancestors, n)).String()), nil)
736	// }
737
738	if len(n.field) > 0 {
739		varName = varName + "." + n.field
740	}
741
742	index, local := varsForDepth(depth)
743	vars := generator.Args{
744		"index": index,
745		"local": local,
746		"var":   varName,
747	}
748
749	isPointer := n.elem && !n.index
750	if isPointer && len(ancestors) > 0 {
751		sw.Do("if $.var$ != nil {\n", vars)
752	}
753
754	switch {
755	case n.index:
756		sw.Do("for $.index$ := range $.var$ {\n", vars)
757		if n.elem {
758			sw.Do("$.local$ := $.var$[$.index$]\n", vars)
759		} else {
760			sw.Do("$.local$ := &$.var$[$.index$]\n", vars)
761		}
762
763		n.writeCalls(local, true, sw)
764		for i := range n.children {
765			n.children[i].WriteMethod(local, depth+1, append(ancestors, n), sw)
766		}
767		sw.Do("}\n", nil)
768	case n.key:
769	default:
770		n.writeCalls(varName, isPointer, sw)
771		for i := range n.children {
772			n.children[i].WriteMethod(varName, depth, append(ancestors, n), sw)
773		}
774	}
775
776	if isPointer && len(ancestors) > 0 {
777		sw.Do("}\n", nil)
778	}
779}
780
781type callPath []*callNode
782
783// String prints a representation of a callPath that roughly approximates what a Go accessor
784// would look like. Used for debugging only.
785func (path callPath) String() string {
786	if len(path) == 0 {
787		return "<none>"
788	}
789	var parts []string
790	for _, p := range path {
791		last := len(parts) - 1
792		switch {
793		case p.elem:
794			if len(parts) > 0 {
795				parts[last] = "*" + parts[last]
796			} else {
797				parts = append(parts, "*")
798			}
799		case p.index:
800			if len(parts) > 0 {
801				parts[last] = parts[last] + "[i]"
802			} else {
803				parts = append(parts, "[i]")
804			}
805		case p.key:
806			if len(parts) > 0 {
807				parts[last] = parts[last] + "[key]"
808			} else {
809				parts = append(parts, "[key]")
810			}
811		default:
812			if len(p.field) > 0 {
813				parts = append(parts, p.field)
814			} else {
815				parts = append(parts, "<root>")
816			}
817		}
818	}
819	var calls []string
820	for _, fn := range path[len(path)-1].call {
821		calls = append(calls, fn.Name.String())
822	}
823	if len(calls) == 0 {
824		calls = append(calls, "<none>")
825	}
826
827	return strings.Join(parts, ".") + " calls " + strings.Join(calls, ", ")
828}
829