1// Copyright 2019 The Hugo Authors. All rights reserved.
2// Some functions in this file (see comments) is based on the Go source code,
3// copyright The Go Authors and  governed by a BSD-style license.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Package codegen contains helpers for code generation.
17package codegen
18
19import (
20	"fmt"
21	"go/ast"
22	"go/parser"
23	"go/token"
24	"os"
25	"path"
26	"path/filepath"
27	"reflect"
28	"regexp"
29	"sort"
30	"strings"
31	"sync"
32)
33
34// Make room for insertions
35const weightWidth = 1000
36
37// NewInspector creates a new Inspector given a source root.
38func NewInspector(root string) *Inspector {
39	return &Inspector{ProjectRootDir: root}
40}
41
42// Inspector provides methods to help code generation. It uses a combination
43// of reflection and source code AST to do the heavy lifting.
44type Inspector struct {
45	ProjectRootDir string
46
47	init sync.Once
48
49	// Determines method order. Go's reflect sorts lexicographically, so
50	// we must parse the source to preserve this order.
51	methodWeight map[string]map[string]int
52}
53
54// MethodsFromTypes create a method set from the include slice, excluding any
55// method in exclude.
56func (c *Inspector) MethodsFromTypes(include []reflect.Type, exclude []reflect.Type) Methods {
57	c.parseSource()
58
59	var methods Methods
60
61	excludes := make(map[string]bool)
62
63	if len(exclude) > 0 {
64		for _, m := range c.MethodsFromTypes(exclude, nil) {
65			excludes[m.Name] = true
66		}
67	}
68
69	// There may be overlapping interfaces in types. Do a simple check for now.
70	seen := make(map[string]bool)
71
72	nameAndPackage := func(t reflect.Type) (string, string) {
73		var name, pkg string
74
75		isPointer := t.Kind() == reflect.Ptr
76
77		if isPointer {
78			t = t.Elem()
79		}
80
81		pkgPrefix := ""
82		if pkgPath := t.PkgPath(); pkgPath != "" {
83			pkgPath = strings.TrimSuffix(pkgPath, "/")
84			_, shortPath := path.Split(pkgPath)
85			pkgPrefix = shortPath + "."
86			pkg = pkgPath
87		}
88
89		name = t.Name()
90		if name == "" {
91			// interface{}
92			name = t.String()
93		}
94
95		if isPointer {
96			pkgPrefix = "*" + pkgPrefix
97		}
98
99		name = pkgPrefix + name
100
101		return name, pkg
102	}
103
104	for _, t := range include {
105		for i := 0; i < t.NumMethod(); i++ {
106
107			m := t.Method(i)
108			if excludes[m.Name] || seen[m.Name] {
109				continue
110			}
111
112			seen[m.Name] = true
113
114			if m.PkgPath != "" {
115				// Not exported
116				continue
117			}
118
119			numIn := m.Type.NumIn()
120
121			ownerName, _ := nameAndPackage(t)
122
123			method := Method{Owner: t, OwnerName: ownerName, Name: m.Name}
124
125			for i := 0; i < numIn; i++ {
126				in := m.Type.In(i)
127
128				name, pkg := nameAndPackage(in)
129
130				if pkg != "" {
131					method.Imports = append(method.Imports, pkg)
132				}
133
134				method.In = append(method.In, name)
135			}
136
137			numOut := m.Type.NumOut()
138
139			if numOut > 0 {
140				for i := 0; i < numOut; i++ {
141					out := m.Type.Out(i)
142					name, pkg := nameAndPackage(out)
143
144					if pkg != "" {
145						method.Imports = append(method.Imports, pkg)
146					}
147
148					method.Out = append(method.Out, name)
149				}
150			}
151
152			methods = append(methods, method)
153		}
154	}
155
156	sort.SliceStable(methods, func(i, j int) bool {
157		mi, mj := methods[i], methods[j]
158
159		wi := c.methodWeight[mi.OwnerName][mi.Name]
160		wj := c.methodWeight[mj.OwnerName][mj.Name]
161
162		if wi == wj {
163			return mi.Name < mj.Name
164		}
165
166		return wi < wj
167	})
168
169	return methods
170}
171
172func (c *Inspector) parseSource() {
173	c.init.Do(func() {
174		if !strings.Contains(c.ProjectRootDir, "hugo") {
175			panic("dir must be set to the Hugo root")
176		}
177
178		c.methodWeight = make(map[string]map[string]int)
179		dirExcludes := regexp.MustCompile("docs|examples")
180		fileExcludes := regexp.MustCompile("autogen")
181		var filenames []string
182
183		filepath.Walk(c.ProjectRootDir, func(path string, info os.FileInfo, err error) error {
184			if info.IsDir() {
185				if dirExcludes.MatchString(info.Name()) {
186					return filepath.SkipDir
187				}
188			}
189
190			if !strings.HasSuffix(path, ".go") || fileExcludes.MatchString(path) {
191				return nil
192			}
193
194			filenames = append(filenames, path)
195
196			return nil
197		})
198
199		for _, filename := range filenames {
200
201			pkg := c.packageFromPath(filename)
202
203			fset := token.NewFileSet()
204			node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
205			if err != nil {
206				panic(err)
207			}
208
209			ast.Inspect(node, func(n ast.Node) bool {
210				switch t := n.(type) {
211				case *ast.TypeSpec:
212					if t.Name.IsExported() {
213						switch it := t.Type.(type) {
214						case *ast.InterfaceType:
215							iface := pkg + "." + t.Name.Name
216							methodNames := collectMethodsRecursive(pkg, it.Methods.List)
217							weights := make(map[string]int)
218							weight := weightWidth
219							for _, name := range methodNames {
220								weights[name] = weight
221								weight += weightWidth
222							}
223							c.methodWeight[iface] = weights
224						}
225					}
226				}
227				return true
228			})
229
230		}
231
232		// Complement
233		for _, v1 := range c.methodWeight {
234			for k2, w := range v1 {
235				if v, found := c.methodWeight[k2]; found {
236					for k3, v3 := range v {
237						v1[k3] = (v3 / weightWidth) + w
238					}
239				}
240			}
241		}
242	})
243}
244
245func (c *Inspector) packageFromPath(p string) string {
246	p = filepath.ToSlash(p)
247	base := path.Base(p)
248	if !strings.Contains(base, ".") {
249		return base
250	}
251	return path.Base(strings.TrimSuffix(p, base))
252}
253
254// Method holds enough information about it to recreate it.
255type Method struct {
256	// The interface we extracted this method from.
257	Owner reflect.Type
258
259	// String version of the above, on the form PACKAGE.NAME, e.g.
260	// page.Page
261	OwnerName string
262
263	// Method name.
264	Name string
265
266	// Imports needed to satisfy the method signature.
267	Imports []string
268
269	// Argument types, including any package prefix, e.g. string, int, interface{},
270	// net.Url
271	In []string
272
273	// Return types.
274	Out []string
275}
276
277// Declaration creates a method declaration (without any body) for the given receiver.
278func (m Method) Declaration(receiver string) string {
279	return fmt.Sprintf("func (%s %s) %s%s %s", receiverShort(receiver), receiver, m.Name, m.inStr(), m.outStr())
280}
281
282// DeclarationNamed creates a method declaration (without any body) for the given receiver
283// with named return values.
284func (m Method) DeclarationNamed(receiver string) string {
285	return fmt.Sprintf("func (%s %s) %s%s %s", receiverShort(receiver), receiver, m.Name, m.inStr(), m.outStrNamed())
286}
287
288// Delegate creates a delegate call string.
289func (m Method) Delegate(receiver, delegate string) string {
290	ret := ""
291	if len(m.Out) > 0 {
292		ret = "return "
293	}
294	return fmt.Sprintf("%s%s.%s.%s%s", ret, receiverShort(receiver), delegate, m.Name, m.inOutStr())
295}
296
297func (m Method) String() string {
298	return m.Name + m.inStr() + " " + m.outStr() + "\n"
299}
300
301func (m Method) inOutStr() string {
302	if len(m.In) == 0 {
303		return "()"
304	}
305
306	args := make([]string, len(m.In))
307	for i := 0; i < len(args); i++ {
308		args[i] = fmt.Sprintf("arg%d", i)
309	}
310	return "(" + strings.Join(args, ", ") + ")"
311}
312
313func (m Method) inStr() string {
314	if len(m.In) == 0 {
315		return "()"
316	}
317
318	args := make([]string, len(m.In))
319	for i := 0; i < len(args); i++ {
320		args[i] = fmt.Sprintf("arg%d %s", i, m.In[i])
321	}
322	return "(" + strings.Join(args, ", ") + ")"
323}
324
325func (m Method) outStr() string {
326	if len(m.Out) == 0 {
327		return ""
328	}
329	if len(m.Out) == 1 {
330		return m.Out[0]
331	}
332
333	return "(" + strings.Join(m.Out, ", ") + ")"
334}
335
336func (m Method) outStrNamed() string {
337	if len(m.Out) == 0 {
338		return ""
339	}
340
341	outs := make([]string, len(m.Out))
342	for i := 0; i < len(outs); i++ {
343		outs[i] = fmt.Sprintf("o%d %s", i, m.Out[i])
344	}
345
346	return "(" + strings.Join(outs, ", ") + ")"
347}
348
349// Methods represents a list of methods for one or more interfaces.
350// The order matches the defined order in their source file(s).
351type Methods []Method
352
353// Imports returns a sorted list of package imports needed to satisfy the
354// signatures of all methods.
355func (m Methods) Imports() []string {
356	var pkgImports []string
357	for _, method := range m {
358		pkgImports = append(pkgImports, method.Imports...)
359	}
360	if len(pkgImports) > 0 {
361		pkgImports = uniqueNonEmptyStrings(pkgImports)
362		sort.Strings(pkgImports)
363	}
364	return pkgImports
365}
366
367// ToMarshalJSON creates a MarshalJSON method for these methods. Any method name
368// matching any of the regexps in excludes will be ignored.
369func (m Methods) ToMarshalJSON(receiver, pkgPath string, excludes ...string) (string, []string) {
370	var sb strings.Builder
371
372	r := receiverShort(receiver)
373	what := firstToUpper(trimAsterisk(receiver))
374	pgkName := path.Base(pkgPath)
375
376	fmt.Fprintf(&sb, "func Marshal%sToJSON(%s %s) ([]byte, error) {\n", what, r, receiver)
377
378	var methods Methods
379	excludeRes := make([]*regexp.Regexp, len(excludes))
380
381	for i, exclude := range excludes {
382		excludeRes[i] = regexp.MustCompile(exclude)
383	}
384
385	for _, method := range m {
386		// Exclude methods with arguments and incompatible return values
387		if len(method.In) > 0 || len(method.Out) == 0 || len(method.Out) > 2 {
388			continue
389		}
390
391		if len(method.Out) == 2 {
392			if method.Out[1] != "error" {
393				continue
394			}
395		}
396
397		for _, re := range excludeRes {
398			if re.MatchString(method.Name) {
399				continue
400			}
401		}
402
403		methods = append(methods, method)
404	}
405
406	for _, method := range methods {
407		varn := varName(method.Name)
408		if len(method.Out) == 1 {
409			fmt.Fprintf(&sb, "\t%s := %s.%s()\n", varn, r, method.Name)
410		} else {
411			fmt.Fprintf(&sb, "\t%s, err := %s.%s()\n", varn, r, method.Name)
412			fmt.Fprint(&sb, "\tif err != nil {\n\t\treturn nil, err\n\t}\n")
413		}
414	}
415
416	fmt.Fprint(&sb, "\n\ts := struct {\n")
417
418	for _, method := range methods {
419		fmt.Fprintf(&sb, "\t\t%s %s\n", method.Name, typeName(method.Out[0], pgkName))
420	}
421
422	fmt.Fprint(&sb, "\n\t}{\n")
423
424	for _, method := range methods {
425		varn := varName(method.Name)
426		fmt.Fprintf(&sb, "\t\t%s: %s,\n", method.Name, varn)
427	}
428
429	fmt.Fprint(&sb, "\n\t}\n\n")
430	fmt.Fprint(&sb, "\treturn json.Marshal(&s)\n}")
431
432	pkgImports := append(methods.Imports(), "encoding/json")
433
434	if pkgPath != "" {
435		// Exclude self
436		for i, pkgImp := range pkgImports {
437			if pkgImp == pkgPath {
438				pkgImports = append(pkgImports[:i], pkgImports[i+1:]...)
439			}
440		}
441	}
442
443	return sb.String(), pkgImports
444}
445
446func collectMethodsRecursive(pkg string, f []*ast.Field) []string {
447	var methodNames []string
448	for _, m := range f {
449		if m.Names != nil {
450			methodNames = append(methodNames, m.Names[0].Name)
451			continue
452		}
453
454		if ident, ok := m.Type.(*ast.Ident); ok && ident.Obj != nil {
455			// Embedded interface
456			methodNames = append(
457				methodNames,
458				collectMethodsRecursive(
459					pkg,
460					ident.Obj.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List)...)
461		} else {
462			// Embedded, but in a different file/package. Return the
463			// package.Name and deal with that later.
464			name := packageName(m.Type)
465			if !strings.Contains(name, ".") {
466				// Assume current package
467				name = pkg + "." + name
468			}
469			methodNames = append(methodNames, name)
470		}
471	}
472
473	return methodNames
474}
475
476func firstToLower(name string) string {
477	return strings.ToLower(name[:1]) + name[1:]
478}
479
480func firstToUpper(name string) string {
481	return strings.ToUpper(name[:1]) + name[1:]
482}
483
484func packageName(e ast.Expr) string {
485	switch tp := e.(type) {
486	case *ast.Ident:
487		return tp.Name
488	case *ast.SelectorExpr:
489		return fmt.Sprintf("%s.%s", packageName(tp.X), packageName(tp.Sel))
490	}
491	return ""
492}
493
494func receiverShort(receiver string) string {
495	return strings.ToLower(trimAsterisk(receiver))[:1]
496}
497
498func trimAsterisk(name string) string {
499	return strings.TrimPrefix(name, "*")
500}
501
502func typeName(name, pkg string) string {
503	return strings.TrimPrefix(name, pkg+".")
504}
505
506func uniqueNonEmptyStrings(s []string) []string {
507	var unique []string
508	set := map[string]interface{}{}
509	for _, val := range s {
510		if val == "" {
511			continue
512		}
513		if _, ok := set[val]; !ok {
514			unique = append(unique, val)
515			set[val] = val
516		}
517	}
518	return unique
519}
520
521func varName(name string) string {
522	name = firstToLower(name)
523
524	// Adjust some reserved keywords, see https://golang.org/ref/spec#Keywords
525	switch name {
526	case "type":
527		name = "typ"
528	case "package":
529		name = "pkg"
530		// Not reserved, but syntax highlighters has it as a keyword.
531	case "len":
532		name = "length"
533	}
534
535	return name
536}
537