1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package generator
18
19import (
20	"bytes"
21	"fmt"
22	"io"
23	"io/ioutil"
24	"os"
25	"path/filepath"
26	"strings"
27
28	"golang.org/x/tools/imports"
29	"k8s.io/gengo/namer"
30	"k8s.io/gengo/types"
31
32	"k8s.io/klog"
33)
34
35func errs2strings(errors []error) []string {
36	strs := make([]string, len(errors))
37	for i := range errors {
38		strs[i] = errors[i].Error()
39	}
40	return strs
41}
42
43// ExecutePackages runs the generators for every package in 'packages'. 'outDir'
44// is the base directory in which to place all the generated packages; it
45// should be a physical path on disk, not an import path. e.g.:
46// /path/to/home/path/to/gopath/src/
47// Each package has its import path already, this will be appended to 'outDir'.
48func (c *Context) ExecutePackages(outDir string, packages Packages) error {
49	var errors []error
50	for _, p := range packages {
51		if err := c.ExecutePackage(outDir, p); err != nil {
52			errors = append(errors, err)
53		}
54	}
55	if len(errors) > 0 {
56		return fmt.Errorf("some packages had errors:\n%v\n", strings.Join(errs2strings(errors), "\n"))
57	}
58	return nil
59}
60
61type DefaultFileType struct {
62	Format   func([]byte) ([]byte, error)
63	Assemble func(io.Writer, *File)
64}
65
66func (ft DefaultFileType) AssembleFile(f *File, pathname string) error {
67	klog.V(2).Infof("Assembling file %q", pathname)
68	destFile, err := os.Create(pathname)
69	if err != nil {
70		return err
71	}
72	defer destFile.Close()
73
74	b := &bytes.Buffer{}
75	et := NewErrorTracker(b)
76	ft.Assemble(et, f)
77	if et.Error() != nil {
78		return et.Error()
79	}
80	if formatted, err := ft.Format(b.Bytes()); err != nil {
81		err = fmt.Errorf("unable to format file %q (%v).", pathname, err)
82		// Write the file anyway, so they can see what's going wrong and fix the generator.
83		if _, err2 := destFile.Write(b.Bytes()); err2 != nil {
84			return err2
85		}
86		return err
87	} else {
88		_, err = destFile.Write(formatted)
89		return err
90	}
91}
92
93func (ft DefaultFileType) VerifyFile(f *File, pathname string) error {
94	klog.V(2).Infof("Verifying file %q", pathname)
95	friendlyName := filepath.Join(f.PackageName, f.Name)
96	b := &bytes.Buffer{}
97	et := NewErrorTracker(b)
98	ft.Assemble(et, f)
99	if et.Error() != nil {
100		return et.Error()
101	}
102	formatted, err := ft.Format(b.Bytes())
103	if err != nil {
104		return fmt.Errorf("unable to format the output for %q: %v", friendlyName, err)
105	}
106	existing, err := ioutil.ReadFile(pathname)
107	if err != nil {
108		return fmt.Errorf("unable to read file %q for comparison: %v", friendlyName, err)
109	}
110	if bytes.Compare(formatted, existing) == 0 {
111		return nil
112	}
113	// Be nice and find the first place where they differ
114	i := 0
115	for i < len(formatted) && i < len(existing) && formatted[i] == existing[i] {
116		i++
117	}
118	eDiff, fDiff := existing[i:], formatted[i:]
119	if len(eDiff) > 100 {
120		eDiff = eDiff[:100]
121	}
122	if len(fDiff) > 100 {
123		fDiff = fDiff[:100]
124	}
125	return fmt.Errorf("output for %q differs; first existing/expected diff: \n  %q\n  %q", friendlyName, string(eDiff), string(fDiff))
126}
127
128func assembleGolangFile(w io.Writer, f *File) {
129	w.Write(f.Header)
130	fmt.Fprintf(w, "package %v\n\n", f.PackageName)
131
132	if len(f.Imports) > 0 {
133		fmt.Fprint(w, "import (\n")
134		for i := range f.Imports {
135			if strings.Contains(i, "\"") {
136				// they included quotes, or are using the
137				// `name "path/to/pkg"` format.
138				fmt.Fprintf(w, "\t%s\n", i)
139			} else {
140				fmt.Fprintf(w, "\t%q\n", i)
141			}
142		}
143		fmt.Fprint(w, ")\n\n")
144	}
145
146	if f.Vars.Len() > 0 {
147		fmt.Fprint(w, "var (\n")
148		w.Write(f.Vars.Bytes())
149		fmt.Fprint(w, ")\n\n")
150	}
151
152	if f.Consts.Len() > 0 {
153		fmt.Fprint(w, "const (\n")
154		w.Write(f.Consts.Bytes())
155		fmt.Fprint(w, ")\n\n")
156	}
157
158	w.Write(f.Body.Bytes())
159}
160
161func importsWrapper(src []byte) ([]byte, error) {
162	return imports.Process("", src, nil)
163}
164
165func NewGolangFile() *DefaultFileType {
166	return &DefaultFileType{
167		Format:   importsWrapper,
168		Assemble: assembleGolangFile,
169	}
170}
171
172// format should be one line only, and not end with \n.
173func addIndentHeaderComment(b *bytes.Buffer, format string, args ...interface{}) {
174	if b.Len() > 0 {
175		fmt.Fprintf(b, "\n// "+format+"\n", args...)
176	} else {
177		fmt.Fprintf(b, "// "+format+"\n", args...)
178	}
179}
180
181func (c *Context) filteredBy(f func(*Context, *types.Type) bool) *Context {
182	c2 := *c
183	c2.Order = []*types.Type{}
184	for _, t := range c.Order {
185		if f(c, t) {
186			c2.Order = append(c2.Order, t)
187		}
188	}
189	return &c2
190}
191
192// make a new context; inheret c.Namers, but add on 'namers'. In case of a name
193// collision, the namer in 'namers' wins.
194func (c *Context) addNameSystems(namers namer.NameSystems) *Context {
195	if namers == nil {
196		return c
197	}
198	c2 := *c
199	// Copy the existing name systems so we don't corrupt a parent context
200	c2.Namers = namer.NameSystems{}
201	for k, v := range c.Namers {
202		c2.Namers[k] = v
203	}
204
205	for name, namer := range namers {
206		c2.Namers[name] = namer
207	}
208	return &c2
209}
210
211// ExecutePackage executes a single package. 'outDir' is the base directory in
212// which to place the package; it should be a physical path on disk, not an
213// import path. e.g.: '/path/to/home/path/to/gopath/src/' The package knows its
214// import path already, this will be appended to 'outDir'.
215func (c *Context) ExecutePackage(outDir string, p Package) error {
216	path := filepath.Join(outDir, p.Path())
217	klog.V(2).Infof("Processing package %q, disk location %q", p.Name(), path)
218	// Filter out any types the *package* doesn't care about.
219	packageContext := c.filteredBy(p.Filter)
220	os.MkdirAll(path, 0755)
221	files := map[string]*File{}
222	for _, g := range p.Generators(packageContext) {
223		// Filter out types the *generator* doesn't care about.
224		genContext := packageContext.filteredBy(g.Filter)
225		// Now add any extra name systems defined by this generator
226		genContext = genContext.addNameSystems(g.Namers(genContext))
227
228		fileType := g.FileType()
229		if len(fileType) == 0 {
230			return fmt.Errorf("generator %q must specify a file type", g.Name())
231		}
232		f := files[g.Filename()]
233		if f == nil {
234			// This is the first generator to reference this file, so start it.
235			f = &File{
236				Name:        g.Filename(),
237				FileType:    fileType,
238				PackageName: p.Name(),
239				Header:      p.Header(g.Filename()),
240				Imports:     map[string]struct{}{},
241			}
242			files[f.Name] = f
243		} else {
244			if f.FileType != g.FileType() {
245				return fmt.Errorf("file %q already has type %q, but generator %q wants to use type %q", f.Name, f.FileType, g.Name(), g.FileType())
246			}
247		}
248
249		if vars := g.PackageVars(genContext); len(vars) > 0 {
250			addIndentHeaderComment(&f.Vars, "Package-wide variables from generator %q.", g.Name())
251			for _, v := range vars {
252				if _, err := fmt.Fprintf(&f.Vars, "%s\n", v); err != nil {
253					return err
254				}
255			}
256		}
257		if consts := g.PackageConsts(genContext); len(consts) > 0 {
258			addIndentHeaderComment(&f.Consts, "Package-wide consts from generator %q.", g.Name())
259			for _, v := range consts {
260				if _, err := fmt.Fprintf(&f.Consts, "%s\n", v); err != nil {
261					return err
262				}
263			}
264		}
265		if err := genContext.executeBody(&f.Body, g); err != nil {
266			return err
267		}
268		if imports := g.Imports(genContext); len(imports) > 0 {
269			for _, i := range imports {
270				f.Imports[i] = struct{}{}
271			}
272		}
273	}
274
275	var errors []error
276	for _, f := range files {
277		finalPath := filepath.Join(path, f.Name)
278		assembler, ok := c.FileTypes[f.FileType]
279		if !ok {
280			return fmt.Errorf("the file type %q registered for file %q does not exist in the context", f.FileType, f.Name)
281		}
282		var err error
283		if c.Verify {
284			err = assembler.VerifyFile(f, finalPath)
285		} else {
286			err = assembler.AssembleFile(f, finalPath)
287		}
288		if err != nil {
289			errors = append(errors, err)
290		}
291	}
292	if len(errors) > 0 {
293		return fmt.Errorf("errors in package %q:\n%v\n", p.Path(), strings.Join(errs2strings(errors), "\n"))
294	}
295	return nil
296}
297
298func (c *Context) executeBody(w io.Writer, generator Generator) error {
299	et := NewErrorTracker(w)
300	if err := generator.Init(c, et); err != nil {
301		return err
302	}
303	for _, t := range c.Order {
304		if err := generator.GenerateType(c, t, et); err != nil {
305			return err
306		}
307	}
308	if err := generator.Finalize(c, et); err != nil {
309		return err
310	}
311	return et.Error()
312}
313