1// Copyright (c) Faye Amacker. All rights reserved.
2// Licensed under the MIT License. See LICENSE in the project root for license information.
3
4package cbor
5
6import (
7	"bytes"
8	"errors"
9	"reflect"
10	"sort"
11	"strconv"
12	"strings"
13	"sync"
14)
15
16type encodeFuncs struct {
17	ef  encodeFunc
18	ief isEmptyFunc
19}
20
21var (
22	decodingStructTypeCache sync.Map // map[reflect.Type]*decodingStructType
23	encodingStructTypeCache sync.Map // map[reflect.Type]*encodingStructType
24	encodeFuncCache         sync.Map // map[reflect.Type]encodeFuncs
25	typeInfoCache           sync.Map // map[reflect.Type]*typeInfo
26)
27
28type specialType int
29
30const (
31	specialTypeNone specialType = iota
32	specialTypeUnmarshalerIface
33	specialTypeEmptyIface
34	specialTypeIface
35	specialTypeTag
36	specialTypeTime
37)
38
39type typeInfo struct {
40	elemTypeInfo *typeInfo
41	keyTypeInfo  *typeInfo
42	typ          reflect.Type
43	kind         reflect.Kind
44	nonPtrType   reflect.Type
45	nonPtrKind   reflect.Kind
46	spclType     specialType
47}
48
49func newTypeInfo(t reflect.Type) *typeInfo {
50	tInfo := typeInfo{typ: t, kind: t.Kind()}
51
52	for t.Kind() == reflect.Ptr {
53		t = t.Elem()
54	}
55
56	k := t.Kind()
57
58	tInfo.nonPtrType = t
59	tInfo.nonPtrKind = k
60
61	if k == reflect.Interface {
62		if t.NumMethod() == 0 {
63			tInfo.spclType = specialTypeEmptyIface
64		} else {
65			tInfo.spclType = specialTypeIface
66		}
67	} else if t == typeTag {
68		tInfo.spclType = specialTypeTag
69	} else if t == typeTime {
70		tInfo.spclType = specialTypeTime
71	} else if reflect.PtrTo(t).Implements(typeUnmarshaler) {
72		tInfo.spclType = specialTypeUnmarshalerIface
73	}
74
75	switch k {
76	case reflect.Array, reflect.Slice:
77		tInfo.elemTypeInfo = getTypeInfo(t.Elem())
78	case reflect.Map:
79		tInfo.keyTypeInfo = getTypeInfo(t.Key())
80		tInfo.elemTypeInfo = getTypeInfo(t.Elem())
81	}
82
83	return &tInfo
84}
85
86type decodingStructType struct {
87	fields  fields
88	err     error
89	toArray bool
90}
91
92func getDecodingStructType(t reflect.Type) *decodingStructType {
93	if v, _ := decodingStructTypeCache.Load(t); v != nil {
94		return v.(*decodingStructType)
95	}
96
97	flds, structOptions := getFields(t)
98
99	toArray := hasToArrayOption(structOptions)
100
101	var err error
102	for i := 0; i < len(flds); i++ {
103		if flds[i].keyAsInt {
104			nameAsInt, numErr := strconv.Atoi(flds[i].name)
105			if numErr != nil {
106				err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
107				break
108			}
109			flds[i].nameAsInt = int64(nameAsInt)
110		}
111
112		flds[i].typInfo = getTypeInfo(flds[i].typ)
113	}
114
115	structType := &decodingStructType{fields: flds, err: err, toArray: toArray}
116	decodingStructTypeCache.Store(t, structType)
117	return structType
118}
119
120type encodingStructType struct {
121	fields             fields
122	bytewiseFields     fields
123	lengthFirstFields  fields
124	omitEmptyFieldsIdx []int
125	err                error
126	toArray            bool
127	fixedLength        bool // Struct type doesn't have any omitempty or anonymous fields.
128}
129
130func (st *encodingStructType) getFields(em *encMode) fields {
131	if em.sort == SortNone {
132		return st.fields
133	}
134	if em.sort == SortLengthFirst {
135		return st.lengthFirstFields
136	}
137	return st.bytewiseFields
138}
139
140type bytewiseFieldSorter struct {
141	fields fields
142}
143
144func (x *bytewiseFieldSorter) Len() int {
145	return len(x.fields)
146}
147
148func (x *bytewiseFieldSorter) Swap(i, j int) {
149	x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
150}
151
152func (x *bytewiseFieldSorter) Less(i, j int) bool {
153	return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
154}
155
156type lengthFirstFieldSorter struct {
157	fields fields
158}
159
160func (x *lengthFirstFieldSorter) Len() int {
161	return len(x.fields)
162}
163
164func (x *lengthFirstFieldSorter) Swap(i, j int) {
165	x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
166}
167
168func (x *lengthFirstFieldSorter) Less(i, j int) bool {
169	if len(x.fields[i].cborName) != len(x.fields[j].cborName) {
170		return len(x.fields[i].cborName) < len(x.fields[j].cborName)
171	}
172	return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
173}
174
175func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
176	if v, _ := encodingStructTypeCache.Load(t); v != nil {
177		structType := v.(*encodingStructType)
178		return structType, structType.err
179	}
180
181	flds, structOptions := getFields(t)
182
183	if hasToArrayOption(structOptions) {
184		return getEncodingStructToArrayType(t, flds)
185	}
186
187	var err error
188	var hasKeyAsInt bool
189	var hasKeyAsStr bool
190	var omitEmptyIdx []int
191	fixedLength := true
192	e := getEncoderBuffer()
193	for i := 0; i < len(flds); i++ {
194		// Get field's encodeFunc
195		flds[i].ef, flds[i].ief = getEncodeFunc(flds[i].typ)
196		if flds[i].ef == nil {
197			err = &UnsupportedTypeError{t}
198			break
199		}
200
201		// Encode field name
202		if flds[i].keyAsInt {
203			nameAsInt, numErr := strconv.Atoi(flds[i].name)
204			if numErr != nil {
205				err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
206				break
207			}
208			flds[i].nameAsInt = int64(nameAsInt)
209			if nameAsInt >= 0 {
210				encodeHead(e, byte(cborTypePositiveInt), uint64(nameAsInt))
211			} else {
212				n := nameAsInt*(-1) - 1
213				encodeHead(e, byte(cborTypeNegativeInt), uint64(n))
214			}
215			flds[i].cborName = make([]byte, e.Len())
216			copy(flds[i].cborName, e.Bytes())
217			e.Reset()
218
219			hasKeyAsInt = true
220		} else {
221			encodeHead(e, byte(cborTypeTextString), uint64(len(flds[i].name)))
222			flds[i].cborName = make([]byte, e.Len()+len(flds[i].name))
223			n := copy(flds[i].cborName, e.Bytes())
224			copy(flds[i].cborName[n:], flds[i].name)
225			e.Reset()
226
227			hasKeyAsStr = true
228		}
229
230		// Check if field is from embedded struct
231		if len(flds[i].idx) > 1 {
232			fixedLength = false
233		}
234
235		// Check if field can be omitted when empty
236		if flds[i].omitEmpty {
237			fixedLength = false
238			omitEmptyIdx = append(omitEmptyIdx, i)
239		}
240	}
241	putEncoderBuffer(e)
242
243	if err != nil {
244		structType := &encodingStructType{err: err}
245		encodingStructTypeCache.Store(t, structType)
246		return structType, structType.err
247	}
248
249	// Sort fields by canonical order
250	bytewiseFields := make(fields, len(flds))
251	copy(bytewiseFields, flds)
252	sort.Sort(&bytewiseFieldSorter{bytewiseFields})
253
254	lengthFirstFields := bytewiseFields
255	if hasKeyAsInt && hasKeyAsStr {
256		lengthFirstFields = make(fields, len(flds))
257		copy(lengthFirstFields, flds)
258		sort.Sort(&lengthFirstFieldSorter{lengthFirstFields})
259	}
260
261	structType := &encodingStructType{
262		fields:             flds,
263		bytewiseFields:     bytewiseFields,
264		lengthFirstFields:  lengthFirstFields,
265		omitEmptyFieldsIdx: omitEmptyIdx,
266		fixedLength:        fixedLength,
267	}
268	encodingStructTypeCache.Store(t, structType)
269	return structType, structType.err
270}
271
272func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructType, error) {
273	for i := 0; i < len(flds); i++ {
274		// Get field's encodeFunc
275		flds[i].ef, flds[i].ief = getEncodeFunc(flds[i].typ)
276		if flds[i].ef == nil {
277			structType := &encodingStructType{err: &UnsupportedTypeError{t}}
278			encodingStructTypeCache.Store(t, structType)
279			return structType, structType.err
280		}
281	}
282
283	structType := &encodingStructType{
284		fields:      flds,
285		toArray:     true,
286		fixedLength: true,
287	}
288	encodingStructTypeCache.Store(t, structType)
289	return structType, structType.err
290}
291
292func getEncodeFunc(t reflect.Type) (encodeFunc, isEmptyFunc) {
293	if v, _ := encodeFuncCache.Load(t); v != nil {
294		fs := v.(encodeFuncs)
295		return fs.ef, fs.ief
296	}
297	ef, ief := getEncodeFuncInternal(t)
298	encodeFuncCache.Store(t, encodeFuncs{ef, ief})
299	return ef, ief
300}
301
302func getTypeInfo(t reflect.Type) *typeInfo {
303	if v, _ := typeInfoCache.Load(t); v != nil {
304		return v.(*typeInfo)
305	}
306	tInfo := newTypeInfo(t)
307	typeInfoCache.Store(t, tInfo)
308	return tInfo
309}
310
311func hasToArrayOption(tag string) bool {
312	s := ",toarray"
313	idx := strings.Index(tag, s)
314	return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
315}
316