1package java
2
3import (
4	"bytes"
5	"fmt"
6	"os"
7	"strings"
8	"text/template"
9	"unicode"
10
11	"github.com/envoyproxy/protoc-gen-validate/templates/shared"
12	"github.com/golang/protobuf/ptypes/duration"
13	"github.com/golang/protobuf/ptypes/timestamp"
14	"github.com/iancoleman/strcase"
15	pgs "github.com/lyft/protoc-gen-star"
16	pgsgo "github.com/lyft/protoc-gen-star/lang/go"
17)
18
19func RegisterIndex(tpl *template.Template, params pgs.Parameters) {
20	fns := javaFuncs{pgsgo.InitContext(params)}
21
22	tpl.Funcs(map[string]interface{}{
23		"classNameFile": classNameFile,
24		"importsPvg":    importsPvg,
25		"javaPackage":   javaPackage,
26		"simpleName":    fns.Name,
27		"qualifiedName": fns.qualifiedName,
28	})
29}
30
31func Register(tpl *template.Template, params pgs.Parameters) {
32	fns := javaFuncs{pgsgo.InitContext(params)}
33
34	tpl.Funcs(map[string]interface{}{
35		"accessor":                 fns.accessor,
36		"byteArrayLit":             fns.byteArrayLit,
37		"camelCase":                fns.camelCase,
38		"classNameFile":            classNameFile,
39		"classNameMessage":         classNameMessage,
40		"durLit":                   fns.durLit,
41		"fieldName":                fns.fieldName,
42		"javaPackage":              javaPackage,
43		"javaStringEscape":         fns.javaStringEscape,
44		"javaTypeFor":              fns.javaTypeFor,
45		"javaTypeLiteralSuffixFor": fns.javaTypeLiteralSuffixFor,
46		"hasAccessor":              fns.hasAccessor,
47		"oneof":                    fns.oneofTypeName,
48		"sprintf":                  fmt.Sprintf,
49		"simpleName":               fns.Name,
50		"tsLit":                    fns.tsLit,
51		"qualifiedName":            fns.qualifiedName,
52		"isOfFileType":             fns.isOfFileType,
53		"isOfMessageType":          fns.isOfMessageType,
54		"isOfStringType":           fns.isOfStringType,
55		"unwrap":                   fns.unwrap,
56		"renderConstants":          fns.renderConstants(tpl),
57		"constantName":             fns.constantName,
58	})
59
60	template.Must(tpl.Parse(fileTpl))
61	template.Must(tpl.New("msg").Parse(msgTpl))
62	template.Must(tpl.New("msgInner").Parse(msgInnerTpl))
63
64	template.Must(tpl.New("none").Parse(noneTpl))
65
66	template.Must(tpl.New("float").Parse(numTpl))
67	template.Must(tpl.New("floatConst").Parse(numConstTpl))
68	template.Must(tpl.New("double").Parse(numTpl))
69	template.Must(tpl.New("doubleConst").Parse(numConstTpl))
70	template.Must(tpl.New("int32").Parse(numTpl))
71	template.Must(tpl.New("int32Const").Parse(numConstTpl))
72	template.Must(tpl.New("int64").Parse(numTpl))
73	template.Must(tpl.New("int64Const").Parse(numConstTpl))
74	template.Must(tpl.New("uint32").Parse(numTpl))
75	template.Must(tpl.New("uint32Const").Parse(numConstTpl))
76	template.Must(tpl.New("uint64").Parse(numTpl))
77	template.Must(tpl.New("uint64Const").Parse(numConstTpl))
78	template.Must(tpl.New("sint32").Parse(numTpl))
79	template.Must(tpl.New("sint32Const").Parse(numConstTpl))
80	template.Must(tpl.New("sint64").Parse(numTpl))
81	template.Must(tpl.New("sint64Const").Parse(numConstTpl))
82	template.Must(tpl.New("fixed32").Parse(numTpl))
83	template.Must(tpl.New("fixed32Const").Parse(numConstTpl))
84	template.Must(tpl.New("fixed64").Parse(numTpl))
85	template.Must(tpl.New("fixed64Const").Parse(numConstTpl))
86	template.Must(tpl.New("sfixed32").Parse(numTpl))
87	template.Must(tpl.New("sfixed32Const").Parse(numConstTpl))
88	template.Must(tpl.New("sfixed64").Parse(numTpl))
89	template.Must(tpl.New("sfixed64Const").Parse(numConstTpl))
90
91	template.Must(tpl.New("bool").Parse(boolTpl))
92	template.Must(tpl.New("string").Parse(stringTpl))
93	template.Must(tpl.New("stringConst").Parse(stringConstTpl))
94	template.Must(tpl.New("bytes").Parse(bytesTpl))
95	template.Must(tpl.New("bytesConst").Parse(bytesConstTpl))
96
97	template.Must(tpl.New("any").Parse(anyTpl))
98	template.Must(tpl.New("anyConst").Parse(anyConstTpl))
99	template.Must(tpl.New("enum").Parse(enumTpl))
100	template.Must(tpl.New("enumConst").Parse(enumConstTpl))
101	template.Must(tpl.New("message").Parse(messageTpl))
102	template.Must(tpl.New("repeated").Parse(repeatedTpl))
103	template.Must(tpl.New("repeatedConst").Parse(repeatedConstTpl))
104	template.Must(tpl.New("map").Parse(mapTpl))
105	template.Must(tpl.New("mapConst").Parse(mapConstTpl))
106	template.Must(tpl.New("oneOf").Parse(oneOfTpl))
107	template.Must(tpl.New("oneOfConst").Parse(oneOfConstTpl))
108
109	template.Must(tpl.New("required").Parse(requiredTpl))
110	template.Must(tpl.New("timestamp").Parse(timestampTpl))
111	template.Must(tpl.New("timestampConst").Parse(timestampConstTpl))
112	template.Must(tpl.New("duration").Parse(durationTpl))
113	template.Must(tpl.New("durationConst").Parse(durationConstTpl))
114	template.Must(tpl.New("wrapper").Parse(wrapperTpl))
115	template.Must(tpl.New("wrapperConst").Parse(wrapperConstTpl))
116}
117
118type javaFuncs struct{ pgsgo.Context }
119
120func JavaFilePath(f pgs.File, ctx pgsgo.Context, tpl *template.Template) *pgs.FilePath {
121	// Don't generate validators for files that don't import PGV
122	if !importsPvg(f) {
123		return nil
124	}
125
126	fullPath := strings.Replace(javaPackage(f), ".", string(os.PathSeparator), -1)
127	fileName := classNameFile(f) + "Validator.java"
128	filePath := pgs.JoinPaths(fullPath, fileName)
129	return &filePath
130}
131
132func JavaMultiFilePath(f pgs.File, m pgs.Message) pgs.FilePath {
133	fullPath := strings.Replace(javaPackage(f), ".", string(os.PathSeparator), -1)
134	fileName := classNameMessage(m) + "Validator.java"
135	filePath := pgs.JoinPaths(fullPath, fileName)
136	return filePath
137}
138
139func importsPvg(f pgs.File) bool {
140	for _, dep := range f.Descriptor().Dependency {
141		if strings.HasSuffix(dep, "validate.proto") {
142			return true
143		}
144	}
145	return false
146}
147
148func classNameFile(f pgs.File) string {
149	// Explicit outer class name overrides implicit name
150	options := f.Descriptor().GetOptions()
151	if options != nil && !options.GetJavaMultipleFiles() && options.JavaOuterClassname != nil {
152		return options.GetJavaOuterClassname()
153	}
154
155	protoName := pgs.FilePath(f.Name().String()).BaseName()
156
157	className := sanitizeClassName(protoName)
158	className = appendOuterClassName(className, f)
159
160	return className
161}
162
163func classNameMessage(m pgs.Message) string {
164	return sanitizeClassName(m.Name().String())
165}
166
167func sanitizeClassName(className string) string {
168	className = makeInvalidClassnameCharactersUnderscores(className)
169	className = underscoreBetweenConsecutiveUppercase(className)
170	className = strcase.ToCamel(strcase.ToSnake(className))
171	className = upperCaseAfterNumber(className)
172	return className
173}
174
175func javaPackage(file pgs.File) string {
176	// Explicit java package overrides implicit package
177	options := file.Descriptor().GetOptions()
178	if options != nil && options.JavaPackage != nil {
179		return options.GetJavaPackage()
180	}
181	return file.Package().ProtoName().String()
182}
183
184func (fns javaFuncs) qualifiedName(entity pgs.Entity) string {
185	file, isFile := entity.(pgs.File)
186	if isFile {
187		name := javaPackage(file)
188		if file.Descriptor().GetOptions() != nil {
189			if !file.Descriptor().GetOptions().GetJavaMultipleFiles() {
190				name += ("." + classNameFile(file))
191			}
192		} else {
193			name += ("." + classNameFile(file))
194		}
195		return name
196	}
197
198	message, isMessage := entity.(pgs.Message)
199	if isMessage && message.Parent() != nil {
200		// recurse
201		return fns.qualifiedName(message.Parent()) + "." + entity.Name().String()
202	}
203
204	enum, isEnum := entity.(pgs.Enum)
205	if isEnum && enum.Parent() != nil {
206		// recurse
207		return fns.qualifiedName(enum.Parent()) + "." + entity.Name().String()
208	}
209
210	return entity.Name().String()
211}
212
213// Replace invalid identifier characters with an underscore
214func makeInvalidClassnameCharactersUnderscores(name string) string {
215	var sb string
216	for _, c := range name {
217		switch {
218		case c >= '0' && c <= '9':
219			sb += string(c)
220		case c >= 'a' && c <= 'z':
221			sb += string(c)
222		case c >= 'A' && c <= 'Z':
223			sb += string(c)
224		default:
225			sb += "_"
226		}
227	}
228	return sb
229}
230
231func upperCaseAfterNumber(name string) string {
232	var sb string
233	var p rune
234
235	for _, c := range name {
236		if unicode.IsDigit(p) {
237			sb += string(unicode.ToUpper(c))
238		} else {
239			sb += string(c)
240		}
241		p = c
242	}
243	return sb
244}
245
246func underscoreBetweenConsecutiveUppercase(name string) string {
247	var sb string
248	var p rune
249
250	for _, c := range name {
251		if unicode.IsUpper(p) && unicode.IsUpper(c) {
252			sb += "_" + string(c)
253		} else {
254			sb += string(c)
255		}
256		p = c
257	}
258	return sb
259}
260
261func appendOuterClassName(outerClassName string, file pgs.File) string {
262	conflict := false
263
264	for _, enum := range file.Enums() {
265		if enum.Name().String() == outerClassName {
266			conflict = true
267		}
268	}
269
270	for _, message := range file.Messages() {
271		if message.Name().String() == outerClassName {
272			conflict = true
273		}
274	}
275
276	for _, service := range file.Services() {
277		if service.Name().String() == outerClassName {
278			conflict = true
279		}
280	}
281
282	if conflict {
283		return outerClassName + "OuterClass"
284	} else {
285		return outerClassName
286	}
287}
288
289func (fns javaFuncs) accessor(ctx shared.RuleContext) string {
290	if ctx.AccessorOverride != "" {
291		return ctx.AccessorOverride
292	}
293	return fns.fieldAccessor(ctx.Field)
294}
295
296func (fns javaFuncs) fieldAccessor(f pgs.Field) string {
297	fieldName := strcase.ToCamel(f.Name().String())
298	if f.Type().IsMap() {
299		fieldName += "Map"
300	}
301	if f.Type().IsRepeated() {
302		fieldName += "List"
303	}
304
305	fieldName = upperCaseAfterNumber(fieldName)
306	return fmt.Sprintf("proto.get%s()", fieldName)
307}
308
309func (fns javaFuncs) hasAccessor(ctx shared.RuleContext) string {
310	if ctx.AccessorOverride != "" {
311		return "true"
312	}
313	fiedlName := strcase.ToCamel(ctx.Field.Name().String())
314	fiedlName = upperCaseAfterNumber(fiedlName)
315	return "proto.has" + fiedlName + "()"
316}
317
318func (fns javaFuncs) fieldName(ctx shared.RuleContext) string {
319	return ctx.Field.Name().String()
320}
321
322func (fns javaFuncs) javaTypeFor(ctx shared.RuleContext) string {
323	t := ctx.Field.Type()
324
325	// Map key and value types
326	if t.IsMap() {
327		switch ctx.AccessorOverride {
328		case "key":
329			return fns.javaTypeForProtoType(t.Key().ProtoType())
330		case "value":
331			return fns.javaTypeForProtoType(t.Element().ProtoType())
332		}
333	}
334
335	if t.IsEmbed() {
336		if embed := t.Embed(); embed.IsWellKnown() {
337			switch embed.WellKnownType() {
338			case pgs.AnyWKT:
339				return "String"
340			case pgs.DurationWKT:
341				return "com.google.protobuf.Duration"
342			case pgs.TimestampWKT:
343				return "com.google.protobuf.Timestamp"
344			case pgs.Int32ValueWKT, pgs.UInt32ValueWKT:
345				return "Integer"
346			case pgs.Int64ValueWKT, pgs.UInt64ValueWKT:
347				return "Long"
348			case pgs.DoubleValueWKT:
349				return "Double"
350			case pgs.FloatValueWKT:
351				return "Float"
352			}
353		}
354	}
355
356	if t.IsRepeated() {
357		if t.ProtoType() == pgs.MessageT {
358			return fns.qualifiedName(t.Element().Embed())
359		} else if t.ProtoType() == pgs.EnumT {
360			return fns.qualifiedName(t.Element().Enum())
361		}
362	}
363
364	if t.IsEnum() {
365		return fns.qualifiedName(t.Enum())
366	}
367
368	return fns.javaTypeForProtoType(t.ProtoType())
369}
370
371func (fns javaFuncs) javaTypeForProtoType(t pgs.ProtoType) string {
372
373	switch t {
374	case pgs.Int32T, pgs.UInt32T, pgs.SInt32, pgs.Fixed32T, pgs.SFixed32:
375		return "Integer"
376	case pgs.Int64T, pgs.UInt64T, pgs.SInt64, pgs.Fixed64T, pgs.SFixed64:
377		return "Long"
378	case pgs.DoubleT:
379		return "Double"
380	case pgs.FloatT:
381		return "Float"
382	case pgs.BoolT:
383		return "Boolean"
384	case pgs.StringT:
385		return "String"
386	case pgs.BytesT:
387		return "com.google.protobuf.ByteString"
388	default:
389		return "Object"
390	}
391}
392
393func (fns javaFuncs) javaTypeLiteralSuffixFor(f pgs.Field) string {
394	switch f.Type().ProtoType() {
395	case pgs.Int64T, pgs.UInt64T, pgs.SInt64, pgs.Fixed64T, pgs.SFixed64:
396		return "L"
397	case pgs.FloatT:
398		return "F"
399	case pgs.DoubleT:
400		return "D"
401	}
402
403	emb := f.Type().Embed()
404	if emb != nil && emb.IsWellKnown() {
405		switch emb.WellKnownType() {
406		case pgs.Int64ValueWKT, pgs.UInt64ValueWKT:
407			return "L"
408		case pgs.FloatValueWKT:
409			return "F"
410		case pgs.DoubleValueWKT:
411			return "D"
412		}
413	}
414
415	return ""
416}
417
418func (fns javaFuncs) javaStringEscape(s string) string {
419	s = fmt.Sprintf("%q", s)
420	s = s[1 : len(s)-1]
421	s = strings.Replace(s, `\u00`, `\x`, -1)
422	s = strings.Replace(s, `\x`, `\\x`, -1)
423	// s = strings.Replace(s, `\`, `\\`, -1)
424	s = strings.Replace(s, `"`, `\"`, -1)
425	return `"` + s + `"`
426}
427
428func (fns javaFuncs) camelCase(name pgs.Name) string {
429	return strcase.ToCamel(name.String())
430}
431
432func (fns javaFuncs) byteArrayLit(bytes []uint8) string {
433	var sb string
434	sb += "new byte[]{"
435	for _, b := range bytes {
436		sb += fmt.Sprintf("(byte)%#x,", b)
437	}
438	sb += "}"
439
440	return sb
441}
442
443func (fns javaFuncs) durLit(dur *duration.Duration) string {
444	return fmt.Sprintf(
445		"io.envoyproxy.pgv.TimestampValidation.toDuration(%d,%d)",
446		dur.GetSeconds(), dur.GetNanos())
447}
448
449func (fns javaFuncs) tsLit(ts *timestamp.Timestamp) string {
450	return fmt.Sprintf(
451		"io.envoyproxy.pgv.TimestampValidation.toTimestamp(%d,%d)",
452		ts.GetSeconds(), ts.GetNanos())
453}
454
455func (fns javaFuncs) oneofTypeName(f pgs.Field) pgsgo.TypeName {
456	return pgsgo.TypeName(fmt.Sprintf("%s", strings.ToUpper(f.Name().String())))
457}
458
459func (fns javaFuncs) isOfFileType(o interface{}) bool {
460	switch o.(type) {
461	case pgs.File:
462		return true
463	default:
464		return false
465	}
466}
467
468func (fns javaFuncs) isOfMessageType(f pgs.Field) bool {
469	return f.Type().ProtoType() == pgs.MessageT
470}
471
472func (fns javaFuncs) isOfStringType(f pgs.Field) bool {
473	return f.Type().ProtoType() == pgs.StringT
474}
475
476func (fns javaFuncs) unwrap(ctx shared.RuleContext) (shared.RuleContext, error) {
477	ctx, err := ctx.Unwrap("wrapped")
478	if err != nil {
479		return ctx, err
480	}
481	ctx.AccessorOverride = fmt.Sprintf("%s.get%s()", fns.fieldAccessor(ctx.Field),
482		fns.camelCase(ctx.Field.Type().Embed().Fields()[0].Name()))
483	return ctx, nil
484}
485
486func (fns javaFuncs) renderConstants(tpl *template.Template) func(ctx shared.RuleContext) (string, error) {
487	return func(ctx shared.RuleContext) (string, error) {
488		var b bytes.Buffer
489		var err error
490
491		hasConstTemplate := false
492		for _, t := range tpl.Templates() {
493			if t.Name() == ctx.Typ+"Const" {
494				hasConstTemplate = true
495			}
496		}
497
498		if hasConstTemplate {
499			err = tpl.ExecuteTemplate(&b, ctx.Typ+"Const", ctx)
500		}
501
502		return b.String(), err
503	}
504}
505
506func (fns javaFuncs) constantName(ctx shared.RuleContext, rule string) string {
507	return strcase.ToScreamingSnake(ctx.Field.Name().String() + "_" + ctx.Index + "_" + rule)
508}
509