1// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
2// Use of this source code is governed by a BSD-style license found in the LICENSE file.
3
4package codec
5
6// Contains code shared by both encode and decode.
7
8import (
9	"encoding/binary"
10	"fmt"
11	"math"
12	"reflect"
13	"sort"
14	"strings"
15	"sync"
16	"time"
17	"unicode"
18	"unicode/utf8"
19)
20
21const (
22	structTagName = "codec"
23
24	// Support
25	//    encoding.BinaryMarshaler: MarshalBinary() (data []byte, err error)
26	//    encoding.BinaryUnmarshaler: UnmarshalBinary(data []byte) error
27	// This constant flag will enable or disable it.
28	supportBinaryMarshal = true
29
30	// Each Encoder or Decoder uses a cache of functions based on conditionals,
31	// so that the conditionals are not run every time.
32	//
33	// Either a map or a slice is used to keep track of the functions.
34	// The map is more natural, but has a higher cost than a slice/array.
35	// This flag (useMapForCodecCache) controls which is used.
36	useMapForCodecCache = false
37
38	// For some common container types, we can short-circuit an elaborate
39	// reflection dance and call encode/decode directly.
40	// The currently supported types are:
41	//    - slices of strings, or id's (int64,uint64) or interfaces.
42	//    - maps of str->str, str->intf, id(int64,uint64)->intf, intf->intf
43	shortCircuitReflectToFastPath = true
44
45	// for debugging, set this to false, to catch panic traces.
46	// Note that this will always cause rpc tests to fail, since they need io.EOF sent via panic.
47	recoverPanicToErr = true
48
49	// if checkStructForEmptyValue, check structs fields to see if an empty value.
50	// This could be an expensive call, so possibly disable it.
51	checkStructForEmptyValue = false
52
53	// if derefForIsEmptyValue, deref pointers and interfaces when checking isEmptyValue
54	derefForIsEmptyValue = false
55)
56
57type charEncoding uint8
58
59const (
60	c_RAW charEncoding = iota
61	c_UTF8
62	c_UTF16LE
63	c_UTF16BE
64	c_UTF32LE
65	c_UTF32BE
66)
67
68// valueType is the stream type
69type valueType uint8
70
71const (
72	valueTypeUnset valueType = iota
73	valueTypeNil
74	valueTypeInt
75	valueTypeUint
76	valueTypeFloat
77	valueTypeBool
78	valueTypeString
79	valueTypeSymbol
80	valueTypeBytes
81	valueTypeMap
82	valueTypeArray
83	valueTypeTimestamp
84	valueTypeExt
85
86	valueTypeInvalid = 0xff
87)
88
89var (
90	bigen               = binary.BigEndian
91	structInfoFieldName = "_struct"
92
93	cachedTypeInfo      = make(map[uintptr]*typeInfo, 4)
94	cachedTypeInfoMutex sync.RWMutex
95
96	intfSliceTyp = reflect.TypeOf([]interface{}(nil))
97	intfTyp      = intfSliceTyp.Elem()
98
99	strSliceTyp     = reflect.TypeOf([]string(nil))
100	boolSliceTyp    = reflect.TypeOf([]bool(nil))
101	uintSliceTyp    = reflect.TypeOf([]uint(nil))
102	uint8SliceTyp   = reflect.TypeOf([]uint8(nil))
103	uint16SliceTyp  = reflect.TypeOf([]uint16(nil))
104	uint32SliceTyp  = reflect.TypeOf([]uint32(nil))
105	uint64SliceTyp  = reflect.TypeOf([]uint64(nil))
106	intSliceTyp     = reflect.TypeOf([]int(nil))
107	int8SliceTyp    = reflect.TypeOf([]int8(nil))
108	int16SliceTyp   = reflect.TypeOf([]int16(nil))
109	int32SliceTyp   = reflect.TypeOf([]int32(nil))
110	int64SliceTyp   = reflect.TypeOf([]int64(nil))
111	float32SliceTyp = reflect.TypeOf([]float32(nil))
112	float64SliceTyp = reflect.TypeOf([]float64(nil))
113
114	mapIntfIntfTyp = reflect.TypeOf(map[interface{}]interface{}(nil))
115	mapStrIntfTyp  = reflect.TypeOf(map[string]interface{}(nil))
116	mapStrStrTyp   = reflect.TypeOf(map[string]string(nil))
117
118	mapIntIntfTyp    = reflect.TypeOf(map[int]interface{}(nil))
119	mapInt64IntfTyp  = reflect.TypeOf(map[int64]interface{}(nil))
120	mapUintIntfTyp   = reflect.TypeOf(map[uint]interface{}(nil))
121	mapUint64IntfTyp = reflect.TypeOf(map[uint64]interface{}(nil))
122
123	stringTyp = reflect.TypeOf("")
124	timeTyp   = reflect.TypeOf(time.Time{})
125	rawExtTyp = reflect.TypeOf(RawExt{})
126
127	mapBySliceTyp        = reflect.TypeOf((*MapBySlice)(nil)).Elem()
128	binaryMarshalerTyp   = reflect.TypeOf((*binaryMarshaler)(nil)).Elem()
129	binaryUnmarshalerTyp = reflect.TypeOf((*binaryUnmarshaler)(nil)).Elem()
130
131	rawExtTypId = reflect.ValueOf(rawExtTyp).Pointer()
132	intfTypId   = reflect.ValueOf(intfTyp).Pointer()
133	timeTypId   = reflect.ValueOf(timeTyp).Pointer()
134
135	intfSliceTypId = reflect.ValueOf(intfSliceTyp).Pointer()
136	strSliceTypId  = reflect.ValueOf(strSliceTyp).Pointer()
137
138	boolSliceTypId    = reflect.ValueOf(boolSliceTyp).Pointer()
139	uintSliceTypId    = reflect.ValueOf(uintSliceTyp).Pointer()
140	uint8SliceTypId   = reflect.ValueOf(uint8SliceTyp).Pointer()
141	uint16SliceTypId  = reflect.ValueOf(uint16SliceTyp).Pointer()
142	uint32SliceTypId  = reflect.ValueOf(uint32SliceTyp).Pointer()
143	uint64SliceTypId  = reflect.ValueOf(uint64SliceTyp).Pointer()
144	intSliceTypId     = reflect.ValueOf(intSliceTyp).Pointer()
145	int8SliceTypId    = reflect.ValueOf(int8SliceTyp).Pointer()
146	int16SliceTypId   = reflect.ValueOf(int16SliceTyp).Pointer()
147	int32SliceTypId   = reflect.ValueOf(int32SliceTyp).Pointer()
148	int64SliceTypId   = reflect.ValueOf(int64SliceTyp).Pointer()
149	float32SliceTypId = reflect.ValueOf(float32SliceTyp).Pointer()
150	float64SliceTypId = reflect.ValueOf(float64SliceTyp).Pointer()
151
152	mapStrStrTypId     = reflect.ValueOf(mapStrStrTyp).Pointer()
153	mapIntfIntfTypId   = reflect.ValueOf(mapIntfIntfTyp).Pointer()
154	mapStrIntfTypId    = reflect.ValueOf(mapStrIntfTyp).Pointer()
155	mapIntIntfTypId    = reflect.ValueOf(mapIntIntfTyp).Pointer()
156	mapInt64IntfTypId  = reflect.ValueOf(mapInt64IntfTyp).Pointer()
157	mapUintIntfTypId   = reflect.ValueOf(mapUintIntfTyp).Pointer()
158	mapUint64IntfTypId = reflect.ValueOf(mapUint64IntfTyp).Pointer()
159	// Id = reflect.ValueOf().Pointer()
160	// mapBySliceTypId  = reflect.ValueOf(mapBySliceTyp).Pointer()
161
162	binaryMarshalerTypId   = reflect.ValueOf(binaryMarshalerTyp).Pointer()
163	binaryUnmarshalerTypId = reflect.ValueOf(binaryUnmarshalerTyp).Pointer()
164
165	intBitsize  uint8 = uint8(reflect.TypeOf(int(0)).Bits())
166	uintBitsize uint8 = uint8(reflect.TypeOf(uint(0)).Bits())
167
168	bsAll0x00 = []byte{0, 0, 0, 0, 0, 0, 0, 0}
169	bsAll0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
170)
171
172type binaryUnmarshaler interface {
173	UnmarshalBinary(data []byte) error
174}
175
176type binaryMarshaler interface {
177	MarshalBinary() (data []byte, err error)
178}
179
180// MapBySlice represents a slice which should be encoded as a map in the stream.
181// The slice contains a sequence of key-value pairs.
182type MapBySlice interface {
183	MapBySlice()
184}
185
186// WARNING: DO NOT USE DIRECTLY. EXPORTED FOR GODOC BENEFIT. WILL BE REMOVED.
187//
188// BasicHandle encapsulates the common options and extension functions.
189type BasicHandle struct {
190	extHandle
191	EncodeOptions
192	DecodeOptions
193}
194
195// Handle is the interface for a specific encoding format.
196//
197// Typically, a Handle is pre-configured before first time use,
198// and not modified while in use. Such a pre-configured Handle
199// is safe for concurrent access.
200type Handle interface {
201	writeExt() bool
202	getBasicHandle() *BasicHandle
203	newEncDriver(w encWriter) encDriver
204	newDecDriver(r decReader) decDriver
205}
206
207// RawExt represents raw unprocessed extension data.
208type RawExt struct {
209	Tag  byte
210	Data []byte
211}
212
213type extTypeTagFn struct {
214	rtid  uintptr
215	rt    reflect.Type
216	tag   byte
217	encFn func(reflect.Value) ([]byte, error)
218	decFn func(reflect.Value, []byte) error
219}
220
221type extHandle []*extTypeTagFn
222
223// AddExt registers an encode and decode function for a reflect.Type.
224// Note that the type must be a named type, and specifically not
225// a pointer or Interface. An error is returned if that is not honored.
226//
227// To Deregister an ext, call AddExt with 0 tag, nil encfn and nil decfn.
228func (o *extHandle) AddExt(
229	rt reflect.Type,
230	tag byte,
231	encfn func(reflect.Value) ([]byte, error),
232	decfn func(reflect.Value, []byte) error,
233) (err error) {
234	// o is a pointer, because we may need to initialize it
235	if rt.PkgPath() == "" || rt.Kind() == reflect.Interface {
236		err = fmt.Errorf("codec.Handle.AddExt: Takes named type, especially not a pointer or interface: %T",
237			reflect.Zero(rt).Interface())
238		return
239	}
240
241	// o cannot be nil, since it is always embedded in a Handle.
242	// if nil, let it panic.
243	// if o == nil {
244	// 	err = errors.New("codec.Handle.AddExt: extHandle cannot be a nil pointer.")
245	// 	return
246	// }
247
248	rtid := reflect.ValueOf(rt).Pointer()
249	for _, v := range *o {
250		if v.rtid == rtid {
251			v.tag, v.encFn, v.decFn = tag, encfn, decfn
252			return
253		}
254	}
255
256	*o = append(*o, &extTypeTagFn{rtid, rt, tag, encfn, decfn})
257	return
258}
259
260func (o extHandle) getExt(rtid uintptr) *extTypeTagFn {
261	for _, v := range o {
262		if v.rtid == rtid {
263			return v
264		}
265	}
266	return nil
267}
268
269func (o extHandle) getExtForTag(tag byte) *extTypeTagFn {
270	for _, v := range o {
271		if v.tag == tag {
272			return v
273		}
274	}
275	return nil
276}
277
278func (o extHandle) getDecodeExtForTag(tag byte) (
279	rv reflect.Value, fn func(reflect.Value, []byte) error) {
280	if x := o.getExtForTag(tag); x != nil {
281		// ext is only registered for base
282		rv = reflect.New(x.rt).Elem()
283		fn = x.decFn
284	}
285	return
286}
287
288func (o extHandle) getDecodeExt(rtid uintptr) (tag byte, fn func(reflect.Value, []byte) error) {
289	if x := o.getExt(rtid); x != nil {
290		tag = x.tag
291		fn = x.decFn
292	}
293	return
294}
295
296func (o extHandle) getEncodeExt(rtid uintptr) (tag byte, fn func(reflect.Value) ([]byte, error)) {
297	if x := o.getExt(rtid); x != nil {
298		tag = x.tag
299		fn = x.encFn
300	}
301	return
302}
303
304type structFieldInfo struct {
305	encName string // encode name
306
307	// only one of 'i' or 'is' can be set. If 'i' is -1, then 'is' has been set.
308
309	is        []int // (recursive/embedded) field index in struct
310	i         int16 // field index in struct
311	omitEmpty bool
312	toArray   bool // if field is _struct, is the toArray set?
313
314	// tag       string   // tag
315	// name      string   // field name
316	// encNameBs []byte   // encoded name as byte stream
317	// ikind     int      // kind of the field as an int i.e. int(reflect.Kind)
318}
319
320func parseStructFieldInfo(fname string, stag string) *structFieldInfo {
321	if fname == "" {
322		panic("parseStructFieldInfo: No Field Name")
323	}
324	si := structFieldInfo{
325		// name: fname,
326		encName: fname,
327		// tag: stag,
328	}
329
330	if stag != "" {
331		for i, s := range strings.Split(stag, ",") {
332			if i == 0 {
333				if s != "" {
334					si.encName = s
335				}
336			} else {
337				switch s {
338				case "omitempty":
339					si.omitEmpty = true
340				case "toarray":
341					si.toArray = true
342				}
343			}
344		}
345	}
346	// si.encNameBs = []byte(si.encName)
347	return &si
348}
349
350type sfiSortedByEncName []*structFieldInfo
351
352func (p sfiSortedByEncName) Len() int {
353	return len(p)
354}
355
356func (p sfiSortedByEncName) Less(i, j int) bool {
357	return p[i].encName < p[j].encName
358}
359
360func (p sfiSortedByEncName) Swap(i, j int) {
361	p[i], p[j] = p[j], p[i]
362}
363
364// typeInfo keeps information about each type referenced in the encode/decode sequence.
365//
366// During an encode/decode sequence, we work as below:
367//   - If base is a built in type, en/decode base value
368//   - If base is registered as an extension, en/decode base value
369//   - If type is binary(M/Unm)arshaler, call Binary(M/Unm)arshal method
370//   - Else decode appropriately based on the reflect.Kind
371type typeInfo struct {
372	sfi  []*structFieldInfo // sorted. Used when enc/dec struct to map.
373	sfip []*structFieldInfo // unsorted. Used when enc/dec struct to array.
374
375	rt   reflect.Type
376	rtid uintptr
377
378	// baseId gives pointer to the base reflect.Type, after deferencing
379	// the pointers. E.g. base type of ***time.Time is time.Time.
380	base      reflect.Type
381	baseId    uintptr
382	baseIndir int8 // number of indirections to get to base
383
384	mbs bool // base type (T or *T) is a MapBySlice
385
386	m        bool // base type (T or *T) is a binaryMarshaler
387	unm      bool // base type (T or *T) is a binaryUnmarshaler
388	mIndir   int8 // number of indirections to get to binaryMarshaler type
389	unmIndir int8 // number of indirections to get to binaryUnmarshaler type
390	toArray  bool // whether this (struct) type should be encoded as an array
391}
392
393func (ti *typeInfo) indexForEncName(name string) int {
394	//tisfi := ti.sfi
395	const binarySearchThreshold = 16
396	if sfilen := len(ti.sfi); sfilen < binarySearchThreshold {
397		// linear search. faster than binary search in my testing up to 16-field structs.
398		for i, si := range ti.sfi {
399			if si.encName == name {
400				return i
401			}
402		}
403	} else {
404		// binary search. adapted from sort/search.go.
405		h, i, j := 0, 0, sfilen
406		for i < j {
407			h = i + (j-i)/2
408			if ti.sfi[h].encName < name {
409				i = h + 1
410			} else {
411				j = h
412			}
413		}
414		if i < sfilen && ti.sfi[i].encName == name {
415			return i
416		}
417	}
418	return -1
419}
420
421func getTypeInfo(rtid uintptr, rt reflect.Type) (pti *typeInfo) {
422	var ok bool
423	cachedTypeInfoMutex.RLock()
424	pti, ok = cachedTypeInfo[rtid]
425	cachedTypeInfoMutex.RUnlock()
426	if ok {
427		return
428	}
429
430	cachedTypeInfoMutex.Lock()
431	defer cachedTypeInfoMutex.Unlock()
432	if pti, ok = cachedTypeInfo[rtid]; ok {
433		return
434	}
435
436	ti := typeInfo{rt: rt, rtid: rtid}
437	pti = &ti
438
439	var indir int8
440	if ok, indir = implementsIntf(rt, binaryMarshalerTyp); ok {
441		ti.m, ti.mIndir = true, indir
442	}
443	if ok, indir = implementsIntf(rt, binaryUnmarshalerTyp); ok {
444		ti.unm, ti.unmIndir = true, indir
445	}
446	if ok, _ = implementsIntf(rt, mapBySliceTyp); ok {
447		ti.mbs = true
448	}
449
450	pt := rt
451	var ptIndir int8
452	// for ; pt.Kind() == reflect.Ptr; pt, ptIndir = pt.Elem(), ptIndir+1 { }
453	for pt.Kind() == reflect.Ptr {
454		pt = pt.Elem()
455		ptIndir++
456	}
457	if ptIndir == 0 {
458		ti.base = rt
459		ti.baseId = rtid
460	} else {
461		ti.base = pt
462		ti.baseId = reflect.ValueOf(pt).Pointer()
463		ti.baseIndir = ptIndir
464	}
465
466	if rt.Kind() == reflect.Struct {
467		var siInfo *structFieldInfo
468		if f, ok := rt.FieldByName(structInfoFieldName); ok {
469			siInfo = parseStructFieldInfo(structInfoFieldName, f.Tag.Get(structTagName))
470			ti.toArray = siInfo.toArray
471		}
472		sfip := make([]*structFieldInfo, 0, rt.NumField())
473		rgetTypeInfo(rt, nil, make(map[string]bool), &sfip, siInfo)
474
475		// // try to put all si close together
476		// const tryToPutAllStructFieldInfoTogether = true
477		// if tryToPutAllStructFieldInfoTogether {
478		// 	sfip2 := make([]structFieldInfo, len(sfip))
479		// 	for i, si := range sfip {
480		// 		sfip2[i] = *si
481		// 	}
482		// 	for i := range sfip {
483		// 		sfip[i] = &sfip2[i]
484		// 	}
485		// }
486
487		ti.sfip = make([]*structFieldInfo, len(sfip))
488		ti.sfi = make([]*structFieldInfo, len(sfip))
489		copy(ti.sfip, sfip)
490		sort.Sort(sfiSortedByEncName(sfip))
491		copy(ti.sfi, sfip)
492	}
493	// sfi = sfip
494	cachedTypeInfo[rtid] = pti
495	return
496}
497
498func rgetTypeInfo(rt reflect.Type, indexstack []int, fnameToHastag map[string]bool,
499	sfi *[]*structFieldInfo, siInfo *structFieldInfo,
500) {
501	// for rt.Kind() == reflect.Ptr {
502	// 	// indexstack = append(indexstack, 0)
503	// 	rt = rt.Elem()
504	// }
505	for j := 0; j < rt.NumField(); j++ {
506		f := rt.Field(j)
507		stag := f.Tag.Get(structTagName)
508		if stag == "-" {
509			continue
510		}
511		if r1, _ := utf8.DecodeRuneInString(f.Name); r1 == utf8.RuneError || !unicode.IsUpper(r1) {
512			continue
513		}
514		// if anonymous and there is no struct tag and its a struct (or pointer to struct), inline it.
515		if f.Anonymous && stag == "" {
516			ft := f.Type
517			for ft.Kind() == reflect.Ptr {
518				ft = ft.Elem()
519			}
520			if ft.Kind() == reflect.Struct {
521				indexstack2 := append(append(make([]int, 0, len(indexstack)+4), indexstack...), j)
522				rgetTypeInfo(ft, indexstack2, fnameToHastag, sfi, siInfo)
523				continue
524			}
525		}
526		// do not let fields with same name in embedded structs override field at higher level.
527		// this must be done after anonymous check, to allow anonymous field
528		// still include their child fields
529		if _, ok := fnameToHastag[f.Name]; ok {
530			continue
531		}
532		si := parseStructFieldInfo(f.Name, stag)
533		// si.ikind = int(f.Type.Kind())
534		if len(indexstack) == 0 {
535			si.i = int16(j)
536		} else {
537			si.i = -1
538			si.is = append(append(make([]int, 0, len(indexstack)+4), indexstack...), j)
539		}
540
541		if siInfo != nil {
542			if siInfo.omitEmpty {
543				si.omitEmpty = true
544			}
545		}
546		*sfi = append(*sfi, si)
547		fnameToHastag[f.Name] = stag != ""
548	}
549}
550
551func panicToErr(err *error) {
552	if recoverPanicToErr {
553		if x := recover(); x != nil {
554			//debug.PrintStack()
555			panicValToErr(x, err)
556		}
557	}
558}
559
560func doPanic(tag string, format string, params ...interface{}) {
561	params2 := make([]interface{}, len(params)+1)
562	params2[0] = tag
563	copy(params2[1:], params)
564	panic(fmt.Errorf("%s: "+format, params2...))
565}
566
567func checkOverflowFloat32(f float64, doCheck bool) {
568	if !doCheck {
569		return
570	}
571	// check overflow (logic adapted from std pkg reflect/value.go OverflowFloat()
572	f2 := f
573	if f2 < 0 {
574		f2 = -f
575	}
576	if math.MaxFloat32 < f2 && f2 <= math.MaxFloat64 {
577		decErr("Overflow float32 value: %v", f2)
578	}
579}
580
581func checkOverflow(ui uint64, i int64, bitsize uint8) {
582	// check overflow (logic adapted from std pkg reflect/value.go OverflowUint()
583	if bitsize == 0 {
584		return
585	}
586	if i != 0 {
587		if trunc := (i << (64 - bitsize)) >> (64 - bitsize); i != trunc {
588			decErr("Overflow int value: %v", i)
589		}
590	}
591	if ui != 0 {
592		if trunc := (ui << (64 - bitsize)) >> (64 - bitsize); ui != trunc {
593			decErr("Overflow uint value: %v", ui)
594		}
595	}
596}
597