1package msgpack 2 3import ( 4 "bytes" 5 "fmt" 6 "reflect" 7 "sync" 8 9 "github.com/vmihailenco/msgpack/v4/codes" 10) 11 12type extInfo struct { 13 Type reflect.Type 14 Decoder decoderFunc 15} 16 17var extTypes = make(map[int8]*extInfo) 18 19var bufferPool = &sync.Pool{ 20 New: func() interface{} { 21 return new(bytes.Buffer) 22 }, 23} 24 25// RegisterExt records a type, identified by a value for that type, 26// under the provided id. That id will identify the concrete type of a value 27// sent or received as an interface variable. Only types that will be 28// transferred as implementations of interface values need to be registered. 29// Expecting to be used only during initialization, it panics if the mapping 30// between types and ids is not a bijection. 31func RegisterExt(id int8, value interface{}) { 32 typ := reflect.TypeOf(value) 33 if typ.Kind() == reflect.Ptr { 34 typ = typ.Elem() 35 } 36 ptr := reflect.PtrTo(typ) 37 38 if _, ok := extTypes[id]; ok { 39 panic(fmt.Errorf("msgpack: ext with id=%d is already registered", id)) 40 } 41 42 registerExt(id, ptr, getEncoder(ptr), getDecoder(ptr)) 43 registerExt(id, typ, getEncoder(typ), getDecoder(typ)) 44} 45 46func registerExt(id int8, typ reflect.Type, enc encoderFunc, dec decoderFunc) { 47 if enc != nil { 48 typeEncMap.Store(typ, makeExtEncoder(id, enc)) 49 } 50 if dec != nil { 51 extTypes[id] = &extInfo{ 52 Type: typ, 53 Decoder: dec, 54 } 55 typeDecMap.Store(typ, makeExtDecoder(id, dec)) 56 } 57} 58 59func (e *Encoder) EncodeExtHeader(typeID int8, length int) error { 60 if err := e.encodeExtLen(length); err != nil { 61 return err 62 } 63 if err := e.w.WriteByte(byte(typeID)); err != nil { 64 return err 65 } 66 return nil 67} 68 69func makeExtEncoder(typeID int8, enc encoderFunc) encoderFunc { 70 return func(e *Encoder, v reflect.Value) error { 71 buf := bufferPool.Get().(*bytes.Buffer) 72 defer bufferPool.Put(buf) 73 buf.Reset() 74 75 oldw := e.w 76 e.w = buf 77 err := enc(e, v) 78 e.w = oldw 79 80 if err != nil { 81 return err 82 } 83 84 err = e.EncodeExtHeader(typeID, buf.Len()) 85 if err != nil { 86 return err 87 } 88 return e.write(buf.Bytes()) 89 } 90} 91 92func makeExtDecoder(typeID int8, dec decoderFunc) decoderFunc { 93 return func(d *Decoder, v reflect.Value) error { 94 c, err := d.PeekCode() 95 if err != nil { 96 return err 97 } 98 99 if !codes.IsExt(c) { 100 return dec(d, v) 101 } 102 103 id, extLen, err := d.DecodeExtHeader() 104 if err != nil { 105 return err 106 } 107 108 if id != typeID { 109 return fmt.Errorf("msgpack: got ext type=%d, wanted %d", id, typeID) 110 } 111 112 d.extLen = extLen 113 return dec(d, v) 114 } 115} 116 117func (e *Encoder) encodeExtLen(l int) error { 118 switch l { 119 case 1: 120 return e.writeCode(codes.FixExt1) 121 case 2: 122 return e.writeCode(codes.FixExt2) 123 case 4: 124 return e.writeCode(codes.FixExt4) 125 case 8: 126 return e.writeCode(codes.FixExt8) 127 case 16: 128 return e.writeCode(codes.FixExt16) 129 } 130 if l < 256 { 131 return e.write1(codes.Ext8, uint8(l)) 132 } 133 if l < 65536 { 134 return e.write2(codes.Ext16, uint16(l)) 135 } 136 return e.write4(codes.Ext32, uint32(l)) 137} 138 139func (d *Decoder) parseExtLen(c codes.Code) (int, error) { 140 switch c { 141 case codes.FixExt1: 142 return 1, nil 143 case codes.FixExt2: 144 return 2, nil 145 case codes.FixExt4: 146 return 4, nil 147 case codes.FixExt8: 148 return 8, nil 149 case codes.FixExt16: 150 return 16, nil 151 case codes.Ext8: 152 n, err := d.uint8() 153 return int(n), err 154 case codes.Ext16: 155 n, err := d.uint16() 156 return int(n), err 157 case codes.Ext32: 158 n, err := d.uint32() 159 return int(n), err 160 default: 161 return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext length", c) 162 } 163} 164 165func (d *Decoder) extHeader(c codes.Code) (int8, int, error) { 166 length, err := d.parseExtLen(c) 167 if err != nil { 168 return 0, 0, err 169 } 170 171 typeID, err := d.readCode() 172 if err != nil { 173 return 0, 0, err 174 } 175 176 return int8(typeID), length, nil 177} 178 179func (d *Decoder) DecodeExtHeader() (typeID int8, length int, err error) { 180 c, err := d.readCode() 181 if err != nil { 182 return 183 } 184 return d.extHeader(c) 185} 186 187func (d *Decoder) extInterface(c codes.Code) (interface{}, error) { 188 extID, extLen, err := d.extHeader(c) 189 if err != nil { 190 return nil, err 191 } 192 193 info, ok := extTypes[extID] 194 if !ok { 195 return nil, fmt.Errorf("msgpack: unknown ext id=%d", extID) 196 } 197 198 v := reflect.New(info.Type) 199 200 d.extLen = extLen 201 err = info.Decoder(d, v.Elem()) 202 d.extLen = 0 203 if err != nil { 204 return nil, err 205 } 206 207 return v.Interface(), nil 208} 209 210func (d *Decoder) skipExt(c codes.Code) error { 211 n, err := d.parseExtLen(c) 212 if err != nil { 213 return err 214 } 215 return d.skipN(n + 1) 216} 217 218func (d *Decoder) skipExtHeader(c codes.Code) error { 219 // Read ext type. 220 _, err := d.readCode() 221 if err != nil { 222 return err 223 } 224 // Read ext body len. 225 for i := 0; i < extHeaderLen(c); i++ { 226 _, err := d.readCode() 227 if err != nil { 228 return err 229 } 230 } 231 return nil 232} 233 234func extHeaderLen(c codes.Code) int { 235 switch c { 236 case codes.Ext8: 237 return 1 238 case codes.Ext16: 239 return 2 240 case codes.Ext32: 241 return 4 242 } 243 return 0 244} 245