1/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package generators
18
19import (
20	"bytes"
21	"encoding/json"
22	"fmt"
23	"io"
24	"path/filepath"
25	"reflect"
26	"strings"
27
28	"k8s.io/gengo/args"
29	"k8s.io/gengo/generator"
30	"k8s.io/gengo/namer"
31	"k8s.io/gengo/types"
32
33	"k8s.io/klog/v2"
34)
35
36// CustomArgs is used tby the go2idl framework to pass args specific to this
37// generator.
38type CustomArgs struct {
39	ExtraPeerDirs []string // Always consider these as last-ditch possibilities for conversions.
40}
41
42var typeZeroValue = map[string]interface{}{
43	"uint":        0.,
44	"uint8":       0.,
45	"uint16":      0.,
46	"uint32":      0.,
47	"uint64":      0.,
48	"int":         0.,
49	"int8":        0.,
50	"int16":       0.,
51	"int32":       0.,
52	"int64":       0.,
53	"byte":        0,
54	"float64":     0.,
55	"float32":     0.,
56	"bool":        false,
57	"time.Time":   "",
58	"string":      "",
59	"integer":     0.,
60	"number":      0.,
61	"boolean":     false,
62	"[]byte":      "", // base64 encoded characters
63	"interface{}": interface{}(nil),
64}
65
66// These are the comment tags that carry parameters for defaulter generation.
67const tagName = "k8s:defaulter-gen"
68const inputTagName = "k8s:defaulter-gen-input"
69const defaultTagName = "default"
70
71func extractDefaultTag(comments []string) []string {
72	return types.ExtractCommentTags("+", comments)[defaultTagName]
73}
74
75func extractTag(comments []string) []string {
76	return types.ExtractCommentTags("+", comments)[tagName]
77}
78
79func extractInputTag(comments []string) []string {
80	return types.ExtractCommentTags("+", comments)[inputTagName]
81}
82
83func checkTag(comments []string, require ...string) bool {
84	values := types.ExtractCommentTags("+", comments)[tagName]
85	if len(require) == 0 {
86		return len(values) == 1 && values[0] == ""
87	}
88	return reflect.DeepEqual(values, require)
89}
90
91func defaultFnNamer() *namer.NameStrategy {
92	return &namer.NameStrategy{
93		Prefix: "SetDefaults_",
94		Join: func(pre string, in []string, post string) string {
95			return pre + strings.Join(in, "_") + post
96		},
97	}
98}
99
100func objectDefaultFnNamer() *namer.NameStrategy {
101	return &namer.NameStrategy{
102		Prefix: "SetObjectDefaults_",
103		Join: func(pre string, in []string, post string) string {
104			return pre + strings.Join(in, "_") + post
105		},
106	}
107}
108
109// NameSystems returns the name system used by the generators in this package.
110func NameSystems() namer.NameSystems {
111	return namer.NameSystems{
112		"public":          namer.NewPublicNamer(1),
113		"raw":             namer.NewRawNamer("", nil),
114		"defaultfn":       defaultFnNamer(),
115		"objectdefaultfn": objectDefaultFnNamer(),
116	}
117}
118
119// DefaultNameSystem returns the default name system for ordering the types to be
120// processed by the generators in this package.
121func DefaultNameSystem() string {
122	return "public"
123}
124
125// defaults holds the declared defaulting functions for a given type (all defaulting functions
126// are expected to be func(1))
127type defaults struct {
128	// object is the defaulter function for a top level type (typically one with TypeMeta) that
129	// invokes all child defaulters. May be nil if the object defaulter has not yet been generated.
130	object *types.Type
131	// base is a defaulter function defined for a type SetDefaults_Pod which does not invoke all
132	// child defaults - the base defaulter alone is insufficient to default a type
133	base *types.Type
134	// additional is zero or more defaulter functions of the form SetDefaults_Pod_XXXX that can be
135	// included in the Object defaulter.
136	additional []*types.Type
137}
138
139// All of the types in conversions map are of type "DeclarationOf" with
140// the underlying type being "Func".
141type defaulterFuncMap map[*types.Type]defaults
142
143// Returns all manually-defined defaulting functions in the package.
144func getManualDefaultingFunctions(context *generator.Context, pkg *types.Package, manualMap defaulterFuncMap) {
145	buffer := &bytes.Buffer{}
146	sw := generator.NewSnippetWriter(buffer, context, "$", "$")
147
148	for _, f := range pkg.Functions {
149		if f.Underlying == nil || f.Underlying.Kind != types.Func {
150			klog.Errorf("Malformed function: %#v", f)
151			continue
152		}
153		if f.Underlying.Signature == nil {
154			klog.Errorf("Function without signature: %#v", f)
155			continue
156		}
157		signature := f.Underlying.Signature
158		// Check whether the function is defaulting function.
159		// Note that all of them have signature:
160		// object: func SetObjectDefaults_inType(*inType)
161		// base: func SetDefaults_inType(*inType)
162		// additional: func SetDefaults_inType_Qualifier(*inType)
163		if signature.Receiver != nil {
164			continue
165		}
166		if len(signature.Parameters) != 1 {
167			continue
168		}
169		if len(signature.Results) != 0 {
170			continue
171		}
172		inType := signature.Parameters[0]
173		if inType.Kind != types.Pointer {
174			continue
175		}
176		// Check if this is the primary defaulter.
177		args := defaultingArgsFromType(inType.Elem)
178		sw.Do("$.inType|defaultfn$", args)
179		switch {
180		case f.Name.Name == buffer.String():
181			key := inType.Elem
182			// We might scan the same package twice, and that's OK.
183			v, ok := manualMap[key]
184			if ok && v.base != nil && v.base.Name.Package != pkg.Path {
185				panic(fmt.Sprintf("duplicate static defaulter defined: %#v", key))
186			}
187			v.base = f
188			manualMap[key] = v
189			klog.V(6).Infof("found base defaulter function for %s from %s", key.Name, f.Name)
190		// Is one of the additional defaulters - a top level defaulter on a type that is
191		// also invoked.
192		case strings.HasPrefix(f.Name.Name, buffer.String()+"_"):
193			key := inType.Elem
194			v, ok := manualMap[key]
195			if ok {
196				exists := false
197				for _, existing := range v.additional {
198					if existing.Name == f.Name {
199						exists = true
200						break
201					}
202				}
203				if exists {
204					continue
205				}
206			}
207			v.additional = append(v.additional, f)
208			manualMap[key] = v
209			klog.V(6).Infof("found additional defaulter function for %s from %s", key.Name, f.Name)
210		}
211		buffer.Reset()
212		sw.Do("$.inType|objectdefaultfn$", args)
213		if f.Name.Name == buffer.String() {
214			key := inType.Elem
215			// We might scan the same package twice, and that's OK.
216			v, ok := manualMap[key]
217			if ok && v.base != nil && v.base.Name.Package != pkg.Path {
218				panic(fmt.Sprintf("duplicate static defaulter defined: %#v", key))
219			}
220			v.object = f
221			manualMap[key] = v
222			klog.V(6).Infof("found object defaulter function for %s from %s", key.Name, f.Name)
223		}
224		buffer.Reset()
225	}
226}
227
228func Packages(context *generator.Context, arguments *args.GeneratorArgs) generator.Packages {
229	boilerplate, err := arguments.LoadGoBoilerplate()
230	if err != nil {
231		klog.Fatalf("Failed loading boilerplate: %v", err)
232	}
233
234	packages := generator.Packages{}
235	header := append([]byte(fmt.Sprintf("// +build !%s\n\n", arguments.GeneratedBuildTag)), boilerplate...)
236
237	// Accumulate pre-existing default functions.
238	// TODO: This is too ad-hoc.  We need a better way.
239	existingDefaulters := defaulterFuncMap{}
240
241	buffer := &bytes.Buffer{}
242	sw := generator.NewSnippetWriter(buffer, context, "$", "$")
243
244	// We are generating defaults only for packages that are explicitly
245	// passed as InputDir.
246	for _, i := range context.Inputs {
247		klog.V(5).Infof("considering pkg %q", i)
248		pkg := context.Universe[i]
249		if pkg == nil {
250			// If the input had no Go files, for example.
251			continue
252		}
253		// typesPkg is where the types that needs defaulter are defined.
254		// Sometimes it is different from pkg. For example, kubernetes core/v1
255		// types are defined in vendor/k8s.io/api/core/v1, while pkg is at
256		// pkg/api/v1.
257		typesPkg := pkg
258
259		// Add defaulting functions.
260		getManualDefaultingFunctions(context, pkg, existingDefaulters)
261
262		var peerPkgs []string
263		if customArgs, ok := arguments.CustomArgs.(*CustomArgs); ok {
264			for _, pkg := range customArgs.ExtraPeerDirs {
265				if i := strings.Index(pkg, "/vendor/"); i != -1 {
266					pkg = pkg[i+len("/vendor/"):]
267				}
268				peerPkgs = append(peerPkgs, pkg)
269			}
270		}
271		// Make sure our peer-packages are added and fully parsed.
272		for _, pp := range peerPkgs {
273			context.AddDir(pp)
274			getManualDefaultingFunctions(context, context.Universe[pp], existingDefaulters)
275		}
276
277		typesWith := extractTag(pkg.Comments)
278		shouldCreateObjectDefaulterFn := func(t *types.Type) bool {
279			if defaults, ok := existingDefaulters[t]; ok && defaults.object != nil {
280				// A default generator is defined
281				baseTypeName := "<unknown>"
282				if defaults.base != nil {
283					baseTypeName = defaults.base.Name.String()
284				}
285				klog.V(5).Infof("  an object defaulter already exists as %s", baseTypeName)
286				return false
287			}
288			// opt-out
289			if checkTag(t.SecondClosestCommentLines, "false") {
290				return false
291			}
292			// opt-in
293			if checkTag(t.SecondClosestCommentLines, "true") {
294				return true
295			}
296			// For every k8s:defaulter-gen tag at the package level, interpret the value as a
297			// field name (like TypeMeta, ListMeta, ObjectMeta) and trigger defaulter generation
298			// for any type with any of the matching field names. Provides a more useful package
299			// level defaulting than global (because we only need defaulters on a subset of objects -
300			// usually those with TypeMeta).
301			if t.Kind == types.Struct && len(typesWith) > 0 {
302				for _, field := range t.Members {
303					for _, s := range typesWith {
304						if field.Name == s {
305							return true
306						}
307					}
308				}
309			}
310			return false
311		}
312
313		// if the types are not in the same package where the defaulter functions to be generated
314		inputTags := extractInputTag(pkg.Comments)
315		if len(inputTags) > 1 {
316			panic(fmt.Sprintf("there could only be one input tag, got %#v", inputTags))
317		}
318		if len(inputTags) == 1 {
319			var err error
320			typesPkg, err = context.AddDirectory(filepath.Join(pkg.Path, inputTags[0]))
321			if err != nil {
322				klog.Fatalf("cannot import package %s", inputTags[0])
323			}
324			// update context.Order to the latest context.Universe
325			orderer := namer.Orderer{Namer: namer.NewPublicNamer(1)}
326			context.Order = orderer.OrderUniverse(context.Universe)
327		}
328
329		newDefaulters := defaulterFuncMap{}
330		for _, t := range typesPkg.Types {
331			if !shouldCreateObjectDefaulterFn(t) {
332				continue
333			}
334			if namer.IsPrivateGoName(t.Name.Name) {
335				// We won't be able to convert to a private type.
336				klog.V(5).Infof("  found a type %v, but it is a private name", t)
337				continue
338			}
339
340			// create a synthetic type we can use during generation
341			newDefaulters[t] = defaults{}
342		}
343
344		// only generate defaulters for objects that actually have defined defaulters
345		// prevents empty defaulters from being registered
346		for {
347			promoted := 0
348			for t, d := range newDefaulters {
349				if d.object != nil {
350					continue
351				}
352				if newCallTreeForType(existingDefaulters, newDefaulters).build(t, true) != nil {
353					args := defaultingArgsFromType(t)
354					sw.Do("$.inType|objectdefaultfn$", args)
355					newDefaulters[t] = defaults{
356						object: &types.Type{
357							Name: types.Name{
358								Package: pkg.Path,
359								Name:    buffer.String(),
360							},
361							Kind: types.Func,
362						},
363					}
364					buffer.Reset()
365					promoted++
366				}
367			}
368			if promoted != 0 {
369				continue
370			}
371
372			// prune any types that were not used
373			for t, d := range newDefaulters {
374				if d.object == nil {
375					klog.V(6).Infof("did not generate defaulter for %s because no child defaulters were registered", t.Name)
376					delete(newDefaulters, t)
377				}
378			}
379			break
380		}
381
382		if len(newDefaulters) == 0 {
383			klog.V(5).Infof("no defaulters in package %s", pkg.Name)
384		}
385
386		path := pkg.Path
387		// if the source path is within a /vendor/ directory (for example,
388		// k8s.io/kubernetes/vendor/k8s.io/apimachinery/pkg/apis/meta/v1), allow
389		// generation to output to the proper relative path (under vendor).
390		// Otherwise, the generator will create the file in the wrong location
391		// in the output directory.
392		// TODO: build a more fundamental concept in gengo for dealing with modifications
393		// to vendored packages.
394		if strings.HasPrefix(pkg.SourcePath, arguments.OutputBase) {
395			expandedPath := strings.TrimPrefix(pkg.SourcePath, arguments.OutputBase)
396			if strings.Contains(expandedPath, "/vendor/") {
397				path = expandedPath
398			}
399		}
400
401		packages = append(packages,
402			&generator.DefaultPackage{
403				PackageName: filepath.Base(pkg.Path),
404				PackagePath: path,
405				HeaderText:  header,
406				GeneratorFunc: func(c *generator.Context) (generators []generator.Generator) {
407					return []generator.Generator{
408						NewGenDefaulter(arguments.OutputFileBaseName, typesPkg.Path, pkg.Path, existingDefaulters, newDefaulters, peerPkgs),
409					}
410				},
411				FilterFunc: func(c *generator.Context, t *types.Type) bool {
412					return t.Name.Package == typesPkg.Path
413				},
414			})
415	}
416	return packages
417}
418
419// callTreeForType contains fields necessary to build a tree for types.
420type callTreeForType struct {
421	existingDefaulters     defaulterFuncMap
422	newDefaulters          defaulterFuncMap
423	currentlyBuildingTypes map[*types.Type]bool
424}
425
426func newCallTreeForType(existingDefaulters, newDefaulters defaulterFuncMap) *callTreeForType {
427	return &callTreeForType{
428		existingDefaulters:     existingDefaulters,
429		newDefaulters:          newDefaulters,
430		currentlyBuildingTypes: make(map[*types.Type]bool),
431	}
432}
433
434func resolveTypeAndDepth(t *types.Type) (*types.Type, int) {
435	var prev *types.Type
436	depth := 0
437	for prev != t {
438		prev = t
439		if t.Kind == types.Alias {
440			t = t.Underlying
441		} else if t.Kind == types.Pointer {
442			t = t.Elem
443			depth += 1
444		}
445	}
446	return t, depth
447}
448
449// getNestedDefault returns the first default value when resolving alias types
450func getNestedDefault(t *types.Type) string {
451	var prev *types.Type
452	for prev != t {
453		prev = t
454		defaultMap := extractDefaultTag(t.CommentLines)
455		if len(defaultMap) == 1 && defaultMap[0] != "" {
456			return defaultMap[0]
457		}
458		if t.Kind == types.Alias {
459			t = t.Underlying
460		} else if t.Kind == types.Pointer {
461			t = t.Elem
462		}
463	}
464	return ""
465}
466
467func mustEnforceDefault(t *types.Type, depth int, omitEmpty bool) (interface{}, error) {
468	if depth > 0 {
469		return nil, nil
470	}
471	switch t.Kind {
472	case types.Pointer, types.Map, types.Slice, types.Array, types.Interface:
473		return nil, nil
474	case types.Struct:
475		return map[string]interface{}{}, nil
476	case types.Builtin:
477		if !omitEmpty {
478			if zero, ok := typeZeroValue[t.String()]; ok {
479				return zero, nil
480			} else {
481				return nil, fmt.Errorf("please add type %v to typeZeroValue struct", t)
482			}
483		}
484		return nil, nil
485	default:
486		return nil, fmt.Errorf("not sure how to enforce default for %v", t.Kind)
487	}
488}
489
490func populateDefaultValue(node *callNode, t *types.Type, tags string, commentLines []string) *callNode {
491	defaultMap := extractDefaultTag(commentLines)
492	var defaultString string
493	if len(defaultMap) == 1 {
494		defaultString = defaultMap[0]
495	}
496
497	t, depth := resolveTypeAndDepth(t)
498	if depth > 0 && defaultString == "" {
499		defaultString = getNestedDefault(t)
500	}
501	if len(defaultMap) > 1 {
502		klog.Fatalf("Found more than one default tag for %v", t.Kind)
503	} else if len(defaultMap) == 0 {
504		return node
505	}
506	var defaultValue interface{}
507	if err := json.Unmarshal([]byte(defaultString), &defaultValue); err != nil {
508		klog.Fatalf("Failed to unmarshal default: %v", err)
509	}
510
511	omitEmpty := strings.Contains(reflect.StructTag(tags).Get("json"), "omitempty")
512	if enforced, err := mustEnforceDefault(t, depth, omitEmpty); err != nil {
513		klog.Fatal(err)
514	} else if enforced != nil {
515		if defaultValue != nil {
516			if reflect.DeepEqual(defaultValue, enforced) {
517				// If the default value annotation matches the default value for the type,
518				// do not generate any defaulting function
519				return node
520			} else {
521				enforcedJSON, _ := json.Marshal(enforced)
522				klog.Fatalf("Invalid default value (%#v) for non-pointer/non-omitempty. If specified, must be: %v", defaultValue, string(enforcedJSON))
523			}
524		}
525	}
526
527	// callNodes are not automatically generated for primitive types. Generate one if the callNode does not exist
528	if node == nil {
529		node = &callNode{}
530		node.markerOnly = true
531	}
532
533	node.defaultIsPrimitive = t.IsPrimitive()
534	node.defaultType = t.String()
535	node.defaultValue = defaultString
536	node.defaultDepth = depth
537	return node
538}
539
540// build creates a tree of paths to fields (based on how they would be accessed in Go - pointer, elem,
541// slice, or key) and the functions that should be invoked on each field. An in-order traversal of the resulting tree
542// can be used to generate a Go function that invokes each nested function on the appropriate type. The return
543// value may be nil if there are no functions to call on type or the type is a primitive (Defaulters can only be
544// invoked on structs today). When root is true this function will not use a newDefaulter. existingDefaulters should
545// contain all defaulting functions by type defined in code - newDefaulters should contain all object defaulters
546// that could be or will be generated. If newDefaulters has an entry for a type, but the 'object' field is nil,
547// this function skips adding that defaulter - this allows us to avoid generating object defaulter functions for
548// list types that call empty defaulters.
549func (c *callTreeForType) build(t *types.Type, root bool) *callNode {
550	parent := &callNode{}
551
552	if root {
553		// the root node is always a pointer
554		parent.elem = true
555	}
556
557	defaults, _ := c.existingDefaulters[t]
558	newDefaults, generated := c.newDefaulters[t]
559	switch {
560	case !root && generated && newDefaults.object != nil:
561		parent.call = append(parent.call, newDefaults.object)
562		// if we will be generating the defaulter, it by definition is a covering
563		// defaulter, so we halt recursion
564		klog.V(6).Infof("the defaulter %s will be generated as an object defaulter", t.Name)
565		return parent
566
567	case defaults.object != nil:
568		// object defaulters are always covering
569		parent.call = append(parent.call, defaults.object)
570		return parent
571
572	case defaults.base != nil:
573		parent.call = append(parent.call, defaults.base)
574		// if the base function indicates it "covers" (it already includes defaulters)
575		// we can halt recursion
576		if checkTag(defaults.base.CommentLines, "covers") {
577			klog.V(6).Infof("the defaulter %s indicates it covers all sub generators", t.Name)
578			return parent
579		}
580	}
581
582	// base has been added already, now add any additional defaulters defined for this object
583	parent.call = append(parent.call, defaults.additional...)
584
585	// if the type already exists, don't build the tree for it and don't generate anything.
586	// This is used to avoid recursion for nested recursive types.
587	if c.currentlyBuildingTypes[t] {
588		return nil
589	}
590	// if type doesn't exist, mark it as existing
591	c.currentlyBuildingTypes[t] = true
592
593	defer func() {
594		// The type will now acts as a parent, not a nested recursive type.
595		// We can now build the tree for it safely.
596		c.currentlyBuildingTypes[t] = false
597	}()
598
599	switch t.Kind {
600	case types.Pointer:
601		if child := c.build(t.Elem, false); child != nil {
602			child.elem = true
603			parent.children = append(parent.children, *child)
604		}
605	case types.Slice, types.Array:
606		if child := c.build(t.Elem, false); child != nil {
607			child.index = true
608			if t.Elem.Kind == types.Pointer {
609				child.elem = true
610			}
611			parent.children = append(parent.children, *child)
612		} else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines); member != nil {
613			member.index = true
614			parent.children = append(parent.children, *member)
615		}
616	case types.Map:
617		if child := c.build(t.Elem, false); child != nil {
618			child.key = true
619			parent.children = append(parent.children, *child)
620		} else if member := populateDefaultValue(nil, t.Elem, "", t.Elem.CommentLines); member != nil {
621			member.key = true
622			parent.children = append(parent.children, *member)
623		}
624
625	case types.Struct:
626		for _, field := range t.Members {
627			name := field.Name
628			if len(name) == 0 {
629				if field.Type.Kind == types.Pointer {
630					name = field.Type.Elem.Name.Name
631				} else {
632					name = field.Type.Name.Name
633				}
634			}
635			if child := c.build(field.Type, false); child != nil {
636				child.field = name
637				populateDefaultValue(child, field.Type, field.Tags, field.CommentLines)
638				parent.children = append(parent.children, *child)
639			} else if member := populateDefaultValue(nil, field.Type, field.Tags, field.CommentLines); member != nil {
640				member.field = name
641				parent.children = append(parent.children, *member)
642			}
643		}
644	case types.Alias:
645		if child := c.build(t.Underlying, false); child != nil {
646			parent.children = append(parent.children, *child)
647		}
648	}
649	if len(parent.children) == 0 && len(parent.call) == 0 {
650		//klog.V(6).Infof("decided type %s needs no generation", t.Name)
651		return nil
652	}
653	return parent
654}
655
656const (
657	runtimePackagePath    = "k8s.io/apimachinery/pkg/runtime"
658	conversionPackagePath = "k8s.io/apimachinery/pkg/conversion"
659)
660
661// genDefaulter produces a file with a autogenerated conversions.
662type genDefaulter struct {
663	generator.DefaultGen
664	typesPackage       string
665	outputPackage      string
666	peerPackages       []string
667	newDefaulters      defaulterFuncMap
668	existingDefaulters defaulterFuncMap
669	imports            namer.ImportTracker
670	typesForInit       []*types.Type
671}
672
673func NewGenDefaulter(sanitizedName, typesPackage, outputPackage string, existingDefaulters, newDefaulters defaulterFuncMap, peerPkgs []string) generator.Generator {
674	return &genDefaulter{
675		DefaultGen: generator.DefaultGen{
676			OptionalName: sanitizedName,
677		},
678		typesPackage:       typesPackage,
679		outputPackage:      outputPackage,
680		peerPackages:       peerPkgs,
681		newDefaulters:      newDefaulters,
682		existingDefaulters: existingDefaulters,
683		imports:            generator.NewImportTracker(),
684		typesForInit:       make([]*types.Type, 0),
685	}
686}
687
688func (g *genDefaulter) Namers(c *generator.Context) namer.NameSystems {
689	// Have the raw namer for this file track what it imports.
690	return namer.NameSystems{
691		"raw": namer.NewRawNamer(g.outputPackage, g.imports),
692	}
693}
694
695func (g *genDefaulter) isOtherPackage(pkg string) bool {
696	if pkg == g.outputPackage {
697		return false
698	}
699	if strings.HasSuffix(pkg, `"`+g.outputPackage+`"`) {
700		return false
701	}
702	return true
703}
704
705func (g *genDefaulter) Filter(c *generator.Context, t *types.Type) bool {
706	defaults, ok := g.newDefaulters[t]
707	if !ok || defaults.object == nil {
708		return false
709	}
710	g.typesForInit = append(g.typesForInit, t)
711	return true
712}
713
714func (g *genDefaulter) Imports(c *generator.Context) (imports []string) {
715	var importLines []string
716	for _, singleImport := range g.imports.ImportLines() {
717		if g.isOtherPackage(singleImport) {
718			importLines = append(importLines, singleImport)
719		}
720	}
721	return importLines
722}
723
724func (g *genDefaulter) Init(c *generator.Context, w io.Writer) error {
725	sw := generator.NewSnippetWriter(w, c, "$", "$")
726
727	scheme := c.Universe.Type(types.Name{Package: runtimePackagePath, Name: "Scheme"})
728	schemePtr := &types.Type{
729		Kind: types.Pointer,
730		Elem: scheme,
731	}
732	sw.Do("// RegisterDefaults adds defaulters functions to the given scheme.\n", nil)
733	sw.Do("// Public to allow building arbitrary schemes.\n", nil)
734	sw.Do("// All generated defaulters are covering - they call all nested defaulters.\n", nil)
735	sw.Do("func RegisterDefaults(scheme $.|raw$) error {\n", schemePtr)
736	for _, t := range g.typesForInit {
737		args := defaultingArgsFromType(t)
738		sw.Do("scheme.AddTypeDefaultingFunc(&$.inType|raw${}, func(obj interface{}) { $.inType|objectdefaultfn$(obj.(*$.inType|raw$)) })\n", args)
739	}
740	sw.Do("return nil\n", nil)
741	sw.Do("}\n\n", nil)
742	return sw.Error()
743}
744
745func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Writer) error {
746	if _, ok := g.newDefaulters[t]; !ok {
747		return nil
748	}
749
750	klog.V(5).Infof("generating for type %v", t)
751
752	callTree := newCallTreeForType(g.existingDefaulters, g.newDefaulters).build(t, true)
753	if callTree == nil {
754		klog.V(5).Infof("  no defaulters defined")
755		return nil
756	}
757	i := 0
758	callTree.VisitInOrder(func(ancestors []*callNode, current *callNode) {
759		if len(current.call) == 0 {
760			return
761		}
762		path := callPath(append(ancestors, current))
763		klog.V(5).Infof("  %d: %s", i, path)
764		i++
765	})
766
767	sw := generator.NewSnippetWriter(w, c, "$", "$")
768	g.generateDefaulter(t, callTree, sw)
769	return sw.Error()
770}
771
772func defaultingArgsFromType(inType *types.Type) generator.Args {
773	return generator.Args{
774		"inType": inType,
775	}
776}
777
778func (g *genDefaulter) generateDefaulter(inType *types.Type, callTree *callNode, sw *generator.SnippetWriter) {
779	sw.Do("func $.inType|objectdefaultfn$(in *$.inType|raw$) {\n", defaultingArgsFromType(inType))
780	callTree.WriteMethod("in", 0, nil, sw)
781	sw.Do("}\n\n", nil)
782}
783
784// callNode represents an entry in a tree of Go type accessors - the path from the root to a leaf represents
785// how in Go code an access would be performed. For example, if a defaulting function exists on a container
786// lifecycle hook, to invoke that defaulter correctly would require this Go code:
787//
788//     for i := range pod.Spec.Containers {
789//       o := &pod.Spec.Containers[i]
790//       if o.LifecycleHook != nil {
791//         SetDefaults_LifecycleHook(o.LifecycleHook)
792//       }
793//     }
794//
795// That would be represented by a call tree like:
796//
797//   callNode
798//     field: "Spec"
799//     children:
800//     - field: "Containers"
801//       children:
802//       - index: true
803//         children:
804//         - field: "LifecycleHook"
805//           elem: true
806//           call:
807//           - SetDefaults_LifecycleHook
808//
809// which we can traverse to build that Go struct (you must call the field Spec, then Containers, then range over
810// that field, then check whether the LifecycleHook field is nil, before calling SetDefaults_LifecycleHook on
811// the pointer to that field).
812type callNode struct {
813	// field is the name of the Go member to access
814	field string
815	// key is true if this is a map and we must range over the key and values
816	key bool
817	// index is true if this is a slice and we must range over the slice values
818	index bool
819	// elem is true if the previous elements refer to a pointer (typically just field)
820	elem bool
821
822	// call is all of the functions that must be invoked on this particular node, in order
823	call []*types.Type
824	// children is the child call nodes that must also be traversed
825	children []callNode
826
827	// defaultValue is the defaultValue of a callNode struct
828	// Only primitive types and pointer types are eligible to have a default value
829	defaultValue string
830
831	// defaultIsPrimitive is used to determine how to assign the default value.
832	// Primitive types will be directly assigned while complex types will use JSON unmarshalling
833	defaultIsPrimitive bool
834
835	// markerOnly is true if the callNode exists solely to fill in a default value
836	markerOnly bool
837
838	// defaultDepth is used to determine pointer level of the default value
839	// For example 1 corresponds to setting a default value and taking its pointer while
840	// 2 corresponds to setting a default value and taking its pointer's pointer
841	// 0 implies that no pointers are used
842	defaultDepth int
843
844	// defaultType is the type of the default value.
845	// Only populated if defaultIsPrimitive is true
846	defaultType string
847}
848
849// CallNodeVisitorFunc is a function for visiting a call tree. ancestors is the list of all parents
850// of this node to the root of the tree - will be empty at the root.
851type CallNodeVisitorFunc func(ancestors []*callNode, node *callNode)
852
853func (n *callNode) VisitInOrder(fn CallNodeVisitorFunc) {
854	n.visitInOrder(nil, fn)
855}
856
857func (n *callNode) visitInOrder(ancestors []*callNode, fn CallNodeVisitorFunc) {
858	fn(ancestors, n)
859	ancestors = append(ancestors, n)
860	for i := range n.children {
861		n.children[i].visitInOrder(ancestors, fn)
862	}
863}
864
865var (
866	indexVariables = "ijklmnop"
867	localVariables = "abcdefgh"
868)
869
870// varsForDepth creates temporary variables guaranteed to be unique within lexical Go scopes
871// of this depth in a function. It uses canonical Go loop variables for the first 7 levels
872// and then resorts to uglier prefixes.
873func varsForDepth(depth int) (index, local string) {
874	if depth > len(indexVariables) {
875		index = fmt.Sprintf("i%d", depth)
876	} else {
877		index = indexVariables[depth : depth+1]
878	}
879	if depth > len(localVariables) {
880		local = fmt.Sprintf("local%d", depth)
881	} else {
882		local = localVariables[depth : depth+1]
883	}
884	return
885}
886
887// writeCalls generates a list of function calls based on the calls field for the provided variable
888// name and pointer.
889func (n *callNode) writeCalls(varName string, isVarPointer bool, sw *generator.SnippetWriter) {
890	accessor := varName
891	if !isVarPointer {
892		accessor = "&" + accessor
893	}
894	for _, fn := range n.call {
895		sw.Do("$.fn|raw$($.var$)\n", generator.Args{
896			"fn":  fn,
897			"var": accessor,
898		})
899	}
900}
901
902func (n *callNode) writeDefaulter(varName string, index string, isVarPointer bool, sw *generator.SnippetWriter) {
903	if n.defaultValue == "" {
904		return
905	}
906	varPointer := varName
907	if !isVarPointer {
908		varPointer = "&" + varPointer
909	}
910
911	args := generator.Args{
912		"defaultValue": n.defaultValue,
913		"varPointer":   varPointer,
914		"varName":      varName,
915		"index":        index,
916		"varDepth":     n.defaultDepth,
917		"varType":      n.defaultType,
918	}
919
920	if n.index {
921		sw.Do("if reflect.ValueOf($.var$[$.index$]).IsZero() {\n", generator.Args{"var": varName, "index": index})
922		if n.defaultIsPrimitive {
923			if n.defaultDepth > 0 {
924				sw.Do("var ptrVar$.varDepth$ $.varType$ = $.defaultValue$\n", args)
925				for i := n.defaultDepth; i > 0; i-- {
926					sw.Do("ptrVar$.ptri$ := &ptrVar$.i$\n", generator.Args{"i": fmt.Sprintf("%d", i), "ptri": fmt.Sprintf("%d", (i - 1))})
927				}
928				sw.Do("$.varName$[$.index$] = ptrVar0", args)
929			} else {
930				sw.Do("$.varName$[$.index$] = $.defaultValue$", args)
931			}
932		} else {
933			sw.Do("if err := json.Unmarshal([]byte(`$.defaultValue$`), $.varPointer$[$.index$]); err != nil {\n", args)
934			sw.Do("panic(err)\n", nil)
935			sw.Do("}\n", nil)
936		}
937	} else if n.key {
938		mapDefaultVar := index + "_default"
939		args["mapDefaultVar"] = mapDefaultVar
940		sw.Do("if reflect.ValueOf($.var$[$.index$]).IsZero() {\n", generator.Args{"var": varName, "index": index})
941
942		if n.defaultIsPrimitive {
943			if n.defaultDepth > 0 {
944				sw.Do("var ptrVar$.varDepth$ $.varType$ = $.defaultValue$\n", args)
945				for i := n.defaultDepth; i > 0; i-- {
946					sw.Do("ptrVar$.ptri$ := &ptrVar$.i$\n", generator.Args{"i": fmt.Sprintf("%d", i), "ptri": fmt.Sprintf("%d", (i - 1))})
947				}
948				sw.Do("$.varName$[$.index$] = ptrVar0", args)
949			} else {
950				sw.Do("$.varName$[$.index$] = $.defaultValue$", args)
951			}
952		} else {
953			sw.Do("$.mapDefaultVar$ := $.varName$[$.index$]\n", args)
954			sw.Do("if err := json.Unmarshal([]byte(`$.defaultValue$`), &$.mapDefaultVar$); err != nil {\n", args)
955			sw.Do("panic(err)\n", nil)
956			sw.Do("}\n", nil)
957			sw.Do("$.varName$[$.index$] = $.mapDefaultVar$\n", args)
958		}
959	} else {
960		sw.Do("if reflect.ValueOf($.var$).IsZero() {\n", generator.Args{"var": varName})
961
962		if n.defaultIsPrimitive {
963			if n.defaultDepth > 0 {
964				sw.Do("var ptrVar$.varDepth$ $.varType$ = $.defaultValue$\n", args)
965				for i := n.defaultDepth; i > 0; i-- {
966					sw.Do("ptrVar$.ptri$ := &ptrVar$.i$\n", generator.Args{"i": fmt.Sprintf("%d", i), "ptri": fmt.Sprintf("%d", (i - 1))})
967				}
968				sw.Do("$.varName$ = ptrVar0", args)
969			} else {
970				sw.Do("$.varName$ = $.defaultValue$", args)
971			}
972		} else {
973			sw.Do("if err := json.Unmarshal([]byte(`$.defaultValue$`), $.varPointer$); err != nil {\n", args)
974			sw.Do("panic(err)\n", nil)
975			sw.Do("}\n", nil)
976		}
977	}
978	sw.Do("}\n", nil)
979}
980
981// WriteMethod performs an in-order traversal of the calltree, generating loops and if blocks as necessary
982// to correctly turn the call tree into a method body that invokes all calls on all child nodes of the call tree.
983// Depth is used to generate local variables at the proper depth.
984func (n *callNode) WriteMethod(varName string, depth int, ancestors []*callNode, sw *generator.SnippetWriter) {
985	// if len(n.call) > 0 {
986	// 	sw.Do(fmt.Sprintf("// %s\n", callPath(append(ancestors, n)).String()), nil)
987	// }
988
989	if len(n.field) > 0 {
990		varName = varName + "." + n.field
991	}
992
993	index, local := varsForDepth(depth)
994	vars := generator.Args{
995		"index": index,
996		"local": local,
997		"var":   varName,
998	}
999
1000	isPointer := n.elem && !n.index
1001	if isPointer && len(ancestors) > 0 {
1002		sw.Do("if $.var$ != nil {\n", vars)
1003	}
1004
1005	switch {
1006	case n.index:
1007		sw.Do("for $.index$ := range $.var$ {\n", vars)
1008		if !n.markerOnly {
1009			if n.elem {
1010				sw.Do("$.local$ := $.var$[$.index$]\n", vars)
1011			} else {
1012				sw.Do("$.local$ := &$.var$[$.index$]\n", vars)
1013			}
1014		}
1015
1016		n.writeDefaulter(varName, index, isPointer, sw)
1017		n.writeCalls(local, true, sw)
1018		for i := range n.children {
1019			n.children[i].WriteMethod(local, depth+1, append(ancestors, n), sw)
1020		}
1021		sw.Do("}\n", nil)
1022	case n.key:
1023		if n.defaultValue != "" {
1024			// Map keys are typed and cannot share the same index variable as arrays and other maps
1025			index = index + "_" + ancestors[len(ancestors)-1].field
1026			vars["index"] = index
1027			sw.Do("for $.index$ := range $.var$ {\n", vars)
1028			n.writeDefaulter(varName, index, isPointer, sw)
1029			sw.Do("}\n", nil)
1030		}
1031	default:
1032		n.writeDefaulter(varName, index, isPointer, sw)
1033		n.writeCalls(varName, isPointer, sw)
1034		for i := range n.children {
1035			n.children[i].WriteMethod(varName, depth, append(ancestors, n), sw)
1036		}
1037	}
1038
1039	if isPointer && len(ancestors) > 0 {
1040		sw.Do("}\n", nil)
1041	}
1042}
1043
1044type callPath []*callNode
1045
1046// String prints a representation of a callPath that roughly approximates what a Go accessor
1047// would look like. Used for debugging only.
1048func (path callPath) String() string {
1049	if len(path) == 0 {
1050		return "<none>"
1051	}
1052	var parts []string
1053	for _, p := range path {
1054		last := len(parts) - 1
1055		switch {
1056		case p.elem:
1057			if len(parts) > 0 {
1058				parts[last] = "*" + parts[last]
1059			} else {
1060				parts = append(parts, "*")
1061			}
1062		case p.index:
1063			if len(parts) > 0 {
1064				parts[last] = parts[last] + "[i]"
1065			} else {
1066				parts = append(parts, "[i]")
1067			}
1068		case p.key:
1069			if len(parts) > 0 {
1070				parts[last] = parts[last] + "[key]"
1071			} else {
1072				parts = append(parts, "[key]")
1073			}
1074		default:
1075			if len(p.field) > 0 {
1076				parts = append(parts, p.field)
1077			} else {
1078				parts = append(parts, "<root>")
1079			}
1080		}
1081	}
1082	var calls []string
1083	for _, fn := range path[len(path)-1].call {
1084		calls = append(calls, fn.Name.String())
1085	}
1086	if len(calls) == 0 {
1087		calls = append(calls, "<none>")
1088	}
1089
1090	return strings.Join(parts, ".") + " calls " + strings.Join(calls, ", ")
1091}
1092