1// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
2// Use of this source code is governed by a BSD-style license found in the LICENSE file.
3
4/*
5MSGPACK
6
7Msgpack-c implementation powers the c, c++, python, ruby, etc libraries.
8We need to maintain compatibility with it and how it encodes integer values
9without caring about the type.
10
11For compatibility with behaviour of msgpack-c reference implementation:
12  - Go intX (>0) and uintX
13       IS ENCODED AS
14    msgpack +ve fixnum, unsigned
15  - Go intX (<0)
16       IS ENCODED AS
17    msgpack -ve fixnum, signed
18
19*/
20package codec
21
22import (
23	"fmt"
24	"io"
25	"math"
26	"net/rpc"
27)
28
29const (
30	mpPosFixNumMin byte = 0x00
31	mpPosFixNumMax      = 0x7f
32	mpFixMapMin         = 0x80
33	mpFixMapMax         = 0x8f
34	mpFixArrayMin       = 0x90
35	mpFixArrayMax       = 0x9f
36	mpFixStrMin         = 0xa0
37	mpFixStrMax         = 0xbf
38	mpNil               = 0xc0
39	_                   = 0xc1
40	mpFalse             = 0xc2
41	mpTrue              = 0xc3
42	mpFloat             = 0xca
43	mpDouble            = 0xcb
44	mpUint8             = 0xcc
45	mpUint16            = 0xcd
46	mpUint32            = 0xce
47	mpUint64            = 0xcf
48	mpInt8              = 0xd0
49	mpInt16             = 0xd1
50	mpInt32             = 0xd2
51	mpInt64             = 0xd3
52
53	// extensions below
54	mpBin8     = 0xc4
55	mpBin16    = 0xc5
56	mpBin32    = 0xc6
57	mpExt8     = 0xc7
58	mpExt16    = 0xc8
59	mpExt32    = 0xc9
60	mpFixExt1  = 0xd4
61	mpFixExt2  = 0xd5
62	mpFixExt4  = 0xd6
63	mpFixExt8  = 0xd7
64	mpFixExt16 = 0xd8
65
66	mpStr8  = 0xd9 // new
67	mpStr16 = 0xda
68	mpStr32 = 0xdb
69
70	mpArray16 = 0xdc
71	mpArray32 = 0xdd
72
73	mpMap16 = 0xde
74	mpMap32 = 0xdf
75
76	mpNegFixNumMin = 0xe0
77	mpNegFixNumMax = 0xff
78)
79
80// MsgpackSpecRpcMultiArgs is a special type which signifies to the MsgpackSpecRpcCodec
81// that the backend RPC service takes multiple arguments, which have been arranged
82// in sequence in the slice.
83//
84// The Codec then passes it AS-IS to the rpc service (without wrapping it in an
85// array of 1 element).
86type MsgpackSpecRpcMultiArgs []interface{}
87
88// A MsgpackContainer type specifies the different types of msgpackContainers.
89type msgpackContainerType struct {
90	fixCutoff                   int
91	bFixMin, b8, b16, b32       byte
92	hasFixMin, has8, has8Always bool
93}
94
95var (
96	msgpackContainerStr  = msgpackContainerType{32, mpFixStrMin, mpStr8, mpStr16, mpStr32, true, true, false}
97	msgpackContainerBin  = msgpackContainerType{0, 0, mpBin8, mpBin16, mpBin32, false, true, true}
98	msgpackContainerList = msgpackContainerType{16, mpFixArrayMin, 0, mpArray16, mpArray32, true, false, false}
99	msgpackContainerMap  = msgpackContainerType{16, mpFixMapMin, 0, mpMap16, mpMap32, true, false, false}
100)
101
102//---------------------------------------------
103
104type msgpackEncDriver struct {
105	w encWriter
106	h *MsgpackHandle
107}
108
109func (e *msgpackEncDriver) isBuiltinType(rt uintptr) bool {
110	//no builtin types. All encodings are based on kinds. Types supported as extensions.
111	return false
112}
113
114func (e *msgpackEncDriver) encodeBuiltin(rt uintptr, v interface{}) {}
115
116func (e *msgpackEncDriver) encodeNil() {
117	e.w.writen1(mpNil)
118}
119
120func (e *msgpackEncDriver) encodeInt(i int64) {
121
122	switch {
123	case i >= 0:
124		e.encodeUint(uint64(i))
125	case i >= -32:
126		e.w.writen1(byte(i))
127	case i >= math.MinInt8:
128		e.w.writen2(mpInt8, byte(i))
129	case i >= math.MinInt16:
130		e.w.writen1(mpInt16)
131		e.w.writeUint16(uint16(i))
132	case i >= math.MinInt32:
133		e.w.writen1(mpInt32)
134		e.w.writeUint32(uint32(i))
135	default:
136		e.w.writen1(mpInt64)
137		e.w.writeUint64(uint64(i))
138	}
139}
140
141func (e *msgpackEncDriver) encodeUint(i uint64) {
142	switch {
143	case i <= math.MaxInt8:
144		e.w.writen1(byte(i))
145	case i <= math.MaxUint8:
146		e.w.writen2(mpUint8, byte(i))
147	case i <= math.MaxUint16:
148		e.w.writen1(mpUint16)
149		e.w.writeUint16(uint16(i))
150	case i <= math.MaxUint32:
151		e.w.writen1(mpUint32)
152		e.w.writeUint32(uint32(i))
153	default:
154		e.w.writen1(mpUint64)
155		e.w.writeUint64(uint64(i))
156	}
157}
158
159func (e *msgpackEncDriver) encodeBool(b bool) {
160	if b {
161		e.w.writen1(mpTrue)
162	} else {
163		e.w.writen1(mpFalse)
164	}
165}
166
167func (e *msgpackEncDriver) encodeFloat32(f float32) {
168	e.w.writen1(mpFloat)
169	e.w.writeUint32(math.Float32bits(f))
170}
171
172func (e *msgpackEncDriver) encodeFloat64(f float64) {
173	e.w.writen1(mpDouble)
174	e.w.writeUint64(math.Float64bits(f))
175}
176
177func (e *msgpackEncDriver) encodeExtPreamble(xtag byte, l int) {
178	switch {
179	case l == 1:
180		e.w.writen2(mpFixExt1, xtag)
181	case l == 2:
182		e.w.writen2(mpFixExt2, xtag)
183	case l == 4:
184		e.w.writen2(mpFixExt4, xtag)
185	case l == 8:
186		e.w.writen2(mpFixExt8, xtag)
187	case l == 16:
188		e.w.writen2(mpFixExt16, xtag)
189	case l < 256:
190		e.w.writen2(mpExt8, byte(l))
191		e.w.writen1(xtag)
192	case l < 65536:
193		e.w.writen1(mpExt16)
194		e.w.writeUint16(uint16(l))
195		e.w.writen1(xtag)
196	default:
197		e.w.writen1(mpExt32)
198		e.w.writeUint32(uint32(l))
199		e.w.writen1(xtag)
200	}
201}
202
203func (e *msgpackEncDriver) encodeArrayPreamble(length int) {
204	e.writeContainerLen(msgpackContainerList, length)
205}
206
207func (e *msgpackEncDriver) encodeMapPreamble(length int) {
208	e.writeContainerLen(msgpackContainerMap, length)
209}
210
211func (e *msgpackEncDriver) encodeString(c charEncoding, s string) {
212	if c == c_RAW && e.h.WriteExt {
213		e.writeContainerLen(msgpackContainerBin, len(s))
214	} else {
215		e.writeContainerLen(msgpackContainerStr, len(s))
216	}
217	if len(s) > 0 {
218		e.w.writestr(s)
219	}
220}
221
222func (e *msgpackEncDriver) encodeSymbol(v string) {
223	e.encodeString(c_UTF8, v)
224}
225
226func (e *msgpackEncDriver) encodeStringBytes(c charEncoding, bs []byte) {
227	if c == c_RAW && e.h.WriteExt {
228		e.writeContainerLen(msgpackContainerBin, len(bs))
229	} else {
230		e.writeContainerLen(msgpackContainerStr, len(bs))
231	}
232	if len(bs) > 0 {
233		e.w.writeb(bs)
234	}
235}
236
237func (e *msgpackEncDriver) writeContainerLen(ct msgpackContainerType, l int) {
238	switch {
239	case ct.hasFixMin && l < ct.fixCutoff:
240		e.w.writen1(ct.bFixMin | byte(l))
241	case ct.has8 && l < 256 && (ct.has8Always || e.h.WriteExt):
242		e.w.writen2(ct.b8, uint8(l))
243	case l < 65536:
244		e.w.writen1(ct.b16)
245		e.w.writeUint16(uint16(l))
246	default:
247		e.w.writen1(ct.b32)
248		e.w.writeUint32(uint32(l))
249	}
250}
251
252//---------------------------------------------
253
254type msgpackDecDriver struct {
255	r      decReader
256	h      *MsgpackHandle
257	bd     byte
258	bdRead bool
259	bdType valueType
260}
261
262func (d *msgpackDecDriver) isBuiltinType(rt uintptr) bool {
263	//no builtin types. All encodings are based on kinds. Types supported as extensions.
264	return false
265}
266
267func (d *msgpackDecDriver) decodeBuiltin(rt uintptr, v interface{}) {}
268
269// Note: This returns either a primitive (int, bool, etc) for non-containers,
270// or a containerType, or a specific type denoting nil or extension.
271// It is called when a nil interface{} is passed, leaving it up to the DecDriver
272// to introspect the stream and decide how best to decode.
273// It deciphers the value by looking at the stream first.
274func (d *msgpackDecDriver) decodeNaked() (v interface{}, vt valueType, decodeFurther bool) {
275	d.initReadNext()
276	bd := d.bd
277
278	switch bd {
279	case mpNil:
280		vt = valueTypeNil
281		d.bdRead = false
282	case mpFalse:
283		vt = valueTypeBool
284		v = false
285	case mpTrue:
286		vt = valueTypeBool
287		v = true
288
289	case mpFloat:
290		vt = valueTypeFloat
291		v = float64(math.Float32frombits(d.r.readUint32()))
292	case mpDouble:
293		vt = valueTypeFloat
294		v = math.Float64frombits(d.r.readUint64())
295
296	case mpUint8:
297		vt = valueTypeUint
298		v = uint64(d.r.readn1())
299	case mpUint16:
300		vt = valueTypeUint
301		v = uint64(d.r.readUint16())
302	case mpUint32:
303		vt = valueTypeUint
304		v = uint64(d.r.readUint32())
305	case mpUint64:
306		vt = valueTypeUint
307		v = uint64(d.r.readUint64())
308
309	case mpInt8:
310		vt = valueTypeInt
311		v = int64(int8(d.r.readn1()))
312	case mpInt16:
313		vt = valueTypeInt
314		v = int64(int16(d.r.readUint16()))
315	case mpInt32:
316		vt = valueTypeInt
317		v = int64(int32(d.r.readUint32()))
318	case mpInt64:
319		vt = valueTypeInt
320		v = int64(int64(d.r.readUint64()))
321
322	default:
323		switch {
324		case bd >= mpPosFixNumMin && bd <= mpPosFixNumMax:
325			// positive fixnum (always signed)
326			vt = valueTypeInt
327			v = int64(int8(bd))
328		case bd >= mpNegFixNumMin && bd <= mpNegFixNumMax:
329			// negative fixnum
330			vt = valueTypeInt
331			v = int64(int8(bd))
332		case bd == mpStr8, bd == mpStr16, bd == mpStr32, bd >= mpFixStrMin && bd <= mpFixStrMax:
333			if d.h.RawToString {
334				var rvm string
335				vt = valueTypeString
336				v = &rvm
337			} else {
338				var rvm = []byte{}
339				vt = valueTypeBytes
340				v = &rvm
341			}
342			decodeFurther = true
343		case bd == mpBin8, bd == mpBin16, bd == mpBin32:
344			var rvm = []byte{}
345			vt = valueTypeBytes
346			v = &rvm
347			decodeFurther = true
348		case bd == mpArray16, bd == mpArray32, bd >= mpFixArrayMin && bd <= mpFixArrayMax:
349			vt = valueTypeArray
350			decodeFurther = true
351		case bd == mpMap16, bd == mpMap32, bd >= mpFixMapMin && bd <= mpFixMapMax:
352			vt = valueTypeMap
353			decodeFurther = true
354		case bd >= mpFixExt1 && bd <= mpFixExt16, bd >= mpExt8 && bd <= mpExt32:
355			clen := d.readExtLen()
356			var re RawExt
357			re.Tag = d.r.readn1()
358			re.Data = d.r.readn(clen)
359			v = &re
360			vt = valueTypeExt
361		default:
362			decErr("Nil-Deciphered DecodeValue: %s: hex: %x, dec: %d", msgBadDesc, bd, bd)
363		}
364	}
365	if !decodeFurther {
366		d.bdRead = false
367	}
368	return
369}
370
371// int can be decoded from msgpack type: intXXX or uintXXX
372func (d *msgpackDecDriver) decodeInt(bitsize uint8) (i int64) {
373	switch d.bd {
374	case mpUint8:
375		i = int64(uint64(d.r.readn1()))
376	case mpUint16:
377		i = int64(uint64(d.r.readUint16()))
378	case mpUint32:
379		i = int64(uint64(d.r.readUint32()))
380	case mpUint64:
381		i = int64(d.r.readUint64())
382	case mpInt8:
383		i = int64(int8(d.r.readn1()))
384	case mpInt16:
385		i = int64(int16(d.r.readUint16()))
386	case mpInt32:
387		i = int64(int32(d.r.readUint32()))
388	case mpInt64:
389		i = int64(d.r.readUint64())
390	default:
391		switch {
392		case d.bd >= mpPosFixNumMin && d.bd <= mpPosFixNumMax:
393			i = int64(int8(d.bd))
394		case d.bd >= mpNegFixNumMin && d.bd <= mpNegFixNumMax:
395			i = int64(int8(d.bd))
396		default:
397			decErr("Unhandled single-byte unsigned integer value: %s: %x", msgBadDesc, d.bd)
398		}
399	}
400	// check overflow (logic adapted from std pkg reflect/value.go OverflowUint()
401	if bitsize > 0 {
402		if trunc := (i << (64 - bitsize)) >> (64 - bitsize); i != trunc {
403			decErr("Overflow int value: %v", i)
404		}
405	}
406	d.bdRead = false
407	return
408}
409
410// uint can be decoded from msgpack type: intXXX or uintXXX
411func (d *msgpackDecDriver) decodeUint(bitsize uint8) (ui uint64) {
412	switch d.bd {
413	case mpUint8:
414		ui = uint64(d.r.readn1())
415	case mpUint16:
416		ui = uint64(d.r.readUint16())
417	case mpUint32:
418		ui = uint64(d.r.readUint32())
419	case mpUint64:
420		ui = d.r.readUint64()
421	case mpInt8:
422		if i := int64(int8(d.r.readn1())); i >= 0 {
423			ui = uint64(i)
424		} else {
425			decErr("Assigning negative signed value: %v, to unsigned type", i)
426		}
427	case mpInt16:
428		if i := int64(int16(d.r.readUint16())); i >= 0 {
429			ui = uint64(i)
430		} else {
431			decErr("Assigning negative signed value: %v, to unsigned type", i)
432		}
433	case mpInt32:
434		if i := int64(int32(d.r.readUint32())); i >= 0 {
435			ui = uint64(i)
436		} else {
437			decErr("Assigning negative signed value: %v, to unsigned type", i)
438		}
439	case mpInt64:
440		if i := int64(d.r.readUint64()); i >= 0 {
441			ui = uint64(i)
442		} else {
443			decErr("Assigning negative signed value: %v, to unsigned type", i)
444		}
445	default:
446		switch {
447		case d.bd >= mpPosFixNumMin && d.bd <= mpPosFixNumMax:
448			ui = uint64(d.bd)
449		case d.bd >= mpNegFixNumMin && d.bd <= mpNegFixNumMax:
450			decErr("Assigning negative signed value: %v, to unsigned type", int(d.bd))
451		default:
452			decErr("Unhandled single-byte unsigned integer value: %s: %x", msgBadDesc, d.bd)
453		}
454	}
455	// check overflow (logic adapted from std pkg reflect/value.go OverflowUint()
456	if bitsize > 0 {
457		if trunc := (ui << (64 - bitsize)) >> (64 - bitsize); ui != trunc {
458			decErr("Overflow uint value: %v", ui)
459		}
460	}
461	d.bdRead = false
462	return
463}
464
465// float can either be decoded from msgpack type: float, double or intX
466func (d *msgpackDecDriver) decodeFloat(chkOverflow32 bool) (f float64) {
467	switch d.bd {
468	case mpFloat:
469		f = float64(math.Float32frombits(d.r.readUint32()))
470	case mpDouble:
471		f = math.Float64frombits(d.r.readUint64())
472	default:
473		f = float64(d.decodeInt(0))
474	}
475	checkOverflowFloat32(f, chkOverflow32)
476	d.bdRead = false
477	return
478}
479
480// bool can be decoded from bool, fixnum 0 or 1.
481func (d *msgpackDecDriver) decodeBool() (b bool) {
482	switch d.bd {
483	case mpFalse, 0:
484		// b = false
485	case mpTrue, 1:
486		b = true
487	default:
488		decErr("Invalid single-byte value for bool: %s: %x", msgBadDesc, d.bd)
489	}
490	d.bdRead = false
491	return
492}
493
494func (d *msgpackDecDriver) decodeString() (s string) {
495	clen := d.readContainerLen(msgpackContainerStr)
496	if clen > 0 {
497		s = string(d.r.readn(clen))
498	}
499	d.bdRead = false
500	return
501}
502
503// Callers must check if changed=true (to decide whether to replace the one they have)
504func (d *msgpackDecDriver) decodeBytes(bs []byte) (bsOut []byte, changed bool) {
505	// bytes can be decoded from msgpackContainerStr or msgpackContainerBin
506	var clen int
507	switch d.bd {
508	case mpBin8, mpBin16, mpBin32:
509		clen = d.readContainerLen(msgpackContainerBin)
510	default:
511		clen = d.readContainerLen(msgpackContainerStr)
512	}
513	// if clen < 0 {
514	// 	changed = true
515	// 	panic("length cannot be zero. this cannot be nil.")
516	// }
517	if clen > 0 {
518		// if no contents in stream, don't update the passed byteslice
519		if len(bs) != clen {
520			// Return changed=true if length of passed slice diff from length of bytes in stream
521			if len(bs) > clen {
522				bs = bs[:clen]
523			} else {
524				bs = make([]byte, clen)
525			}
526			bsOut = bs
527			changed = true
528		}
529		d.r.readb(bs)
530	}
531	d.bdRead = false
532	return
533}
534
535// Every top-level decode funcs (i.e. decodeValue, decode) must call this first.
536func (d *msgpackDecDriver) initReadNext() {
537	if d.bdRead {
538		return
539	}
540	d.bd = d.r.readn1()
541	d.bdRead = true
542	d.bdType = valueTypeUnset
543}
544
545func (d *msgpackDecDriver) currentEncodedType() valueType {
546	if d.bdType == valueTypeUnset {
547		bd := d.bd
548		switch bd {
549		case mpNil:
550			d.bdType = valueTypeNil
551		case mpFalse, mpTrue:
552			d.bdType = valueTypeBool
553		case mpFloat, mpDouble:
554			d.bdType = valueTypeFloat
555		case mpUint8, mpUint16, mpUint32, mpUint64:
556			d.bdType = valueTypeUint
557		case mpInt8, mpInt16, mpInt32, mpInt64:
558			d.bdType = valueTypeInt
559		default:
560			switch {
561			case bd >= mpPosFixNumMin && bd <= mpPosFixNumMax:
562				d.bdType = valueTypeInt
563			case bd >= mpNegFixNumMin && bd <= mpNegFixNumMax:
564				d.bdType = valueTypeInt
565			case bd == mpStr8, bd == mpStr16, bd == mpStr32, bd >= mpFixStrMin && bd <= mpFixStrMax:
566				if d.h.RawToString {
567					d.bdType = valueTypeString
568				} else {
569					d.bdType = valueTypeBytes
570				}
571			case bd == mpBin8, bd == mpBin16, bd == mpBin32:
572				d.bdType = valueTypeBytes
573			case bd == mpArray16, bd == mpArray32, bd >= mpFixArrayMin && bd <= mpFixArrayMax:
574				d.bdType = valueTypeArray
575			case bd == mpMap16, bd == mpMap32, bd >= mpFixMapMin && bd <= mpFixMapMax:
576				d.bdType = valueTypeMap
577			case bd >= mpFixExt1 && bd <= mpFixExt16, bd >= mpExt8 && bd <= mpExt32:
578				d.bdType = valueTypeExt
579			default:
580				decErr("currentEncodedType: Undeciphered descriptor: %s: hex: %x, dec: %d", msgBadDesc, bd, bd)
581			}
582		}
583	}
584	return d.bdType
585}
586
587func (d *msgpackDecDriver) tryDecodeAsNil() bool {
588	if d.bd == mpNil {
589		d.bdRead = false
590		return true
591	}
592	return false
593}
594
595func (d *msgpackDecDriver) readContainerLen(ct msgpackContainerType) (clen int) {
596	bd := d.bd
597	switch {
598	case bd == mpNil:
599		clen = -1 // to represent nil
600	case bd == ct.b8:
601		clen = int(d.r.readn1())
602	case bd == ct.b16:
603		clen = int(d.r.readUint16())
604	case bd == ct.b32:
605		clen = int(d.r.readUint32())
606	case (ct.bFixMin & bd) == ct.bFixMin:
607		clen = int(ct.bFixMin ^ bd)
608	default:
609		decErr("readContainerLen: %s: hex: %x, dec: %d", msgBadDesc, bd, bd)
610	}
611	d.bdRead = false
612	return
613}
614
615func (d *msgpackDecDriver) readMapLen() int {
616	return d.readContainerLen(msgpackContainerMap)
617}
618
619func (d *msgpackDecDriver) readArrayLen() int {
620	return d.readContainerLen(msgpackContainerList)
621}
622
623func (d *msgpackDecDriver) readExtLen() (clen int) {
624	switch d.bd {
625	case mpNil:
626		clen = -1 // to represent nil
627	case mpFixExt1:
628		clen = 1
629	case mpFixExt2:
630		clen = 2
631	case mpFixExt4:
632		clen = 4
633	case mpFixExt8:
634		clen = 8
635	case mpFixExt16:
636		clen = 16
637	case mpExt8:
638		clen = int(d.r.readn1())
639	case mpExt16:
640		clen = int(d.r.readUint16())
641	case mpExt32:
642		clen = int(d.r.readUint32())
643	default:
644		decErr("decoding ext bytes: found unexpected byte: %x", d.bd)
645	}
646	return
647}
648
649func (d *msgpackDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) {
650	xbd := d.bd
651	switch {
652	case xbd == mpBin8, xbd == mpBin16, xbd == mpBin32:
653		xbs, _ = d.decodeBytes(nil)
654	case xbd == mpStr8, xbd == mpStr16, xbd == mpStr32,
655		xbd >= mpFixStrMin && xbd <= mpFixStrMax:
656		xbs = []byte(d.decodeString())
657	default:
658		clen := d.readExtLen()
659		xtag = d.r.readn1()
660		if verifyTag && xtag != tag {
661			decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag)
662		}
663		xbs = d.r.readn(clen)
664	}
665	d.bdRead = false
666	return
667}
668
669//--------------------------------------------------
670
671//MsgpackHandle is a Handle for the Msgpack Schema-Free Encoding Format.
672type MsgpackHandle struct {
673	BasicHandle
674
675	// RawToString controls how raw bytes are decoded into a nil interface{}.
676	RawToString bool
677	// WriteExt flag supports encoding configured extensions with extension tags.
678	// It also controls whether other elements of the new spec are encoded (ie Str8).
679	//
680	// With WriteExt=false, configured extensions are serialized as raw bytes
681	// and Str8 is not encoded.
682	//
683	// A stream can still be decoded into a typed value, provided an appropriate value
684	// is provided, but the type cannot be inferred from the stream. If no appropriate
685	// type is provided (e.g. decoding into a nil interface{}), you get back
686	// a []byte or string based on the setting of RawToString.
687	WriteExt bool
688}
689
690func (h *MsgpackHandle) newEncDriver(w encWriter) encDriver {
691	return &msgpackEncDriver{w: w, h: h}
692}
693
694func (h *MsgpackHandle) newDecDriver(r decReader) decDriver {
695	return &msgpackDecDriver{r: r, h: h}
696}
697
698func (h *MsgpackHandle) writeExt() bool {
699	return h.WriteExt
700}
701
702func (h *MsgpackHandle) getBasicHandle() *BasicHandle {
703	return &h.BasicHandle
704}
705
706//--------------------------------------------------
707
708type msgpackSpecRpcCodec struct {
709	rpcCodec
710}
711
712// /////////////// Spec RPC Codec ///////////////////
713func (c *msgpackSpecRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
714	// WriteRequest can write to both a Go service, and other services that do
715	// not abide by the 1 argument rule of a Go service.
716	// We discriminate based on if the body is a MsgpackSpecRpcMultiArgs
717	var bodyArr []interface{}
718	if m, ok := body.(MsgpackSpecRpcMultiArgs); ok {
719		bodyArr = ([]interface{})(m)
720	} else {
721		bodyArr = []interface{}{body}
722	}
723	r2 := []interface{}{0, uint32(r.Seq), r.ServiceMethod, bodyArr}
724	return c.write(r2, nil, false, true)
725}
726
727func (c *msgpackSpecRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error {
728	var moe interface{}
729	if r.Error != "" {
730		moe = r.Error
731	}
732	if moe != nil && body != nil {
733		body = nil
734	}
735	r2 := []interface{}{1, uint32(r.Seq), moe, body}
736	return c.write(r2, nil, false, true)
737}
738
739func (c *msgpackSpecRpcCodec) ReadResponseHeader(r *rpc.Response) error {
740	return c.parseCustomHeader(1, &r.Seq, &r.Error)
741}
742
743func (c *msgpackSpecRpcCodec) ReadRequestHeader(r *rpc.Request) error {
744	return c.parseCustomHeader(0, &r.Seq, &r.ServiceMethod)
745}
746
747func (c *msgpackSpecRpcCodec) ReadRequestBody(body interface{}) error {
748	if body == nil { // read and discard
749		return c.read(nil)
750	}
751	bodyArr := []interface{}{body}
752	return c.read(&bodyArr)
753}
754
755func (c *msgpackSpecRpcCodec) parseCustomHeader(expectTypeByte byte, msgid *uint64, methodOrError *string) (err error) {
756
757	if c.cls {
758		return io.EOF
759	}
760
761	// We read the response header by hand
762	// so that the body can be decoded on its own from the stream at a later time.
763
764	const fia byte = 0x94 //four item array descriptor value
765	// Not sure why the panic of EOF is swallowed above.
766	// if bs1 := c.dec.r.readn1(); bs1 != fia {
767	// 	err = fmt.Errorf("Unexpected value for array descriptor: Expecting %v. Received %v", fia, bs1)
768	// 	return
769	// }
770	var b byte
771	b, err = c.br.ReadByte()
772	if err != nil {
773		return
774	}
775	if b != fia {
776		err = fmt.Errorf("Unexpected value for array descriptor: Expecting %v. Received %v", fia, b)
777		return
778	}
779
780	if err = c.read(&b); err != nil {
781		return
782	}
783	if b != expectTypeByte {
784		err = fmt.Errorf("Unexpected byte descriptor in header. Expecting %v. Received %v", expectTypeByte, b)
785		return
786	}
787	if err = c.read(msgid); err != nil {
788		return
789	}
790	if err = c.read(methodOrError); err != nil {
791		return
792	}
793	return
794}
795
796//--------------------------------------------------
797
798// msgpackSpecRpc is the implementation of Rpc that uses custom communication protocol
799// as defined in the msgpack spec at https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md
800type msgpackSpecRpc struct{}
801
802// MsgpackSpecRpc implements Rpc using the communication protocol defined in
803// the msgpack spec at https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md .
804// Its methods (ServerCodec and ClientCodec) return values that implement RpcCodecBuffered.
805var MsgpackSpecRpc msgpackSpecRpc
806
807func (x msgpackSpecRpc) ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec {
808	return &msgpackSpecRpcCodec{newRPCCodec(conn, h)}
809}
810
811func (x msgpackSpecRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec {
812	return &msgpackSpecRpcCodec{newRPCCodec(conn, h)}
813}
814
815var _ decDriver = (*msgpackDecDriver)(nil)
816var _ encDriver = (*msgpackEncDriver)(nil)
817