1package java 2 3import ( 4 "bytes" 5 "fmt" 6 "os" 7 "strings" 8 "text/template" 9 "unicode" 10 11 "github.com/envoyproxy/protoc-gen-validate/templates/shared" 12 "github.com/iancoleman/strcase" 13 pgs "github.com/lyft/protoc-gen-star" 14 pgsgo "github.com/lyft/protoc-gen-star/lang/go" 15 "google.golang.org/protobuf/types/known/durationpb" 16 "google.golang.org/protobuf/types/known/timestamppb" 17) 18 19func RegisterIndex(tpl *template.Template, params pgs.Parameters) { 20 fns := javaFuncs{pgsgo.InitContext(params)} 21 22 tpl.Funcs(map[string]interface{}{ 23 "classNameFile": classNameFile, 24 "importsPvg": importsPvg, 25 "javaPackage": javaPackage, 26 "simpleName": fns.Name, 27 "qualifiedName": fns.qualifiedName, 28 }) 29} 30 31func Register(tpl *template.Template, params pgs.Parameters) { 32 fns := javaFuncs{pgsgo.InitContext(params)} 33 34 tpl.Funcs(map[string]interface{}{ 35 "accessor": fns.accessor, 36 "byteArrayLit": fns.byteArrayLit, 37 "camelCase": fns.camelCase, 38 "classNameFile": classNameFile, 39 "classNameMessage": classNameMessage, 40 "durLit": fns.durLit, 41 "fieldName": fns.fieldName, 42 "javaPackage": javaPackage, 43 "javaStringEscape": fns.javaStringEscape, 44 "javaTypeFor": fns.javaTypeFor, 45 "javaTypeLiteralSuffixFor": fns.javaTypeLiteralSuffixFor, 46 "hasAccessor": fns.hasAccessor, 47 "oneof": fns.oneofTypeName, 48 "sprintf": fmt.Sprintf, 49 "simpleName": fns.Name, 50 "tsLit": fns.tsLit, 51 "qualifiedName": fns.qualifiedName, 52 "isOfFileType": fns.isOfFileType, 53 "isOfMessageType": fns.isOfMessageType, 54 "isOfStringType": fns.isOfStringType, 55 "unwrap": fns.unwrap, 56 "renderConstants": fns.renderConstants(tpl), 57 "constantName": fns.constantName, 58 }) 59 60 template.Must(tpl.Parse(fileTpl)) 61 template.Must(tpl.New("msg").Parse(msgTpl)) 62 template.Must(tpl.New("msgInner").Parse(msgInnerTpl)) 63 64 template.Must(tpl.New("none").Parse(noneTpl)) 65 66 template.Must(tpl.New("float").Parse(numTpl)) 67 template.Must(tpl.New("floatConst").Parse(numConstTpl)) 68 template.Must(tpl.New("double").Parse(numTpl)) 69 template.Must(tpl.New("doubleConst").Parse(numConstTpl)) 70 template.Must(tpl.New("int32").Parse(numTpl)) 71 template.Must(tpl.New("int32Const").Parse(numConstTpl)) 72 template.Must(tpl.New("int64").Parse(numTpl)) 73 template.Must(tpl.New("int64Const").Parse(numConstTpl)) 74 template.Must(tpl.New("uint32").Parse(numTpl)) 75 template.Must(tpl.New("uint32Const").Parse(numConstTpl)) 76 template.Must(tpl.New("uint64").Parse(numTpl)) 77 template.Must(tpl.New("uint64Const").Parse(numConstTpl)) 78 template.Must(tpl.New("sint32").Parse(numTpl)) 79 template.Must(tpl.New("sint32Const").Parse(numConstTpl)) 80 template.Must(tpl.New("sint64").Parse(numTpl)) 81 template.Must(tpl.New("sint64Const").Parse(numConstTpl)) 82 template.Must(tpl.New("fixed32").Parse(numTpl)) 83 template.Must(tpl.New("fixed32Const").Parse(numConstTpl)) 84 template.Must(tpl.New("fixed64").Parse(numTpl)) 85 template.Must(tpl.New("fixed64Const").Parse(numConstTpl)) 86 template.Must(tpl.New("sfixed32").Parse(numTpl)) 87 template.Must(tpl.New("sfixed32Const").Parse(numConstTpl)) 88 template.Must(tpl.New("sfixed64").Parse(numTpl)) 89 template.Must(tpl.New("sfixed64Const").Parse(numConstTpl)) 90 91 template.Must(tpl.New("bool").Parse(boolTpl)) 92 template.Must(tpl.New("string").Parse(stringTpl)) 93 template.Must(tpl.New("stringConst").Parse(stringConstTpl)) 94 template.Must(tpl.New("bytes").Parse(bytesTpl)) 95 template.Must(tpl.New("bytesConst").Parse(bytesConstTpl)) 96 97 template.Must(tpl.New("any").Parse(anyTpl)) 98 template.Must(tpl.New("anyConst").Parse(anyConstTpl)) 99 template.Must(tpl.New("enum").Parse(enumTpl)) 100 template.Must(tpl.New("enumConst").Parse(enumConstTpl)) 101 template.Must(tpl.New("message").Parse(messageTpl)) 102 template.Must(tpl.New("repeated").Parse(repeatedTpl)) 103 template.Must(tpl.New("repeatedConst").Parse(repeatedConstTpl)) 104 template.Must(tpl.New("map").Parse(mapTpl)) 105 template.Must(tpl.New("mapConst").Parse(mapConstTpl)) 106 template.Must(tpl.New("oneOf").Parse(oneOfTpl)) 107 template.Must(tpl.New("oneOfConst").Parse(oneOfConstTpl)) 108 109 template.Must(tpl.New("required").Parse(requiredTpl)) 110 template.Must(tpl.New("timestamp").Parse(timestampTpl)) 111 template.Must(tpl.New("timestampConst").Parse(timestampConstTpl)) 112 template.Must(tpl.New("duration").Parse(durationTpl)) 113 template.Must(tpl.New("durationConst").Parse(durationConstTpl)) 114 template.Must(tpl.New("wrapper").Parse(wrapperTpl)) 115 template.Must(tpl.New("wrapperConst").Parse(wrapperConstTpl)) 116} 117 118type javaFuncs struct{ pgsgo.Context } 119 120func JavaFilePath(f pgs.File, ctx pgsgo.Context, tpl *template.Template) *pgs.FilePath { 121 // Don't generate validators for files that don't import PGV 122 if !importsPvg(f) { 123 return nil 124 } 125 126 fullPath := strings.Replace(javaPackage(f), ".", string(os.PathSeparator), -1) 127 fileName := classNameFile(f) + "Validator.java" 128 filePath := pgs.JoinPaths(fullPath, fileName) 129 return &filePath 130} 131 132func JavaMultiFilePath(f pgs.File, m pgs.Message) pgs.FilePath { 133 fullPath := strings.Replace(javaPackage(f), ".", string(os.PathSeparator), -1) 134 fileName := classNameMessage(m) + "Validator.java" 135 filePath := pgs.JoinPaths(fullPath, fileName) 136 return filePath 137} 138 139func importsPvg(f pgs.File) bool { 140 for _, dep := range f.Descriptor().Dependency { 141 if strings.HasSuffix(dep, "validate.proto") { 142 return true 143 } 144 } 145 return false 146} 147 148func classNameFile(f pgs.File) string { 149 // Explicit outer class name overrides implicit name 150 options := f.Descriptor().GetOptions() 151 if options != nil && !options.GetJavaMultipleFiles() && options.JavaOuterClassname != nil { 152 return options.GetJavaOuterClassname() 153 } 154 155 protoName := pgs.FilePath(f.Name().String()).BaseName() 156 157 className := sanitizeClassName(protoName) 158 className = appendOuterClassName(className, f) 159 160 return className 161} 162 163func classNameMessage(m pgs.Message) string { 164 className := m.Name().String() 165 // This is really silly, but when the multiple files option is true, protoc puts underscores in file names. 166 // When multiple files is false, underscores are stripped. Short of rewriting all the name sanitization 167 // logic for java, using "UnderscoreUnderscoreUnderscore" is an escape sequence seems to work with an extremely 168 // small likelihood of name conflict. 169 className = strings.Replace(className, "_", "UnderscoreUnderscoreUnderscore", -1) 170 className = sanitizeClassName(className) 171 className = strings.Replace(className, "UnderscoreUnderscoreUnderscore", "_", -1) 172 return className 173} 174 175func sanitizeClassName(className string) string { 176 className = makeInvalidClassnameCharactersUnderscores(className) 177 className = underscoreBetweenConsecutiveUppercase(className) 178 className = strcase.ToCamel(strcase.ToSnake(className)) 179 className = upperCaseAfterNumber(className) 180 return className 181} 182 183func javaPackage(file pgs.File) string { 184 // Explicit java package overrides implicit package 185 options := file.Descriptor().GetOptions() 186 if options != nil && options.JavaPackage != nil { 187 return options.GetJavaPackage() 188 } 189 return file.Package().ProtoName().String() 190} 191 192func (fns javaFuncs) qualifiedName(entity pgs.Entity) string { 193 file, isFile := entity.(pgs.File) 194 if isFile { 195 name := javaPackage(file) 196 if file.Descriptor().GetOptions() != nil { 197 if !file.Descriptor().GetOptions().GetJavaMultipleFiles() { 198 name += ("." + classNameFile(file)) 199 } 200 } else { 201 name += ("." + classNameFile(file)) 202 } 203 return name 204 } 205 206 message, isMessage := entity.(pgs.Message) 207 if isMessage && message.Parent() != nil { 208 // recurse 209 return fns.qualifiedName(message.Parent()) + "." + entity.Name().String() 210 } 211 212 enum, isEnum := entity.(pgs.Enum) 213 if isEnum && enum.Parent() != nil { 214 // recurse 215 return fns.qualifiedName(enum.Parent()) + "." + entity.Name().String() 216 } 217 218 return entity.Name().String() 219} 220 221// Replace invalid identifier characters with an underscore 222func makeInvalidClassnameCharactersUnderscores(name string) string { 223 var sb string 224 for _, c := range name { 225 switch { 226 case c >= '0' && c <= '9': 227 sb += string(c) 228 case c >= 'a' && c <= 'z': 229 sb += string(c) 230 case c >= 'A' && c <= 'Z': 231 sb += string(c) 232 default: 233 sb += "_" 234 } 235 } 236 return sb 237} 238 239func upperCaseAfterNumber(name string) string { 240 var sb string 241 var p rune 242 243 for _, c := range name { 244 if unicode.IsDigit(p) { 245 sb += string(unicode.ToUpper(c)) 246 } else { 247 sb += string(c) 248 } 249 p = c 250 } 251 return sb 252} 253 254func underscoreBetweenConsecutiveUppercase(name string) string { 255 var sb string 256 var p rune 257 258 for _, c := range name { 259 if unicode.IsUpper(p) && unicode.IsUpper(c) { 260 sb += "_" + string(c) 261 } else { 262 sb += string(c) 263 } 264 p = c 265 } 266 return sb 267} 268 269func appendOuterClassName(outerClassName string, file pgs.File) string { 270 conflict := false 271 272 for _, enum := range file.Enums() { 273 if enum.Name().String() == outerClassName { 274 conflict = true 275 } 276 } 277 278 for _, message := range file.Messages() { 279 if message.Name().String() == outerClassName { 280 conflict = true 281 } 282 } 283 284 for _, service := range file.Services() { 285 if service.Name().String() == outerClassName { 286 conflict = true 287 } 288 } 289 290 if conflict { 291 return outerClassName + "OuterClass" 292 } else { 293 return outerClassName 294 } 295} 296 297func (fns javaFuncs) accessor(ctx shared.RuleContext) string { 298 if ctx.AccessorOverride != "" { 299 return ctx.AccessorOverride 300 } 301 return fns.fieldAccessor(ctx.Field) 302} 303 304func (fns javaFuncs) fieldAccessor(f pgs.Field) string { 305 fieldName := strcase.ToCamel(f.Name().String()) 306 if f.Type().IsMap() { 307 fieldName += "Map" 308 } 309 if f.Type().IsRepeated() { 310 fieldName += "List" 311 } 312 313 fieldName = upperCaseAfterNumber(fieldName) 314 return fmt.Sprintf("proto.get%s()", fieldName) 315} 316 317func (fns javaFuncs) hasAccessor(ctx shared.RuleContext) string { 318 if ctx.AccessorOverride != "" { 319 return "true" 320 } 321 fiedlName := strcase.ToCamel(ctx.Field.Name().String()) 322 fiedlName = upperCaseAfterNumber(fiedlName) 323 return "proto.has" + fiedlName + "()" 324} 325 326func (fns javaFuncs) fieldName(ctx shared.RuleContext) string { 327 return ctx.Field.Name().String() 328} 329 330func (fns javaFuncs) javaTypeFor(ctx shared.RuleContext) string { 331 t := ctx.Field.Type() 332 333 // Map key and value types 334 if t.IsMap() { 335 switch ctx.AccessorOverride { 336 case "key": 337 return fns.javaTypeForProtoType(t.Key().ProtoType()) 338 case "value": 339 return fns.javaTypeForProtoType(t.Element().ProtoType()) 340 } 341 } 342 343 if t.IsEmbed() { 344 if embed := t.Embed(); embed.IsWellKnown() { 345 switch embed.WellKnownType() { 346 case pgs.AnyWKT: 347 return "String" 348 case pgs.DurationWKT: 349 return "com.google.protobuf.Duration" 350 case pgs.TimestampWKT: 351 return "com.google.protobuf.Timestamp" 352 case pgs.Int32ValueWKT, pgs.UInt32ValueWKT: 353 return "Integer" 354 case pgs.Int64ValueWKT, pgs.UInt64ValueWKT: 355 return "Long" 356 case pgs.DoubleValueWKT: 357 return "Double" 358 case pgs.FloatValueWKT: 359 return "Float" 360 } 361 } 362 } 363 364 if t.IsRepeated() { 365 if t.ProtoType() == pgs.MessageT { 366 return fns.qualifiedName(t.Element().Embed()) 367 } else if t.ProtoType() == pgs.EnumT { 368 return fns.qualifiedName(t.Element().Enum()) 369 } 370 } 371 372 if t.IsEnum() { 373 return fns.qualifiedName(t.Enum()) 374 } 375 376 return fns.javaTypeForProtoType(t.ProtoType()) 377} 378 379func (fns javaFuncs) javaTypeForProtoType(t pgs.ProtoType) string { 380 381 switch t { 382 case pgs.Int32T, pgs.UInt32T, pgs.SInt32, pgs.Fixed32T, pgs.SFixed32: 383 return "Integer" 384 case pgs.Int64T, pgs.UInt64T, pgs.SInt64, pgs.Fixed64T, pgs.SFixed64: 385 return "Long" 386 case pgs.DoubleT: 387 return "Double" 388 case pgs.FloatT: 389 return "Float" 390 case pgs.BoolT: 391 return "Boolean" 392 case pgs.StringT: 393 return "String" 394 case pgs.BytesT: 395 return "com.google.protobuf.ByteString" 396 default: 397 return "Object" 398 } 399} 400 401func (fns javaFuncs) javaTypeLiteralSuffixFor(ctx shared.RuleContext) string { 402 t := ctx.Field.Type() 403 404 if t.IsMap() { 405 switch ctx.AccessorOverride { 406 case "key": 407 return fns.javaTypeLiteralSuffixForPrototype(t.Key().ProtoType()) 408 case "value": 409 return fns.javaTypeLiteralSuffixForPrototype(t.Element().ProtoType()) 410 } 411 } 412 413 if t.IsEmbed() { 414 if embed := t.Embed(); embed.IsWellKnown() { 415 switch embed.WellKnownType() { 416 case pgs.Int64ValueWKT, pgs.UInt64ValueWKT: 417 return "L" 418 case pgs.FloatValueWKT: 419 return "F" 420 case pgs.DoubleValueWKT: 421 return "D" 422 } 423 } 424 } 425 426 return fns.javaTypeLiteralSuffixForPrototype(t.ProtoType()) 427} 428 429func (fns javaFuncs) javaTypeLiteralSuffixForPrototype(t pgs.ProtoType) string { 430 switch t { 431 case pgs.Int64T, pgs.UInt64T, pgs.SInt64, pgs.Fixed64T, pgs.SFixed64: 432 return "L" 433 case pgs.FloatT: 434 return "F" 435 case pgs.DoubleT: 436 return "D" 437 default: 438 return "" 439 } 440} 441 442func (fns javaFuncs) javaStringEscape(s string) string { 443 s = fmt.Sprintf("%q", s) 444 s = s[1 : len(s)-1] 445 s = strings.Replace(s, `\u00`, `\x`, -1) 446 s = strings.Replace(s, `\x`, `\\x`, -1) 447 // s = strings.Replace(s, `\`, `\\`, -1) 448 s = strings.Replace(s, `"`, `\"`, -1) 449 return `"` + s + `"` 450} 451 452func (fns javaFuncs) camelCase(name pgs.Name) string { 453 return strcase.ToCamel(name.String()) 454} 455 456func (fns javaFuncs) byteArrayLit(bytes []uint8) string { 457 var sb string 458 sb += "new byte[]{" 459 for _, b := range bytes { 460 sb += fmt.Sprintf("(byte)%#x,", b) 461 } 462 sb += "}" 463 464 return sb 465} 466 467func (fns javaFuncs) durLit(dur *durationpb.Duration) string { 468 return fmt.Sprintf( 469 "io.envoyproxy.pgv.TimestampValidation.toDuration(%d,%d)", 470 dur.GetSeconds(), dur.GetNanos()) 471} 472 473func (fns javaFuncs) tsLit(ts *timestamppb.Timestamp) string { 474 return fmt.Sprintf( 475 "io.envoyproxy.pgv.TimestampValidation.toTimestamp(%d,%d)", 476 ts.GetSeconds(), ts.GetNanos()) 477} 478 479func (fns javaFuncs) oneofTypeName(f pgs.Field) pgsgo.TypeName { 480 return pgsgo.TypeName(fmt.Sprintf("%s", strings.ToUpper(f.Name().String()))) 481} 482 483func (fns javaFuncs) isOfFileType(o interface{}) bool { 484 switch o.(type) { 485 case pgs.File: 486 return true 487 default: 488 return false 489 } 490} 491 492func (fns javaFuncs) isOfMessageType(f pgs.Field) bool { 493 return f.Type().ProtoType() == pgs.MessageT 494} 495 496func (fns javaFuncs) isOfStringType(f pgs.Field) bool { 497 return f.Type().ProtoType() == pgs.StringT 498} 499 500func (fns javaFuncs) unwrap(ctx shared.RuleContext) (shared.RuleContext, error) { 501 ctx, err := ctx.Unwrap("wrapped") 502 if err != nil { 503 return ctx, err 504 } 505 ctx.AccessorOverride = fmt.Sprintf("%s.get%s()", fns.fieldAccessor(ctx.Field), 506 fns.camelCase(ctx.Field.Type().Embed().Fields()[0].Name())) 507 return ctx, nil 508} 509 510func (fns javaFuncs) renderConstants(tpl *template.Template) func(ctx shared.RuleContext) (string, error) { 511 return func(ctx shared.RuleContext) (string, error) { 512 var b bytes.Buffer 513 var err error 514 515 hasConstTemplate := false 516 for _, t := range tpl.Templates() { 517 if t.Name() == ctx.Typ+"Const" { 518 hasConstTemplate = true 519 } 520 } 521 522 if hasConstTemplate { 523 err = tpl.ExecuteTemplate(&b, ctx.Typ+"Const", ctx) 524 } 525 526 return b.String(), err 527 } 528} 529 530func (fns javaFuncs) constantName(ctx shared.RuleContext, rule string) string { 531 return strcase.ToScreamingSnake(ctx.Field.Name().String() + "_" + ctx.Index + "_" + rule) 532} 533