1package validator
2
3import (
4	"fmt"
5	"reflect"
6	"strings"
7	"sync"
8	"sync/atomic"
9)
10
11type tagType uint8
12
13const (
14	typeDefault tagType = iota
15	typeOmitEmpty
16	typeIsDefault
17	typeNoStructLevel
18	typeStructOnly
19	typeDive
20	typeOr
21	typeKeys
22	typeEndKeys
23)
24
25const (
26	invalidValidation   = "Invalid validation tag on field '%s'"
27	undefinedValidation = "Undefined validation function '%s' on field '%s'"
28	keysTagNotDefined   = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
29)
30
31type structCache struct {
32	lock sync.Mutex
33	m    atomic.Value // map[reflect.Type]*cStruct
34}
35
36func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
37	c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
38	return
39}
40
41func (sc *structCache) Set(key reflect.Type, value *cStruct) {
42	m := sc.m.Load().(map[reflect.Type]*cStruct)
43	nm := make(map[reflect.Type]*cStruct, len(m)+1)
44	for k, v := range m {
45		nm[k] = v
46	}
47	nm[key] = value
48	sc.m.Store(nm)
49}
50
51type tagCache struct {
52	lock sync.Mutex
53	m    atomic.Value // map[string]*cTag
54}
55
56func (tc *tagCache) Get(key string) (c *cTag, found bool) {
57	c, found = tc.m.Load().(map[string]*cTag)[key]
58	return
59}
60
61func (tc *tagCache) Set(key string, value *cTag) {
62	m := tc.m.Load().(map[string]*cTag)
63	nm := make(map[string]*cTag, len(m)+1)
64	for k, v := range m {
65		nm[k] = v
66	}
67	nm[key] = value
68	tc.m.Store(nm)
69}
70
71type cStruct struct {
72	name   string
73	fields []*cField
74	fn     StructLevelFuncCtx
75}
76
77type cField struct {
78	idx        int
79	name       string
80	altName    string
81	namesEqual bool
82	cTags      *cTag
83}
84
85type cTag struct {
86	tag                  string
87	aliasTag             string
88	actualAliasTag       string
89	param                string
90	keys                 *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation
91	next                 *cTag
92	fn                   FuncCtx
93	typeof               tagType
94	hasTag               bool
95	hasAlias             bool
96	hasParam             bool // true if parameter used eg. eq= where the equal sign has been set
97	isBlockEnd           bool // indicates the current tag represents the last validation in the block
98	runValidationWhenNil bool
99}
100
101func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
102	v.structCache.lock.Lock()
103	defer v.structCache.lock.Unlock() // leave as defer! because if inner panics, it will never get unlocked otherwise!
104
105	typ := current.Type()
106
107	// could have been multiple trying to access, but once first is done this ensures struct
108	// isn't parsed again.
109	cs, ok := v.structCache.Get(typ)
110	if ok {
111		return cs
112	}
113
114	cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
115
116	numFields := current.NumField()
117
118	var ctag *cTag
119	var fld reflect.StructField
120	var tag string
121	var customName string
122
123	for i := 0; i < numFields; i++ {
124
125		fld = typ.Field(i)
126
127		if !fld.Anonymous && len(fld.PkgPath) > 0 {
128			continue
129		}
130
131		tag = fld.Tag.Get(v.tagName)
132
133		if tag == skipValidationTag {
134			continue
135		}
136
137		customName = fld.Name
138
139		if v.hasTagNameFunc {
140			name := v.tagNameFunc(fld)
141			if len(name) > 0 {
142				customName = name
143			}
144		}
145
146		// NOTE: cannot use shared tag cache, because tags may be equal, but things like alias may be different
147		// and so only struct level caching can be used instead of combined with Field tag caching
148
149		if len(tag) > 0 {
150			ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
151		} else {
152			// even if field doesn't have validations need cTag for traversing to potential inner/nested
153			// elements of the field.
154			ctag = new(cTag)
155		}
156
157		cs.fields = append(cs.fields, &cField{
158			idx:        i,
159			name:       fld.Name,
160			altName:    customName,
161			cTags:      ctag,
162			namesEqual: fld.Name == customName,
163		})
164	}
165	v.structCache.Set(typ, cs)
166	return cs
167}
168
169func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
170	var t string
171	noAlias := len(alias) == 0
172	tags := strings.Split(tag, tagSeparator)
173
174	for i := 0; i < len(tags); i++ {
175		t = tags[i]
176		if noAlias {
177			alias = t
178		}
179
180		// check map for alias and process new tags, otherwise process as usual
181		if tagsVal, found := v.aliases[t]; found {
182			if i == 0 {
183				firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
184			} else {
185				next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
186				current.next, current = next, curr
187
188			}
189			continue
190		}
191
192		var prevTag tagType
193
194		if i == 0 {
195			current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true, typeof: typeDefault}
196			firstCtag = current
197		} else {
198			prevTag = current.typeof
199			current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
200			current = current.next
201		}
202
203		switch t {
204		case diveTag:
205			current.typeof = typeDive
206			continue
207
208		case keysTag:
209			current.typeof = typeKeys
210
211			if i == 0 || prevTag != typeDive {
212				panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
213			}
214
215			current.typeof = typeKeys
216
217			// need to pass along only keys tag
218			// need to increment i to skip over the keys tags
219			b := make([]byte, 0, 64)
220
221			i++
222
223			for ; i < len(tags); i++ {
224
225				b = append(b, tags[i]...)
226				b = append(b, ',')
227
228				if tags[i] == endKeysTag {
229					break
230				}
231			}
232
233			current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
234			continue
235
236		case endKeysTag:
237			current.typeof = typeEndKeys
238
239			// if there are more in tags then there was no keysTag defined
240			// and an error should be thrown
241			if i != len(tags)-1 {
242				panic(keysTagNotDefined)
243			}
244			return
245
246		case omitempty:
247			current.typeof = typeOmitEmpty
248			continue
249
250		case structOnlyTag:
251			current.typeof = typeStructOnly
252			continue
253
254		case noStructLevelTag:
255			current.typeof = typeNoStructLevel
256			continue
257
258		default:
259			if t == isdefault {
260				current.typeof = typeIsDefault
261			}
262			// if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
263			orVals := strings.Split(t, orSeparator)
264
265			for j := 0; j < len(orVals); j++ {
266				vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
267				if noAlias {
268					alias = vals[0]
269					current.aliasTag = alias
270				} else {
271					current.actualAliasTag = t
272				}
273
274				if j > 0 {
275					current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
276					current = current.next
277				}
278				current.hasParam = len(vals) > 1
279
280				current.tag = vals[0]
281				if len(current.tag) == 0 {
282					panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
283				}
284
285				if wrapper, ok := v.validations[current.tag]; ok {
286					current.fn = wrapper.fn
287					current.runValidationWhenNil = wrapper.runValidatinOnNil
288				} else {
289					panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
290				}
291
292				if len(orVals) > 1 {
293					current.typeof = typeOr
294				}
295
296				if len(vals) > 1 {
297					current.param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
298				}
299			}
300			current.isBlockEnd = true
301		}
302	}
303	return
304}
305
306func (v *Validate) fetchCacheTag(tag string) *cTag {
307	// find cached tag
308	ctag, found := v.tagCache.Get(tag)
309	if !found {
310		v.tagCache.lock.Lock()
311		defer v.tagCache.lock.Unlock()
312
313		// could have been multiple trying to access, but once first is done this ensures tag
314		// isn't parsed again.
315		ctag, found = v.tagCache.Get(tag)
316		if !found {
317			ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
318			v.tagCache.Set(tag, ctag)
319		}
320	}
321	return ctag
322}
323