1package msgp 2 3import ( 4 "fmt" 5 "math" 6) 7 8const ( 9 // Complex64Extension is the extension number used for complex64 10 Complex64Extension = 3 11 12 // Complex128Extension is the extension number used for complex128 13 Complex128Extension = 4 14 15 // TimeExtension is the extension number used for time.Time 16 TimeExtension = 5 17) 18 19// our extensions live here 20var extensionReg = make(map[int8]func() Extension) 21 22// RegisterExtension registers extensions so that they 23// can be initialized and returned by methods that 24// decode `interface{}` values. This should only 25// be called during initialization. f() should return 26// a newly-initialized zero value of the extension. Keep in 27// mind that extensions 3, 4, and 5 are reserved for 28// complex64, complex128, and time.Time, respectively, 29// and that MessagePack reserves extension types from -127 to -1. 30// 31// For example, if you wanted to register a user-defined struct: 32// 33// msgp.RegisterExtension(10, func() msgp.Extension { &MyExtension{} }) 34// 35// RegisterExtension will panic if you call it multiple times 36// with the same 'typ' argument, or if you use a reserved 37// type (3, 4, or 5). 38func RegisterExtension(typ int8, f func() Extension) { 39 switch typ { 40 case Complex64Extension, Complex128Extension, TimeExtension: 41 panic(fmt.Sprint("msgp: forbidden extension type:", typ)) 42 } 43 if _, ok := extensionReg[typ]; ok { 44 panic(fmt.Sprint("msgp: RegisterExtension() called with typ", typ, "more than once")) 45 } 46 extensionReg[typ] = f 47} 48 49// ExtensionTypeError is an error type returned 50// when there is a mis-match between an extension type 51// and the type encoded on the wire 52type ExtensionTypeError struct { 53 Got int8 54 Want int8 55} 56 57// Error implements the error interface 58func (e ExtensionTypeError) Error() string { 59 return fmt.Sprintf("msgp: error decoding extension: wanted type %d; got type %d", e.Want, e.Got) 60} 61 62// Resumable returns 'true' for ExtensionTypeErrors 63func (e ExtensionTypeError) Resumable() bool { return true } 64 65func errExt(got int8, wanted int8) error { 66 return ExtensionTypeError{Got: got, Want: wanted} 67} 68 69// Extension is the interface fulfilled 70// by types that want to define their 71// own binary encoding. 72type Extension interface { 73 // ExtensionType should return 74 // a int8 that identifies the concrete 75 // type of the extension. (Types <0 are 76 // officially reserved by the MessagePack 77 // specifications.) 78 ExtensionType() int8 79 80 // Len should return the length 81 // of the data to be encoded 82 Len() int 83 84 // MarshalBinaryTo should copy 85 // the data into the supplied slice, 86 // assuming that the slice has length Len() 87 MarshalBinaryTo([]byte) error 88 89 UnmarshalBinary([]byte) error 90} 91 92// RawExtension implements the Extension interface 93type RawExtension struct { 94 Data []byte 95 Type int8 96} 97 98// ExtensionType implements Extension.ExtensionType, and returns r.Type 99func (r *RawExtension) ExtensionType() int8 { return r.Type } 100 101// Len implements Extension.Len, and returns len(r.Data) 102func (r *RawExtension) Len() int { return len(r.Data) } 103 104// MarshalBinaryTo implements Extension.MarshalBinaryTo, 105// and returns a copy of r.Data 106func (r *RawExtension) MarshalBinaryTo(d []byte) error { 107 copy(d, r.Data) 108 return nil 109} 110 111// UnmarshalBinary implements Extension.UnmarshalBinary, 112// and sets r.Data to the contents of the provided slice 113func (r *RawExtension) UnmarshalBinary(b []byte) error { 114 if cap(r.Data) >= len(b) { 115 r.Data = r.Data[0:len(b)] 116 } else { 117 r.Data = make([]byte, len(b)) 118 } 119 copy(r.Data, b) 120 return nil 121} 122 123// WriteExtension writes an extension type to the writer 124func (mw *Writer) WriteExtension(e Extension) error { 125 l := e.Len() 126 var err error 127 switch l { 128 case 0: 129 o, err := mw.require(3) 130 if err != nil { 131 return err 132 } 133 mw.buf[o] = mext8 134 mw.buf[o+1] = 0 135 mw.buf[o+2] = byte(e.ExtensionType()) 136 case 1: 137 o, err := mw.require(2) 138 if err != nil { 139 return err 140 } 141 mw.buf[o] = mfixext1 142 mw.buf[o+1] = byte(e.ExtensionType()) 143 case 2: 144 o, err := mw.require(2) 145 if err != nil { 146 return err 147 } 148 mw.buf[o] = mfixext2 149 mw.buf[o+1] = byte(e.ExtensionType()) 150 case 4: 151 o, err := mw.require(2) 152 if err != nil { 153 return err 154 } 155 mw.buf[o] = mfixext4 156 mw.buf[o+1] = byte(e.ExtensionType()) 157 case 8: 158 o, err := mw.require(2) 159 if err != nil { 160 return err 161 } 162 mw.buf[o] = mfixext8 163 mw.buf[o+1] = byte(e.ExtensionType()) 164 case 16: 165 o, err := mw.require(2) 166 if err != nil { 167 return err 168 } 169 mw.buf[o] = mfixext16 170 mw.buf[o+1] = byte(e.ExtensionType()) 171 default: 172 switch { 173 case l < math.MaxUint8: 174 o, err := mw.require(3) 175 if err != nil { 176 return err 177 } 178 mw.buf[o] = mext8 179 mw.buf[o+1] = byte(uint8(l)) 180 mw.buf[o+2] = byte(e.ExtensionType()) 181 case l < math.MaxUint16: 182 o, err := mw.require(4) 183 if err != nil { 184 return err 185 } 186 mw.buf[o] = mext16 187 big.PutUint16(mw.buf[o+1:], uint16(l)) 188 mw.buf[o+3] = byte(e.ExtensionType()) 189 default: 190 o, err := mw.require(6) 191 if err != nil { 192 return err 193 } 194 mw.buf[o] = mext32 195 big.PutUint32(mw.buf[o+1:], uint32(l)) 196 mw.buf[o+5] = byte(e.ExtensionType()) 197 } 198 } 199 // we can only write directly to the 200 // buffer if we're sure that it 201 // fits the object 202 if l <= mw.bufsize() { 203 o, err := mw.require(l) 204 if err != nil { 205 return err 206 } 207 return e.MarshalBinaryTo(mw.buf[o:]) 208 } 209 // here we create a new buffer 210 // just large enough for the body 211 // and save it as the write buffer 212 err = mw.flush() 213 if err != nil { 214 return err 215 } 216 buf := make([]byte, l) 217 err = e.MarshalBinaryTo(buf) 218 if err != nil { 219 return err 220 } 221 mw.buf = buf 222 mw.wloc = l 223 return nil 224} 225 226// peek at the extension type, assuming the next 227// kind to be read is Extension 228func (m *Reader) peekExtensionType() (int8, error) { 229 p, err := m.R.Peek(2) 230 if err != nil { 231 return 0, err 232 } 233 spec := sizes[p[0]] 234 if spec.typ != ExtensionType { 235 return 0, badPrefix(ExtensionType, p[0]) 236 } 237 if spec.extra == constsize { 238 return int8(p[1]), nil 239 } 240 size := spec.size 241 p, err = m.R.Peek(int(size)) 242 if err != nil { 243 return 0, err 244 } 245 return int8(p[size-1]), nil 246} 247 248// peekExtension peeks at the extension encoding type 249// (must guarantee at least 1 byte in 'b') 250func peekExtension(b []byte) (int8, error) { 251 spec := sizes[b[0]] 252 size := spec.size 253 if spec.typ != ExtensionType { 254 return 0, badPrefix(ExtensionType, b[0]) 255 } 256 if len(b) < int(size) { 257 return 0, ErrShortBytes 258 } 259 // for fixed extensions, 260 // the type information is in 261 // the second byte 262 if spec.extra == constsize { 263 return int8(b[1]), nil 264 } 265 // otherwise, it's in the last 266 // part of the prefix 267 return int8(b[size-1]), nil 268} 269 270// ReadExtension reads the next object from the reader 271// as an extension. ReadExtension will fail if the next 272// object in the stream is not an extension, or if 273// e.Type() is not the same as the wire type. 274func (m *Reader) ReadExtension(e Extension) (err error) { 275 var p []byte 276 p, err = m.R.Peek(2) 277 if err != nil { 278 return 279 } 280 lead := p[0] 281 var read int 282 var off int 283 switch lead { 284 case mfixext1: 285 if int8(p[1]) != e.ExtensionType() { 286 err = errExt(int8(p[1]), e.ExtensionType()) 287 return 288 } 289 p, err = m.R.Peek(3) 290 if err != nil { 291 return 292 } 293 err = e.UnmarshalBinary(p[2:]) 294 if err == nil { 295 _, err = m.R.Skip(3) 296 } 297 return 298 299 case mfixext2: 300 if int8(p[1]) != e.ExtensionType() { 301 err = errExt(int8(p[1]), e.ExtensionType()) 302 return 303 } 304 p, err = m.R.Peek(4) 305 if err != nil { 306 return 307 } 308 err = e.UnmarshalBinary(p[2:]) 309 if err == nil { 310 _, err = m.R.Skip(4) 311 } 312 return 313 314 case mfixext4: 315 if int8(p[1]) != e.ExtensionType() { 316 err = errExt(int8(p[1]), e.ExtensionType()) 317 return 318 } 319 p, err = m.R.Peek(6) 320 if err != nil { 321 return 322 } 323 err = e.UnmarshalBinary(p[2:]) 324 if err == nil { 325 _, err = m.R.Skip(6) 326 } 327 return 328 329 case mfixext8: 330 if int8(p[1]) != e.ExtensionType() { 331 err = errExt(int8(p[1]), e.ExtensionType()) 332 return 333 } 334 p, err = m.R.Peek(10) 335 if err != nil { 336 return 337 } 338 err = e.UnmarshalBinary(p[2:]) 339 if err == nil { 340 _, err = m.R.Skip(10) 341 } 342 return 343 344 case mfixext16: 345 if int8(p[1]) != e.ExtensionType() { 346 err = errExt(int8(p[1]), e.ExtensionType()) 347 return 348 } 349 p, err = m.R.Peek(18) 350 if err != nil { 351 return 352 } 353 err = e.UnmarshalBinary(p[2:]) 354 if err == nil { 355 _, err = m.R.Skip(18) 356 } 357 return 358 359 case mext8: 360 p, err = m.R.Peek(3) 361 if err != nil { 362 return 363 } 364 if int8(p[2]) != e.ExtensionType() { 365 err = errExt(int8(p[2]), e.ExtensionType()) 366 return 367 } 368 read = int(uint8(p[1])) 369 off = 3 370 371 case mext16: 372 p, err = m.R.Peek(4) 373 if err != nil { 374 return 375 } 376 if int8(p[3]) != e.ExtensionType() { 377 err = errExt(int8(p[3]), e.ExtensionType()) 378 return 379 } 380 read = int(big.Uint16(p[1:])) 381 off = 4 382 383 case mext32: 384 p, err = m.R.Peek(6) 385 if err != nil { 386 return 387 } 388 if int8(p[5]) != e.ExtensionType() { 389 err = errExt(int8(p[5]), e.ExtensionType()) 390 return 391 } 392 read = int(big.Uint32(p[1:])) 393 off = 6 394 395 default: 396 err = badPrefix(ExtensionType, lead) 397 return 398 } 399 400 p, err = m.R.Peek(read + off) 401 if err != nil { 402 return 403 } 404 err = e.UnmarshalBinary(p[off:]) 405 if err == nil { 406 _, err = m.R.Skip(read + off) 407 } 408 return 409} 410 411// AppendExtension appends a MessagePack extension to the provided slice 412func AppendExtension(b []byte, e Extension) ([]byte, error) { 413 l := e.Len() 414 var o []byte 415 var n int 416 switch l { 417 case 0: 418 o, n = ensure(b, 3) 419 o[n] = mext8 420 o[n+1] = 0 421 o[n+2] = byte(e.ExtensionType()) 422 return o[:n+3], nil 423 case 1: 424 o, n = ensure(b, 3) 425 o[n] = mfixext1 426 o[n+1] = byte(e.ExtensionType()) 427 n += 2 428 case 2: 429 o, n = ensure(b, 4) 430 o[n] = mfixext2 431 o[n+1] = byte(e.ExtensionType()) 432 n += 2 433 case 4: 434 o, n = ensure(b, 6) 435 o[n] = mfixext4 436 o[n+1] = byte(e.ExtensionType()) 437 n += 2 438 case 8: 439 o, n = ensure(b, 10) 440 o[n] = mfixext8 441 o[n+1] = byte(e.ExtensionType()) 442 n += 2 443 case 16: 444 o, n = ensure(b, 18) 445 o[n] = mfixext16 446 o[n+1] = byte(e.ExtensionType()) 447 n += 2 448 } 449 switch { 450 case l < math.MaxUint8: 451 o, n = ensure(b, l+3) 452 o[n] = mext8 453 o[n+1] = byte(uint8(l)) 454 o[n+2] = byte(e.ExtensionType()) 455 n += 3 456 case l < math.MaxUint16: 457 o, n = ensure(b, l+4) 458 o[n] = mext16 459 big.PutUint16(o[n+1:], uint16(l)) 460 o[n+3] = byte(e.ExtensionType()) 461 n += 4 462 default: 463 o, n = ensure(b, l+6) 464 o[n] = mext32 465 big.PutUint32(o[n+1:], uint32(l)) 466 o[n+5] = byte(e.ExtensionType()) 467 n += 6 468 } 469 return o, e.MarshalBinaryTo(o[n:]) 470} 471 472// ReadExtensionBytes reads an extension from 'b' into 'e' 473// and returns any remaining bytes. 474// Possible errors: 475// - ErrShortBytes ('b' not long enough) 476// - ExtensionTypeErorr{} (wire type not the same as e.Type()) 477// - TypeErorr{} (next object not an extension) 478// - InvalidPrefixError 479// - An umarshal error returned from e.UnmarshalBinary 480func ReadExtensionBytes(b []byte, e Extension) ([]byte, error) { 481 l := len(b) 482 if l < 3 { 483 return b, ErrShortBytes 484 } 485 lead := b[0] 486 var ( 487 sz int // size of 'data' 488 off int // offset of 'data' 489 typ int8 490 ) 491 switch lead { 492 case mfixext1: 493 typ = int8(b[1]) 494 sz = 1 495 off = 2 496 case mfixext2: 497 typ = int8(b[1]) 498 sz = 2 499 off = 2 500 case mfixext4: 501 typ = int8(b[1]) 502 sz = 4 503 off = 2 504 case mfixext8: 505 typ = int8(b[1]) 506 sz = 8 507 off = 2 508 case mfixext16: 509 typ = int8(b[1]) 510 sz = 16 511 off = 2 512 case mext8: 513 sz = int(uint8(b[1])) 514 typ = int8(b[2]) 515 off = 3 516 if sz == 0 { 517 return b[3:], e.UnmarshalBinary(b[3:3]) 518 } 519 case mext16: 520 if l < 4 { 521 return b, ErrShortBytes 522 } 523 sz = int(big.Uint16(b[1:])) 524 typ = int8(b[3]) 525 off = 4 526 case mext32: 527 if l < 6 { 528 return b, ErrShortBytes 529 } 530 sz = int(big.Uint32(b[1:])) 531 typ = int8(b[5]) 532 off = 6 533 default: 534 return b, badPrefix(ExtensionType, lead) 535 } 536 537 if typ != e.ExtensionType() { 538 return b, errExt(typ, e.ExtensionType()) 539 } 540 541 // the data of the extension starts 542 // at 'off' and is 'sz' bytes long 543 if len(b[off:]) < sz { 544 return b, ErrShortBytes 545 } 546 tot := off + sz 547 return b[tot:], e.UnmarshalBinary(b[off:tot]) 548} 549