1package pgs
2
3import (
4	"os"
5	"path/filepath"
6	"strings"
7
8	"github.com/golang/protobuf/proto"
9	plugin_go "github.com/golang/protobuf/protoc-gen-go/plugin"
10	"github.com/spf13/afero"
11)
12
13type persister interface {
14	SetDebugger(d Debugger)
15	SetFS(fs afero.Fs)
16	AddPostProcessor(proc ...PostProcessor)
17	Persist(a ...Artifact) *plugin_go.CodeGeneratorResponse
18}
19
20type stdPersister struct {
21	Debugger
22
23	fs    afero.Fs
24	procs []PostProcessor
25}
26
27func newPersister() *stdPersister { return &stdPersister{fs: afero.NewOsFs()} }
28
29func (p *stdPersister) SetDebugger(d Debugger)                 { p.Debugger = d }
30func (p *stdPersister) SetFS(fs afero.Fs)                      { p.fs = fs }
31func (p *stdPersister) AddPostProcessor(proc ...PostProcessor) { p.procs = append(p.procs, proc...) }
32
33func (p *stdPersister) Persist(arts ...Artifact) *plugin_go.CodeGeneratorResponse {
34	resp := new(plugin_go.CodeGeneratorResponse)
35
36	for _, a := range arts {
37		switch a := a.(type) {
38		case GeneratorFile:
39			f, err := a.ProtoFile()
40			p.CheckErr(err, "unable to convert ", a.Name, " to proto")
41			f.Content = proto.String(p.postProcess(a, f.GetContent()))
42			p.insertFile(resp, f, a.Overwrite)
43		case GeneratorTemplateFile:
44			f, err := a.ProtoFile()
45			p.CheckErr(err, "unable to convert ", a.Name, " to proto")
46			f.Content = proto.String(p.postProcess(a, f.GetContent()))
47			p.insertFile(resp, f, a.Overwrite)
48		case GeneratorAppend:
49			f, err := a.ProtoFile()
50			p.CheckErr(err, "unable to convert append for ", a.FileName, " to proto")
51			f.Content = proto.String(p.postProcess(a, f.GetContent()))
52			n, _ := cleanGeneratorFileName(a.FileName)
53			p.insertAppend(resp, n, f)
54		case GeneratorTemplateAppend:
55			f, err := a.ProtoFile()
56			p.CheckErr(err, "unable to convert append for ", a.FileName, " to proto")
57			f.Content = proto.String(p.postProcess(a, f.GetContent()))
58			n, _ := cleanGeneratorFileName(a.FileName)
59			p.insertAppend(resp, n, f)
60		case GeneratorInjection:
61			f, err := a.ProtoFile()
62			p.CheckErr(err, "unable to convert injection ", a.InsertionPoint, " for ", a.FileName, " to proto")
63			f.Content = proto.String(p.postProcess(a, f.GetContent()))
64			p.insertFile(resp, f, false)
65		case GeneratorTemplateInjection:
66			f, err := a.ProtoFile()
67			p.CheckErr(err, "unable to convert injection ", a.InsertionPoint, " for ", a.FileName, " to proto")
68			f.Content = proto.String(p.postProcess(a, f.GetContent()))
69			p.insertFile(resp, f, false)
70		case CustomFile:
71			p.writeFile(
72				a.Name,
73				[]byte(p.postProcess(a, a.Contents)),
74				a.Overwrite,
75				a.Perms,
76			)
77		case CustomTemplateFile:
78			content, err := a.render()
79			p.CheckErr(err, "unable to render CustomTemplateFile: ", a.Name)
80			content = p.postProcess(a, content)
81			p.writeFile(
82				a.Name,
83				[]byte(content),
84				a.Overwrite,
85				a.Perms,
86			)
87		case GeneratorError:
88			if resp.Error == nil {
89				resp.Error = proto.String(a.Message)
90				continue
91			}
92			resp.Error = proto.String(strings.Join([]string{resp.GetError(), a.Message}, "; "))
93		default:
94			p.Failf("unrecognized artifact type: %T", a)
95		}
96	}
97
98	return resp
99}
100
101func (p *stdPersister) indexOfFile(resp *plugin_go.CodeGeneratorResponse, name string) int {
102	for i, f := range resp.GetFile() {
103		if f.GetName() == name && f.InsertionPoint == nil {
104			return i
105		}
106	}
107
108	return -1
109}
110
111func (p *stdPersister) insertFile(resp *plugin_go.CodeGeneratorResponse,
112	f *plugin_go.CodeGeneratorResponse_File, overwrite bool) {
113	if overwrite {
114		if i := p.indexOfFile(resp, f.GetName()); i >= 0 {
115			resp.File[i] = f
116			return
117		}
118	}
119
120	resp.File = append(resp.File, f)
121}
122
123func (p *stdPersister) insertAppend(resp *plugin_go.CodeGeneratorResponse,
124	name string, f *plugin_go.CodeGeneratorResponse_File) {
125	i := p.indexOfFile(resp, name)
126	p.Assert(i > -1, "append target ", name, " missing")
127
128	resp.File = append(
129		resp.File[:i+1],
130		append(
131			[]*plugin_go.CodeGeneratorResponse_File{f},
132			resp.File[i+1:]...,
133		)...,
134	)
135}
136
137func (p *stdPersister) writeFile(name string, content []byte, overwrite bool, perms os.FileMode) {
138	dir := filepath.Dir(name)
139	p.CheckErr(
140		p.fs.MkdirAll(dir, 0755),
141		"unable to create directory:", dir)
142
143	exists, err := afero.Exists(p.fs, name)
144	p.CheckErr(err, "unable to check file exists:", name)
145
146	if exists {
147		if !overwrite {
148			p.Debug("file", name, "exists, skipping")
149			return
150		}
151		p.Debug("file", name, "exists, overwriting")
152	}
153
154	p.CheckErr(
155		afero.WriteFile(p.fs, name, content, perms),
156		"unable to write file:", name)
157}
158
159func (p *stdPersister) postProcess(a Artifact, in string) string {
160	var err error
161	b := []byte(in)
162	for _, pp := range p.procs {
163		if pp.Match(a) {
164			b, err = pp.Process(b)
165			p.CheckErr(err, "failed post-processing")
166		}
167	}
168
169	return string(b)
170}
171