1package module 2 3import ( 4 "reflect" 5 "regexp" 6 "time" 7 "unicode/utf8" 8 9 "github.com/golang/protobuf/proto" 10 "github.com/golang/protobuf/ptypes" 11 "github.com/golang/protobuf/ptypes/duration" 12 "github.com/golang/protobuf/ptypes/timestamp" 13 "github.com/lyft/protoc-gen-star" 14 "github.com/envoyproxy/protoc-gen-validate/validate" 15) 16 17type FieldType interface { 18 ProtoType() pgs.ProtoType 19 Embed() pgs.Message 20} 21 22type Repeatable interface { 23 IsRepeated() bool 24} 25 26func (m *Module) CheckRules(msg pgs.Message) { 27 m.Push("msg: " + msg.Name().String()) 28 defer m.Pop() 29 30 var disabled bool 31 _, err := msg.Extension(validate.E_Disabled, &disabled) 32 m.CheckErr(err, "unable to read validation extension from message") 33 34 if disabled { 35 m.Debug("validation disabled, skipping checks") 36 return 37 } 38 39 for _, f := range msg.Fields() { 40 m.Push(f.Name().String()) 41 42 var rules validate.FieldRules 43 _, err = f.Extension(validate.E_Rules, &rules) 44 m.CheckErr(err, "unable to read validation rules from field") 45 46 m.CheckFieldRules(f.Type(), &rules) 47 48 m.Pop() 49 } 50} 51 52func (m *Module) CheckFieldRules(typ FieldType, rules *validate.FieldRules) { 53 if rules == nil { 54 return 55 } 56 57 switch r := rules.Type.(type) { 58 case *validate.FieldRules_Float: 59 m.MustType(typ, pgs.FloatT, pgs.FloatValueWKT) 60 m.CheckFloat(r.Float) 61 case *validate.FieldRules_Double: 62 m.MustType(typ, pgs.DoubleT, pgs.DoubleValueWKT) 63 m.CheckDouble(r.Double) 64 case *validate.FieldRules_Int32: 65 m.MustType(typ, pgs.Int32T, pgs.Int32ValueWKT) 66 m.CheckInt32(r.Int32) 67 case *validate.FieldRules_Int64: 68 m.MustType(typ, pgs.Int64T, pgs.Int64ValueWKT) 69 m.CheckInt64(r.Int64) 70 case *validate.FieldRules_Uint32: 71 m.MustType(typ, pgs.UInt32T, pgs.UInt32ValueWKT) 72 m.CheckUInt32(r.Uint32) 73 case *validate.FieldRules_Uint64: 74 m.MustType(typ, pgs.UInt64T, pgs.UInt64ValueWKT) 75 m.CheckUInt64(r.Uint64) 76 case *validate.FieldRules_Sint32: 77 m.MustType(typ, pgs.SInt32, pgs.UnknownWKT) 78 m.CheckSInt32(r.Sint32) 79 case *validate.FieldRules_Sint64: 80 m.MustType(typ, pgs.SInt64, pgs.UnknownWKT) 81 m.CheckSInt64(r.Sint64) 82 case *validate.FieldRules_Fixed32: 83 m.MustType(typ, pgs.Fixed32T, pgs.UnknownWKT) 84 m.CheckFixed32(r.Fixed32) 85 case *validate.FieldRules_Fixed64: 86 m.MustType(typ, pgs.Fixed64T, pgs.UnknownWKT) 87 m.CheckFixed64(r.Fixed64) 88 case *validate.FieldRules_Sfixed32: 89 m.MustType(typ, pgs.SFixed32, pgs.UnknownWKT) 90 m.CheckSFixed32(r.Sfixed32) 91 case *validate.FieldRules_Sfixed64: 92 m.MustType(typ, pgs.SFixed64, pgs.UnknownWKT) 93 m.CheckSFixed64(r.Sfixed64) 94 case *validate.FieldRules_Bool: 95 m.MustType(typ, pgs.BoolT, pgs.BoolValueWKT) 96 case *validate.FieldRules_String_: 97 m.MustType(typ, pgs.StringT, pgs.StringValueWKT) 98 m.CheckString(r.String_) 99 case *validate.FieldRules_Bytes: 100 m.MustType(typ, pgs.BytesT, pgs.BytesValueWKT) 101 m.CheckBytes(r.Bytes) 102 case *validate.FieldRules_Enum: 103 m.MustType(typ, pgs.EnumT, pgs.UnknownWKT) 104 m.CheckEnum(typ, r.Enum) 105 case *validate.FieldRules_Message: 106 m.MustType(typ, pgs.MessageT, pgs.UnknownWKT) 107 case *validate.FieldRules_Repeated: 108 m.CheckRepeated(typ, r.Repeated) 109 case *validate.FieldRules_Map: 110 m.CheckMap(typ, r.Map) 111 case *validate.FieldRules_Any: 112 m.CheckAny(typ, r.Any) 113 case *validate.FieldRules_Duration: 114 m.CheckDuration(typ, r.Duration) 115 case *validate.FieldRules_Timestamp: 116 m.CheckTimestamp(typ, r.Timestamp) 117 case nil: // noop 118 default: 119 m.Failf("unknown rule type (%T)", rules.Type) 120 } 121} 122 123func (m *Module) MustType(typ FieldType, pt pgs.ProtoType, wrapper pgs.WellKnownType) { 124 if emb := typ.Embed(); emb != nil && emb.IsWellKnown() && emb.WellKnownType() == wrapper { 125 m.MustType(emb.Fields()[0].Type(), pt, pgs.UnknownWKT) 126 return 127 } 128 129 if typ, ok := typ.(Repeatable); ok { 130 m.Assert(!typ.IsRepeated(), 131 "repeated rule should be used for repeated fields") 132 } 133 134 m.Assert(typ.ProtoType() == pt, 135 " expected rules for ", 136 typ.ProtoType().Proto(), 137 " but got ", 138 pt.Proto(), 139 ) 140} 141 142func (m *Module) CheckFloat(r *validate.FloatRules) { 143 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 144} 145 146func (m *Module) CheckDouble(r *validate.DoubleRules) { 147 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 148} 149 150func (m *Module) CheckInt32(r *validate.Int32Rules) { 151 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 152} 153 154func (m *Module) CheckInt64(r *validate.Int64Rules) { 155 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 156} 157 158func (m *Module) CheckUInt32(r *validate.UInt32Rules) { 159 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 160} 161 162func (m *Module) CheckUInt64(r *validate.UInt64Rules) { 163 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 164} 165 166func (m *Module) CheckSInt32(r *validate.SInt32Rules) { 167 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 168} 169 170func (m *Module) CheckSInt64(r *validate.SInt64Rules) { 171 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 172} 173 174func (m *Module) CheckFixed32(r *validate.Fixed32Rules) { 175 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 176} 177 178func (m *Module) CheckFixed64(r *validate.Fixed64Rules) { 179 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 180} 181 182func (m *Module) CheckSFixed32(r *validate.SFixed32Rules) { 183 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 184} 185 186func (m *Module) CheckSFixed64(r *validate.SFixed64Rules) { 187 m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte) 188} 189 190func (m *Module) CheckString(r *validate.StringRules) { 191 m.checkLen(r.Len, r.MinLen, r.MaxLen) 192 m.checkLen(r.LenBytes, r.MinBytes, r.MaxBytes) 193 m.checkMinMax(r.MinLen, r.MaxLen) 194 m.checkMinMax(r.MinBytes, r.MaxBytes) 195 m.checkIns(len(r.In), len(r.NotIn)) 196 m.checkPattern(r.Pattern, len(r.In)) 197 198 if r.MaxLen != nil { 199 max := int(r.GetMaxLen()) 200 m.Assert(utf8.RuneCountInString(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_len`") 201 m.Assert(utf8.RuneCountInString(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_len`") 202 m.Assert(utf8.RuneCountInString(r.GetContains()) <= max, "`contains` length exceeds the `max_len`") 203 204 m.Assert( 205 r.MaxBytes == nil || r.GetMaxBytes() >= r.GetMaxLen(), 206 "`max_len` cannot exceed `max_bytes`") 207 } 208 209 if r.MaxBytes != nil { 210 max := int(r.GetMaxBytes()) 211 m.Assert(len(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_bytes`") 212 m.Assert(len(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_bytes`") 213 m.Assert(len(r.GetContains()) <= max, "`contains` length exceeds the `max_bytes`") 214 } 215} 216 217func (m *Module) CheckBytes(r *validate.BytesRules) { 218 m.checkMinMax(r.MinLen, r.MaxLen) 219 m.checkIns(len(r.In), len(r.NotIn)) 220 m.checkPattern(r.Pattern, len(r.In)) 221 222 if r.MaxLen != nil { 223 max := int(r.GetMaxLen()) 224 m.Assert(len(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_len`") 225 m.Assert(len(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_len`") 226 m.Assert(len(r.GetContains()) <= max, "`contains` length exceeds the `max_len`") 227 } 228} 229 230func (m *Module) CheckEnum(ft FieldType, r *validate.EnumRules) { 231 m.checkIns(len(r.In), len(r.NotIn)) 232 233 if r.GetDefinedOnly() && len(r.In) > 0 { 234 typ, ok := ft.(interface { 235 Enum() pgs.Enum 236 }) 237 238 if !ok { 239 m.Failf("unexpected field type (%T)", ft) 240 } 241 242 defined := typ.Enum().Values() 243 vals := make(map[int32]struct{}, len(defined)) 244 245 for _, val := range defined { 246 vals[val.Value()] = struct{}{} 247 } 248 249 for _, in := range r.In { 250 if _, ok = vals[in]; !ok { 251 m.Failf("undefined `in` value (%d) conflicts with `defined_only` rule") 252 } 253 } 254 } 255} 256 257func (m *Module) CheckMessage(ft FieldType, r *validate.MessageRules) { 258 if !r.GetSkip() { 259 m.CheckRules(m.mustFieldType(ft).Embed()) 260 } 261} 262 263func (m *Module) CheckRepeated(ft FieldType, r *validate.RepeatedRules) { 264 typ := m.mustFieldType(ft) 265 266 m.Assert(typ.IsRepeated(), "field is not repeated but got repeated rules") 267 268 m.checkMinMax(r.MinItems, r.MaxItems) 269 270 if r.GetUnique() { 271 m.Assert( 272 !typ.Element().IsEmbed(), 273 "unique rule is only applicable for scalar types") 274 } 275 276 m.Push("items") 277 m.CheckFieldRules(typ.Element(), r.Items) 278 m.Pop() 279} 280 281func (m *Module) CheckMap(ft FieldType, r *validate.MapRules) { 282 typ := m.mustFieldType(ft) 283 284 m.Assert(typ.IsMap(), "field is not a map but got map rules") 285 286 m.checkMinMax(r.MinPairs, r.MaxPairs) 287 288 if r.GetNoSparse() { 289 m.Assert( 290 typ.Element().IsEmbed(), 291 "no_sparse rule is only applicable for embedded message types", 292 ) 293 } 294 295 m.Push("keys") 296 m.CheckFieldRules(typ.Key(), r.Keys) 297 m.Pop() 298 299 m.Push("values") 300 m.CheckFieldRules(typ.Element(), r.Values) 301 m.Pop() 302} 303 304func (m *Module) CheckAny(ft FieldType, r *validate.AnyRules) { 305 m.checkIns(len(r.In), len(r.NotIn)) 306} 307 308func (m *Module) CheckDuration(ft FieldType, r *validate.DurationRules) { 309 m.checkNums( 310 len(r.GetIn()), 311 len(r.GetNotIn()), 312 m.checkDur(r.GetConst()), 313 m.checkDur(r.GetLt()), 314 m.checkDur(r.GetLte()), 315 m.checkDur(r.GetGt()), 316 m.checkDur(r.GetGte())) 317 318 for _, v := range r.GetIn() { 319 m.Assert(v != nil, "cannot have nil values in `in`") 320 m.checkDur(v) 321 } 322 323 for _, v := range r.GetNotIn() { 324 m.Assert(v != nil, "cannot have nil values in `not_in`") 325 m.checkDur(v) 326 } 327} 328 329func (m *Module) CheckTimestamp(ft FieldType, r *validate.TimestampRules) { 330 m.checkNums(0, 0, 331 m.checkTS(r.GetConst()), 332 m.checkTS(r.GetLt()), 333 m.checkTS(r.GetLte()), 334 m.checkTS(r.GetGt()), 335 m.checkTS(r.GetGte())) 336 337 m.Assert( 338 (r.LtNow == nil && r.GtNow == nil) || (r.Lt == nil && r.Lte == nil && r.Gt == nil && r.Gte == nil), 339 "`now` rules cannot be mixed with absolute `lt/gt` rules") 340 341 m.Assert( 342 r.Within == nil || (r.Lt == nil && r.Lte == nil && r.Gt == nil && r.Gte == nil), 343 "`within` rule cannot be used with absolute `lt/gt` rules") 344 345 m.Assert( 346 r.LtNow == nil || r.GtNow == nil, 347 "both `now` rules cannot be used together") 348 349 dur := m.checkDur(r.Within) 350 m.Assert( 351 dur == nil || *dur > 0, 352 "`within` rule must be positive and non-zero") 353} 354 355func (m *Module) mustFieldType(ft FieldType) pgs.FieldType { 356 typ, ok := ft.(pgs.FieldType) 357 if !ok { 358 m.Failf("unexpected field type (%T)", ft) 359 } 360 361 return typ 362} 363 364func (m *Module) checkNums(in, notIn int, ci, lti, ltei, gti, gtei interface{}) { 365 m.checkIns(in, notIn) 366 367 c := reflect.ValueOf(ci) 368 lt, lte := reflect.ValueOf(lti), reflect.ValueOf(ltei) 369 gt, gte := reflect.ValueOf(gti), reflect.ValueOf(gtei) 370 371 m.Assert( 372 c.IsNil() || 373 in == 0 && notIn == 0 && 374 lt.IsNil() && lte.IsNil() && 375 gt.IsNil() && gte.IsNil(), 376 "`const` can be the only rule on a field", 377 ) 378 379 m.Assert( 380 in == 0 || 381 lt.IsNil() && lte.IsNil() && 382 gt.IsNil() && gte.IsNil(), 383 "cannot have both `in` and range constraint rules on the same field", 384 ) 385 386 m.Assert( 387 lt.IsNil() || lte.IsNil(), 388 "cannot have both `lt` and `lte` rules on the same field", 389 ) 390 391 m.Assert( 392 gt.IsNil() || gte.IsNil(), 393 "cannot have both `gt` and `gte` rules on the same field", 394 ) 395 396 if !lt.IsNil() { 397 m.Assert(gt.IsNil() || !reflect.DeepEqual(lti, gti), 398 "cannot have equal `gt` and `lt` rules on the same field") 399 m.Assert(gte.IsNil() || !reflect.DeepEqual(lti, gtei), 400 "cannot have equal `gte` and `lt` rules on the same field") 401 } else if !lte.IsNil() { 402 m.Assert(gt.IsNil() || !reflect.DeepEqual(ltei, gti), 403 "cannot have equal `gt` and `lte` rules on the same field") 404 m.Assert(gte.IsNil() || !reflect.DeepEqual(ltei, gtei), 405 "use `const` instead of equal `lte` and `gte` rules") 406 } 407} 408 409func (m *Module) checkIns(in, notIn int) { 410 m.Assert( 411 in == 0 || notIn == 0, 412 "cannot have both `in` and `not_in` rules on the same field") 413} 414 415func (m *Module) checkMinMax(min, max *uint64) { 416 if min == nil || max == nil { 417 return 418 } 419 420 m.Assert( 421 *min <= *max, 422 "`min` value is greater than `max` value") 423} 424 425func (m *Module) checkLen(len, min, max *uint64) { 426 if len == nil { 427 return 428 } 429 430 m.Assert( 431 min == nil, 432 "cannot have both `len` and `min_len` rules on the same field") 433 434 m.Assert( 435 max == nil, 436 "cannot have both `len` and `max_len` rules on the same field") 437} 438 439func (m *Module) checkPattern(p *string, in int) { 440 if p != nil { 441 m.Assert(in == 0, "regex `pattern` and `in` rules are incompatible") 442 _, err := regexp.Compile(*p) 443 m.CheckErr(err, "unable to parse regex `pattern`") 444 } 445} 446 447func (m *Module) checkDur(d *duration.Duration) *time.Duration { 448 if d == nil { 449 return nil 450 } 451 452 dur, err := ptypes.Duration(d) 453 m.CheckErr(err, "could not resolve duration") 454 return &dur 455} 456 457func (m *Module) checkTS(ts *timestamp.Timestamp) *int64 { 458 if ts == nil { 459 return nil 460 } 461 462 t, err := ptypes.Timestamp(ts) 463 m.CheckErr(err, "could not resolve timestamp") 464 return proto.Int64(t.UnixNano()) 465} 466 467