1package module
2
3import (
4	"reflect"
5	"regexp"
6	"time"
7	"unicode/utf8"
8
9	"github.com/golang/protobuf/proto"
10	"github.com/golang/protobuf/ptypes"
11	"github.com/golang/protobuf/ptypes/duration"
12	"github.com/golang/protobuf/ptypes/timestamp"
13	"github.com/lyft/protoc-gen-star"
14	"github.com/envoyproxy/protoc-gen-validate/validate"
15)
16
17type FieldType interface {
18	ProtoType() pgs.ProtoType
19	Embed() pgs.Message
20}
21
22type Repeatable interface {
23	IsRepeated() bool
24}
25
26func (m *Module) CheckRules(msg pgs.Message) {
27	m.Push("msg: " + msg.Name().String())
28	defer m.Pop()
29
30	var disabled bool
31	_, err := msg.Extension(validate.E_Disabled, &disabled)
32	m.CheckErr(err, "unable to read validation extension from message")
33
34	if disabled {
35		m.Debug("validation disabled, skipping checks")
36		return
37	}
38
39	for _, f := range msg.Fields() {
40		m.Push(f.Name().String())
41
42		var rules validate.FieldRules
43		_, err = f.Extension(validate.E_Rules, &rules)
44		m.CheckErr(err, "unable to read validation rules from field")
45
46		m.CheckFieldRules(f.Type(), &rules)
47
48		m.Pop()
49	}
50}
51
52func (m *Module) CheckFieldRules(typ FieldType, rules *validate.FieldRules) {
53	if rules == nil {
54		return
55	}
56
57	switch r := rules.Type.(type) {
58	case *validate.FieldRules_Float:
59		m.MustType(typ, pgs.FloatT, pgs.FloatValueWKT)
60		m.CheckFloat(r.Float)
61	case *validate.FieldRules_Double:
62		m.MustType(typ, pgs.DoubleT, pgs.DoubleValueWKT)
63		m.CheckDouble(r.Double)
64	case *validate.FieldRules_Int32:
65		m.MustType(typ, pgs.Int32T, pgs.Int32ValueWKT)
66		m.CheckInt32(r.Int32)
67	case *validate.FieldRules_Int64:
68		m.MustType(typ, pgs.Int64T, pgs.Int64ValueWKT)
69		m.CheckInt64(r.Int64)
70	case *validate.FieldRules_Uint32:
71		m.MustType(typ, pgs.UInt32T, pgs.UInt32ValueWKT)
72		m.CheckUInt32(r.Uint32)
73	case *validate.FieldRules_Uint64:
74		m.MustType(typ, pgs.UInt64T, pgs.UInt64ValueWKT)
75		m.CheckUInt64(r.Uint64)
76	case *validate.FieldRules_Sint32:
77		m.MustType(typ, pgs.SInt32, pgs.UnknownWKT)
78		m.CheckSInt32(r.Sint32)
79	case *validate.FieldRules_Sint64:
80		m.MustType(typ, pgs.SInt64, pgs.UnknownWKT)
81		m.CheckSInt64(r.Sint64)
82	case *validate.FieldRules_Fixed32:
83		m.MustType(typ, pgs.Fixed32T, pgs.UnknownWKT)
84		m.CheckFixed32(r.Fixed32)
85	case *validate.FieldRules_Fixed64:
86		m.MustType(typ, pgs.Fixed64T, pgs.UnknownWKT)
87		m.CheckFixed64(r.Fixed64)
88	case *validate.FieldRules_Sfixed32:
89		m.MustType(typ, pgs.SFixed32, pgs.UnknownWKT)
90		m.CheckSFixed32(r.Sfixed32)
91	case *validate.FieldRules_Sfixed64:
92		m.MustType(typ, pgs.SFixed64, pgs.UnknownWKT)
93		m.CheckSFixed64(r.Sfixed64)
94	case *validate.FieldRules_Bool:
95		m.MustType(typ, pgs.BoolT, pgs.BoolValueWKT)
96	case *validate.FieldRules_String_:
97		m.MustType(typ, pgs.StringT, pgs.StringValueWKT)
98		m.CheckString(r.String_)
99	case *validate.FieldRules_Bytes:
100		m.MustType(typ, pgs.BytesT, pgs.BytesValueWKT)
101		m.CheckBytes(r.Bytes)
102	case *validate.FieldRules_Enum:
103		m.MustType(typ, pgs.EnumT, pgs.UnknownWKT)
104		m.CheckEnum(typ, r.Enum)
105	case *validate.FieldRules_Message:
106		m.MustType(typ, pgs.MessageT, pgs.UnknownWKT)
107	case *validate.FieldRules_Repeated:
108		m.CheckRepeated(typ, r.Repeated)
109	case *validate.FieldRules_Map:
110		m.CheckMap(typ, r.Map)
111	case *validate.FieldRules_Any:
112		m.CheckAny(typ, r.Any)
113	case *validate.FieldRules_Duration:
114		m.CheckDuration(typ, r.Duration)
115	case *validate.FieldRules_Timestamp:
116		m.CheckTimestamp(typ, r.Timestamp)
117	case nil: // noop
118	default:
119		m.Failf("unknown rule type (%T)", rules.Type)
120	}
121}
122
123func (m *Module) MustType(typ FieldType, pt pgs.ProtoType, wrapper pgs.WellKnownType) {
124	if emb := typ.Embed(); emb != nil && emb.IsWellKnown() && emb.WellKnownType() == wrapper {
125		m.MustType(emb.Fields()[0].Type(), pt, pgs.UnknownWKT)
126		return
127	}
128
129	if typ, ok := typ.(Repeatable); ok {
130		m.Assert(!typ.IsRepeated(),
131			"repeated rule should be used for repeated fields")
132	}
133
134	m.Assert(typ.ProtoType() == pt,
135		" expected rules for ",
136		typ.ProtoType().Proto(),
137		" but got ",
138		pt.Proto(),
139	)
140}
141
142func (m *Module) CheckFloat(r *validate.FloatRules) {
143	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
144}
145
146func (m *Module) CheckDouble(r *validate.DoubleRules) {
147	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
148}
149
150func (m *Module) CheckInt32(r *validate.Int32Rules) {
151	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
152}
153
154func (m *Module) CheckInt64(r *validate.Int64Rules) {
155	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
156}
157
158func (m *Module) CheckUInt32(r *validate.UInt32Rules) {
159	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
160}
161
162func (m *Module) CheckUInt64(r *validate.UInt64Rules) {
163	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
164}
165
166func (m *Module) CheckSInt32(r *validate.SInt32Rules) {
167	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
168}
169
170func (m *Module) CheckSInt64(r *validate.SInt64Rules) {
171	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
172}
173
174func (m *Module) CheckFixed32(r *validate.Fixed32Rules) {
175	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
176}
177
178func (m *Module) CheckFixed64(r *validate.Fixed64Rules) {
179	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
180}
181
182func (m *Module) CheckSFixed32(r *validate.SFixed32Rules) {
183	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
184}
185
186func (m *Module) CheckSFixed64(r *validate.SFixed64Rules) {
187	m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
188}
189
190func (m *Module) CheckString(r *validate.StringRules) {
191	m.checkLen(r.Len, r.MinLen, r.MaxLen)
192	m.checkLen(r.LenBytes, r.MinBytes, r.MaxBytes)
193	m.checkMinMax(r.MinLen, r.MaxLen)
194	m.checkMinMax(r.MinBytes, r.MaxBytes)
195	m.checkIns(len(r.In), len(r.NotIn))
196	m.checkPattern(r.Pattern, len(r.In))
197
198	if r.MaxLen != nil {
199		max := int(r.GetMaxLen())
200		m.Assert(utf8.RuneCountInString(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_len`")
201		m.Assert(utf8.RuneCountInString(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_len`")
202		m.Assert(utf8.RuneCountInString(r.GetContains()) <= max, "`contains` length exceeds the `max_len`")
203
204		m.Assert(
205			r.MaxBytes == nil || r.GetMaxBytes() >= r.GetMaxLen(),
206			"`max_len` cannot exceed `max_bytes`")
207	}
208
209	if r.MaxBytes != nil {
210		max := int(r.GetMaxBytes())
211		m.Assert(len(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_bytes`")
212		m.Assert(len(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_bytes`")
213		m.Assert(len(r.GetContains()) <= max, "`contains` length exceeds the `max_bytes`")
214	}
215}
216
217func (m *Module) CheckBytes(r *validate.BytesRules) {
218	m.checkMinMax(r.MinLen, r.MaxLen)
219	m.checkIns(len(r.In), len(r.NotIn))
220	m.checkPattern(r.Pattern, len(r.In))
221
222	if r.MaxLen != nil {
223		max := int(r.GetMaxLen())
224		m.Assert(len(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_len`")
225		m.Assert(len(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_len`")
226		m.Assert(len(r.GetContains()) <= max, "`contains` length exceeds the `max_len`")
227	}
228}
229
230func (m *Module) CheckEnum(ft FieldType, r *validate.EnumRules) {
231	m.checkIns(len(r.In), len(r.NotIn))
232
233	if r.GetDefinedOnly() && len(r.In) > 0 {
234		typ, ok := ft.(interface {
235			Enum() pgs.Enum
236		})
237
238		if !ok {
239			m.Failf("unexpected field type (%T)", ft)
240		}
241
242		defined := typ.Enum().Values()
243		vals := make(map[int32]struct{}, len(defined))
244
245		for _, val := range defined {
246			vals[val.Value()] = struct{}{}
247		}
248
249		for _, in := range r.In {
250			if _, ok = vals[in]; !ok {
251				m.Failf("undefined `in` value (%d) conflicts with `defined_only` rule")
252			}
253		}
254	}
255}
256
257func (m *Module) CheckMessage(ft FieldType, r *validate.MessageRules) {
258	if !r.GetSkip() {
259		m.CheckRules(m.mustFieldType(ft).Embed())
260	}
261}
262
263func (m *Module) CheckRepeated(ft FieldType, r *validate.RepeatedRules) {
264	typ := m.mustFieldType(ft)
265
266	m.Assert(typ.IsRepeated(), "field is not repeated but got repeated rules")
267
268	m.checkMinMax(r.MinItems, r.MaxItems)
269
270	if r.GetUnique() {
271		m.Assert(
272			!typ.Element().IsEmbed(),
273			"unique rule is only applicable for scalar types")
274	}
275
276	m.Push("items")
277	m.CheckFieldRules(typ.Element(), r.Items)
278	m.Pop()
279}
280
281func (m *Module) CheckMap(ft FieldType, r *validate.MapRules) {
282	typ := m.mustFieldType(ft)
283
284	m.Assert(typ.IsMap(), "field is not a map but got map rules")
285
286	m.checkMinMax(r.MinPairs, r.MaxPairs)
287
288	if r.GetNoSparse() {
289		m.Assert(
290			typ.Element().IsEmbed(),
291			"no_sparse rule is only applicable for embedded message types",
292		)
293	}
294
295	m.Push("keys")
296	m.CheckFieldRules(typ.Key(), r.Keys)
297	m.Pop()
298
299	m.Push("values")
300	m.CheckFieldRules(typ.Element(), r.Values)
301	m.Pop()
302}
303
304func (m *Module) CheckAny(ft FieldType, r *validate.AnyRules) {
305	m.checkIns(len(r.In), len(r.NotIn))
306}
307
308func (m *Module) CheckDuration(ft FieldType, r *validate.DurationRules) {
309	m.checkNums(
310		len(r.GetIn()),
311		len(r.GetNotIn()),
312		m.checkDur(r.GetConst()),
313		m.checkDur(r.GetLt()),
314		m.checkDur(r.GetLte()),
315		m.checkDur(r.GetGt()),
316		m.checkDur(r.GetGte()))
317
318	for _, v := range r.GetIn() {
319		m.Assert(v != nil, "cannot have nil values in `in`")
320		m.checkDur(v)
321	}
322
323	for _, v := range r.GetNotIn() {
324		m.Assert(v != nil, "cannot have nil values in `not_in`")
325		m.checkDur(v)
326	}
327}
328
329func (m *Module) CheckTimestamp(ft FieldType, r *validate.TimestampRules) {
330	m.checkNums(0, 0,
331		m.checkTS(r.GetConst()),
332		m.checkTS(r.GetLt()),
333		m.checkTS(r.GetLte()),
334		m.checkTS(r.GetGt()),
335		m.checkTS(r.GetGte()))
336
337	m.Assert(
338		(r.LtNow == nil && r.GtNow == nil) || (r.Lt == nil && r.Lte == nil && r.Gt == nil && r.Gte == nil),
339		"`now` rules cannot be mixed with absolute `lt/gt` rules")
340
341	m.Assert(
342		r.Within == nil || (r.Lt == nil && r.Lte == nil && r.Gt == nil && r.Gte == nil),
343		"`within` rule cannot be used with absolute `lt/gt` rules")
344
345	m.Assert(
346		r.LtNow == nil || r.GtNow == nil,
347		"both `now` rules cannot be used together")
348
349	dur := m.checkDur(r.Within)
350	m.Assert(
351		dur == nil || *dur > 0,
352		"`within` rule must be positive and non-zero")
353}
354
355func (m *Module) mustFieldType(ft FieldType) pgs.FieldType {
356	typ, ok := ft.(pgs.FieldType)
357	if !ok {
358		m.Failf("unexpected field type (%T)", ft)
359	}
360
361	return typ
362}
363
364func (m *Module) checkNums(in, notIn int, ci, lti, ltei, gti, gtei interface{}) {
365	m.checkIns(in, notIn)
366
367	c := reflect.ValueOf(ci)
368	lt, lte := reflect.ValueOf(lti), reflect.ValueOf(ltei)
369	gt, gte := reflect.ValueOf(gti), reflect.ValueOf(gtei)
370
371	m.Assert(
372		c.IsNil() ||
373			in == 0 && notIn == 0 &&
374				lt.IsNil() && lte.IsNil() &&
375				gt.IsNil() && gte.IsNil(),
376		"`const` can be the only rule on a field",
377	)
378
379	m.Assert(
380		in == 0 ||
381			lt.IsNil() && lte.IsNil() &&
382				gt.IsNil() && gte.IsNil(),
383		"cannot have both `in` and range constraint rules on the same field",
384	)
385
386	m.Assert(
387		lt.IsNil() || lte.IsNil(),
388		"cannot have both `lt` and `lte` rules on the same field",
389	)
390
391	m.Assert(
392		gt.IsNil() || gte.IsNil(),
393		"cannot have both `gt` and `gte` rules on the same field",
394	)
395
396	if !lt.IsNil() {
397		m.Assert(gt.IsNil() || !reflect.DeepEqual(lti, gti),
398			"cannot have equal `gt` and `lt` rules on the same field")
399		m.Assert(gte.IsNil() || !reflect.DeepEqual(lti, gtei),
400			"cannot have equal `gte` and `lt` rules on the same field")
401	} else if !lte.IsNil() {
402		m.Assert(gt.IsNil() || !reflect.DeepEqual(ltei, gti),
403			"cannot have equal `gt` and `lte` rules on the same field")
404		m.Assert(gte.IsNil() || !reflect.DeepEqual(ltei, gtei),
405			"use `const` instead of equal `lte` and `gte` rules")
406	}
407}
408
409func (m *Module) checkIns(in, notIn int) {
410	m.Assert(
411		in == 0 || notIn == 0,
412		"cannot have both `in` and `not_in` rules on the same field")
413}
414
415func (m *Module) checkMinMax(min, max *uint64) {
416	if min == nil || max == nil {
417		return
418	}
419
420	m.Assert(
421		*min <= *max,
422		"`min` value is greater than `max` value")
423}
424
425func (m *Module) checkLen(len, min, max *uint64) {
426	if len == nil {
427		return
428	}
429
430	m.Assert(
431		min == nil,
432		"cannot have both `len` and `min_len` rules on the same field")
433
434	m.Assert(
435		max == nil,
436		"cannot have both `len` and `max_len` rules on the same field")
437}
438
439func (m *Module) checkPattern(p *string, in int) {
440	if p != nil {
441		m.Assert(in == 0, "regex `pattern` and `in` rules are incompatible")
442		_, err := regexp.Compile(*p)
443		m.CheckErr(err, "unable to parse regex `pattern`")
444	}
445}
446
447func (m *Module) checkDur(d *duration.Duration) *time.Duration {
448	if d == nil {
449		return nil
450	}
451
452	dur, err := ptypes.Duration(d)
453	m.CheckErr(err, "could not resolve duration")
454	return &dur
455}
456
457func (m *Module) checkTS(ts *timestamp.Timestamp) *int64 {
458	if ts == nil {
459		return nil
460	}
461
462	t, err := ptypes.Timestamp(ts)
463	m.CheckErr(err, "could not resolve timestamp")
464	return proto.Int64(t.UnixNano())
465}
466
467