1/* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20package thrift 21 22import ( 23 "bytes" 24 "encoding/binary" 25 "errors" 26 "fmt" 27 "io" 28 "math" 29) 30 31type TBinaryProtocol struct { 32 trans TRichTransport 33 origTransport TTransport 34 reader io.Reader 35 writer io.Writer 36 strictRead bool 37 strictWrite bool 38 buffer [64]byte 39} 40 41type TBinaryProtocolFactory struct { 42 strictRead bool 43 strictWrite bool 44} 45 46func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol { 47 return NewTBinaryProtocol(t, false, true) 48} 49 50func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol { 51 p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite} 52 if et, ok := t.(TRichTransport); ok { 53 p.trans = et 54 } else { 55 p.trans = NewTRichTransport(t) 56 } 57 p.reader = p.trans 58 p.writer = p.trans 59 return p 60} 61 62func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory { 63 return NewTBinaryProtocolFactory(false, true) 64} 65 66func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory { 67 return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite} 68} 69 70func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol { 71 return NewTBinaryProtocol(t, p.strictRead, p.strictWrite) 72} 73 74/** 75 * Writing Methods 76 */ 77 78func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { 79 if p.strictWrite { 80 version := uint32(VERSION_1) | uint32(typeId) 81 e := p.WriteI32(int32(version)) 82 if e != nil { 83 return e 84 } 85 e = p.WriteString(name) 86 if e != nil { 87 return e 88 } 89 e = p.WriteI32(seqId) 90 return e 91 } else { 92 e := p.WriteString(name) 93 if e != nil { 94 return e 95 } 96 e = p.WriteByte(int8(typeId)) 97 if e != nil { 98 return e 99 } 100 e = p.WriteI32(seqId) 101 return e 102 } 103 return nil 104} 105 106func (p *TBinaryProtocol) WriteMessageEnd() error { 107 return nil 108} 109 110func (p *TBinaryProtocol) WriteStructBegin(name string) error { 111 return nil 112} 113 114func (p *TBinaryProtocol) WriteStructEnd() error { 115 return nil 116} 117 118func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { 119 e := p.WriteByte(int8(typeId)) 120 if e != nil { 121 return e 122 } 123 e = p.WriteI16(id) 124 return e 125} 126 127func (p *TBinaryProtocol) WriteFieldEnd() error { 128 return nil 129} 130 131func (p *TBinaryProtocol) WriteFieldStop() error { 132 e := p.WriteByte(STOP) 133 return e 134} 135 136func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { 137 e := p.WriteByte(int8(keyType)) 138 if e != nil { 139 return e 140 } 141 e = p.WriteByte(int8(valueType)) 142 if e != nil { 143 return e 144 } 145 e = p.WriteI32(int32(size)) 146 return e 147} 148 149func (p *TBinaryProtocol) WriteMapEnd() error { 150 return nil 151} 152 153func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error { 154 e := p.WriteByte(int8(elemType)) 155 if e != nil { 156 return e 157 } 158 e = p.WriteI32(int32(size)) 159 return e 160} 161 162func (p *TBinaryProtocol) WriteListEnd() error { 163 return nil 164} 165 166func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error { 167 e := p.WriteByte(int8(elemType)) 168 if e != nil { 169 return e 170 } 171 e = p.WriteI32(int32(size)) 172 return e 173} 174 175func (p *TBinaryProtocol) WriteSetEnd() error { 176 return nil 177} 178 179func (p *TBinaryProtocol) WriteBool(value bool) error { 180 if value { 181 return p.WriteByte(1) 182 } 183 return p.WriteByte(0) 184} 185 186func (p *TBinaryProtocol) WriteByte(value int8) error { 187 e := p.trans.WriteByte(byte(value)) 188 return NewTProtocolException(e) 189} 190 191func (p *TBinaryProtocol) WriteI16(value int16) error { 192 v := p.buffer[0:2] 193 binary.BigEndian.PutUint16(v, uint16(value)) 194 _, e := p.writer.Write(v) 195 return NewTProtocolException(e) 196} 197 198func (p *TBinaryProtocol) WriteI32(value int32) error { 199 v := p.buffer[0:4] 200 binary.BigEndian.PutUint32(v, uint32(value)) 201 _, e := p.writer.Write(v) 202 return NewTProtocolException(e) 203} 204 205func (p *TBinaryProtocol) WriteI64(value int64) error { 206 v := p.buffer[0:8] 207 binary.BigEndian.PutUint64(v, uint64(value)) 208 _, err := p.writer.Write(v) 209 return NewTProtocolException(err) 210} 211 212func (p *TBinaryProtocol) WriteDouble(value float64) error { 213 return p.WriteI64(int64(math.Float64bits(value))) 214} 215 216func (p *TBinaryProtocol) WriteString(value string) error { 217 e := p.WriteI32(int32(len(value))) 218 if e != nil { 219 return e 220 } 221 _, err := p.trans.WriteString(value) 222 return NewTProtocolException(err) 223} 224 225func (p *TBinaryProtocol) WriteBinary(value []byte) error { 226 e := p.WriteI32(int32(len(value))) 227 if e != nil { 228 return e 229 } 230 _, err := p.writer.Write(value) 231 return NewTProtocolException(err) 232} 233 234/** 235 * Reading methods 236 */ 237 238func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { 239 size, e := p.ReadI32() 240 if e != nil { 241 return "", typeId, 0, NewTProtocolException(e) 242 } 243 if size < 0 { 244 typeId = TMessageType(size & 0x0ff) 245 version := int64(int64(size) & VERSION_MASK) 246 if version != VERSION_1 { 247 return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin")) 248 } 249 name, e = p.ReadString() 250 if e != nil { 251 return name, typeId, seqId, NewTProtocolException(e) 252 } 253 seqId, e = p.ReadI32() 254 if e != nil { 255 return name, typeId, seqId, NewTProtocolException(e) 256 } 257 return name, typeId, seqId, nil 258 } 259 if p.strictRead { 260 return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin")) 261 } 262 name, e2 := p.readStringBody(size) 263 if e2 != nil { 264 return name, typeId, seqId, e2 265 } 266 b, e3 := p.ReadByte() 267 if e3 != nil { 268 return name, typeId, seqId, e3 269 } 270 typeId = TMessageType(b) 271 seqId, e4 := p.ReadI32() 272 if e4 != nil { 273 return name, typeId, seqId, e4 274 } 275 return name, typeId, seqId, nil 276} 277 278func (p *TBinaryProtocol) ReadMessageEnd() error { 279 return nil 280} 281 282func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) { 283 return 284} 285 286func (p *TBinaryProtocol) ReadStructEnd() error { 287 return nil 288} 289 290func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) { 291 t, err := p.ReadByte() 292 typeId = TType(t) 293 if err != nil { 294 return name, typeId, seqId, err 295 } 296 if t != STOP { 297 seqId, err = p.ReadI16() 298 } 299 return name, typeId, seqId, err 300} 301 302func (p *TBinaryProtocol) ReadFieldEnd() error { 303 return nil 304} 305 306var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length")) 307 308func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) { 309 k, e := p.ReadByte() 310 if e != nil { 311 err = NewTProtocolException(e) 312 return 313 } 314 kType = TType(k) 315 v, e := p.ReadByte() 316 if e != nil { 317 err = NewTProtocolException(e) 318 return 319 } 320 vType = TType(v) 321 size32, e := p.ReadI32() 322 if e != nil { 323 err = NewTProtocolException(e) 324 return 325 } 326 if size32 < 0 { 327 err = invalidDataLength 328 return 329 } 330 size = int(size32) 331 return kType, vType, size, nil 332} 333 334func (p *TBinaryProtocol) ReadMapEnd() error { 335 return nil 336} 337 338func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) { 339 b, e := p.ReadByte() 340 if e != nil { 341 err = NewTProtocolException(e) 342 return 343 } 344 elemType = TType(b) 345 size32, e := p.ReadI32() 346 if e != nil { 347 err = NewTProtocolException(e) 348 return 349 } 350 if size32 < 0 { 351 err = invalidDataLength 352 return 353 } 354 size = int(size32) 355 356 return 357} 358 359func (p *TBinaryProtocol) ReadListEnd() error { 360 return nil 361} 362 363func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) { 364 b, e := p.ReadByte() 365 if e != nil { 366 err = NewTProtocolException(e) 367 return 368 } 369 elemType = TType(b) 370 size32, e := p.ReadI32() 371 if e != nil { 372 err = NewTProtocolException(e) 373 return 374 } 375 if size32 < 0 { 376 err = invalidDataLength 377 return 378 } 379 size = int(size32) 380 return elemType, size, nil 381} 382 383func (p *TBinaryProtocol) ReadSetEnd() error { 384 return nil 385} 386 387func (p *TBinaryProtocol) ReadBool() (bool, error) { 388 b, e := p.ReadByte() 389 v := true 390 if b != 1 { 391 v = false 392 } 393 return v, e 394} 395 396func (p *TBinaryProtocol) ReadByte() (int8, error) { 397 v, err := p.trans.ReadByte() 398 return int8(v), err 399} 400 401func (p *TBinaryProtocol) ReadI16() (value int16, err error) { 402 buf := p.buffer[0:2] 403 err = p.readAll(buf) 404 value = int16(binary.BigEndian.Uint16(buf)) 405 return value, err 406} 407 408func (p *TBinaryProtocol) ReadI32() (value int32, err error) { 409 buf := p.buffer[0:4] 410 err = p.readAll(buf) 411 value = int32(binary.BigEndian.Uint32(buf)) 412 return value, err 413} 414 415func (p *TBinaryProtocol) ReadI64() (value int64, err error) { 416 buf := p.buffer[0:8] 417 err = p.readAll(buf) 418 value = int64(binary.BigEndian.Uint64(buf)) 419 return value, err 420} 421 422func (p *TBinaryProtocol) ReadDouble() (value float64, err error) { 423 buf := p.buffer[0:8] 424 err = p.readAll(buf) 425 value = math.Float64frombits(binary.BigEndian.Uint64(buf)) 426 return value, err 427} 428 429func (p *TBinaryProtocol) ReadString() (value string, err error) { 430 size, e := p.ReadI32() 431 if e != nil { 432 return "", e 433 } 434 if size < 0 { 435 err = invalidDataLength 436 return 437 } 438 439 return p.readStringBody(size) 440} 441 442func (p *TBinaryProtocol) ReadBinary() ([]byte, error) { 443 size, e := p.ReadI32() 444 if e != nil { 445 return nil, e 446 } 447 if size < 0 { 448 return nil, invalidDataLength 449 } 450 if uint64(size) > p.trans.RemainingBytes() { 451 return nil, invalidDataLength 452 } 453 454 isize := int(size) 455 buf := make([]byte, isize) 456 _, err := io.ReadFull(p.trans, buf) 457 return buf, NewTProtocolException(err) 458} 459 460func (p *TBinaryProtocol) Flush() (err error) { 461 return NewTProtocolException(p.trans.Flush()) 462} 463 464func (p *TBinaryProtocol) Skip(fieldType TType) (err error) { 465 return SkipDefaultDepth(p, fieldType) 466} 467 468func (p *TBinaryProtocol) Transport() TTransport { 469 return p.origTransport 470} 471 472func (p *TBinaryProtocol) readAll(buf []byte) error { 473 _, err := io.ReadFull(p.reader, buf) 474 return NewTProtocolException(err) 475} 476 477const readLimit = 32768 478 479func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) { 480 if size < 0 { 481 return "", nil 482 } 483 if uint64(size) > p.trans.RemainingBytes() { 484 return "", invalidDataLength 485 } 486 487 var ( 488 buf bytes.Buffer 489 e error 490 b []byte 491 ) 492 493 switch { 494 case int(size) <= len(p.buffer): 495 b = p.buffer[:size] // avoids allocation for small reads 496 case int(size) < readLimit: 497 b = make([]byte, size) 498 default: 499 b = make([]byte, readLimit) 500 } 501 502 for size > 0 { 503 _, e = io.ReadFull(p.trans, b) 504 buf.Write(b) 505 if e != nil { 506 break 507 } 508 size -= readLimit 509 if size < readLimit && size > 0 { 510 b = b[:size] 511 } 512 } 513 return buf.String(), NewTProtocolException(e) 514} 515