1package java 2 3import ( 4 "bytes" 5 "fmt" 6 "os" 7 "strings" 8 "text/template" 9 "unicode" 10 11 "github.com/envoyproxy/protoc-gen-validate/templates/shared" 12 "github.com/golang/protobuf/ptypes/duration" 13 "github.com/golang/protobuf/ptypes/timestamp" 14 "github.com/iancoleman/strcase" 15 pgs "github.com/lyft/protoc-gen-star" 16 pgsgo "github.com/lyft/protoc-gen-star/lang/go" 17) 18 19func RegisterIndex(tpl *template.Template, params pgs.Parameters) { 20 fns := javaFuncs{pgsgo.InitContext(params)} 21 22 tpl.Funcs(map[string]interface{}{ 23 "classNameFile": classNameFile, 24 "importsPvg": importsPvg, 25 "javaPackage": javaPackage, 26 "simpleName": fns.Name, 27 "qualifiedName": fns.qualifiedName, 28 }) 29} 30 31func Register(tpl *template.Template, params pgs.Parameters) { 32 fns := javaFuncs{pgsgo.InitContext(params)} 33 34 tpl.Funcs(map[string]interface{}{ 35 "accessor": fns.accessor, 36 "byteArrayLit": fns.byteArrayLit, 37 "camelCase": fns.camelCase, 38 "classNameFile": classNameFile, 39 "classNameMessage": classNameMessage, 40 "durLit": fns.durLit, 41 "fieldName": fns.fieldName, 42 "javaPackage": javaPackage, 43 "javaStringEscape": fns.javaStringEscape, 44 "javaTypeFor": fns.javaTypeFor, 45 "javaTypeLiteralSuffixFor": fns.javaTypeLiteralSuffixFor, 46 "hasAccessor": fns.hasAccessor, 47 "oneof": fns.oneofTypeName, 48 "sprintf": fmt.Sprintf, 49 "simpleName": fns.Name, 50 "tsLit": fns.tsLit, 51 "qualifiedName": fns.qualifiedName, 52 "isOfFileType": fns.isOfFileType, 53 "isOfMessageType": fns.isOfMessageType, 54 "isOfStringType": fns.isOfStringType, 55 "unwrap": fns.unwrap, 56 "renderConstants": fns.renderConstants(tpl), 57 "constantName": fns.constantName, 58 }) 59 60 template.Must(tpl.Parse(fileTpl)) 61 template.Must(tpl.New("msg").Parse(msgTpl)) 62 template.Must(tpl.New("msgInner").Parse(msgInnerTpl)) 63 64 template.Must(tpl.New("none").Parse(noneTpl)) 65 66 template.Must(tpl.New("float").Parse(numTpl)) 67 template.Must(tpl.New("floatConst").Parse(numConstTpl)) 68 template.Must(tpl.New("double").Parse(numTpl)) 69 template.Must(tpl.New("doubleConst").Parse(numConstTpl)) 70 template.Must(tpl.New("int32").Parse(numTpl)) 71 template.Must(tpl.New("int32Const").Parse(numConstTpl)) 72 template.Must(tpl.New("int64").Parse(numTpl)) 73 template.Must(tpl.New("int64Const").Parse(numConstTpl)) 74 template.Must(tpl.New("uint32").Parse(numTpl)) 75 template.Must(tpl.New("uint32Const").Parse(numConstTpl)) 76 template.Must(tpl.New("uint64").Parse(numTpl)) 77 template.Must(tpl.New("uint64Const").Parse(numConstTpl)) 78 template.Must(tpl.New("sint32").Parse(numTpl)) 79 template.Must(tpl.New("sint32Const").Parse(numConstTpl)) 80 template.Must(tpl.New("sint64").Parse(numTpl)) 81 template.Must(tpl.New("sint64Const").Parse(numConstTpl)) 82 template.Must(tpl.New("fixed32").Parse(numTpl)) 83 template.Must(tpl.New("fixed32Const").Parse(numConstTpl)) 84 template.Must(tpl.New("fixed64").Parse(numTpl)) 85 template.Must(tpl.New("fixed64Const").Parse(numConstTpl)) 86 template.Must(tpl.New("sfixed32").Parse(numTpl)) 87 template.Must(tpl.New("sfixed32Const").Parse(numConstTpl)) 88 template.Must(tpl.New("sfixed64").Parse(numTpl)) 89 template.Must(tpl.New("sfixed64Const").Parse(numConstTpl)) 90 91 template.Must(tpl.New("bool").Parse(boolTpl)) 92 template.Must(tpl.New("string").Parse(stringTpl)) 93 template.Must(tpl.New("stringConst").Parse(stringConstTpl)) 94 template.Must(tpl.New("bytes").Parse(bytesTpl)) 95 template.Must(tpl.New("bytesConst").Parse(bytesConstTpl)) 96 97 template.Must(tpl.New("any").Parse(anyTpl)) 98 template.Must(tpl.New("anyConst").Parse(anyConstTpl)) 99 template.Must(tpl.New("enum").Parse(enumTpl)) 100 template.Must(tpl.New("enumConst").Parse(enumConstTpl)) 101 template.Must(tpl.New("message").Parse(messageTpl)) 102 template.Must(tpl.New("repeated").Parse(repeatedTpl)) 103 template.Must(tpl.New("repeatedConst").Parse(repeatedConstTpl)) 104 template.Must(tpl.New("map").Parse(mapTpl)) 105 template.Must(tpl.New("mapConst").Parse(mapConstTpl)) 106 template.Must(tpl.New("oneOf").Parse(oneOfTpl)) 107 template.Must(tpl.New("oneOfConst").Parse(oneOfConstTpl)) 108 109 template.Must(tpl.New("required").Parse(requiredTpl)) 110 template.Must(tpl.New("timestamp").Parse(timestampTpl)) 111 template.Must(tpl.New("timestampConst").Parse(timestampConstTpl)) 112 template.Must(tpl.New("duration").Parse(durationTpl)) 113 template.Must(tpl.New("durationConst").Parse(durationConstTpl)) 114 template.Must(tpl.New("wrapper").Parse(wrapperTpl)) 115 template.Must(tpl.New("wrapperConst").Parse(wrapperConstTpl)) 116} 117 118type javaFuncs struct{ pgsgo.Context } 119 120func JavaFilePath(f pgs.File, ctx pgsgo.Context, tpl *template.Template) *pgs.FilePath { 121 // Don't generate validators for files that don't import PGV 122 if !importsPvg(f) { 123 return nil 124 } 125 126 fullPath := strings.Replace(javaPackage(f), ".", string(os.PathSeparator), -1) 127 fileName := classNameFile(f) + "Validator.java" 128 filePath := pgs.JoinPaths(fullPath, fileName) 129 return &filePath 130} 131 132func JavaMultiFilePath(f pgs.File, m pgs.Message) pgs.FilePath { 133 fullPath := strings.Replace(javaPackage(f), ".", string(os.PathSeparator), -1) 134 fileName := classNameMessage(m) + "Validator.java" 135 filePath := pgs.JoinPaths(fullPath, fileName) 136 return filePath 137} 138 139func importsPvg(f pgs.File) bool { 140 for _, dep := range f.Descriptor().Dependency { 141 if strings.HasSuffix(dep, "validate.proto") { 142 return true 143 } 144 } 145 return false 146} 147 148func classNameFile(f pgs.File) string { 149 // Explicit outer class name overrides implicit name 150 options := f.Descriptor().GetOptions() 151 if options != nil && !options.GetJavaMultipleFiles() && options.JavaOuterClassname != nil { 152 return options.GetJavaOuterClassname() 153 } 154 155 protoName := pgs.FilePath(f.Name().String()).BaseName() 156 157 className := sanitizeClassName(protoName) 158 className = appendOuterClassName(className, f) 159 160 return className 161} 162 163func classNameMessage(m pgs.Message) string { 164 return sanitizeClassName(m.Name().String()) 165} 166 167func sanitizeClassName(className string) string { 168 className = makeInvalidClassnameCharactersUnderscores(className) 169 className = underscoreBetweenConsecutiveUppercase(className) 170 className = strcase.ToCamel(strcase.ToSnake(className)) 171 className = upperCaseAfterNumber(className) 172 return className 173} 174 175func javaPackage(file pgs.File) string { 176 // Explicit java package overrides implicit package 177 options := file.Descriptor().GetOptions() 178 if options != nil && options.JavaPackage != nil { 179 return options.GetJavaPackage() 180 } 181 return file.Package().ProtoName().String() 182} 183 184func (fns javaFuncs) qualifiedName(entity pgs.Entity) string { 185 file, isFile := entity.(pgs.File) 186 if isFile { 187 name := javaPackage(file) 188 if file.Descriptor().GetOptions() != nil { 189 if !file.Descriptor().GetOptions().GetJavaMultipleFiles() { 190 name += ("." + classNameFile(file)) 191 } 192 } else { 193 name += ("." + classNameFile(file)) 194 } 195 return name 196 } 197 198 message, isMessage := entity.(pgs.Message) 199 if isMessage && message.Parent() != nil { 200 // recurse 201 return fns.qualifiedName(message.Parent()) + "." + entity.Name().String() 202 } 203 204 enum, isEnum := entity.(pgs.Enum) 205 if isEnum && enum.Parent() != nil { 206 // recurse 207 return fns.qualifiedName(enum.Parent()) + "." + entity.Name().String() 208 } 209 210 return entity.Name().String() 211} 212 213// Replace invalid identifier characters with an underscore 214func makeInvalidClassnameCharactersUnderscores(name string) string { 215 var sb string 216 for _, c := range name { 217 switch { 218 case c >= '0' && c <= '9': 219 sb += string(c) 220 case c >= 'a' && c <= 'z': 221 sb += string(c) 222 case c >= 'A' && c <= 'Z': 223 sb += string(c) 224 default: 225 sb += "_" 226 } 227 } 228 return sb 229} 230 231func upperCaseAfterNumber(name string) string { 232 var sb string 233 var p rune 234 235 for _, c := range name { 236 if unicode.IsDigit(p) { 237 sb += string(unicode.ToUpper(c)) 238 } else { 239 sb += string(c) 240 } 241 p = c 242 } 243 return sb 244} 245 246func underscoreBetweenConsecutiveUppercase(name string) string { 247 var sb string 248 var p rune 249 250 for _, c := range name { 251 if unicode.IsUpper(p) && unicode.IsUpper(c) { 252 sb += "_" + string(c) 253 } else { 254 sb += string(c) 255 } 256 p = c 257 } 258 return sb 259} 260 261func appendOuterClassName(outerClassName string, file pgs.File) string { 262 conflict := false 263 264 for _, enum := range file.Enums() { 265 if enum.Name().String() == outerClassName { 266 conflict = true 267 } 268 } 269 270 for _, message := range file.Messages() { 271 if message.Name().String() == outerClassName { 272 conflict = true 273 } 274 } 275 276 for _, service := range file.Services() { 277 if service.Name().String() == outerClassName { 278 conflict = true 279 } 280 } 281 282 if conflict { 283 return outerClassName + "OuterClass" 284 } else { 285 return outerClassName 286 } 287} 288 289func (fns javaFuncs) accessor(ctx shared.RuleContext) string { 290 if ctx.AccessorOverride != "" { 291 return ctx.AccessorOverride 292 } 293 return fns.fieldAccessor(ctx.Field) 294} 295 296func (fns javaFuncs) fieldAccessor(f pgs.Field) string { 297 fieldName := strcase.ToCamel(f.Name().String()) 298 if f.Type().IsMap() { 299 fieldName += "Map" 300 } 301 if f.Type().IsRepeated() { 302 fieldName += "List" 303 } 304 305 fieldName = upperCaseAfterNumber(fieldName) 306 return fmt.Sprintf("proto.get%s()", fieldName) 307} 308 309func (fns javaFuncs) hasAccessor(ctx shared.RuleContext) string { 310 if ctx.AccessorOverride != "" { 311 return "true" 312 } 313 fiedlName := strcase.ToCamel(ctx.Field.Name().String()) 314 fiedlName = upperCaseAfterNumber(fiedlName) 315 return "proto.has" + fiedlName + "()" 316} 317 318func (fns javaFuncs) fieldName(ctx shared.RuleContext) string { 319 return ctx.Field.Name().String() 320} 321 322func (fns javaFuncs) javaTypeFor(ctx shared.RuleContext) string { 323 t := ctx.Field.Type() 324 325 // Map key and value types 326 if t.IsMap() { 327 switch ctx.AccessorOverride { 328 case "key": 329 return fns.javaTypeForProtoType(t.Key().ProtoType()) 330 case "value": 331 return fns.javaTypeForProtoType(t.Element().ProtoType()) 332 } 333 } 334 335 if t.IsEmbed() { 336 if embed := t.Embed(); embed.IsWellKnown() { 337 switch embed.WellKnownType() { 338 case pgs.AnyWKT: 339 return "String" 340 case pgs.DurationWKT: 341 return "com.google.protobuf.Duration" 342 case pgs.TimestampWKT: 343 return "com.google.protobuf.Timestamp" 344 case pgs.Int32ValueWKT, pgs.UInt32ValueWKT: 345 return "Integer" 346 case pgs.Int64ValueWKT, pgs.UInt64ValueWKT: 347 return "Long" 348 case pgs.DoubleValueWKT: 349 return "Double" 350 case pgs.FloatValueWKT: 351 return "Float" 352 } 353 } 354 } 355 356 if t.IsRepeated() { 357 if t.ProtoType() == pgs.MessageT { 358 return fns.qualifiedName(t.Element().Embed()) 359 } else if t.ProtoType() == pgs.EnumT { 360 return fns.qualifiedName(t.Element().Enum()) 361 } 362 } 363 364 if t.IsEnum() { 365 return fns.qualifiedName(t.Enum()) 366 } 367 368 return fns.javaTypeForProtoType(t.ProtoType()) 369} 370 371func (fns javaFuncs) javaTypeForProtoType(t pgs.ProtoType) string { 372 373 switch t { 374 case pgs.Int32T, pgs.UInt32T, pgs.SInt32, pgs.Fixed32T, pgs.SFixed32: 375 return "Integer" 376 case pgs.Int64T, pgs.UInt64T, pgs.SInt64, pgs.Fixed64T, pgs.SFixed64: 377 return "Long" 378 case pgs.DoubleT: 379 return "Double" 380 case pgs.FloatT: 381 return "Float" 382 case pgs.BoolT: 383 return "Boolean" 384 case pgs.StringT: 385 return "String" 386 case pgs.BytesT: 387 return "com.google.protobuf.ByteString" 388 default: 389 return "Object" 390 } 391} 392 393func (fns javaFuncs) javaTypeLiteralSuffixFor(f pgs.Field) string { 394 switch f.Type().ProtoType() { 395 case pgs.Int64T, pgs.UInt64T, pgs.SInt64, pgs.Fixed64T, pgs.SFixed64: 396 return "L" 397 case pgs.FloatT: 398 return "F" 399 case pgs.DoubleT: 400 return "D" 401 } 402 403 emb := f.Type().Embed() 404 if emb != nil && emb.IsWellKnown() { 405 switch emb.WellKnownType() { 406 case pgs.Int64ValueWKT, pgs.UInt64ValueWKT: 407 return "L" 408 case pgs.FloatValueWKT: 409 return "F" 410 case pgs.DoubleValueWKT: 411 return "D" 412 } 413 } 414 415 return "" 416} 417 418func (fns javaFuncs) javaStringEscape(s string) string { 419 s = fmt.Sprintf("%q", s) 420 s = s[1 : len(s)-1] 421 s = strings.Replace(s, `\u00`, `\x`, -1) 422 s = strings.Replace(s, `\x`, `\\x`, -1) 423 // s = strings.Replace(s, `\`, `\\`, -1) 424 s = strings.Replace(s, `"`, `\"`, -1) 425 return `"` + s + `"` 426} 427 428func (fns javaFuncs) camelCase(name pgs.Name) string { 429 return strcase.ToCamel(name.String()) 430} 431 432func (fns javaFuncs) byteArrayLit(bytes []uint8) string { 433 var sb string 434 sb += "new byte[]{" 435 for _, b := range bytes { 436 sb += fmt.Sprintf("(byte)%#x,", b) 437 } 438 sb += "}" 439 440 return sb 441} 442 443func (fns javaFuncs) durLit(dur *duration.Duration) string { 444 return fmt.Sprintf( 445 "io.envoyproxy.pgv.TimestampValidation.toDuration(%d,%d)", 446 dur.GetSeconds(), dur.GetNanos()) 447} 448 449func (fns javaFuncs) tsLit(ts *timestamp.Timestamp) string { 450 return fmt.Sprintf( 451 "io.envoyproxy.pgv.TimestampValidation.toTimestamp(%d,%d)", 452 ts.GetSeconds(), ts.GetNanos()) 453} 454 455func (fns javaFuncs) oneofTypeName(f pgs.Field) pgsgo.TypeName { 456 return pgsgo.TypeName(fmt.Sprintf("%s", strings.ToUpper(f.Name().String()))) 457} 458 459func (fns javaFuncs) isOfFileType(o interface{}) bool { 460 switch o.(type) { 461 case pgs.File: 462 return true 463 default: 464 return false 465 } 466} 467 468func (fns javaFuncs) isOfMessageType(f pgs.Field) bool { 469 return f.Type().ProtoType() == pgs.MessageT 470} 471 472func (fns javaFuncs) isOfStringType(f pgs.Field) bool { 473 return f.Type().ProtoType() == pgs.StringT 474} 475 476func (fns javaFuncs) unwrap(ctx shared.RuleContext) (shared.RuleContext, error) { 477 ctx, err := ctx.Unwrap("wrapped") 478 if err != nil { 479 return ctx, err 480 } 481 ctx.AccessorOverride = fmt.Sprintf("%s.get%s()", fns.fieldAccessor(ctx.Field), 482 fns.camelCase(ctx.Field.Type().Embed().Fields()[0].Name())) 483 return ctx, nil 484} 485 486func (fns javaFuncs) renderConstants(tpl *template.Template) func(ctx shared.RuleContext) (string, error) { 487 return func(ctx shared.RuleContext) (string, error) { 488 var b bytes.Buffer 489 var err error 490 491 hasConstTemplate := false 492 for _, t := range tpl.Templates() { 493 if t.Name() == ctx.Typ+"Const" { 494 hasConstTemplate = true 495 } 496 } 497 498 if hasConstTemplate { 499 err = tpl.ExecuteTemplate(&b, ctx.Typ+"Const", ctx) 500 } 501 502 return b.String(), err 503 } 504} 505 506func (fns javaFuncs) constantName(ctx shared.RuleContext, rule string) string { 507 return strcase.ToScreamingSnake(ctx.Field.Name().String() + "_" + ctx.Index + "_" + rule) 508} 509