1package msgp
2
3import (
4	"fmt"
5	"math"
6)
7
8const (
9	// Complex64Extension is the extension number used for complex64
10	Complex64Extension = 3
11
12	// Complex128Extension is the extension number used for complex128
13	Complex128Extension = 4
14
15	// TimeExtension is the extension number used for time.Time
16	TimeExtension = 5
17)
18
19// our extensions live here
20var extensionReg = make(map[int8]func() Extension)
21
22// RegisterExtension registers extensions so that they
23// can be initialized and returned by methods that
24// decode `interface{}` values. This should only
25// be called during initialization. f() should return
26// a newly-initialized zero value of the extension. Keep in
27// mind that extensions 3, 4, and 5 are reserved for
28// complex64, complex128, and time.Time, respectively,
29// and that MessagePack reserves extension types from -127 to -1.
30//
31// For example, if you wanted to register a user-defined struct:
32//
33//  msgp.RegisterExtension(10, func() msgp.Extension { &MyExtension{} })
34//
35// RegisterExtension will panic if you call it multiple times
36// with the same 'typ' argument, or if you use a reserved
37// type (3, 4, or 5).
38func RegisterExtension(typ int8, f func() Extension) {
39	switch typ {
40	case Complex64Extension, Complex128Extension, TimeExtension:
41		panic(fmt.Sprint("msgp: forbidden extension type:", typ))
42	}
43	if _, ok := extensionReg[typ]; ok {
44		panic(fmt.Sprint("msgp: RegisterExtension() called with typ", typ, "more than once"))
45	}
46	extensionReg[typ] = f
47}
48
49// ExtensionTypeError is an error type returned
50// when there is a mis-match between an extension type
51// and the type encoded on the wire
52type ExtensionTypeError struct {
53	Got  int8
54	Want int8
55}
56
57// Error implements the error interface
58func (e ExtensionTypeError) Error() string {
59	return fmt.Sprintf("msgp: error decoding extension: wanted type %d; got type %d", e.Want, e.Got)
60}
61
62// Resumable returns 'true' for ExtensionTypeErrors
63func (e ExtensionTypeError) Resumable() bool { return true }
64
65func errExt(got int8, wanted int8) error {
66	return ExtensionTypeError{Got: got, Want: wanted}
67}
68
69// Extension is the interface fulfilled
70// by types that want to define their
71// own binary encoding.
72type Extension interface {
73	// ExtensionType should return
74	// a int8 that identifies the concrete
75	// type of the extension. (Types <0 are
76	// officially reserved by the MessagePack
77	// specifications.)
78	ExtensionType() int8
79
80	// Len should return the length
81	// of the data to be encoded
82	Len() int
83
84	// MarshalBinaryTo should copy
85	// the data into the supplied slice,
86	// assuming that the slice has length Len()
87	MarshalBinaryTo([]byte) error
88
89	UnmarshalBinary([]byte) error
90}
91
92// RawExtension implements the Extension interface
93type RawExtension struct {
94	Data []byte
95	Type int8
96}
97
98// ExtensionType implements Extension.ExtensionType, and returns r.Type
99func (r *RawExtension) ExtensionType() int8 { return r.Type }
100
101// Len implements Extension.Len, and returns len(r.Data)
102func (r *RawExtension) Len() int { return len(r.Data) }
103
104// MarshalBinaryTo implements Extension.MarshalBinaryTo,
105// and returns a copy of r.Data
106func (r *RawExtension) MarshalBinaryTo(d []byte) error {
107	copy(d, r.Data)
108	return nil
109}
110
111// UnmarshalBinary implements Extension.UnmarshalBinary,
112// and sets r.Data to the contents of the provided slice
113func (r *RawExtension) UnmarshalBinary(b []byte) error {
114	if cap(r.Data) >= len(b) {
115		r.Data = r.Data[0:len(b)]
116	} else {
117		r.Data = make([]byte, len(b))
118	}
119	copy(r.Data, b)
120	return nil
121}
122
123// WriteExtension writes an extension type to the writer
124func (mw *Writer) WriteExtension(e Extension) error {
125	l := e.Len()
126	var err error
127	switch l {
128	case 0:
129		o, err := mw.require(3)
130		if err != nil {
131			return err
132		}
133		mw.buf[o] = mext8
134		mw.buf[o+1] = 0
135		mw.buf[o+2] = byte(e.ExtensionType())
136	case 1:
137		o, err := mw.require(2)
138		if err != nil {
139			return err
140		}
141		mw.buf[o] = mfixext1
142		mw.buf[o+1] = byte(e.ExtensionType())
143	case 2:
144		o, err := mw.require(2)
145		if err != nil {
146			return err
147		}
148		mw.buf[o] = mfixext2
149		mw.buf[o+1] = byte(e.ExtensionType())
150	case 4:
151		o, err := mw.require(2)
152		if err != nil {
153			return err
154		}
155		mw.buf[o] = mfixext4
156		mw.buf[o+1] = byte(e.ExtensionType())
157	case 8:
158		o, err := mw.require(2)
159		if err != nil {
160			return err
161		}
162		mw.buf[o] = mfixext8
163		mw.buf[o+1] = byte(e.ExtensionType())
164	case 16:
165		o, err := mw.require(2)
166		if err != nil {
167			return err
168		}
169		mw.buf[o] = mfixext16
170		mw.buf[o+1] = byte(e.ExtensionType())
171	default:
172		switch {
173		case l < math.MaxUint8:
174			o, err := mw.require(3)
175			if err != nil {
176				return err
177			}
178			mw.buf[o] = mext8
179			mw.buf[o+1] = byte(uint8(l))
180			mw.buf[o+2] = byte(e.ExtensionType())
181		case l < math.MaxUint16:
182			o, err := mw.require(4)
183			if err != nil {
184				return err
185			}
186			mw.buf[o] = mext16
187			big.PutUint16(mw.buf[o+1:], uint16(l))
188			mw.buf[o+3] = byte(e.ExtensionType())
189		default:
190			o, err := mw.require(6)
191			if err != nil {
192				return err
193			}
194			mw.buf[o] = mext32
195			big.PutUint32(mw.buf[o+1:], uint32(l))
196			mw.buf[o+5] = byte(e.ExtensionType())
197		}
198	}
199	// we can only write directly to the
200	// buffer if we're sure that it
201	// fits the object
202	if l <= mw.bufsize() {
203		o, err := mw.require(l)
204		if err != nil {
205			return err
206		}
207		return e.MarshalBinaryTo(mw.buf[o:])
208	}
209	// here we create a new buffer
210	// just large enough for the body
211	// and save it as the write buffer
212	err = mw.flush()
213	if err != nil {
214		return err
215	}
216	buf := make([]byte, l)
217	err = e.MarshalBinaryTo(buf)
218	if err != nil {
219		return err
220	}
221	mw.buf = buf
222	mw.wloc = l
223	return nil
224}
225
226// peek at the extension type, assuming the next
227// kind to be read is Extension
228func (m *Reader) peekExtensionType() (int8, error) {
229	p, err := m.R.Peek(2)
230	if err != nil {
231		return 0, err
232	}
233	spec := sizes[p[0]]
234	if spec.typ != ExtensionType {
235		return 0, badPrefix(ExtensionType, p[0])
236	}
237	if spec.extra == constsize {
238		return int8(p[1]), nil
239	}
240	size := spec.size
241	p, err = m.R.Peek(int(size))
242	if err != nil {
243		return 0, err
244	}
245	return int8(p[size-1]), nil
246}
247
248// peekExtension peeks at the extension encoding type
249// (must guarantee at least 1 byte in 'b')
250func peekExtension(b []byte) (int8, error) {
251	spec := sizes[b[0]]
252	size := spec.size
253	if spec.typ != ExtensionType {
254		return 0, badPrefix(ExtensionType, b[0])
255	}
256	if len(b) < int(size) {
257		return 0, ErrShortBytes
258	}
259	// for fixed extensions,
260	// the type information is in
261	// the second byte
262	if spec.extra == constsize {
263		return int8(b[1]), nil
264	}
265	// otherwise, it's in the last
266	// part of the prefix
267	return int8(b[size-1]), nil
268}
269
270// ReadExtension reads the next object from the reader
271// as an extension. ReadExtension will fail if the next
272// object in the stream is not an extension, or if
273// e.Type() is not the same as the wire type.
274func (m *Reader) ReadExtension(e Extension) (err error) {
275	var p []byte
276	p, err = m.R.Peek(2)
277	if err != nil {
278		return
279	}
280	lead := p[0]
281	var read int
282	var off int
283	switch lead {
284	case mfixext1:
285		if int8(p[1]) != e.ExtensionType() {
286			err = errExt(int8(p[1]), e.ExtensionType())
287			return
288		}
289		p, err = m.R.Peek(3)
290		if err != nil {
291			return
292		}
293		err = e.UnmarshalBinary(p[2:])
294		if err == nil {
295			_, err = m.R.Skip(3)
296		}
297		return
298
299	case mfixext2:
300		if int8(p[1]) != e.ExtensionType() {
301			err = errExt(int8(p[1]), e.ExtensionType())
302			return
303		}
304		p, err = m.R.Peek(4)
305		if err != nil {
306			return
307		}
308		err = e.UnmarshalBinary(p[2:])
309		if err == nil {
310			_, err = m.R.Skip(4)
311		}
312		return
313
314	case mfixext4:
315		if int8(p[1]) != e.ExtensionType() {
316			err = errExt(int8(p[1]), e.ExtensionType())
317			return
318		}
319		p, err = m.R.Peek(6)
320		if err != nil {
321			return
322		}
323		err = e.UnmarshalBinary(p[2:])
324		if err == nil {
325			_, err = m.R.Skip(6)
326		}
327		return
328
329	case mfixext8:
330		if int8(p[1]) != e.ExtensionType() {
331			err = errExt(int8(p[1]), e.ExtensionType())
332			return
333		}
334		p, err = m.R.Peek(10)
335		if err != nil {
336			return
337		}
338		err = e.UnmarshalBinary(p[2:])
339		if err == nil {
340			_, err = m.R.Skip(10)
341		}
342		return
343
344	case mfixext16:
345		if int8(p[1]) != e.ExtensionType() {
346			err = errExt(int8(p[1]), e.ExtensionType())
347			return
348		}
349		p, err = m.R.Peek(18)
350		if err != nil {
351			return
352		}
353		err = e.UnmarshalBinary(p[2:])
354		if err == nil {
355			_, err = m.R.Skip(18)
356		}
357		return
358
359	case mext8:
360		p, err = m.R.Peek(3)
361		if err != nil {
362			return
363		}
364		if int8(p[2]) != e.ExtensionType() {
365			err = errExt(int8(p[2]), e.ExtensionType())
366			return
367		}
368		read = int(uint8(p[1]))
369		off = 3
370
371	case mext16:
372		p, err = m.R.Peek(4)
373		if err != nil {
374			return
375		}
376		if int8(p[3]) != e.ExtensionType() {
377			err = errExt(int8(p[3]), e.ExtensionType())
378			return
379		}
380		read = int(big.Uint16(p[1:]))
381		off = 4
382
383	case mext32:
384		p, err = m.R.Peek(6)
385		if err != nil {
386			return
387		}
388		if int8(p[5]) != e.ExtensionType() {
389			err = errExt(int8(p[5]), e.ExtensionType())
390			return
391		}
392		read = int(big.Uint32(p[1:]))
393		off = 6
394
395	default:
396		err = badPrefix(ExtensionType, lead)
397		return
398	}
399
400	p, err = m.R.Peek(read + off)
401	if err != nil {
402		return
403	}
404	err = e.UnmarshalBinary(p[off:])
405	if err == nil {
406		_, err = m.R.Skip(read + off)
407	}
408	return
409}
410
411// AppendExtension appends a MessagePack extension to the provided slice
412func AppendExtension(b []byte, e Extension) ([]byte, error) {
413	l := e.Len()
414	var o []byte
415	var n int
416	switch l {
417	case 0:
418		o, n = ensure(b, 3)
419		o[n] = mext8
420		o[n+1] = 0
421		o[n+2] = byte(e.ExtensionType())
422		return o[:n+3], nil
423	case 1:
424		o, n = ensure(b, 3)
425		o[n] = mfixext1
426		o[n+1] = byte(e.ExtensionType())
427		n += 2
428	case 2:
429		o, n = ensure(b, 4)
430		o[n] = mfixext2
431		o[n+1] = byte(e.ExtensionType())
432		n += 2
433	case 4:
434		o, n = ensure(b, 6)
435		o[n] = mfixext4
436		o[n+1] = byte(e.ExtensionType())
437		n += 2
438	case 8:
439		o, n = ensure(b, 10)
440		o[n] = mfixext8
441		o[n+1] = byte(e.ExtensionType())
442		n += 2
443	case 16:
444		o, n = ensure(b, 18)
445		o[n] = mfixext16
446		o[n+1] = byte(e.ExtensionType())
447		n += 2
448	}
449	switch {
450	case l < math.MaxUint8:
451		o, n = ensure(b, l+3)
452		o[n] = mext8
453		o[n+1] = byte(uint8(l))
454		o[n+2] = byte(e.ExtensionType())
455		n += 3
456	case l < math.MaxUint16:
457		o, n = ensure(b, l+4)
458		o[n] = mext16
459		big.PutUint16(o[n+1:], uint16(l))
460		o[n+3] = byte(e.ExtensionType())
461		n += 4
462	default:
463		o, n = ensure(b, l+6)
464		o[n] = mext32
465		big.PutUint32(o[n+1:], uint32(l))
466		o[n+5] = byte(e.ExtensionType())
467		n += 6
468	}
469	return o, e.MarshalBinaryTo(o[n:])
470}
471
472// ReadExtensionBytes reads an extension from 'b' into 'e'
473// and returns any remaining bytes.
474// Possible errors:
475// - ErrShortBytes ('b' not long enough)
476// - ExtensionTypeErorr{} (wire type not the same as e.Type())
477// - TypeErorr{} (next object not an extension)
478// - InvalidPrefixError
479// - An umarshal error returned from e.UnmarshalBinary
480func ReadExtensionBytes(b []byte, e Extension) ([]byte, error) {
481	l := len(b)
482	if l < 3 {
483		return b, ErrShortBytes
484	}
485	lead := b[0]
486	var (
487		sz  int // size of 'data'
488		off int // offset of 'data'
489		typ int8
490	)
491	switch lead {
492	case mfixext1:
493		typ = int8(b[1])
494		sz = 1
495		off = 2
496	case mfixext2:
497		typ = int8(b[1])
498		sz = 2
499		off = 2
500	case mfixext4:
501		typ = int8(b[1])
502		sz = 4
503		off = 2
504	case mfixext8:
505		typ = int8(b[1])
506		sz = 8
507		off = 2
508	case mfixext16:
509		typ = int8(b[1])
510		sz = 16
511		off = 2
512	case mext8:
513		sz = int(uint8(b[1]))
514		typ = int8(b[2])
515		off = 3
516		if sz == 0 {
517			return b[3:], e.UnmarshalBinary(b[3:3])
518		}
519	case mext16:
520		if l < 4 {
521			return b, ErrShortBytes
522		}
523		sz = int(big.Uint16(b[1:]))
524		typ = int8(b[3])
525		off = 4
526	case mext32:
527		if l < 6 {
528			return b, ErrShortBytes
529		}
530		sz = int(big.Uint32(b[1:]))
531		typ = int8(b[5])
532		off = 6
533	default:
534		return b, badPrefix(ExtensionType, lead)
535	}
536
537	if typ != e.ExtensionType() {
538		return b, errExt(typ, e.ExtensionType())
539	}
540
541	// the data of the extension starts
542	// at 'off' and is 'sz' bytes long
543	if len(b[off:]) < sz {
544		return b, ErrShortBytes
545	}
546	tot := off + sz
547	return b[tot:], e.UnmarshalBinary(b[off:tot])
548}
549