1// Copyright 2016 Keybase Inc. All rights reserved.
2// Use of this source code is governed by a BSD
3// license that can be found in the LICENSE file.
4
5package kbfscodec
6
7import (
8	"fmt"
9	"reflect"
10
11	"github.com/keybase/go-codec/codec"
12	"github.com/pkg/errors"
13)
14
15// ext is a no-op extension that's useful for tagging interfaces with
16// a type.  Note that it cannot be used for anything that has nested
17// extensions.
18type ext struct {
19	// codec should NOT encode extension types
20	codec Codec
21}
22
23// ConvertExt implements the codec.Ext interface for ext.
24func (e ext) ConvertExt(v interface{}) interface{} {
25	panic("ConvertExt not supported")
26}
27
28// UpdateExt implements the codec.Ext interface for ext.
29func (e ext) UpdateExt(dest interface{}, v interface{}) {
30	panic("UpdateExt not supported")
31}
32
33// WriteExt implements the codec.Ext interface for ext.
34func (e ext) WriteExt(v interface{}) (buf []byte) {
35	buf, err := e.codec.Encode(v)
36	if err != nil {
37		panic(fmt.Sprintf("Couldn't encode data in %v", v))
38	}
39	return buf
40}
41
42// ReadExt implements the codec.Ext interface for ext.
43func (e ext) ReadExt(v interface{}, buf []byte) {
44	err := e.codec.Decode(buf, v)
45	if err != nil {
46		panic(fmt.Sprintf("Couldn't decode data into %v", v))
47	}
48}
49
50// extSlice is an extension that's useful for slices that contain
51// extension types as elements.  The contained extension types cannot
52// themselves contain nested extension types.
53type extSlice struct {
54	// codec SHOULD encode extension types
55	codec Codec
56	typer func(interface{}) reflect.Value
57}
58
59// ConvertExt implements the codec.Ext interface for extSlice.
60func (es extSlice) ConvertExt(v interface{}) interface{} {
61	panic("ConvertExt not supported")
62}
63
64// UpdateExt implements the codec.Ext interface for extSlice.
65func (es extSlice) UpdateExt(dest interface{}, v interface{}) {
66	panic("UpdateExt not supported")
67}
68
69// WriteExt implements the codec.Ext interface for extSlice.
70func (es extSlice) WriteExt(v interface{}) (buf []byte) {
71	val := reflect.ValueOf(v)
72	if val.Kind() != reflect.Slice {
73		panic(fmt.Sprintf("Non-slice passed to extSlice.WriteExt %v",
74			val.Kind()))
75	}
76
77	ifaceArray := make([]interface{}, val.Len())
78	for i := 0; i < val.Len(); i++ {
79		ifaceArray[i] = val.Index(i).Interface()
80	}
81
82	buf, err := es.codec.Encode(ifaceArray)
83	if err != nil {
84		panic(fmt.Sprintf("Couldn't encode data in %v", v))
85	}
86	return buf
87}
88
89// ReadExt implements the codec.Ext interface for extSlice.
90func (es extSlice) ReadExt(v interface{}, buf []byte) {
91	// ReadExt actually receives a pointer to the list
92	val := reflect.ValueOf(v)
93	if val.Kind() != reflect.Ptr {
94		panic(fmt.Sprintf("Non-pointer passed to extSlice.ReadExt: %v",
95			val.Kind()))
96	}
97
98	val = val.Elem()
99	if val.Kind() != reflect.Slice {
100		panic(fmt.Sprintf("Non-slice passed to extSlice.ReadExt %v",
101			val.Kind()))
102	}
103
104	var ifaceArray []interface{}
105	err := es.codec.Decode(buf, &ifaceArray)
106	if err != nil {
107		panic(fmt.Sprintf("Couldn't decode data into %v", v))
108	}
109
110	if len(ifaceArray) > 0 {
111		val.Set(reflect.MakeSlice(val.Type(), len(ifaceArray),
112			len(ifaceArray)))
113	}
114
115	for i, v := range ifaceArray {
116		if es.typer != nil {
117			val.Index(i).Set(es.typer(v))
118		} else {
119			val.Index(i).Set(reflect.ValueOf(v))
120		}
121	}
122}
123
124// CodecMsgpack implements the Codec interface using msgpack
125// marshaling and unmarshaling.
126type CodecMsgpack struct {
127	h        codec.Handle
128	ExtCodec *CodecMsgpack
129}
130
131// newCodecMsgpackHelper constructs a new CodecMsgpack that may or may
132// not handle unknown fields.
133func newCodecMsgpackHelper(handleUnknownFields bool) *CodecMsgpack {
134	handle := codec.MsgpackHandle{}
135	handle.Canonical = true
136	handle.WriteExt = true
137	handle.DecodeUnknownFields = handleUnknownFields
138	handle.EncodeUnknownFields = handleUnknownFields
139
140	// save a codec that doesn't write extensions, so that we can just
141	// call Encode/Decode when we want to (de)serialize extension
142	// types.
143	handleNoExt := handle
144	handleNoExt.WriteExt = false
145	ExtCodec := &CodecMsgpack{&handleNoExt, nil}
146	return &CodecMsgpack{&handle, ExtCodec}
147}
148
149// NewMsgpack constructs a new CodecMsgpack.
150func NewMsgpack() *CodecMsgpack {
151	return newCodecMsgpackHelper(true)
152}
153
154// NewMsgpackNoUnknownFields constructs a new CodecMsgpack that
155// doesn't handle unknown fields.
156func NewMsgpackNoUnknownFields() *CodecMsgpack {
157	return newCodecMsgpackHelper(false)
158}
159
160// Decode implements the Codec interface for CodecMsgpack
161func (c *CodecMsgpack) Decode(buf []byte, obj interface{}) error {
162	err := codec.NewDecoderBytes(buf, c.h).Decode(obj)
163	if err != nil {
164		return errors.Wrap(err, "failed to decode")
165	}
166	return nil
167}
168
169// Encode implements the Codec interface for CodecMsgpack
170func (c *CodecMsgpack) Encode(obj interface{}) (buf []byte, err error) {
171	err = codec.NewEncoderBytes(&buf, c.h).Encode(obj)
172	if err != nil {
173		return nil, errors.Wrap(err, "failed to encode")
174	}
175	return buf, nil
176}
177
178// RegisterType implements the Codec interface for CodecMsgpack
179func (c *CodecMsgpack) RegisterType(rt reflect.Type, code ExtCode) {
180	err := c.h.(*codec.MsgpackHandle).SetBytesExt(
181		rt, uint64(code), ext{c.ExtCodec})
182	if err != nil {
183		panic(err)
184	}
185}
186
187// RegisterIfaceSliceType implements the Codec interface for CodecMsgpack
188func (c *CodecMsgpack) RegisterIfaceSliceType(
189	rt reflect.Type, code ExtCode, typer func(interface{}) reflect.Value) {
190	err := c.h.(*codec.MsgpackHandle).SetBytesExt(
191		rt, uint64(code), extSlice{c, typer})
192	if err != nil {
193		panic(err)
194	}
195}
196