1package msgpack
2
3import (
4	"bufio"
5	"bytes"
6	"errors"
7	"fmt"
8	"io"
9	"reflect"
10	"sync"
11	"time"
12
13	"github.com/vmihailenco/msgpack/v4/codes"
14)
15
16const (
17	looseIfaceFlag uint32 = 1 << iota
18	decodeUsingJSONFlag
19	disallowUnknownFieldsFlag
20)
21
22const (
23	bytesAllocLimit = 1e6 // 1mb
24	sliceAllocLimit = 1e4
25	maxMapSize      = 1e6
26)
27
28type bufReader interface {
29	io.Reader
30	io.ByteScanner
31}
32
33//------------------------------------------------------------------------------
34
35var decPool = sync.Pool{
36	New: func() interface{} {
37		return NewDecoder(nil)
38	},
39}
40
41// Unmarshal decodes the MessagePack-encoded data and stores the result
42// in the value pointed to by v.
43func Unmarshal(data []byte, v interface{}) error {
44	dec := decPool.Get().(*Decoder)
45
46	if r, ok := dec.r.(*bytes.Reader); ok {
47		r.Reset(data)
48	} else {
49		dec.Reset(bytes.NewReader(data))
50	}
51	err := dec.Decode(v)
52
53	decPool.Put(dec)
54
55	return err
56}
57
58// A Decoder reads and decodes MessagePack values from an input stream.
59type Decoder struct {
60	r   io.Reader
61	s   io.ByteScanner
62	buf []byte
63
64	extLen int
65	rec    []byte // accumulates read data if not nil
66
67	intern        []string
68	flags         uint32
69	decodeMapFunc func(*Decoder) (interface{}, error)
70}
71
72// NewDecoder returns a new decoder that reads from r.
73//
74// The decoder introduces its own buffering and may read data from r
75// beyond the MessagePack values requested. Buffering can be disabled
76// by passing a reader that implements io.ByteScanner interface.
77func NewDecoder(r io.Reader) *Decoder {
78	d := new(Decoder)
79	d.Reset(r)
80	return d
81}
82
83// Reset discards any buffered data, resets all state, and switches the buffered
84// reader to read from r.
85func (d *Decoder) Reset(r io.Reader) {
86	if br, ok := r.(bufReader); ok {
87		d.r = br
88		d.s = br
89	} else if br, ok := d.r.(*bufio.Reader); ok {
90		br.Reset(r)
91	} else {
92		br := bufio.NewReader(r)
93		d.r = br
94		d.s = br
95	}
96
97	if d.intern != nil {
98		d.intern = d.intern[:0]
99	}
100
101	//TODO:
102	//d.useLoose = false
103	//d.useJSONTag = false
104	//d.disallowUnknownFields = false
105	//d.decodeMapFunc = nil
106}
107
108func (d *Decoder) SetDecodeMapFunc(fn func(*Decoder) (interface{}, error)) {
109	d.decodeMapFunc = fn
110}
111
112// UseDecodeInterfaceLoose causes decoder to use DecodeInterfaceLoose
113// to decode msgpack value into Go interface{}.
114func (d *Decoder) UseDecodeInterfaceLoose(on bool) *Decoder {
115	if on {
116		d.flags |= looseIfaceFlag
117	} else {
118		d.flags &= ^looseIfaceFlag
119	}
120	return d
121}
122
123// UseJSONTag causes the Decoder to use json struct tag as fallback option
124// if there is no msgpack tag.
125func (d *Decoder) UseJSONTag(on bool) *Decoder {
126	if on {
127		d.flags |= decodeUsingJSONFlag
128	} else {
129		d.flags &= ^decodeUsingJSONFlag
130	}
131	return d
132}
133
134// DisallowUnknownFields causes the Decoder to return an error when the destination
135// is a struct and the input contains object keys which do not match any
136// non-ignored, exported fields in the destination.
137func (d *Decoder) DisallowUnknownFields() {
138	if true {
139		d.flags |= disallowUnknownFieldsFlag
140	} else {
141		d.flags &= ^disallowUnknownFieldsFlag
142	}
143}
144
145// Buffered returns a reader of the data remaining in the Decoder's buffer.
146// The reader is valid until the next call to Decode.
147func (d *Decoder) Buffered() io.Reader {
148	return d.r
149}
150
151//nolint:gocyclo
152func (d *Decoder) Decode(v interface{}) error {
153	var err error
154	switch v := v.(type) {
155	case *string:
156		if v != nil {
157			*v, err = d.DecodeString()
158			return err
159		}
160	case *[]byte:
161		if v != nil {
162			return d.decodeBytesPtr(v)
163		}
164	case *int:
165		if v != nil {
166			*v, err = d.DecodeInt()
167			return err
168		}
169	case *int8:
170		if v != nil {
171			*v, err = d.DecodeInt8()
172			return err
173		}
174	case *int16:
175		if v != nil {
176			*v, err = d.DecodeInt16()
177			return err
178		}
179	case *int32:
180		if v != nil {
181			*v, err = d.DecodeInt32()
182			return err
183		}
184	case *int64:
185		if v != nil {
186			*v, err = d.DecodeInt64()
187			return err
188		}
189	case *uint:
190		if v != nil {
191			*v, err = d.DecodeUint()
192			return err
193		}
194	case *uint8:
195		if v != nil {
196			*v, err = d.DecodeUint8()
197			return err
198		}
199	case *uint16:
200		if v != nil {
201			*v, err = d.DecodeUint16()
202			return err
203		}
204	case *uint32:
205		if v != nil {
206			*v, err = d.DecodeUint32()
207			return err
208		}
209	case *uint64:
210		if v != nil {
211			*v, err = d.DecodeUint64()
212			return err
213		}
214	case *bool:
215		if v != nil {
216			*v, err = d.DecodeBool()
217			return err
218		}
219	case *float32:
220		if v != nil {
221			*v, err = d.DecodeFloat32()
222			return err
223		}
224	case *float64:
225		if v != nil {
226			*v, err = d.DecodeFloat64()
227			return err
228		}
229	case *[]string:
230		return d.decodeStringSlicePtr(v)
231	case *map[string]string:
232		return d.decodeMapStringStringPtr(v)
233	case *map[string]interface{}:
234		return d.decodeMapStringInterfacePtr(v)
235	case *time.Duration:
236		if v != nil {
237			vv, err := d.DecodeInt64()
238			*v = time.Duration(vv)
239			return err
240		}
241	case *time.Time:
242		if v != nil {
243			*v, err = d.DecodeTime()
244			return err
245		}
246	}
247
248	vv := reflect.ValueOf(v)
249	if !vv.IsValid() {
250		return errors.New("msgpack: Decode(nil)")
251	}
252	if vv.Kind() != reflect.Ptr {
253		return fmt.Errorf("msgpack: Decode(nonsettable %T)", v)
254	}
255	vv = vv.Elem()
256	if !vv.IsValid() {
257		return fmt.Errorf("msgpack: Decode(nonsettable %T)", v)
258	}
259	return d.DecodeValue(vv)
260}
261
262func (d *Decoder) DecodeMulti(v ...interface{}) error {
263	for _, vv := range v {
264		if err := d.Decode(vv); err != nil {
265			return err
266		}
267	}
268	return nil
269}
270
271func (d *Decoder) decodeInterfaceCond() (interface{}, error) {
272	if d.flags&looseIfaceFlag != 0 {
273		return d.DecodeInterfaceLoose()
274	}
275	return d.DecodeInterface()
276}
277
278func (d *Decoder) DecodeValue(v reflect.Value) error {
279	decode := getDecoder(v.Type())
280	return decode(d, v)
281}
282
283func (d *Decoder) DecodeNil() error {
284	c, err := d.readCode()
285	if err != nil {
286		return err
287	}
288	if c != codes.Nil {
289		return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
290	}
291	return nil
292}
293
294func (d *Decoder) decodeNilValue(v reflect.Value) error {
295	err := d.DecodeNil()
296	if v.IsNil() {
297		return err
298	}
299	if v.Kind() == reflect.Ptr {
300		v = v.Elem()
301	}
302	v.Set(reflect.Zero(v.Type()))
303	return err
304}
305
306func (d *Decoder) DecodeBool() (bool, error) {
307	c, err := d.readCode()
308	if err != nil {
309		return false, err
310	}
311	return d.bool(c)
312}
313
314func (d *Decoder) bool(c codes.Code) (bool, error) {
315	if c == codes.False {
316		return false, nil
317	}
318	if c == codes.True {
319		return true, nil
320	}
321	return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
322}
323
324func (d *Decoder) DecodeDuration() (time.Duration, error) {
325	n, err := d.DecodeInt64()
326	if err != nil {
327		return 0, err
328	}
329	return time.Duration(n), nil
330}
331
332// DecodeInterface decodes value into interface. It returns following types:
333//   - nil,
334//   - bool,
335//   - int8, int16, int32, int64,
336//   - uint8, uint16, uint32, uint64,
337//   - float32 and float64,
338//   - string,
339//   - []byte,
340//   - slices of any of the above,
341//   - maps of any of the above.
342//
343// DecodeInterface should be used only when you don't know the type of value
344// you are decoding. For example, if you are decoding number it is better to use
345// DecodeInt64 for negative numbers and DecodeUint64 for positive numbers.
346func (d *Decoder) DecodeInterface() (interface{}, error) {
347	c, err := d.readCode()
348	if err != nil {
349		return nil, err
350	}
351
352	if codes.IsFixedNum(c) {
353		return int8(c), nil
354	}
355	if codes.IsFixedMap(c) {
356		err = d.s.UnreadByte()
357		if err != nil {
358			return nil, err
359		}
360		return d.DecodeMap()
361	}
362	if codes.IsFixedArray(c) {
363		return d.decodeSlice(c)
364	}
365	if codes.IsFixedString(c) {
366		return d.string(c)
367	}
368
369	switch c {
370	case codes.Nil:
371		return nil, nil
372	case codes.False, codes.True:
373		return d.bool(c)
374	case codes.Float:
375		return d.float32(c)
376	case codes.Double:
377		return d.float64(c)
378	case codes.Uint8:
379		return d.uint8()
380	case codes.Uint16:
381		return d.uint16()
382	case codes.Uint32:
383		return d.uint32()
384	case codes.Uint64:
385		return d.uint64()
386	case codes.Int8:
387		return d.int8()
388	case codes.Int16:
389		return d.int16()
390	case codes.Int32:
391		return d.int32()
392	case codes.Int64:
393		return d.int64()
394	case codes.Bin8, codes.Bin16, codes.Bin32:
395		return d.bytes(c, nil)
396	case codes.Str8, codes.Str16, codes.Str32:
397		return d.string(c)
398	case codes.Array16, codes.Array32:
399		return d.decodeSlice(c)
400	case codes.Map16, codes.Map32:
401		err = d.s.UnreadByte()
402		if err != nil {
403			return nil, err
404		}
405		return d.DecodeMap()
406	case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
407		codes.Ext8, codes.Ext16, codes.Ext32:
408		return d.extInterface(c)
409	}
410
411	return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
412}
413
414// DecodeInterfaceLoose is like DecodeInterface except that:
415//   - int8, int16, and int32 are converted to int64,
416//   - uint8, uint16, and uint32 are converted to uint64,
417//   - float32 is converted to float64.
418func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
419	c, err := d.readCode()
420	if err != nil {
421		return nil, err
422	}
423
424	if codes.IsFixedNum(c) {
425		return int64(int8(c)), nil
426	}
427	if codes.IsFixedMap(c) {
428		err = d.s.UnreadByte()
429		if err != nil {
430			return nil, err
431		}
432		return d.DecodeMap()
433	}
434	if codes.IsFixedArray(c) {
435		return d.decodeSlice(c)
436	}
437	if codes.IsFixedString(c) {
438		return d.string(c)
439	}
440
441	switch c {
442	case codes.Nil:
443		return nil, nil
444	case codes.False, codes.True:
445		return d.bool(c)
446	case codes.Float, codes.Double:
447		return d.float64(c)
448	case codes.Uint8, codes.Uint16, codes.Uint32, codes.Uint64:
449		return d.uint(c)
450	case codes.Int8, codes.Int16, codes.Int32, codes.Int64:
451		return d.int(c)
452	case codes.Bin8, codes.Bin16, codes.Bin32:
453		return d.bytes(c, nil)
454	case codes.Str8, codes.Str16, codes.Str32:
455		return d.string(c)
456	case codes.Array16, codes.Array32:
457		return d.decodeSlice(c)
458	case codes.Map16, codes.Map32:
459		err = d.s.UnreadByte()
460		if err != nil {
461			return nil, err
462		}
463		return d.DecodeMap()
464	case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
465		codes.Ext8, codes.Ext16, codes.Ext32:
466		return d.extInterface(c)
467	}
468
469	return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
470}
471
472// Skip skips next value.
473func (d *Decoder) Skip() error {
474	c, err := d.readCode()
475	if err != nil {
476		return err
477	}
478
479	if codes.IsFixedNum(c) {
480		return nil
481	}
482	if codes.IsFixedMap(c) {
483		return d.skipMap(c)
484	}
485	if codes.IsFixedArray(c) {
486		return d.skipSlice(c)
487	}
488	if codes.IsFixedString(c) {
489		return d.skipBytes(c)
490	}
491
492	switch c {
493	case codes.Nil, codes.False, codes.True:
494		return nil
495	case codes.Uint8, codes.Int8:
496		return d.skipN(1)
497	case codes.Uint16, codes.Int16:
498		return d.skipN(2)
499	case codes.Uint32, codes.Int32, codes.Float:
500		return d.skipN(4)
501	case codes.Uint64, codes.Int64, codes.Double:
502		return d.skipN(8)
503	case codes.Bin8, codes.Bin16, codes.Bin32:
504		return d.skipBytes(c)
505	case codes.Str8, codes.Str16, codes.Str32:
506		return d.skipBytes(c)
507	case codes.Array16, codes.Array32:
508		return d.skipSlice(c)
509	case codes.Map16, codes.Map32:
510		return d.skipMap(c)
511	case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
512		codes.Ext8, codes.Ext16, codes.Ext32:
513		return d.skipExt(c)
514	}
515
516	return fmt.Errorf("msgpack: unknown code %x", c)
517}
518
519// PeekCode returns the next MessagePack code without advancing the reader.
520// Subpackage msgpack/codes contains list of available codes.
521func (d *Decoder) PeekCode() (codes.Code, error) {
522	c, err := d.s.ReadByte()
523	if err != nil {
524		return 0, err
525	}
526	return codes.Code(c), d.s.UnreadByte()
527}
528
529func (d *Decoder) hasNilCode() bool {
530	code, err := d.PeekCode()
531	return err == nil && code == codes.Nil
532}
533
534func (d *Decoder) readCode() (codes.Code, error) {
535	d.extLen = 0
536	c, err := d.s.ReadByte()
537	if err != nil {
538		return 0, err
539	}
540	if d.rec != nil {
541		d.rec = append(d.rec, c)
542	}
543	return codes.Code(c), nil
544}
545
546func (d *Decoder) readFull(b []byte) error {
547	_, err := io.ReadFull(d.r, b)
548	if err != nil {
549		return err
550	}
551	if d.rec != nil {
552		//TODO: read directly into d.rec?
553		d.rec = append(d.rec, b...)
554	}
555	return nil
556}
557
558func (d *Decoder) readN(n int) ([]byte, error) {
559	var err error
560	d.buf, err = readN(d.r, d.buf, n)
561	if err != nil {
562		return nil, err
563	}
564	if d.rec != nil {
565		//TODO: read directly into d.rec?
566		d.rec = append(d.rec, d.buf...)
567	}
568	return d.buf, nil
569}
570
571func readN(r io.Reader, b []byte, n int) ([]byte, error) {
572	if b == nil {
573		if n == 0 {
574			return make([]byte, 0), nil
575		}
576		switch {
577		case n < 64:
578			b = make([]byte, 0, 64)
579		case n <= bytesAllocLimit:
580			b = make([]byte, 0, n)
581		default:
582			b = make([]byte, 0, bytesAllocLimit)
583		}
584	}
585
586	if n <= cap(b) {
587		b = b[:n]
588		_, err := io.ReadFull(r, b)
589		return b, err
590	}
591	b = b[:cap(b)]
592
593	var pos int
594	for {
595		alloc := min(n-len(b), bytesAllocLimit)
596		b = append(b, make([]byte, alloc)...)
597
598		_, err := io.ReadFull(r, b[pos:])
599		if err != nil {
600			return b, err
601		}
602
603		if len(b) == n {
604			break
605		}
606		pos = len(b)
607	}
608
609	return b, nil
610}
611
612func min(a, b int) int { //nolint:unparam
613	if a <= b {
614		return a
615	}
616	return b
617}
618