1// +build ignore
2
3//go:generate go run private-gen.go
4//go:generate gofmt -w ./private
5
6package main
7
8import (
9	"bytes"
10	"fmt"
11	"go/ast"
12	"go/parser"
13	"go/printer"
14	"go/token"
15	"io"
16	"io/ioutil"
17	"log"
18	"os"
19	"reflect"
20	"strings"
21	"unicode"
22	"unicode/utf8"
23)
24
25var inFiles = []string{"cpuid.go", "cpuid_test.go", "detect_arm64.go", "detect_ref.go", "detect_intel.go"}
26var copyFiles = []string{"cpuid_amd64.s", "cpuid_386.s", "cpuid_arm64.s"}
27var fileSet = token.NewFileSet()
28var reWrites = []rewrite{
29	initRewrite("CPUInfo -> cpuInfo"),
30	initRewrite("Vendor -> vendor"),
31	initRewrite("Flags -> flags"),
32	initRewrite("Detect -> detect"),
33	initRewrite("CPU -> cpu"),
34}
35var excludeNames = map[string]bool{"string": true, "join": true, "trim": true,
36	// cpuid_test.go
37	"t": true, "println": true, "logf": true, "log": true, "fatalf": true, "fatal": true,
38	"maxuint32": true, "lastindex": true,
39}
40
41var excludePrefixes = []string{"test", "benchmark"}
42
43func main() {
44	Package := "private"
45	parserMode := parser.ParseComments
46	exported := make(map[string]rewrite)
47	for _, file := range inFiles {
48		in, err := os.Open(file)
49		if err != nil {
50			log.Fatalf("opening input", err)
51		}
52
53		src, err := ioutil.ReadAll(in)
54		if err != nil {
55			log.Fatalf("reading input", err)
56		}
57
58		astfile, err := parser.ParseFile(fileSet, file, src, parserMode)
59		if err != nil {
60			log.Fatalf("parsing input", err)
61		}
62
63		for _, rw := range reWrites {
64			astfile = rw(astfile)
65		}
66
67		// Inspect the AST and print all identifiers and literals.
68		var startDecl token.Pos
69		var endDecl token.Pos
70		ast.Inspect(astfile, func(n ast.Node) bool {
71			var s string
72			switch x := n.(type) {
73			case *ast.Ident:
74				if x.IsExported() {
75					t := strings.ToLower(x.Name)
76					for _, pre := range excludePrefixes {
77						if strings.HasPrefix(t, pre) {
78							return true
79						}
80					}
81					if excludeNames[t] != true {
82						//if x.Pos() > startDecl && x.Pos() < endDecl {
83						exported[x.Name] = initRewrite(x.Name + " -> " + t)
84					}
85				}
86
87			case *ast.GenDecl:
88				if x.Tok == token.CONST && x.Lparen > 0 {
89					startDecl = x.Lparen
90					endDecl = x.Rparen
91					// fmt.Printf("Decl:%s -> %s\n", fileSet.Position(startDecl), fileSet.Position(endDecl))
92				}
93			}
94			if s != "" {
95				fmt.Printf("%s:\t%s\n", fileSet.Position(n.Pos()), s)
96			}
97			return true
98		})
99
100		for _, rw := range exported {
101			astfile = rw(astfile)
102		}
103
104		var buf bytes.Buffer
105
106		printer.Fprint(&buf, fileSet, astfile)
107
108		// Remove package documentation and insert information
109		s := buf.String()
110		ind := strings.Index(buf.String(), "\npackage cpuid")
111		if i := strings.Index(buf.String(), "\n//+build "); i > 0 {
112			ind = i
113		}
114		s = s[ind:]
115		s = "// Generated, DO NOT EDIT,\n" +
116			"// but copy it to your own project and rename the package.\n" +
117			"// See more at http://github.com/klauspost/cpuid\n" +
118			s
119		if !strings.HasPrefix(file, "cpuid") {
120			file = "cpuid_" + file
121		}
122		outputName := Package + string(os.PathSeparator) + file
123
124		err = ioutil.WriteFile(outputName, []byte(s), 0644)
125		if err != nil {
126			log.Fatalf("writing output: %s", err)
127		}
128		log.Println("Generated", outputName)
129	}
130
131	for _, file := range copyFiles {
132		dst := ""
133		if strings.HasPrefix(file, "cpuid") {
134			dst = Package + string(os.PathSeparator) + file
135		} else {
136			dst = Package + string(os.PathSeparator) + "cpuid_" + file
137		}
138		err := copyFile(file, dst)
139		if err != nil {
140			log.Fatalf("copying file: %s", err)
141		}
142		log.Println("Copied", dst)
143	}
144}
145
146// CopyFile copies a file from src to dst. If src and dst files exist, and are
147// the same, then return success. Copy the file contents from src to dst.
148func copyFile(src, dst string) (err error) {
149	sfi, err := os.Stat(src)
150	if err != nil {
151		return
152	}
153	if !sfi.Mode().IsRegular() {
154		// cannot copy non-regular files (e.g., directories,
155		// symlinks, devices, etc.)
156		return fmt.Errorf("CopyFile: non-regular source file %s (%q)", sfi.Name(), sfi.Mode().String())
157	}
158	dfi, err := os.Stat(dst)
159	if err != nil {
160		if !os.IsNotExist(err) {
161			return
162		}
163	} else {
164		if !(dfi.Mode().IsRegular()) {
165			return fmt.Errorf("CopyFile: non-regular destination file %s (%q)", dfi.Name(), dfi.Mode().String())
166		}
167		if os.SameFile(sfi, dfi) {
168			return
169		}
170	}
171	err = copyFileContents(src, dst)
172	return
173}
174
175// copyFileContents copies the contents of the file named src to the file named
176// by dst. The file will be created if it does not already exist. If the
177// destination file exists, all it's contents will be replaced by the contents
178// of the source file.
179func copyFileContents(src, dst string) (err error) {
180	in, err := os.Open(src)
181	if err != nil {
182		return
183	}
184	defer in.Close()
185	out, err := os.Create(dst)
186	if err != nil {
187		return
188	}
189	defer func() {
190		cerr := out.Close()
191		if err == nil {
192			err = cerr
193		}
194	}()
195	if _, err = io.Copy(out, in); err != nil {
196		return
197	}
198	err = out.Sync()
199	return
200}
201
202type rewrite func(*ast.File) *ast.File
203
204// Mostly copied from gofmt
205func initRewrite(rewriteRule string) rewrite {
206	f := strings.Split(rewriteRule, "->")
207	if len(f) != 2 {
208		fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
209		os.Exit(2)
210	}
211	pattern := parseExpr(f[0], "pattern")
212	replace := parseExpr(f[1], "replacement")
213	return func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
214}
215
216// parseExpr parses s as an expression.
217// It might make sense to expand this to allow statement patterns,
218// but there are problems with preserving formatting and also
219// with what a wildcard for a statement looks like.
220func parseExpr(s, what string) ast.Expr {
221	x, err := parser.ParseExpr(s)
222	if err != nil {
223		fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
224		os.Exit(2)
225	}
226	return x
227}
228
229// Keep this function for debugging.
230/*
231func dump(msg string, val reflect.Value) {
232	fmt.Printf("%s:\n", msg)
233	ast.Print(fileSet, val.Interface())
234	fmt.Println()
235}
236*/
237
238// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
239func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
240	cmap := ast.NewCommentMap(fileSet, p, p.Comments)
241	m := make(map[string]reflect.Value)
242	pat := reflect.ValueOf(pattern)
243	repl := reflect.ValueOf(replace)
244
245	var rewriteVal func(val reflect.Value) reflect.Value
246	rewriteVal = func(val reflect.Value) reflect.Value {
247		// don't bother if val is invalid to start with
248		if !val.IsValid() {
249			return reflect.Value{}
250		}
251		for k := range m {
252			delete(m, k)
253		}
254		val = apply(rewriteVal, val)
255		if match(m, pat, val) {
256			val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
257		}
258		return val
259	}
260
261	r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
262	r.Comments = cmap.Filter(r).Comments() // recreate comments list
263	return r
264}
265
266// set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
267func set(x, y reflect.Value) {
268	// don't bother if x cannot be set or y is invalid
269	if !x.CanSet() || !y.IsValid() {
270		return
271	}
272	defer func() {
273		if x := recover(); x != nil {
274			if s, ok := x.(string); ok &&
275				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
276				// x cannot be set to y - ignore this rewrite
277				return
278			}
279			panic(x)
280		}
281	}()
282	x.Set(y)
283}
284
285// Values/types for special cases.
286var (
287	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
288	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
289
290	identType     = reflect.TypeOf((*ast.Ident)(nil))
291	objectPtrType = reflect.TypeOf((*ast.Object)(nil))
292	positionType  = reflect.TypeOf(token.NoPos)
293	callExprType  = reflect.TypeOf((*ast.CallExpr)(nil))
294	scopePtrType  = reflect.TypeOf((*ast.Scope)(nil))
295)
296
297// apply replaces each AST field x in val with f(x), returning val.
298// To avoid extra conversions, f operates on the reflect.Value form.
299func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
300	if !val.IsValid() {
301		return reflect.Value{}
302	}
303
304	// *ast.Objects introduce cycles and are likely incorrect after
305	// rewrite; don't follow them but replace with nil instead
306	if val.Type() == objectPtrType {
307		return objectPtrNil
308	}
309
310	// similarly for scopes: they are likely incorrect after a rewrite;
311	// replace them with nil
312	if val.Type() == scopePtrType {
313		return scopePtrNil
314	}
315
316	switch v := reflect.Indirect(val); v.Kind() {
317	case reflect.Slice:
318		for i := 0; i < v.Len(); i++ {
319			e := v.Index(i)
320			set(e, f(e))
321		}
322	case reflect.Struct:
323		for i := 0; i < v.NumField(); i++ {
324			e := v.Field(i)
325			set(e, f(e))
326		}
327	case reflect.Interface:
328		e := v.Elem()
329		set(v, f(e))
330	}
331	return val
332}
333
334func isWildcard(s string) bool {
335	rune, size := utf8.DecodeRuneInString(s)
336	return size == len(s) && unicode.IsLower(rune)
337}
338
339// match returns true if pattern matches val,
340// recording wildcard submatches in m.
341// If m == nil, match checks whether pattern == val.
342func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
343	// Wildcard matches any expression.  If it appears multiple
344	// times in the pattern, it must match the same expression
345	// each time.
346	if m != nil && pattern.IsValid() && pattern.Type() == identType {
347		name := pattern.Interface().(*ast.Ident).Name
348		if isWildcard(name) && val.IsValid() {
349			// wildcards only match valid (non-nil) expressions.
350			if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
351				if old, ok := m[name]; ok {
352					return match(nil, old, val)
353				}
354				m[name] = val
355				return true
356			}
357		}
358	}
359
360	// Otherwise, pattern and val must match recursively.
361	if !pattern.IsValid() || !val.IsValid() {
362		return !pattern.IsValid() && !val.IsValid()
363	}
364	if pattern.Type() != val.Type() {
365		return false
366	}
367
368	// Special cases.
369	switch pattern.Type() {
370	case identType:
371		// For identifiers, only the names need to match
372		// (and none of the other *ast.Object information).
373		// This is a common case, handle it all here instead
374		// of recursing down any further via reflection.
375		p := pattern.Interface().(*ast.Ident)
376		v := val.Interface().(*ast.Ident)
377		return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
378	case objectPtrType, positionType:
379		// object pointers and token positions always match
380		return true
381	case callExprType:
382		// For calls, the Ellipsis fields (token.Position) must
383		// match since that is how f(x) and f(x...) are different.
384		// Check them here but fall through for the remaining fields.
385		p := pattern.Interface().(*ast.CallExpr)
386		v := val.Interface().(*ast.CallExpr)
387		if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
388			return false
389		}
390	}
391
392	p := reflect.Indirect(pattern)
393	v := reflect.Indirect(val)
394	if !p.IsValid() || !v.IsValid() {
395		return !p.IsValid() && !v.IsValid()
396	}
397
398	switch p.Kind() {
399	case reflect.Slice:
400		if p.Len() != v.Len() {
401			return false
402		}
403		for i := 0; i < p.Len(); i++ {
404			if !match(m, p.Index(i), v.Index(i)) {
405				return false
406			}
407		}
408		return true
409
410	case reflect.Struct:
411		for i := 0; i < p.NumField(); i++ {
412			if !match(m, p.Field(i), v.Field(i)) {
413				return false
414			}
415		}
416		return true
417
418	case reflect.Interface:
419		return match(m, p.Elem(), v.Elem())
420	}
421
422	// Handle token integers, etc.
423	return p.Interface() == v.Interface()
424}
425
426// subst returns a copy of pattern with values from m substituted in place
427// of wildcards and pos used as the position of tokens from the pattern.
428// if m == nil, subst returns a copy of pattern and doesn't change the line
429// number information.
430func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
431	if !pattern.IsValid() {
432		return reflect.Value{}
433	}
434
435	// Wildcard gets replaced with map value.
436	if m != nil && pattern.Type() == identType {
437		name := pattern.Interface().(*ast.Ident).Name
438		if isWildcard(name) {
439			if old, ok := m[name]; ok {
440				return subst(nil, old, reflect.Value{})
441			}
442		}
443	}
444
445	if pos.IsValid() && pattern.Type() == positionType {
446		// use new position only if old position was valid in the first place
447		if old := pattern.Interface().(token.Pos); !old.IsValid() {
448			return pattern
449		}
450		return pos
451	}
452
453	// Otherwise copy.
454	switch p := pattern; p.Kind() {
455	case reflect.Slice:
456		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
457		for i := 0; i < p.Len(); i++ {
458			v.Index(i).Set(subst(m, p.Index(i), pos))
459		}
460		return v
461
462	case reflect.Struct:
463		v := reflect.New(p.Type()).Elem()
464		for i := 0; i < p.NumField(); i++ {
465			v.Field(i).Set(subst(m, p.Field(i), pos))
466		}
467		return v
468
469	case reflect.Ptr:
470		v := reflect.New(p.Type()).Elem()
471		if elem := p.Elem(); elem.IsValid() {
472			v.Set(subst(m, elem, pos).Addr())
473		}
474		return v
475
476	case reflect.Interface:
477		v := reflect.New(p.Type()).Elem()
478		if elem := p.Elem(); elem.IsValid() {
479			v.Set(subst(m, elem, pos))
480		}
481		return v
482	}
483
484	return pattern
485}
486