1// Copyright (c) Faye Amacker. All rights reserved. 2// Licensed under the MIT License. See LICENSE in the project root for license information. 3 4package cbor 5 6import ( 7 "bytes" 8 "errors" 9 "reflect" 10 "sort" 11 "strconv" 12 "strings" 13 "sync" 14) 15 16type encodeFuncs struct { 17 ef encodeFunc 18 ief isEmptyFunc 19} 20 21var ( 22 decodingStructTypeCache sync.Map // map[reflect.Type]*decodingStructType 23 encodingStructTypeCache sync.Map // map[reflect.Type]*encodingStructType 24 encodeFuncCache sync.Map // map[reflect.Type]encodeFuncs 25 typeInfoCache sync.Map // map[reflect.Type]*typeInfo 26) 27 28type specialType int 29 30const ( 31 specialTypeNone specialType = iota 32 specialTypeUnmarshalerIface 33 specialTypeEmptyIface 34 specialTypeIface 35 specialTypeTag 36 specialTypeTime 37) 38 39type typeInfo struct { 40 elemTypeInfo *typeInfo 41 keyTypeInfo *typeInfo 42 typ reflect.Type 43 kind reflect.Kind 44 nonPtrType reflect.Type 45 nonPtrKind reflect.Kind 46 spclType specialType 47} 48 49func newTypeInfo(t reflect.Type) *typeInfo { 50 tInfo := typeInfo{typ: t, kind: t.Kind()} 51 52 for t.Kind() == reflect.Ptr { 53 t = t.Elem() 54 } 55 56 k := t.Kind() 57 58 tInfo.nonPtrType = t 59 tInfo.nonPtrKind = k 60 61 if k == reflect.Interface { 62 if t.NumMethod() == 0 { 63 tInfo.spclType = specialTypeEmptyIface 64 } else { 65 tInfo.spclType = specialTypeIface 66 } 67 } else if t == typeTag { 68 tInfo.spclType = specialTypeTag 69 } else if t == typeTime { 70 tInfo.spclType = specialTypeTime 71 } else if reflect.PtrTo(t).Implements(typeUnmarshaler) { 72 tInfo.spclType = specialTypeUnmarshalerIface 73 } 74 75 switch k { 76 case reflect.Array, reflect.Slice: 77 tInfo.elemTypeInfo = getTypeInfo(t.Elem()) 78 case reflect.Map: 79 tInfo.keyTypeInfo = getTypeInfo(t.Key()) 80 tInfo.elemTypeInfo = getTypeInfo(t.Elem()) 81 } 82 83 return &tInfo 84} 85 86type decodingStructType struct { 87 fields fields 88 err error 89 toArray bool 90} 91 92func getDecodingStructType(t reflect.Type) *decodingStructType { 93 if v, _ := decodingStructTypeCache.Load(t); v != nil { 94 return v.(*decodingStructType) 95 } 96 97 flds, structOptions := getFields(t) 98 99 toArray := hasToArrayOption(structOptions) 100 101 var err error 102 for i := 0; i < len(flds); i++ { 103 if flds[i].keyAsInt { 104 nameAsInt, numErr := strconv.Atoi(flds[i].name) 105 if numErr != nil { 106 err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")") 107 break 108 } 109 flds[i].nameAsInt = int64(nameAsInt) 110 } 111 112 flds[i].typInfo = getTypeInfo(flds[i].typ) 113 } 114 115 structType := &decodingStructType{fields: flds, err: err, toArray: toArray} 116 decodingStructTypeCache.Store(t, structType) 117 return structType 118} 119 120type encodingStructType struct { 121 fields fields 122 bytewiseFields fields 123 lengthFirstFields fields 124 omitEmptyFieldsIdx []int 125 err error 126 toArray bool 127 fixedLength bool // Struct type doesn't have any omitempty or anonymous fields. 128} 129 130func (st *encodingStructType) getFields(em *encMode) fields { 131 if em.sort == SortNone { 132 return st.fields 133 } 134 if em.sort == SortLengthFirst { 135 return st.lengthFirstFields 136 } 137 return st.bytewiseFields 138} 139 140type bytewiseFieldSorter struct { 141 fields fields 142} 143 144func (x *bytewiseFieldSorter) Len() int { 145 return len(x.fields) 146} 147 148func (x *bytewiseFieldSorter) Swap(i, j int) { 149 x.fields[i], x.fields[j] = x.fields[j], x.fields[i] 150} 151 152func (x *bytewiseFieldSorter) Less(i, j int) bool { 153 return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0 154} 155 156type lengthFirstFieldSorter struct { 157 fields fields 158} 159 160func (x *lengthFirstFieldSorter) Len() int { 161 return len(x.fields) 162} 163 164func (x *lengthFirstFieldSorter) Swap(i, j int) { 165 x.fields[i], x.fields[j] = x.fields[j], x.fields[i] 166} 167 168func (x *lengthFirstFieldSorter) Less(i, j int) bool { 169 if len(x.fields[i].cborName) != len(x.fields[j].cborName) { 170 return len(x.fields[i].cborName) < len(x.fields[j].cborName) 171 } 172 return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0 173} 174 175func getEncodingStructType(t reflect.Type) (*encodingStructType, error) { 176 if v, _ := encodingStructTypeCache.Load(t); v != nil { 177 structType := v.(*encodingStructType) 178 return structType, structType.err 179 } 180 181 flds, structOptions := getFields(t) 182 183 if hasToArrayOption(structOptions) { 184 return getEncodingStructToArrayType(t, flds) 185 } 186 187 var err error 188 var hasKeyAsInt bool 189 var hasKeyAsStr bool 190 var omitEmptyIdx []int 191 fixedLength := true 192 e := getEncoderBuffer() 193 for i := 0; i < len(flds); i++ { 194 // Get field's encodeFunc 195 flds[i].ef, flds[i].ief = getEncodeFunc(flds[i].typ) 196 if flds[i].ef == nil { 197 err = &UnsupportedTypeError{t} 198 break 199 } 200 201 // Encode field name 202 if flds[i].keyAsInt { 203 nameAsInt, numErr := strconv.Atoi(flds[i].name) 204 if numErr != nil { 205 err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")") 206 break 207 } 208 flds[i].nameAsInt = int64(nameAsInt) 209 if nameAsInt >= 0 { 210 encodeHead(e, byte(cborTypePositiveInt), uint64(nameAsInt)) 211 } else { 212 n := nameAsInt*(-1) - 1 213 encodeHead(e, byte(cborTypeNegativeInt), uint64(n)) 214 } 215 flds[i].cborName = make([]byte, e.Len()) 216 copy(flds[i].cborName, e.Bytes()) 217 e.Reset() 218 219 hasKeyAsInt = true 220 } else { 221 encodeHead(e, byte(cborTypeTextString), uint64(len(flds[i].name))) 222 flds[i].cborName = make([]byte, e.Len()+len(flds[i].name)) 223 n := copy(flds[i].cborName, e.Bytes()) 224 copy(flds[i].cborName[n:], flds[i].name) 225 e.Reset() 226 227 hasKeyAsStr = true 228 } 229 230 // Check if field is from embedded struct 231 if len(flds[i].idx) > 1 { 232 fixedLength = false 233 } 234 235 // Check if field can be omitted when empty 236 if flds[i].omitEmpty { 237 fixedLength = false 238 omitEmptyIdx = append(omitEmptyIdx, i) 239 } 240 } 241 putEncoderBuffer(e) 242 243 if err != nil { 244 structType := &encodingStructType{err: err} 245 encodingStructTypeCache.Store(t, structType) 246 return structType, structType.err 247 } 248 249 // Sort fields by canonical order 250 bytewiseFields := make(fields, len(flds)) 251 copy(bytewiseFields, flds) 252 sort.Sort(&bytewiseFieldSorter{bytewiseFields}) 253 254 lengthFirstFields := bytewiseFields 255 if hasKeyAsInt && hasKeyAsStr { 256 lengthFirstFields = make(fields, len(flds)) 257 copy(lengthFirstFields, flds) 258 sort.Sort(&lengthFirstFieldSorter{lengthFirstFields}) 259 } 260 261 structType := &encodingStructType{ 262 fields: flds, 263 bytewiseFields: bytewiseFields, 264 lengthFirstFields: lengthFirstFields, 265 omitEmptyFieldsIdx: omitEmptyIdx, 266 fixedLength: fixedLength, 267 } 268 encodingStructTypeCache.Store(t, structType) 269 return structType, structType.err 270} 271 272func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructType, error) { 273 for i := 0; i < len(flds); i++ { 274 // Get field's encodeFunc 275 flds[i].ef, flds[i].ief = getEncodeFunc(flds[i].typ) 276 if flds[i].ef == nil { 277 structType := &encodingStructType{err: &UnsupportedTypeError{t}} 278 encodingStructTypeCache.Store(t, structType) 279 return structType, structType.err 280 } 281 } 282 283 structType := &encodingStructType{ 284 fields: flds, 285 toArray: true, 286 fixedLength: true, 287 } 288 encodingStructTypeCache.Store(t, structType) 289 return structType, structType.err 290} 291 292func getEncodeFunc(t reflect.Type) (encodeFunc, isEmptyFunc) { 293 if v, _ := encodeFuncCache.Load(t); v != nil { 294 fs := v.(encodeFuncs) 295 return fs.ef, fs.ief 296 } 297 ef, ief := getEncodeFuncInternal(t) 298 encodeFuncCache.Store(t, encodeFuncs{ef, ief}) 299 return ef, ief 300} 301 302func getTypeInfo(t reflect.Type) *typeInfo { 303 if v, _ := typeInfoCache.Load(t); v != nil { 304 return v.(*typeInfo) 305 } 306 tInfo := newTypeInfo(t) 307 typeInfoCache.Store(t, tInfo) 308 return tInfo 309} 310 311func hasToArrayOption(tag string) bool { 312 s := ",toarray" 313 idx := strings.Index(tag, s) 314 return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',') 315} 316