1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package generators
18
19import (
20	"bytes"
21	"encoding/json"
22	"fmt"
23	"io"
24	"path/filepath"
25	"reflect"
26	"strconv"
27	"strings"
28
29	"k8s.io/gengo/args"
30	"k8s.io/gengo/generator"
31	"k8s.io/gengo/namer"
32	"k8s.io/gengo/types"
33
34	"k8s.io/klog/v2"
35)
36
37// CustomArgs is used tby the go2idl framework to pass args specific to this
38// generator.
39type CustomArgs struct {
40	ExtraPeerDirs []string // Always consider these as last-ditch possibilities for conversions.
41}
42
43var typeZeroValue = map[string]interface{}{
44	"uint":        0.,
45	"uint8":       0.,
46	"uint16":      0.,
47	"uint32":      0.,
48	"uint64":      0.,
49	"int":         0.,
50	"int8":        0.,
51	"int16":       0.,
52	"int32":       0.,
53	"int64":       0.,
54	"byte":        0.,
55	"float64":     0.,
56	"float32":     0.,
57	"bool":        false,
58	"time.Time":   "",
59	"string":      "",
60	"integer":     0.,
61	"number":      0.,
62	"boolean":     false,
63	"[]byte":      "", // base64 encoded characters
64	"interface{}": interface{}(nil),
65}
66
67// These are the comment tags that carry parameters for defaulter generation.
68const tagName = "k8s:defaulter-gen"
69const inputTagName = "k8s:defaulter-gen-input"
70const defaultTagName = "default"
71
72func extractDefaultTag(comments []string) []string {
73	return types.ExtractCommentTags("+", comments)[defaultTagName]
74}
75
76func extractTag(comments []string) []string {
77	return types.ExtractCommentTags("+", comments)[tagName]
78}
79
80func extractInputTag(comments []string) []string {
81	return types.ExtractCommentTags("+", comments)[inputTagName]
82}
83
84func checkTag(comments []string, require ...string) bool {
85	values := types.ExtractCommentTags("+", comments)[tagName]
86	if len(require) == 0 {
87		return len(values) == 1 && values[0] == ""
88	}
89	return reflect.DeepEqual(values, require)
90}
91
92func defaultFnNamer() *namer.NameStrategy {
93	return &namer.NameStrategy{
94		Prefix: "SetDefaults_",
95		Join: func(pre string, in []string, post string) string {
96			return pre + strings.Join(in, "_") + post
97		},
98	}
99}
100
101func objectDefaultFnNamer() *namer.NameStrategy {
102	return &namer.NameStrategy{
103		Prefix: "SetObjectDefaults_",
104		Join: func(pre string, in []string, post string) string {
105			return pre + strings.Join(in, "_") + post
106		},
107	}
108}
109
110// NameSystems returns the name system used by the generators in this package.
111func NameSystems() namer.NameSystems {
112	return namer.NameSystems{
113		"public":          namer.NewPublicNamer(1),
114		"raw":             namer.NewRawNamer("", nil),
115		"defaultfn":       defaultFnNamer(),
116		"objectdefaultfn": objectDefaultFnNamer(),
117	}
118}
119
120// DefaultNameSystem returns the default name system for ordering the types to be
121// processed by the generators in this package.
122func DefaultNameSystem() string {
123	return "public"
124}
125
126// defaults holds the declared defaulting functions for a given type (all defaulting functions
127// are expected to be func(1))
128type defaults struct {
129	// object is the defaulter function for a top level type (typically one with TypeMeta) that
130	// invokes all child defaulters. May be nil if the object defaulter has not yet been generated.
131	object *types.Type
132	// base is a defaulter function defined for a type SetDefaults_Pod which does not invoke all
133	// child defaults - the base defaulter alone is insufficient to default a type
134	base *types.Type
135	// additional is zero or more defaulter functions of the form SetDefaults_Pod_XXXX that can be
136	// included in the Object defaulter.
137	additional []*types.Type
138}
139
140// All of the types in conversions map are of type "DeclarationOf" with
141// the underlying type being "Func".
142type defaulterFuncMap map[*types.Type]defaults
143
144// Returns all manually-defined defaulting functions in the package.
145func getManualDefaultingFunctions(context *generator.Context, pkg *types.Package, manualMap defaulterFuncMap) {
146	buffer := &bytes.Buffer{}
147	sw := generator.NewSnippetWriter(buffer, context, "$", "$")
148
149	for _, f := range pkg.Functions {
150		if f.Underlying == nil || f.Underlying.Kind != types.Func {
151			klog.Errorf("Malformed function: %#v", f)
152			continue
153		}
154		if f.Underlying.Signature == nil {
155			klog.Errorf("Function without signature: %#v", f)
156			continue
157		}
158		signature := f.Underlying.Signature
159		// Check whether the function is defaulting function.
160		// Note that all of them have signature:
161		// object: func SetObjectDefaults_inType(*inType)
162		// base: func SetDefaults_inType(*inType)
163		// additional: func SetDefaults_inType_Qualifier(*inType)
164		if signature.Receiver != nil {
165			continue
166		}
167		if len(signature.Parameters) != 1 {
168			continue
169		}
170		if len(signature.Results) != 0 {
171			continue
172		}
173		inType := signature.Parameters[0]
174		if inType.Kind != types.Pointer {
175			continue
176		}
177		// Check if this is the primary defaulter.
178		args := defaultingArgsFromType(inType.Elem)
179		sw.Do("$.inType|defaultfn$", args)
180		switch {
181		case f.Name.Name == buffer.String():
182			key := inType.Elem
183			// We might scan the same package twice, and that's OK.
184			v, ok := manualMap[key]
185			if ok && v.base != nil && v.base.Name.Package != pkg.Path {
186				panic(fmt.Sprintf("duplicate static defaulter defined: %#v", key))
187			}
188			v.base = f
189			manualMap[key] = v
190			klog.V(6).Infof("found base defaulter function for %s from %s", key.Name, f.Name)
191		// Is one of the additional defaulters - a top level defaulter on a type that is
192		// also invoked.
193		case strings.HasPrefix(f.Name.Name, buffer.String()+"_"):
194			key := inType.Elem
195			v, ok := manualMap[key]
196			if ok {
197				exists := false
198				for _, existing := range v.additional {
199					if existing.Name == f.Name {
200						exists = true
201						break
202					}
203				}
204				if exists {
205					continue
206				}
207			}
208			v.additional = append(v.additional, f)
209			manualMap[key] = v
210			klog.V(6).Infof("found additional defaulter function for %s from %s", key.Name, f.Name)
211		}
212		buffer.Reset()
213		sw.Do("$.inType|objectdefaultfn$", args)
214		if f.Name.Name == buffer.String() {
215			key := inType.Elem
216			// We might scan the same package twice, and that's OK.
217			v, ok := manualMap[key]
218			if ok && v.base != nil && v.base.Name.Package != pkg.Path {
219				panic(fmt.Sprintf("duplicate static defaulter defined: %#v", key))
220			}
221			v.object = f
222			manualMap[key] = v
223			klog.V(6).Infof("found object defaulter function for %s from %s", key.Name, f.Name)
224		}
225		buffer.Reset()
226	}
227}
228
229func Packages(context *generator.Context, arguments *args.GeneratorArgs) generator.Packages {
230	boilerplate, err := arguments.LoadGoBoilerplate()
231	if err != nil {
232		klog.Fatalf("Failed loading boilerplate: %v", err)
233	}
234
235	packages := generator.Packages{}
236	header := append([]byte(fmt.Sprintf("// +build !%s\n\n", arguments.GeneratedBuildTag)), boilerplate...)
237
238	// Accumulate pre-existing default functions.
239	// TODO: This is too ad-hoc.  We need a better way.
240	existingDefaulters := defaulterFuncMap{}
241
242	buffer := &bytes.Buffer{}
243	sw := generator.NewSnippetWriter(buffer, context, "$", "$")
244
245	// We are generating defaults only for packages that are explicitly
246	// passed as InputDir.
247	for _, i := range context.Inputs {
248		klog.V(5).Infof("considering pkg %q", i)
249		pkg := context.Universe[i]
250		if pkg == nil {
251			// If the input had no Go files, for example.
252			continue
253		}
254		// typesPkg is where the types that needs defaulter are defined.
255		// Sometimes it is different from pkg. For example, kubernetes core/v1
256		// types are defined in vendor/k8s.io/api/core/v1, while pkg is at
257		// pkg/api/v1.
258		typesPkg := pkg
259
260		// Add defaulting functions.
261		getManualDefaultingFunctions(context, pkg, existingDefaulters)
262
263		var peerPkgs []string
264		if customArgs, ok := arguments.CustomArgs.(*CustomArgs); ok {
265			for _, pkg := range customArgs.ExtraPeerDirs {
266				if i := strings.Index(pkg, "/vendor/"); i != -1 {
267					pkg = pkg[i+len("/vendor/"):]
268				}
269				peerPkgs = append(peerPkgs, pkg)
270			}
271		}
272		// Make sure our peer-packages are added and fully parsed.
273		for _, pp := range peerPkgs {
274			context.AddDir(pp)
275			getManualDefaultingFunctions(context, context.Universe[pp], existingDefaulters)
276		}
277
278		typesWith := extractTag(pkg.Comments)
279		shouldCreateObjectDefaulterFn := func(t *types.Type) bool {
280			if defaults, ok := existingDefaulters[t]; ok && defaults.object != nil {
281				// A default generator is defined
282				baseTypeName := "<unknown>"
283				if defaults.base != nil {
284					baseTypeName = defaults.base.Name.String()
285				}
286				klog.V(5).Infof("  an object defaulter already exists as %s", baseTypeName)
287				return false
288			}
289			// opt-out
290			if checkTag(t.SecondClosestCommentLines, "false") {
291				return false
292			}
293			// opt-in
294			if checkTag(t.SecondClosestCommentLines, "true") {
295				return true
296			}
297			// For every k8s:defaulter-gen tag at the package level, interpret the value as a
298			// field name (like TypeMeta, ListMeta, ObjectMeta) and trigger defaulter generation
299			// for any type with any of the matching field names. Provides a more useful package
300			// level defaulting than global (because we only need defaulters on a subset of objects -
301			// usually those with TypeMeta).
302			if t.Kind == types.Struct && len(typesWith) > 0 {
303				for _, field := range t.Members {
304					for _, s := range typesWith {
305						if field.Name == s {
306							return true
307						}
308					}
309				}
310			}
311			return false
312		}
313
314		// if the types are not in the same package where the defaulter functions to be generated
315		inputTags := extractInputTag(pkg.Comments)
316		if len(inputTags) > 1 {
317			panic(fmt.Sprintf("there could only be one input tag, got %#v", inputTags))
318		}
319		if len(inputTags) == 1 {
320			var err error
321			typesPkg, err = context.AddDirectory(filepath.Join(pkg.Path, inputTags[0]))
322			if err != nil {
323				klog.Fatalf("cannot import package %s", inputTags[0])
324			}
325			// update context.Order to the latest context.Universe
326			orderer := namer.Orderer{Namer: namer.NewPublicNamer(1)}
327			context.Order = orderer.OrderUniverse(context.Universe)
328		}
329
330		newDefaulters := defaulterFuncMap{}
331		for _, t := range typesPkg.Types {
332			if !shouldCreateObjectDefaulterFn(t) {
333				continue
334			}
335			if namer.IsPrivateGoName(t.Name.Name) {
336				// We won't be able to convert to a private type.
337				klog.V(5).Infof("  found a type %v, but it is a private name", t)
338				continue
339			}
340
341			// create a synthetic type we can use during generation
342			newDefaulters[t] = defaults{}
343		}
344
345		// only generate defaulters for objects that actually have defined defaulters
346		// prevents empty defaulters from being registered
347		for {
348			promoted := 0
349			for t, d := range newDefaulters {
350				if d.object != nil {
351					continue
352				}
353				if newCallTreeForType(existingDefaulters, newDefaulters).build(t, true) != nil {
354					args := defaultingArgsFromType(t)
355					sw.Do("$.inType|objectdefaultfn$", args)
356					newDefaulters[t] = defaults{
357						object: &types.Type{
358							Name: types.Name{
359								Package: pkg.Path,
360								Name:    buffer.String(),
361							},
362							Kind: types.Func,
363						},
364					}
365					buffer.Reset()
366					promoted++
367				}
368			}
369			if promoted != 0 {
370				continue
371			}
372
373			// prune any types that were not used
374			for t, d := range newDefaulters {
375				if d.object == nil {
376					klog.V(6).Infof("did not generate defaulter for %s because no child defaulters were registered", t.Name)
377					delete(newDefaulters, t)
378				}
379			}
380			break
381		}
382
383		if len(newDefaulters) == 0 {
384			klog.V(5).Infof("no defaulters in package %s", pkg.Name)
385		}
386
387		path := pkg.Path
388		// if the source path is within a /vendor/ directory (for example,
389		// k8s.io/kubernetes/vendor/k8s.io/apimachinery/pkg/apis/meta/v1), allow
390		// generation to output to the proper relative path (under vendor).
391		// Otherwise, the generator will create the file in the wrong location
392		// in the output directory.
393		// TODO: build a more fundamental concept in gengo for dealing with modifications
394		// to vendored packages.
395		if strings.HasPrefix(pkg.SourcePath, arguments.OutputBase) {
396			expandedPath := strings.TrimPrefix(pkg.SourcePath, arguments.OutputBase)
397			if strings.Contains(expandedPath, "/vendor/") {
398				path = expandedPath
399			}
400		}
401
402		packages = append(packages,
403			&generator.DefaultPackage{
404				PackageName: filepath.Base(pkg.Path),
405				PackagePath: path,
406				HeaderText:  header,
407				GeneratorFunc: func(c *generator.Context) (generators []generator.Generator) {
408					return []generator.Generator{
409						NewGenDefaulter(arguments.OutputFileBaseName, typesPkg.Path, pkg.Path, existingDefaulters, newDefaulters, peerPkgs),
410					}
411				},
412				FilterFunc: func(c *generator.Context, t *types.Type) bool {
413					return t.Name.Package == typesPkg.Path
414				},
415			})
416	}
417	return packages
418}
419
420// callTreeForType contains fields necessary to build a tree for types.
421type callTreeForType struct {
422	existingDefaulters     defaulterFuncMap
423	newDefaulters          defaulterFuncMap
424	currentlyBuildingTypes map[*types.Type]bool
425}
426
427func newCallTreeForType(existingDefaulters, newDefaulters defaulterFuncMap) *callTreeForType {
428	return &callTreeForType{
429		existingDefaulters:     existingDefaulters,
430		newDefaulters:          newDefaulters,
431		currentlyBuildingTypes: make(map[*types.Type]bool),
432	}
433}
434
435func resolveTypeAndDepth(t *types.Type) (*types.Type, int) {
436	var prev *types.Type
437	depth := 0
438	for prev != t {
439		prev = t
440		if t.Kind == types.Alias {
441			t = t.Underlying
442		} else if t.Kind == types.Pointer {
443			t = t.Elem
444			depth += 1
445		}
446	}
447	return t, depth
448}
449
450// getNestedDefault returns the first default value when resolving alias types
451func getNestedDefault(t *types.Type) string {
452	var prev *types.Type
453	for prev != t {
454		prev = t
455		defaultMap := extractDefaultTag(t.CommentLines)
456		if len(defaultMap) == 1 && defaultMap[0] != "" {
457			return defaultMap[0]
458		}
459		if t.Kind == types.Alias {
460			t = t.Underlying
461		} else if t.Kind == types.Pointer {
462			t = t.Elem
463		}
464	}
465	return ""
466}
467
468func mustEnforceDefault(t *types.Type, depth int, omitEmpty bool) (interface{}, error) {
469	if depth > 0 {
470		return nil, nil
471	}
472	switch t.Kind {
473	case types.Pointer, types.Map, types.Slice, types.Array, types.Interface:
474		return nil, nil
475	case types.Struct:
476		return map[string]interface{}{}, nil
477	case types.Builtin:
478		if !omitEmpty {
479			if zero, ok := typeZeroValue[t.String()]; ok {
480				return zero, nil
481			} else {
482				return nil, fmt.Errorf("please add type %v to typeZeroValue struct", t)
483			}
484		}
485		return nil, nil
486	default:
487		return nil, fmt.Errorf("not sure how to enforce default for %v", t.Kind)
488	}
489}
490
491func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLines []string) *callNode {
492	defaultMap := extractDefaultTag(commentLines)
493	var defaultString string
494	if len(defaultMap) == 1 {
495		defaultString = defaultMap[0]
496	}
497
498	t, depth := resolveTypeAndDepth(t)
499	if depth > 0 && defaultString == "" {
500		defaultString = getNestedDefault(t)
501	}
502	if len(defaultMap) > 1 {
503		klog.Fatalf("Found more than one default tag for %v", t.Kind)
504	} else if len(defaultMap) == 0 {
505		return node
506	}
507	var defaultValue interface{}
508	if err := json.Unmarshal([]byte(defaultString), &defaultValue); err != nil {
509		klog.Fatalf("Failed to unmarshal default: %v", err)
510	}
511
512	omitEmpty := strings.Contains(reflect.StructTag(tags).Get("json"), "omitempty")
513	if enforced, err := mustEnforceDefault(t, depth, omitEmpty); err != nil {
514		klog.Fatal(err)
515	} else if enforced != nil {
516		if defaultValue != nil {
517			if reflect.DeepEqual(defaultValue, enforced) {
518				// If the default value annotation matches the default value for the type,
519				// do not generate any defaulting function
520				return node
521			} else {
522				enforcedJSON, _ := json.Marshal(enforced)
523				klog.Fatalf("Invalid default value (%#v) for non-pointer/non-omitempty. If specified, must be: %v", defaultValue, string(enforcedJSON))
524			}
525		}
526	}
527
528	// callNodes are not automatically generated for primitive types. Generate one if the callNode does not exist
529	if node == nil {
530		node = &callNode{}
531		node.markerOnly = true
532	}
533
534	node.defaultIsPrimitive = t.IsPrimitive()
535	node.defaultType = t.String()
536	node.defaultValue = defaultString
537	node.defaultDepth = depth
538	return node
539}
540
541// build creates a tree of paths to fields (based on how they would be accessed in Go - pointer, elem,
542// slice, or key) and the functions that should be invoked on each field. An in-order traversal of the resulting tree
543// can be used to generate a Go function that invokes each nested function on the appropriate type. The return
544// value may be nil if there are no functions to call on type or the type is a primitive (Defaulters can only be
545// invoked on structs today). When root is true this function will not use a newDefaulter. existingDefaulters should
546// contain all defaulting functions by type defined in code - newDefaulters should contain all object defaulters
547// that could be or will be generated. If newDefaulters has an entry for a type, but the 'object' field is nil,
548// this function skips adding that defaulter - this allows us to avoid generating object defaulter functions for
549// list types that call empty defaulters.
550func (c *callTreeForType) build(t *types.Type, root bool) *callNode {
551	parent := &callNode{}
552
553	if root {
554		// the root node is always a pointer
555		parent.elem = true
556	}
557
558	defaults, _ := c.existingDefaulters[t]
559	newDefaults, generated := c.newDefaulters[t]
560	switch {
561	case !root && generated && newDefaults.object != nil:
562		parent.call = append(parent.call, newDefaults.object)
563		// if we will be generating the defaulter, it by definition is a covering
564		// defaulter, so we halt recursion
565		klog.V(6).Infof("the defaulter %s will be generated as an object defaulter", t.Name)
566		return parent
567
568	case defaults.object != nil:
569		// object defaulters are always covering
570		parent.call = append(parent.call, defaults.object)
571		return parent
572
573	case defaults.base != nil:
574		parent.call = append(parent.call, defaults.base)
575		// if the base function indicates it "covers" (it already includes defaulters)
576		// we can halt recursion
577		if checkTag(defaults.base.CommentLines, "covers") {
578			klog.V(6).Infof("the defaulter %s indicates it covers all sub generators", t.Name)
579			return parent
580		}
581	}
582
583	// base has been added already, now add any additional defaulters defined for this object
584	parent.call = append(parent.call, defaults.additional...)
585
586	// if the type already exists, don't build the tree for it and don't generate anything.
587	// This is used to avoid recursion for nested recursive types.
588	if c.currentlyBuildingTypes[t] {
589		return nil
590	}
591	// if type doesn't exist, mark it as existing
592	c.currentlyBuildingTypes[t] = true
593
594	defer func() {
595		// The type will now acts as a parent, not a nested recursive type.
596		// We can now build the tree for it safely.
597		c.currentlyBuildingTypes[t] = false
598	}()
599
600	switch t.Kind {
601	case types.Pointer:
602		if child := c.build(t.Elem, false); child != nil {
603			child.elem = true
604			parent.children = append(parent.children, *child)
605		}
606	case types.Slice, types.Array:
607		if child := c.build(t.Elem, false); child != nil {
608			child.index = true
609			if t.Elem.Kind == types.Pointer {
610				child.elem = true
611			}
612			parent.children = append(parent.children, *child)
613		} else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines); member != nil {
614			member.index = true
615			parent.children = append(parent.children, *member)
616		}
617	case types.Map:
618		if child := c.build(t.Elem, false); child != nil {
619			child.key = true
620			parent.children = append(parent.children, *child)
621		} else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines); member != nil {
622			member.key = true
623			parent.children = append(parent.children, *member)
624		}
625
626	case types.Struct:
627		for _, field := range t.Members {
628			name := field.Name
629			if len(name) == 0 {
630				if field.Type.Kind == types.Pointer {
631					name = field.Type.Elem.Name.Name
632				} else {
633					name = field.Type.Name.Name
634				}
635			}
636			if child := c.build(field.Type, false); child != nil {
637				child.field = name
638				populateDefaultValue(child, field.Type, field.Tags, field.CommentLines)
639				parent.children = append(parent.children, *child)
640			} else if member := populateDefaultValue(nil, field.Type, field.Tags, field.CommentLines); member != nil {
641				member.field = name
642				parent.children = append(parent.children, *member)
643			}
644		}
645	case types.Alias:
646		if child := c.build(t.Underlying, false); child != nil {
647			parent.children = append(parent.children, *child)
648		}
649	}
650	if len(parent.children) == 0 && len(parent.call) == 0 {
651		//klog.V(6).Infof("decided type %s needs no generation", t.Name)
652		return nil
653	}
654	return parent
655}
656
657const (
658	runtimePackagePath    = "k8s.io/apimachinery/pkg/runtime"
659	conversionPackagePath = "k8s.io/apimachinery/pkg/conversion"
660)
661
662// genDefaulter produces a file with a autogenerated conversions.
663type genDefaulter struct {
664	generator.DefaultGen
665	typesPackage       string
666	outputPackage      string
667	peerPackages       []string
668	newDefaulters      defaulterFuncMap
669	existingDefaulters defaulterFuncMap
670	imports            namer.ImportTracker
671	typesForInit       []*types.Type
672}
673
674func NewGenDefaulter(sanitizedName, typesPackage, outputPackage string, existingDefaulters, newDefaulters defaulterFuncMap, peerPkgs []string) generator.Generator {
675	return &genDefaulter{
676		DefaultGen: generator.DefaultGen{
677			OptionalName: sanitizedName,
678		},
679		typesPackage:       typesPackage,
680		outputPackage:      outputPackage,
681		peerPackages:       peerPkgs,
682		newDefaulters:      newDefaulters,
683		existingDefaulters: existingDefaulters,
684		imports:            generator.NewImportTracker(),
685		typesForInit:       make([]*types.Type, 0),
686	}
687}
688
689func (g *genDefaulter) Namers(c *generator.Context) namer.NameSystems {
690	// Have the raw namer for this file track what it imports.
691	return namer.NameSystems{
692		"raw": namer.NewRawNamer(g.outputPackage, g.imports),
693	}
694}
695
696func (g *genDefaulter) isOtherPackage(pkg string) bool {
697	if pkg == g.outputPackage {
698		return false
699	}
700	if strings.HasSuffix(pkg, `"`+g.outputPackage+`"`) {
701		return false
702	}
703	return true
704}
705
706func (g *genDefaulter) Filter(c *generator.Context, t *types.Type) bool {
707	defaults, ok := g.newDefaulters[t]
708	if !ok || defaults.object == nil {
709		return false
710	}
711	g.typesForInit = append(g.typesForInit, t)
712	return true
713}
714
715func (g *genDefaulter) Imports(c *generator.Context) (imports []string) {
716	var importLines []string
717	for _, singleImport := range g.imports.ImportLines() {
718		if g.isOtherPackage(singleImport) {
719			importLines = append(importLines, singleImport)
720		}
721	}
722	return importLines
723}
724
725func (g *genDefaulter) Init(c *generator.Context, w io.Writer) error {
726	sw := generator.NewSnippetWriter(w, c, "$", "$")
727
728	scheme := c.Universe.Type(types.Name{Package: runtimePackagePath, Name: "Scheme"})
729	schemePtr := &types.Type{
730		Kind: types.Pointer,
731		Elem: scheme,
732	}
733	sw.Do("// RegisterDefaults adds defaulters functions to the given scheme.\n", nil)
734	sw.Do("// Public to allow building arbitrary schemes.\n", nil)
735	sw.Do("// All generated defaulters are covering - they call all nested defaulters.\n", nil)
736	sw.Do("func RegisterDefaults(scheme $.|raw$) error {\n", schemePtr)
737	for _, t := range g.typesForInit {
738		args := defaultingArgsFromType(t)
739		sw.Do("scheme.AddTypeDefaultingFunc(&$.inType|raw${}, func(obj interface{}) { $.inType|objectdefaultfn$(obj.(*$.inType|raw$)) })\n", args)
740	}
741	sw.Do("return nil\n", nil)
742	sw.Do("}\n\n", nil)
743	return sw.Error()
744}
745
746func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Writer) error {
747	if _, ok := g.newDefaulters[t]; !ok {
748		return nil
749	}
750
751	klog.V(5).Infof("generating for type %v", t)
752
753	callTree := newCallTreeForType(g.existingDefaulters, g.newDefaulters).build(t, true)
754	if callTree == nil {
755		klog.V(5).Infof("  no defaulters defined")
756		return nil
757	}
758	i := 0
759	callTree.VisitInOrder(func(ancestors []*callNode, current *callNode) {
760		if len(current.call) == 0 {
761			return
762		}
763		path := callPath(append(ancestors, current))
764		klog.V(5).Infof("  %d: %s", i, path)
765		i++
766	})
767
768	sw := generator.NewSnippetWriter(w, c, "$", "$")
769	g.generateDefaulter(t, callTree, sw)
770	return sw.Error()
771}
772
773func defaultingArgsFromType(inType *types.Type) generator.Args {
774	return generator.Args{
775		"inType": inType,
776	}
777}
778
779func (g *genDefaulter) generateDefaulter(inType *types.Type, callTree *callNode, sw *generator.SnippetWriter) {
780	sw.Do("func $.inType|objectdefaultfn$(in *$.inType|raw$) {\n", defaultingArgsFromType(inType))
781	callTree.WriteMethod("in", 0, nil, sw)
782	sw.Do("}\n\n", nil)
783}
784
785// callNode represents an entry in a tree of Go type accessors - the path from the root to a leaf represents
786// how in Go code an access would be performed. For example, if a defaulting function exists on a container
787// lifecycle hook, to invoke that defaulter correctly would require this Go code:
788//
789//     for i := range pod.Spec.Containers {
790//       o := &pod.Spec.Containers[i]
791//       if o.LifecycleHook != nil {
792//         SetDefaults_LifecycleHook(o.LifecycleHook)
793//       }
794//     }
795//
796// That would be represented by a call tree like:
797//
798//   callNode
799//     field: "Spec"
800//     children:
801//     - field: "Containers"
802//       children:
803//       - index: true
804//         children:
805//         - field: "LifecycleHook"
806//           elem: true
807//           call:
808//           - SetDefaults_LifecycleHook
809//
810// which we can traverse to build that Go struct (you must call the field Spec, then Containers, then range over
811// that field, then check whether the LifecycleHook field is nil, before calling SetDefaults_LifecycleHook on
812// the pointer to that field).
813type callNode struct {
814	// field is the name of the Go member to access
815	field string
816	// key is true if this is a map and we must range over the key and values
817	key bool
818	// index is true if this is a slice and we must range over the slice values
819	index bool
820	// elem is true if the previous elements refer to a pointer (typically just field)
821	elem bool
822
823	// call is all of the functions that must be invoked on this particular node, in order
824	call []*types.Type
825	// children is the child call nodes that must also be traversed
826	children []callNode
827
828	// defaultValue is the defaultValue of a callNode struct
829	// Only primitive types and pointer types are eligible to have a default value
830	defaultValue string
831
832	// defaultIsPrimitive is used to determine how to assign the default value.
833	// Primitive types will be directly assigned while complex types will use JSON unmarshalling
834	defaultIsPrimitive bool
835
836	// markerOnly is true if the callNode exists solely to fill in a default value
837	markerOnly bool
838
839	// defaultDepth is used to determine pointer level of the default value
840	// For example 1 corresponds to setting a default value and taking its pointer while
841	// 2 corresponds to setting a default value and taking its pointer's pointer
842	// 0 implies that no pointers are used
843	// This is used in situations where a field is a pointer to a primitive value rather than a primitive value itself.
844	//
845	//     type A {
846	//       +default="foo"
847	//       Field *string
848	//     }
849	defaultDepth int
850
851	// defaultType is the type of the default value.
852	// Only populated if defaultIsPrimitive is true
853	defaultType string
854}
855
856// CallNodeVisitorFunc is a function for visiting a call tree. ancestors is the list of all parents
857// of this node to the root of the tree - will be empty at the root.
858type CallNodeVisitorFunc func(ancestors []*callNode, node *callNode)
859
860func (n *callNode) VisitInOrder(fn CallNodeVisitorFunc) {
861	n.visitInOrder(nil, fn)
862}
863
864func (n *callNode) visitInOrder(ancestors []*callNode, fn CallNodeVisitorFunc) {
865	fn(ancestors, n)
866	ancestors = append(ancestors, n)
867	for i := range n.children {
868		n.children[i].visitInOrder(ancestors, fn)
869	}
870}
871
872var (
873	indexVariables = "ijklmnop"
874	localVariables = "abcdefgh"
875)
876
877// varsForDepth creates temporary variables guaranteed to be unique within lexical Go scopes
878// of this depth in a function. It uses canonical Go loop variables for the first 7 levels
879// and then resorts to uglier prefixes.
880func varsForDepth(depth int) (index, local string) {
881	if depth > len(indexVariables) {
882		index = fmt.Sprintf("i%d", depth)
883	} else {
884		index = indexVariables[depth : depth+1]
885	}
886	if depth > len(localVariables) {
887		local = fmt.Sprintf("local%d", depth)
888	} else {
889		local = localVariables[depth : depth+1]
890	}
891	return
892}
893
894// writeCalls generates a list of function calls based on the calls field for the provided variable
895// name and pointer.
896func (n *callNode) writeCalls(varName string, isVarPointer bool, sw *generator.SnippetWriter) {
897	accessor := varName
898	if !isVarPointer {
899		accessor = "&" + accessor
900	}
901	for _, fn := range n.call {
902		sw.Do("$.fn|raw$($.var$)\n", generator.Args{
903			"fn":  fn,
904			"var": accessor,
905		})
906	}
907}
908
909func getTypeZeroValue(t string) (interface{}, error) {
910	defaultZero, ok := typeZeroValue[t]
911	if !ok {
912		return nil, fmt.Errorf("Cannot find zero value for type %v in typeZeroValue", t)
913	}
914
915	// To generate the code for empty string, they must be quoted
916	if defaultZero == "" {
917		defaultZero = strconv.Quote(defaultZero.(string))
918	}
919	return defaultZero, nil
920}
921
922func (n *callNode) writeDefaulter(varName string, index string, isVarPointer bool, sw *generator.SnippetWriter) {
923	if n.defaultValue == "" {
924		return
925	}
926	args := generator.Args{
927		"defaultValue": n.defaultValue,
928		"varName":      varName,
929		"index":        index,
930		"varDepth":     n.defaultDepth,
931		"varType":      n.defaultType,
932	}
933
934	variablePlaceholder := ""
935
936	if n.index {
937		// Defaulting for array
938		variablePlaceholder = "$.varName$[$.index$]"
939	} else if n.key {
940		// Defaulting for map
941		variablePlaceholder = "$.varName$[$.index$]"
942		mapDefaultVar := args["index"].(string) + "_default"
943		args["mapDefaultVar"] = mapDefaultVar
944	} else {
945		// Defaulting for primitive type
946		variablePlaceholder = "$.varName$"
947	}
948
949	// defaultIsPrimitive is true if the type or underlying type (in an array/map) is primitive
950	// or is a pointer to a primitive type
951	// (Eg: int, map[string]*string, []int)
952	if n.defaultIsPrimitive {
953		// If the default value is a primitive when the assigned type is a pointer
954		// keep using the address-of operator on the primitive value until the types match
955		if n.defaultDepth > 0 {
956			sw.Do(fmt.Sprintf("if %s == nil {\n", variablePlaceholder), args)
957			sw.Do("var ptrVar$.varDepth$ $.varType$ = $.defaultValue$\n", args)
958			// We iterate until a depth of 1 instead of 0 because the following line
959			// `if $.varName$ == &ptrVar1` accounts for 1 level already
960			for i := n.defaultDepth; i > 1; i-- {
961				sw.Do("ptrVar$.ptri$ := &ptrVar$.i$\n", generator.Args{"i": fmt.Sprintf("%d", i), "ptri": fmt.Sprintf("%d", (i - 1))})
962			}
963			sw.Do(fmt.Sprintf("%s = &ptrVar1", variablePlaceholder), args)
964		} else {
965			// For primitive types, nil checks cannot be used and the zero value must be determined
966			defaultZero, err := getTypeZeroValue(n.defaultType)
967			if err != nil {
968				klog.Error(err)
969			}
970			args["defaultZero"] = defaultZero
971
972			sw.Do(fmt.Sprintf("if %s == $.defaultZero$ {\n", variablePlaceholder), args)
973			sw.Do(fmt.Sprintf("%s = $.defaultValue$", variablePlaceholder), args)
974		}
975	} else {
976		sw.Do(fmt.Sprintf("if %s == nil {\n", variablePlaceholder), args)
977		// Map values are not directly addressable and we need a temporary variable to do json unmarshalling
978		// This applies to maps with non-primitive values (eg: map[string]SubStruct)
979		if n.key {
980			sw.Do("$.mapDefaultVar$ := $.varName$[$.index$]\n", args)
981			sw.Do("if err := json.Unmarshal([]byte(`$.defaultValue$`), &$.mapDefaultVar$); err != nil {\n", args)
982		} else {
983			variablePointer := variablePlaceholder
984			if !isVarPointer {
985				variablePointer = "&" + variablePointer
986			}
987			sw.Do(fmt.Sprintf("if err := json.Unmarshal([]byte(`$.defaultValue$`), %s); err != nil {\n", variablePointer), args)
988		}
989		sw.Do("panic(err)\n", nil)
990		sw.Do("}\n", nil)
991		if n.key {
992			sw.Do("$.varName$[$.index$] = $.mapDefaultVar$\n", args)
993		}
994	}
995	sw.Do("}\n", nil)
996}
997
998// WriteMethod performs an in-order traversal of the calltree, generating loops and if blocks as necessary
999// to correctly turn the call tree into a method body that invokes all calls on all child nodes of the call tree.
1000// Depth is used to generate local variables at the proper depth.
1001func (n *callNode) WriteMethod(varName string, depth int, ancestors []*callNode, sw *generator.SnippetWriter) {
1002	// if len(n.call) > 0 {
1003	// 	sw.Do(fmt.Sprintf("// %s\n", callPath(append(ancestors, n)).String()), nil)
1004	// }
1005
1006	if len(n.field) > 0 {
1007		varName = varName + "." + n.field
1008	}
1009
1010	index, local := varsForDepth(depth)
1011	vars := generator.Args{
1012		"index": index,
1013		"local": local,
1014		"var":   varName,
1015	}
1016
1017	isPointer := n.elem && !n.index
1018	if isPointer && len(ancestors) > 0 {
1019		sw.Do("if $.var$ != nil {\n", vars)
1020	}
1021
1022	switch {
1023	case n.index:
1024		sw.Do("for $.index$ := range $.var$ {\n", vars)
1025		if !n.markerOnly {
1026			if n.elem {
1027				sw.Do("$.local$ := $.var$[$.index$]\n", vars)
1028			} else {
1029				sw.Do("$.local$ := &$.var$[$.index$]\n", vars)
1030			}
1031		}
1032
1033		n.writeDefaulter(varName, index, isPointer, sw)
1034		n.writeCalls(local, true, sw)
1035		for i := range n.children {
1036			n.children[i].WriteMethod(local, depth+1, append(ancestors, n), sw)
1037		}
1038		sw.Do("}\n", nil)
1039	case n.key:
1040		if n.defaultValue != "" {
1041			// Map keys are typed and cannot share the same index variable as arrays and other maps
1042			index = index + "_" + ancestors[len(ancestors)-1].field
1043			vars["index"] = index
1044			sw.Do("for $.index$ := range $.var$ {\n", vars)
1045			n.writeDefaulter(varName, index, isPointer, sw)
1046			sw.Do("}\n", nil)
1047		}
1048	default:
1049		n.writeDefaulter(varName, index, isPointer, sw)
1050		n.writeCalls(varName, isPointer, sw)
1051		for i := range n.children {
1052			n.children[i].WriteMethod(varName, depth, append(ancestors, n), sw)
1053		}
1054	}
1055
1056	if isPointer && len(ancestors) > 0 {
1057		sw.Do("}\n", nil)
1058	}
1059}
1060
1061type callPath []*callNode
1062
1063// String prints a representation of a callPath that roughly approximates what a Go accessor
1064// would look like. Used for debugging only.
1065func (path callPath) String() string {
1066	if len(path) == 0 {
1067		return "<none>"
1068	}
1069	var parts []string
1070	for _, p := range path {
1071		last := len(parts) - 1
1072		switch {
1073		case p.elem:
1074			if len(parts) > 0 {
1075				parts[last] = "*" + parts[last]
1076			} else {
1077				parts = append(parts, "*")
1078			}
1079		case p.index:
1080			if len(parts) > 0 {
1081				parts[last] = parts[last] + "[i]"
1082			} else {
1083				parts = append(parts, "[i]")
1084			}
1085		case p.key:
1086			if len(parts) > 0 {
1087				parts[last] = parts[last] + "[key]"
1088			} else {
1089				parts = append(parts, "[key]")
1090			}
1091		default:
1092			if len(p.field) > 0 {
1093				parts = append(parts, p.field)
1094			} else {
1095				parts = append(parts, "<root>")
1096			}
1097		}
1098	}
1099	var calls []string
1100	for _, fn := range path[len(path)-1].call {
1101		calls = append(calls, fn.Name.String())
1102	}
1103	if len(calls) == 0 {
1104		calls = append(calls, "<none>")
1105	}
1106
1107	return strings.Join(parts, ".") + " calls " + strings.Join(calls, ", ")
1108}
1109