1package java
2
3import (
4	"bytes"
5	"fmt"
6	"os"
7	"strings"
8	"text/template"
9	"unicode"
10
11	"github.com/envoyproxy/protoc-gen-validate/templates/shared"
12	"github.com/iancoleman/strcase"
13	pgs "github.com/lyft/protoc-gen-star"
14	pgsgo "github.com/lyft/protoc-gen-star/lang/go"
15	"google.golang.org/protobuf/types/known/durationpb"
16	"google.golang.org/protobuf/types/known/timestamppb"
17)
18
19func RegisterIndex(tpl *template.Template, params pgs.Parameters) {
20	fns := javaFuncs{pgsgo.InitContext(params)}
21
22	tpl.Funcs(map[string]interface{}{
23		"classNameFile": classNameFile,
24		"importsPvg":    importsPvg,
25		"javaPackage":   javaPackage,
26		"simpleName":    fns.Name,
27		"qualifiedName": fns.qualifiedName,
28	})
29}
30
31func Register(tpl *template.Template, params pgs.Parameters) {
32	fns := javaFuncs{pgsgo.InitContext(params)}
33
34	tpl.Funcs(map[string]interface{}{
35		"accessor":                 fns.accessor,
36		"byteArrayLit":             fns.byteArrayLit,
37		"camelCase":                fns.camelCase,
38		"classNameFile":            classNameFile,
39		"classNameMessage":         classNameMessage,
40		"durLit":                   fns.durLit,
41		"fieldName":                fns.fieldName,
42		"javaPackage":              javaPackage,
43		"javaStringEscape":         fns.javaStringEscape,
44		"javaTypeFor":              fns.javaTypeFor,
45		"javaTypeLiteralSuffixFor": fns.javaTypeLiteralSuffixFor,
46		"hasAccessor":              fns.hasAccessor,
47		"oneof":                    fns.oneofTypeName,
48		"sprintf":                  fmt.Sprintf,
49		"simpleName":               fns.Name,
50		"tsLit":                    fns.tsLit,
51		"qualifiedName":            fns.qualifiedName,
52		"isOfFileType":             fns.isOfFileType,
53		"isOfMessageType":          fns.isOfMessageType,
54		"isOfStringType":           fns.isOfStringType,
55		"unwrap":                   fns.unwrap,
56		"renderConstants":          fns.renderConstants(tpl),
57		"constantName":             fns.constantName,
58	})
59
60	template.Must(tpl.Parse(fileTpl))
61	template.Must(tpl.New("msg").Parse(msgTpl))
62	template.Must(tpl.New("msgInner").Parse(msgInnerTpl))
63
64	template.Must(tpl.New("none").Parse(noneTpl))
65
66	template.Must(tpl.New("float").Parse(numTpl))
67	template.Must(tpl.New("floatConst").Parse(numConstTpl))
68	template.Must(tpl.New("double").Parse(numTpl))
69	template.Must(tpl.New("doubleConst").Parse(numConstTpl))
70	template.Must(tpl.New("int32").Parse(numTpl))
71	template.Must(tpl.New("int32Const").Parse(numConstTpl))
72	template.Must(tpl.New("int64").Parse(numTpl))
73	template.Must(tpl.New("int64Const").Parse(numConstTpl))
74	template.Must(tpl.New("uint32").Parse(numTpl))
75	template.Must(tpl.New("uint32Const").Parse(numConstTpl))
76	template.Must(tpl.New("uint64").Parse(numTpl))
77	template.Must(tpl.New("uint64Const").Parse(numConstTpl))
78	template.Must(tpl.New("sint32").Parse(numTpl))
79	template.Must(tpl.New("sint32Const").Parse(numConstTpl))
80	template.Must(tpl.New("sint64").Parse(numTpl))
81	template.Must(tpl.New("sint64Const").Parse(numConstTpl))
82	template.Must(tpl.New("fixed32").Parse(numTpl))
83	template.Must(tpl.New("fixed32Const").Parse(numConstTpl))
84	template.Must(tpl.New("fixed64").Parse(numTpl))
85	template.Must(tpl.New("fixed64Const").Parse(numConstTpl))
86	template.Must(tpl.New("sfixed32").Parse(numTpl))
87	template.Must(tpl.New("sfixed32Const").Parse(numConstTpl))
88	template.Must(tpl.New("sfixed64").Parse(numTpl))
89	template.Must(tpl.New("sfixed64Const").Parse(numConstTpl))
90
91	template.Must(tpl.New("bool").Parse(boolTpl))
92	template.Must(tpl.New("string").Parse(stringTpl))
93	template.Must(tpl.New("stringConst").Parse(stringConstTpl))
94	template.Must(tpl.New("bytes").Parse(bytesTpl))
95	template.Must(tpl.New("bytesConst").Parse(bytesConstTpl))
96
97	template.Must(tpl.New("any").Parse(anyTpl))
98	template.Must(tpl.New("anyConst").Parse(anyConstTpl))
99	template.Must(tpl.New("enum").Parse(enumTpl))
100	template.Must(tpl.New("enumConst").Parse(enumConstTpl))
101	template.Must(tpl.New("message").Parse(messageTpl))
102	template.Must(tpl.New("repeated").Parse(repeatedTpl))
103	template.Must(tpl.New("repeatedConst").Parse(repeatedConstTpl))
104	template.Must(tpl.New("map").Parse(mapTpl))
105	template.Must(tpl.New("mapConst").Parse(mapConstTpl))
106	template.Must(tpl.New("oneOf").Parse(oneOfTpl))
107	template.Must(tpl.New("oneOfConst").Parse(oneOfConstTpl))
108
109	template.Must(tpl.New("required").Parse(requiredTpl))
110	template.Must(tpl.New("timestamp").Parse(timestampTpl))
111	template.Must(tpl.New("timestampConst").Parse(timestampConstTpl))
112	template.Must(tpl.New("duration").Parse(durationTpl))
113	template.Must(tpl.New("durationConst").Parse(durationConstTpl))
114	template.Must(tpl.New("wrapper").Parse(wrapperTpl))
115	template.Must(tpl.New("wrapperConst").Parse(wrapperConstTpl))
116}
117
118type javaFuncs struct{ pgsgo.Context }
119
120func JavaFilePath(f pgs.File, ctx pgsgo.Context, tpl *template.Template) *pgs.FilePath {
121	// Don't generate validators for files that don't import PGV
122	if !importsPvg(f) {
123		return nil
124	}
125
126	fullPath := strings.Replace(javaPackage(f), ".", string(os.PathSeparator), -1)
127	fileName := classNameFile(f) + "Validator.java"
128	filePath := pgs.JoinPaths(fullPath, fileName)
129	return &filePath
130}
131
132func JavaMultiFilePath(f pgs.File, m pgs.Message) pgs.FilePath {
133	fullPath := strings.Replace(javaPackage(f), ".", string(os.PathSeparator), -1)
134	fileName := classNameMessage(m) + "Validator.java"
135	filePath := pgs.JoinPaths(fullPath, fileName)
136	return filePath
137}
138
139func importsPvg(f pgs.File) bool {
140	for _, dep := range f.Descriptor().Dependency {
141		if strings.HasSuffix(dep, "validate.proto") {
142			return true
143		}
144	}
145	return false
146}
147
148func classNameFile(f pgs.File) string {
149	// Explicit outer class name overrides implicit name
150	options := f.Descriptor().GetOptions()
151	if options != nil && !options.GetJavaMultipleFiles() && options.JavaOuterClassname != nil {
152		return options.GetJavaOuterClassname()
153	}
154
155	protoName := pgs.FilePath(f.Name().String()).BaseName()
156
157	className := sanitizeClassName(protoName)
158	className = appendOuterClassName(className, f)
159
160	return className
161}
162
163func classNameMessage(m pgs.Message) string {
164	className := m.Name().String()
165	// This is really silly, but when the multiple files option is true, protoc puts underscores in file names.
166	// When multiple files is false, underscores are stripped. Short of rewriting all the name sanitization
167	// logic for java, using "UnderscoreUnderscoreUnderscore" is an escape sequence seems to work with an extremely
168	// small likelihood of name conflict.
169	className = strings.Replace(className, "_", "UnderscoreUnderscoreUnderscore", -1)
170	className = sanitizeClassName(className)
171	className = strings.Replace(className, "UnderscoreUnderscoreUnderscore", "_", -1)
172	return className
173}
174
175func sanitizeClassName(className string) string {
176	className = makeInvalidClassnameCharactersUnderscores(className)
177	className = underscoreBetweenConsecutiveUppercase(className)
178	className = strcase.ToCamel(strcase.ToSnake(className))
179	className = upperCaseAfterNumber(className)
180	return className
181}
182
183func javaPackage(file pgs.File) string {
184	// Explicit java package overrides implicit package
185	options := file.Descriptor().GetOptions()
186	if options != nil && options.JavaPackage != nil {
187		return options.GetJavaPackage()
188	}
189	return file.Package().ProtoName().String()
190}
191
192func (fns javaFuncs) qualifiedName(entity pgs.Entity) string {
193	file, isFile := entity.(pgs.File)
194	if isFile {
195		name := javaPackage(file)
196		if file.Descriptor().GetOptions() != nil {
197			if !file.Descriptor().GetOptions().GetJavaMultipleFiles() {
198				name += ("." + classNameFile(file))
199			}
200		} else {
201			name += ("." + classNameFile(file))
202		}
203		return name
204	}
205
206	message, isMessage := entity.(pgs.Message)
207	if isMessage && message.Parent() != nil {
208		// recurse
209		return fns.qualifiedName(message.Parent()) + "." + entity.Name().String()
210	}
211
212	enum, isEnum := entity.(pgs.Enum)
213	if isEnum && enum.Parent() != nil {
214		// recurse
215		return fns.qualifiedName(enum.Parent()) + "." + entity.Name().String()
216	}
217
218	return entity.Name().String()
219}
220
221// Replace invalid identifier characters with an underscore
222func makeInvalidClassnameCharactersUnderscores(name string) string {
223	var sb string
224	for _, c := range name {
225		switch {
226		case c >= '0' && c <= '9':
227			sb += string(c)
228		case c >= 'a' && c <= 'z':
229			sb += string(c)
230		case c >= 'A' && c <= 'Z':
231			sb += string(c)
232		default:
233			sb += "_"
234		}
235	}
236	return sb
237}
238
239func upperCaseAfterNumber(name string) string {
240	var sb string
241	var p rune
242
243	for _, c := range name {
244		if unicode.IsDigit(p) {
245			sb += string(unicode.ToUpper(c))
246		} else {
247			sb += string(c)
248		}
249		p = c
250	}
251	return sb
252}
253
254func underscoreBetweenConsecutiveUppercase(name string) string {
255	var sb string
256	var p rune
257
258	for _, c := range name {
259		if unicode.IsUpper(p) && unicode.IsUpper(c) {
260			sb += "_" + string(c)
261		} else {
262			sb += string(c)
263		}
264		p = c
265	}
266	return sb
267}
268
269func appendOuterClassName(outerClassName string, file pgs.File) string {
270	conflict := false
271
272	for _, enum := range file.Enums() {
273		if enum.Name().String() == outerClassName {
274			conflict = true
275		}
276	}
277
278	for _, message := range file.Messages() {
279		if message.Name().String() == outerClassName {
280			conflict = true
281		}
282	}
283
284	for _, service := range file.Services() {
285		if service.Name().String() == outerClassName {
286			conflict = true
287		}
288	}
289
290	if conflict {
291		return outerClassName + "OuterClass"
292	} else {
293		return outerClassName
294	}
295}
296
297func (fns javaFuncs) accessor(ctx shared.RuleContext) string {
298	if ctx.AccessorOverride != "" {
299		return ctx.AccessorOverride
300	}
301	return fns.fieldAccessor(ctx.Field)
302}
303
304func (fns javaFuncs) fieldAccessor(f pgs.Field) string {
305	fieldName := strcase.ToCamel(f.Name().String())
306	if f.Type().IsMap() {
307		fieldName += "Map"
308	}
309	if f.Type().IsRepeated() {
310		fieldName += "List"
311	}
312
313	fieldName = upperCaseAfterNumber(fieldName)
314	return fmt.Sprintf("proto.get%s()", fieldName)
315}
316
317func (fns javaFuncs) hasAccessor(ctx shared.RuleContext) string {
318	if ctx.AccessorOverride != "" {
319		return "true"
320	}
321	fiedlName := strcase.ToCamel(ctx.Field.Name().String())
322	fiedlName = upperCaseAfterNumber(fiedlName)
323	return "proto.has" + fiedlName + "()"
324}
325
326func (fns javaFuncs) fieldName(ctx shared.RuleContext) string {
327	return ctx.Field.Name().String()
328}
329
330func (fns javaFuncs) javaTypeFor(ctx shared.RuleContext) string {
331	t := ctx.Field.Type()
332
333	// Map key and value types
334	if t.IsMap() {
335		switch ctx.AccessorOverride {
336		case "key":
337			return fns.javaTypeForProtoType(t.Key().ProtoType())
338		case "value":
339			return fns.javaTypeForProtoType(t.Element().ProtoType())
340		}
341	}
342
343	if t.IsEmbed() {
344		if embed := t.Embed(); embed.IsWellKnown() {
345			switch embed.WellKnownType() {
346			case pgs.AnyWKT:
347				return "String"
348			case pgs.DurationWKT:
349				return "com.google.protobuf.Duration"
350			case pgs.TimestampWKT:
351				return "com.google.protobuf.Timestamp"
352			case pgs.Int32ValueWKT, pgs.UInt32ValueWKT:
353				return "Integer"
354			case pgs.Int64ValueWKT, pgs.UInt64ValueWKT:
355				return "Long"
356			case pgs.DoubleValueWKT:
357				return "Double"
358			case pgs.FloatValueWKT:
359				return "Float"
360			}
361		}
362	}
363
364	if t.IsRepeated() {
365		if t.ProtoType() == pgs.MessageT {
366			return fns.qualifiedName(t.Element().Embed())
367		} else if t.ProtoType() == pgs.EnumT {
368			return fns.qualifiedName(t.Element().Enum())
369		}
370	}
371
372	if t.IsEnum() {
373		return fns.qualifiedName(t.Enum())
374	}
375
376	return fns.javaTypeForProtoType(t.ProtoType())
377}
378
379func (fns javaFuncs) javaTypeForProtoType(t pgs.ProtoType) string {
380
381	switch t {
382	case pgs.Int32T, pgs.UInt32T, pgs.SInt32, pgs.Fixed32T, pgs.SFixed32:
383		return "Integer"
384	case pgs.Int64T, pgs.UInt64T, pgs.SInt64, pgs.Fixed64T, pgs.SFixed64:
385		return "Long"
386	case pgs.DoubleT:
387		return "Double"
388	case pgs.FloatT:
389		return "Float"
390	case pgs.BoolT:
391		return "Boolean"
392	case pgs.StringT:
393		return "String"
394	case pgs.BytesT:
395		return "com.google.protobuf.ByteString"
396	default:
397		return "Object"
398	}
399}
400
401func (fns javaFuncs) javaTypeLiteralSuffixFor(ctx shared.RuleContext) string {
402	t := ctx.Field.Type()
403
404	if t.IsMap() {
405		switch ctx.AccessorOverride {
406		case "key":
407			return fns.javaTypeLiteralSuffixForPrototype(t.Key().ProtoType())
408		case "value":
409			return fns.javaTypeLiteralSuffixForPrototype(t.Element().ProtoType())
410		}
411	}
412
413	if t.IsEmbed() {
414		if embed := t.Embed(); embed.IsWellKnown() {
415			switch embed.WellKnownType() {
416			case pgs.Int64ValueWKT, pgs.UInt64ValueWKT:
417				return "L"
418			case pgs.FloatValueWKT:
419				return "F"
420			case pgs.DoubleValueWKT:
421				return "D"
422			}
423		}
424	}
425
426	return fns.javaTypeLiteralSuffixForPrototype(t.ProtoType())
427}
428
429func (fns javaFuncs) javaTypeLiteralSuffixForPrototype(t pgs.ProtoType) string {
430	switch t {
431	case pgs.Int64T, pgs.UInt64T, pgs.SInt64, pgs.Fixed64T, pgs.SFixed64:
432		return "L"
433	case pgs.FloatT:
434		return "F"
435	case pgs.DoubleT:
436		return "D"
437	default:
438		return ""
439	}
440}
441
442func (fns javaFuncs) javaStringEscape(s string) string {
443	s = fmt.Sprintf("%q", s)
444	s = s[1 : len(s)-1]
445	s = strings.Replace(s, `\u00`, `\x`, -1)
446	s = strings.Replace(s, `\x`, `\\x`, -1)
447	// s = strings.Replace(s, `\`, `\\`, -1)
448	s = strings.Replace(s, `"`, `\"`, -1)
449	return `"` + s + `"`
450}
451
452func (fns javaFuncs) camelCase(name pgs.Name) string {
453	return strcase.ToCamel(name.String())
454}
455
456func (fns javaFuncs) byteArrayLit(bytes []uint8) string {
457	var sb string
458	sb += "new byte[]{"
459	for _, b := range bytes {
460		sb += fmt.Sprintf("(byte)%#x,", b)
461	}
462	sb += "}"
463
464	return sb
465}
466
467func (fns javaFuncs) durLit(dur *durationpb.Duration) string {
468	return fmt.Sprintf(
469		"io.envoyproxy.pgv.TimestampValidation.toDuration(%d,%d)",
470		dur.GetSeconds(), dur.GetNanos())
471}
472
473func (fns javaFuncs) tsLit(ts *timestamppb.Timestamp) string {
474	return fmt.Sprintf(
475		"io.envoyproxy.pgv.TimestampValidation.toTimestamp(%d,%d)",
476		ts.GetSeconds(), ts.GetNanos())
477}
478
479func (fns javaFuncs) oneofTypeName(f pgs.Field) pgsgo.TypeName {
480	return pgsgo.TypeName(fmt.Sprintf("%s", strings.ToUpper(f.Name().String())))
481}
482
483func (fns javaFuncs) isOfFileType(o interface{}) bool {
484	switch o.(type) {
485	case pgs.File:
486		return true
487	default:
488		return false
489	}
490}
491
492func (fns javaFuncs) isOfMessageType(f pgs.Field) bool {
493	return f.Type().ProtoType() == pgs.MessageT
494}
495
496func (fns javaFuncs) isOfStringType(f pgs.Field) bool {
497	return f.Type().ProtoType() == pgs.StringT
498}
499
500func (fns javaFuncs) unwrap(ctx shared.RuleContext) (shared.RuleContext, error) {
501	ctx, err := ctx.Unwrap("wrapped")
502	if err != nil {
503		return ctx, err
504	}
505	ctx.AccessorOverride = fmt.Sprintf("%s.get%s()", fns.fieldAccessor(ctx.Field),
506		fns.camelCase(ctx.Field.Type().Embed().Fields()[0].Name()))
507	return ctx, nil
508}
509
510func (fns javaFuncs) renderConstants(tpl *template.Template) func(ctx shared.RuleContext) (string, error) {
511	return func(ctx shared.RuleContext) (string, error) {
512		var b bytes.Buffer
513		var err error
514
515		hasConstTemplate := false
516		for _, t := range tpl.Templates() {
517			if t.Name() == ctx.Typ+"Const" {
518				hasConstTemplate = true
519			}
520		}
521
522		if hasConstTemplate {
523			err = tpl.ExecuteTemplate(&b, ctx.Typ+"Const", ctx)
524		}
525
526		return b.String(), err
527	}
528}
529
530func (fns javaFuncs) constantName(ctx shared.RuleContext, rule string) string {
531	return strcase.ToScreamingSnake(ctx.Field.Name().String() + "_" + ctx.Index + "_" + rule)
532}
533