1package printer
2
3import (
4	"bytes"
5	"fmt"
6	"github.com/tinylib/msgp/gen"
7	"github.com/tinylib/msgp/parse"
8	"github.com/ttacon/chalk"
9	"golang.org/x/tools/imports"
10	"io"
11	"io/ioutil"
12	"strings"
13)
14
15func infof(s string, v ...interface{}) {
16	fmt.Printf(chalk.Magenta.Color(s), v...)
17}
18
19// PrintFile prints the methods for the provided list
20// of elements to the given file name and canonical
21// package path.
22func PrintFile(file string, f *parse.FileSet, mode gen.Method) error {
23	out, tests, err := generate(f, mode)
24	if err != nil {
25		return err
26	}
27
28	// we'll run goimports on the main file
29	// in another goroutine, and run it here
30	// for the test file. empirically, this
31	// takes about the same amount of time as
32	// doing them in serial when GOMAXPROCS=1,
33	// and faster otherwise.
34	res := goformat(file, out.Bytes())
35	if tests != nil {
36		testfile := strings.TrimSuffix(file, ".go") + "_test.go"
37		err = format(testfile, tests.Bytes())
38		if err != nil {
39			return err
40		}
41		infof(">>> Wrote and formatted \"%s\"\n", testfile)
42	}
43	err = <-res
44	if err != nil {
45		return err
46	}
47	return nil
48}
49
50func format(file string, data []byte) error {
51	out, err := imports.Process(file, data, nil)
52	if err != nil {
53		return err
54	}
55	return ioutil.WriteFile(file, out, 0600)
56}
57
58func goformat(file string, data []byte) <-chan error {
59	out := make(chan error, 1)
60	go func(file string, data []byte, end chan error) {
61		end <- format(file, data)
62		infof(">>> Wrote and formatted \"%s\"\n", file)
63	}(file, data, out)
64	return out
65}
66
67func dedupImports(imp []string) []string {
68	m := make(map[string]struct{})
69	for i := range imp {
70		m[imp[i]] = struct{}{}
71	}
72	r := []string{}
73	for k := range m {
74		r = append(r, k)
75	}
76	return r
77}
78
79func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, error) {
80	outbuf := bytes.NewBuffer(make([]byte, 0, 4096))
81	writePkgHeader(outbuf, f.Package)
82
83	myImports := []string{"github.com/tinylib/msgp/msgp"}
84	for _, imp := range f.Imports {
85		if imp.Name != nil {
86			// have an alias, include it.
87			myImports = append(myImports, imp.Name.Name+` `+imp.Path.Value)
88		} else {
89			myImports = append(myImports, imp.Path.Value)
90		}
91	}
92	dedup := dedupImports(myImports)
93	writeImportHeader(outbuf, dedup...)
94
95	var testbuf *bytes.Buffer
96	var testwr io.Writer
97	if mode&gen.Test == gen.Test {
98		testbuf = bytes.NewBuffer(make([]byte, 0, 4096))
99		writePkgHeader(testbuf, f.Package)
100		if mode&(gen.Encode|gen.Decode) != 0 {
101			writeImportHeader(testbuf, "bytes", "github.com/tinylib/msgp/msgp", "testing")
102		} else {
103			writeImportHeader(testbuf, "github.com/tinylib/msgp/msgp", "testing")
104		}
105		testwr = testbuf
106	}
107	return outbuf, testbuf, f.PrintTo(gen.NewPrinter(mode, outbuf, testwr))
108}
109
110func writePkgHeader(b *bytes.Buffer, name string) {
111	b.WriteString("package ")
112	b.WriteString(name)
113	b.WriteByte('\n')
114	b.WriteString("// NOTE: THIS FILE WAS PRODUCED BY THE\n// MSGP CODE GENERATION TOOL (github.com/tinylib/msgp)\n// DO NOT EDIT\n\n")
115}
116
117func writeImportHeader(b *bytes.Buffer, imports ...string) {
118	b.WriteString("import (\n")
119	for _, im := range imports {
120		if im[len(im)-1] == '"' {
121			// support aliased imports
122			fmt.Fprintf(b, "\t%s\n", im)
123		} else {
124			fmt.Fprintf(b, "\t%q\n", im)
125		}
126	}
127	b.WriteString(")\n\n")
128}
129