1package goshared
2
3import (
4	"fmt"
5	"reflect"
6	"strings"
7	"text/template"
8
9	"github.com/golang/protobuf/ptypes"
10	"github.com/golang/protobuf/ptypes/duration"
11	"github.com/golang/protobuf/ptypes/timestamp"
12	"github.com/lyft/protoc-gen-star"
13	"github.com/lyft/protoc-gen-star/lang/go"
14	"github.com/envoyproxy/protoc-gen-validate/templates/shared"
15)
16
17func Register(tpl *template.Template, params pgs.Parameters) {
18	fns := goSharedFuncs{pgsgo.InitContext(params)}
19
20	tpl.Funcs(map[string]interface{}{
21		"accessor":      fns.accessor,
22		"byteStr":       fns.byteStr,
23		"cmt":           pgs.C80,
24		"durGt":         fns.durGt,
25		"durLit":        fns.durLit,
26		"durStr":        fns.durStr,
27		"err":           fns.err,
28		"errCause":      fns.errCause,
29		"errIdx":        fns.errIdx,
30		"errIdxCause":   fns.errIdxCause,
31		"errname":       fns.errName,
32		"inKey":         fns.inKey,
33		"inType":        fns.inType,
34		"isBytes":       fns.isBytes,
35		"lit":           fns.lit,
36		"lookup":        fns.lookup,
37		"msgTyp":        fns.msgTyp,
38		"name":          fns.Name,
39		"oneof":         fns.oneofTypeName,
40		"pkg":           fns.PackageName,
41		"tsGt":          fns.tsGt,
42		"tsLit":         fns.tsLit,
43		"tsStr":         fns.tsStr,
44		"typ":           fns.Type,
45		"unwrap":        fns.unwrap,
46		"externalEnums": fns.externalEnums,
47		"enumPackages":  fns.enumPackages,
48	})
49
50	template.Must(tpl.New("msg").Parse(msgTpl))
51	template.Must(tpl.New("const").Parse(constTpl))
52	template.Must(tpl.New("ltgt").Parse(ltgtTpl))
53	template.Must(tpl.New("in").Parse(inTpl))
54
55	template.Must(tpl.New("none").Parse(noneTpl))
56	template.Must(tpl.New("float").Parse(numTpl))
57	template.Must(tpl.New("double").Parse(numTpl))
58	template.Must(tpl.New("int32").Parse(numTpl))
59	template.Must(tpl.New("int64").Parse(numTpl))
60	template.Must(tpl.New("uint32").Parse(numTpl))
61	template.Must(tpl.New("uint64").Parse(numTpl))
62	template.Must(tpl.New("sint32").Parse(numTpl))
63	template.Must(tpl.New("sint64").Parse(numTpl))
64	template.Must(tpl.New("fixed32").Parse(numTpl))
65	template.Must(tpl.New("fixed64").Parse(numTpl))
66	template.Must(tpl.New("sfixed32").Parse(numTpl))
67	template.Must(tpl.New("sfixed64").Parse(numTpl))
68
69	template.Must(tpl.New("bool").Parse(constTpl))
70	template.Must(tpl.New("string").Parse(strTpl))
71	template.Must(tpl.New("bytes").Parse(bytesTpl))
72
73	template.Must(tpl.New("email").Parse(emailTpl))
74	template.Must(tpl.New("hostname").Parse(hostTpl))
75	template.Must(tpl.New("address").Parse(hostTpl))
76
77	template.Must(tpl.New("enum").Parse(enumTpl))
78	template.Must(tpl.New("repeated").Parse(repTpl))
79	template.Must(tpl.New("map").Parse(mapTpl))
80
81	template.Must(tpl.New("any").Parse(anyTpl))
82	template.Must(tpl.New("timestampcmp").Parse(timestampcmpTpl))
83	template.Must(tpl.New("durationcmp").Parse(durationcmpTpl))
84
85	template.Must(tpl.New("wrapper").Parse(wrapperTpl))
86}
87
88type goSharedFuncs struct{ pgsgo.Context }
89
90func (fns goSharedFuncs) accessor(ctx shared.RuleContext) string {
91	if ctx.AccessorOverride != "" {
92		return ctx.AccessorOverride
93	}
94
95	return fmt.Sprintf("m.Get%s()", fns.Name(ctx.Field))
96}
97
98func (fns goSharedFuncs) errName(m pgs.Message) pgs.Name {
99	return fns.Name(m) + "ValidationError"
100}
101
102func (fns goSharedFuncs) errIdxCause(ctx shared.RuleContext, idx, cause string, reason ...interface{}) string {
103	f := ctx.Field
104	n := fns.Name(f)
105
106	var fld string
107	if idx != "" {
108		fld = fmt.Sprintf(`fmt.Sprintf("%s[%%v]", %s)`, n, idx)
109	} else if ctx.Index != "" {
110		fld = fmt.Sprintf(`fmt.Sprintf("%s[%%v]", %s)`, n, ctx.Index)
111	} else {
112		fld = fmt.Sprintf("%q", n)
113	}
114
115	causeFld := ""
116	if cause != "nil" && cause != "" {
117		causeFld = fmt.Sprintf("cause: %s,", cause)
118	}
119
120	keyFld := ""
121	if ctx.OnKey {
122		keyFld = "key: true,"
123	}
124
125	return fmt.Sprintf(`%s{
126		field: %s,
127		reason: %q,
128		%s%s
129	}`,
130		fns.errName(f.Message()),
131		fld,
132		fmt.Sprint(reason...),
133		causeFld,
134		keyFld)
135}
136
137func (fns goSharedFuncs) err(ctx shared.RuleContext, reason ...interface{}) string {
138	return fns.errIdxCause(ctx, "", "nil", reason...)
139}
140
141func (fns goSharedFuncs) errCause(ctx shared.RuleContext, cause string, reason ...interface{}) string {
142	return fns.errIdxCause(ctx, "", cause, reason...)
143}
144
145func (fns goSharedFuncs) errIdx(ctx shared.RuleContext, idx string, reason ...interface{}) string {
146	return fns.errIdxCause(ctx, idx, "nil", reason...)
147}
148
149func (fns goSharedFuncs) lookup(f pgs.Field, name string) string {
150	return fmt.Sprintf(
151		"_%s_%s_%s",
152		fns.Name(f.Message()),
153		fns.Name(f),
154		name,
155	)
156}
157
158func (fns goSharedFuncs) lit(x interface{}) string {
159	val := reflect.ValueOf(x)
160
161	if val.Kind() == reflect.Interface {
162		val = val.Elem()
163	}
164
165	if val.Kind() == reflect.Ptr {
166		val = val.Elem()
167	}
168
169	switch val.Kind() {
170	case reflect.String:
171		return fmt.Sprintf("%q", x)
172	case reflect.Uint8:
173		return fmt.Sprintf("0x%X", x)
174	case reflect.Slice:
175		els := make([]string, val.Len())
176		for i, l := 0, val.Len(); i < l; i++ {
177			els[i] = fns.lit(val.Index(i).Interface())
178		}
179		return fmt.Sprintf("%T{%s}", val.Interface(), strings.Join(els, ", "))
180	default:
181		return fmt.Sprint(x)
182	}
183}
184
185func (fns goSharedFuncs) isBytes(f interface {
186	ProtoType() pgs.ProtoType
187}) bool {
188	return f.ProtoType() == pgs.BytesT
189}
190
191func (fns goSharedFuncs) byteStr(x []byte) string {
192	elms := make([]string, len(x))
193	for i, b := range x {
194		elms[i] = fmt.Sprintf(`\x%X`, b)
195	}
196
197	return fmt.Sprintf(`"%s"`, strings.Join(elms, ""))
198}
199
200func (fns goSharedFuncs) oneofTypeName(f pgs.Field) pgsgo.TypeName {
201	return pgsgo.TypeName(fns.OneofOption(f)).Pointer()
202}
203
204func (fns goSharedFuncs) inType(f pgs.Field, x interface{}) string {
205	switch f.Type().ProtoType() {
206	case pgs.BytesT:
207		return "string"
208	case pgs.MessageT:
209		switch x.(type) {
210		case []*duration.Duration:
211			return "time.Duration"
212		default:
213			return pgsgo.TypeName(fmt.Sprintf("%T", x)).Element().String()
214		}
215	default:
216		return fns.Type(f).String()
217	}
218}
219
220func (fns goSharedFuncs) inKey(f pgs.Field, x interface{}) string {
221	switch f.Type().ProtoType() {
222	case pgs.BytesT:
223		return fns.byteStr(x.([]byte))
224	case pgs.MessageT:
225		switch x := x.(type) {
226		case *duration.Duration:
227			dur, _ := ptypes.Duration(x)
228			return fns.lit(int64(dur))
229		default:
230			return fns.lit(x)
231		}
232	default:
233		return fns.lit(x)
234	}
235}
236
237func (fns goSharedFuncs) durLit(dur *duration.Duration) string {
238	return fmt.Sprintf(
239		"time.Duration(%d * time.Second + %d * time.Nanosecond)",
240		dur.GetSeconds(), dur.GetNanos())
241}
242
243func (fns goSharedFuncs) durStr(dur *duration.Duration) string {
244	d, _ := ptypes.Duration(dur)
245	return d.String()
246}
247
248func (fns goSharedFuncs) durGt(a, b *duration.Duration) bool {
249	ad, _ := ptypes.Duration(a)
250	bd, _ := ptypes.Duration(b)
251
252	return ad > bd
253}
254
255func (fns goSharedFuncs) tsLit(ts *timestamp.Timestamp) string {
256	return fmt.Sprintf(
257		"time.Unix(%d, %d)",
258		ts.GetSeconds(), ts.GetNanos(),
259	)
260}
261
262func (fns goSharedFuncs) tsGt(a, b *timestamp.Timestamp) bool {
263	at, _ := ptypes.Timestamp(a)
264	bt, _ := ptypes.Timestamp(b)
265
266	return bt.Before(at)
267}
268
269func (fns goSharedFuncs) tsStr(ts *timestamp.Timestamp) string {
270	t, _ := ptypes.Timestamp(ts)
271	return t.String()
272}
273
274func (fns goSharedFuncs) unwrap(ctx shared.RuleContext, name string) (shared.RuleContext, error) {
275	ctx, err := ctx.Unwrap("wrapper")
276	if err != nil {
277		return ctx, err
278	}
279
280	ctx.AccessorOverride = fmt.Sprintf("%s.Get%s()", name,
281		pgsgo.PGGUpperCamelCase(ctx.Field.Type().Embed().Fields()[0].Name()))
282
283	return ctx, nil
284}
285
286func (fns goSharedFuncs) msgTyp(message pgs.Message) pgsgo.TypeName {
287	return pgsgo.TypeName(fns.Name(message))
288}
289
290func (fns goSharedFuncs) externalEnums(file pgs.File) []pgs.Enum {
291	var out []pgs.Enum
292
293	for _, msg := range file.AllMessages() {
294		for _, fld := range msg.Fields() {
295			if en := fld.Type().Enum(); fld.Type().IsEnum() && en.Package().ProtoName() != fld.Package().ProtoName() && fns.PackageName(en) != fns.PackageName(fld) {
296				out = append(out, en)
297			}
298		}
299	}
300
301	return out
302}
303
304func (fns goSharedFuncs) enumPackages(enums []pgs.Enum) map[pgs.FilePath]pgs.Name {
305	out := make(map[pgs.FilePath]pgs.Name, len(enums))
306
307	for _, en := range enums {
308		out[fns.ImportPath(en)] = fns.PackageName(en)
309	}
310
311	return out
312}
313