1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package namer
18
19import (
20	"sort"
21
22	"k8s.io/gengo/types"
23)
24
25// ImportTracker may be passed to a namer.RawNamer, to track the imports needed
26// for the types it names.
27//
28// TODO: pay attention to the package name (instead of renaming every package).
29type DefaultImportTracker struct {
30	pathToName map[string]string
31	// forbidden names are in here. (e.g. "go" is a directory in which
32	// there is code, but "go" is not a legal name for a package, so we put
33	// it here to prevent us from naming any package "go")
34	nameToPath map[string]string
35	local      types.Name
36
37	// Returns true if a given types is an invalid type and should be ignored.
38	IsInvalidType func(*types.Type) bool
39	// Returns the final local name for the given name
40	LocalName func(types.Name) string
41	// Returns the "import" line for a given (path, name).
42	PrintImport func(string, string) string
43}
44
45func NewDefaultImportTracker(local types.Name) DefaultImportTracker {
46	return DefaultImportTracker{
47		pathToName: map[string]string{},
48		nameToPath: map[string]string{},
49		local:      local,
50	}
51}
52
53func (tracker *DefaultImportTracker) AddTypes(types ...*types.Type) {
54	for _, t := range types {
55		tracker.AddType(t)
56	}
57}
58func (tracker *DefaultImportTracker) AddType(t *types.Type) {
59	if tracker.local.Package == t.Name.Package {
60		return
61	}
62
63	if tracker.IsInvalidType(t) {
64		if t.Kind == types.Builtin {
65			return
66		}
67		if _, ok := tracker.nameToPath[t.Name.Package]; !ok {
68			tracker.nameToPath[t.Name.Package] = ""
69		}
70		return
71	}
72
73	if len(t.Name.Package) == 0 {
74		return
75	}
76	path := t.Name.Path
77	if len(path) == 0 {
78		path = t.Name.Package
79	}
80	if _, ok := tracker.pathToName[path]; ok {
81		return
82	}
83	name := tracker.LocalName(t.Name)
84	tracker.nameToPath[name] = path
85	tracker.pathToName[path] = name
86}
87
88func (tracker *DefaultImportTracker) ImportLines() []string {
89	importPaths := []string{}
90	for path := range tracker.pathToName {
91		importPaths = append(importPaths, path)
92	}
93	sort.Sort(sort.StringSlice(importPaths))
94	out := []string{}
95	for _, path := range importPaths {
96		out = append(out, tracker.PrintImport(path, tracker.pathToName[path]))
97	}
98	return out
99}
100
101// LocalNameOf returns the name you would use to refer to the package at the
102// specified path within the body of a file.
103func (tracker *DefaultImportTracker) LocalNameOf(path string) string {
104	return tracker.pathToName[path]
105}
106
107// PathOf returns the path that a given localName is referring to within the
108// body of a file.
109func (tracker *DefaultImportTracker) PathOf(localName string) (string, bool) {
110	name, ok := tracker.nameToPath[localName]
111	return name, ok
112}
113