/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" "context" "encoding/binary" "errors" "fmt" "io" "math" ) type TBinaryProtocol struct { trans TRichTransport origTransport TTransport cfg *TConfiguration buffer [64]byte } type TBinaryProtocolFactory struct { cfg *TConfiguration } // Deprecated: Use NewTBinaryProtocolConf instead. func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol { return NewTBinaryProtocolConf(t, &TConfiguration{ noPropagation: true, }) } // Deprecated: Use NewTBinaryProtocolConf instead. func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol { return NewTBinaryProtocolConf(t, &TConfiguration{ TBinaryStrictRead: &strictRead, TBinaryStrictWrite: &strictWrite, noPropagation: true, }) } func NewTBinaryProtocolConf(t TTransport, conf *TConfiguration) *TBinaryProtocol { PropagateTConfiguration(t, conf) p := &TBinaryProtocol{ origTransport: t, cfg: conf, } if et, ok := t.(TRichTransport); ok { p.trans = et } else { p.trans = NewTRichTransport(t) } return p } // Deprecated: Use NewTBinaryProtocolFactoryConf instead. func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory { return NewTBinaryProtocolFactoryConf(&TConfiguration{ noPropagation: true, }) } // Deprecated: Use NewTBinaryProtocolFactoryConf instead. func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory { return NewTBinaryProtocolFactoryConf(&TConfiguration{ TBinaryStrictRead: &strictRead, TBinaryStrictWrite: &strictWrite, noPropagation: true, }) } func NewTBinaryProtocolFactoryConf(conf *TConfiguration) *TBinaryProtocolFactory { return &TBinaryProtocolFactory{ cfg: conf, } } func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol { return NewTBinaryProtocolConf(t, p.cfg) } func (p *TBinaryProtocolFactory) SetTConfiguration(conf *TConfiguration) { p.cfg = conf } /** * Writing Methods */ func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error { if p.cfg.GetTBinaryStrictWrite() { version := uint32(VERSION_1) | uint32(typeId) e := p.WriteI32(ctx, int32(version)) if e != nil { return e } e = p.WriteString(ctx, name) if e != nil { return e } e = p.WriteI32(ctx, seqId) return e } else { e := p.WriteString(ctx, name) if e != nil { return e } e = p.WriteByte(ctx, int8(typeId)) if e != nil { return e } e = p.WriteI32(ctx, seqId) return e } return nil } func (p *TBinaryProtocol) WriteMessageEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) WriteStructBegin(ctx context.Context, name string) error { return nil } func (p *TBinaryProtocol) WriteStructEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { e := p.WriteByte(ctx, int8(typeId)) if e != nil { return e } e = p.WriteI16(ctx, id) return e } func (p *TBinaryProtocol) WriteFieldEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) WriteFieldStop(ctx context.Context) error { e := p.WriteByte(ctx, STOP) return e } func (p *TBinaryProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { e := p.WriteByte(ctx, int8(keyType)) if e != nil { return e } e = p.WriteByte(ctx, int8(valueType)) if e != nil { return e } e = p.WriteI32(ctx, int32(size)) return e } func (p *TBinaryProtocol) WriteMapEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { e := p.WriteByte(ctx, int8(elemType)) if e != nil { return e } e = p.WriteI32(ctx, int32(size)) return e } func (p *TBinaryProtocol) WriteListEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { e := p.WriteByte(ctx, int8(elemType)) if e != nil { return e } e = p.WriteI32(ctx, int32(size)) return e } func (p *TBinaryProtocol) WriteSetEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) WriteBool(ctx context.Context, value bool) error { if value { return p.WriteByte(ctx, 1) } return p.WriteByte(ctx, 0) } func (p *TBinaryProtocol) WriteByte(ctx context.Context, value int8) error { e := p.trans.WriteByte(byte(value)) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI16(ctx context.Context, value int16) error { v := p.buffer[0:2] binary.BigEndian.PutUint16(v, uint16(value)) _, e := p.trans.Write(v) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI32(ctx context.Context, value int32) error { v := p.buffer[0:4] binary.BigEndian.PutUint32(v, uint32(value)) _, e := p.trans.Write(v) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI64(ctx context.Context, value int64) error { v := p.buffer[0:8] binary.BigEndian.PutUint64(v, uint64(value)) _, err := p.trans.Write(v) return NewTProtocolException(err) } func (p *TBinaryProtocol) WriteDouble(ctx context.Context, value float64) error { return p.WriteI64(ctx, int64(math.Float64bits(value))) } func (p *TBinaryProtocol) WriteString(ctx context.Context, value string) error { e := p.WriteI32(ctx, int32(len(value))) if e != nil { return e } _, err := p.trans.WriteString(value) return NewTProtocolException(err) } func (p *TBinaryProtocol) WriteBinary(ctx context.Context, value []byte) error { e := p.WriteI32(ctx, int32(len(value))) if e != nil { return e } _, err := p.trans.Write(value) return NewTProtocolException(err) } /** * Reading methods */ func (p *TBinaryProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) { size, e := p.ReadI32(ctx) if e != nil { return "", typeId, 0, NewTProtocolException(e) } if size < 0 { typeId = TMessageType(size & 0x0ff) version := int64(int64(size) & VERSION_MASK) if version != VERSION_1 { return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin")) } name, e = p.ReadString(ctx) if e != nil { return name, typeId, seqId, NewTProtocolException(e) } seqId, e = p.ReadI32(ctx) if e != nil { return name, typeId, seqId, NewTProtocolException(e) } return name, typeId, seqId, nil } if p.cfg.GetTBinaryStrictRead() { return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin")) } name, e2 := p.readStringBody(size) if e2 != nil { return name, typeId, seqId, e2 } b, e3 := p.ReadByte(ctx) if e3 != nil { return name, typeId, seqId, e3 } typeId = TMessageType(b) seqId, e4 := p.ReadI32(ctx) if e4 != nil { return name, typeId, seqId, e4 } return name, typeId, seqId, nil } func (p *TBinaryProtocol) ReadMessageEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { return } func (p *TBinaryProtocol) ReadStructEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, seqId int16, err error) { t, err := p.ReadByte(ctx) typeId = TType(t) if err != nil { return name, typeId, seqId, err } if t != STOP { seqId, err = p.ReadI16(ctx) } return name, typeId, seqId, err } func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error { return nil } var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length")) func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) { k, e := p.ReadByte(ctx) if e != nil { err = NewTProtocolException(e) return } kType = TType(k) v, e := p.ReadByte(ctx) if e != nil { err = NewTProtocolException(e) return } vType = TType(v) size32, e := p.ReadI32(ctx) if e != nil { err = NewTProtocolException(e) return } if size32 < 0 { err = invalidDataLength return } size = int(size32) return kType, vType, size, nil } func (p *TBinaryProtocol) ReadMapEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) { b, e := p.ReadByte(ctx) if e != nil { err = NewTProtocolException(e) return } elemType = TType(b) size32, e := p.ReadI32(ctx) if e != nil { err = NewTProtocolException(e) return } if size32 < 0 { err = invalidDataLength return } size = int(size32) return } func (p *TBinaryProtocol) ReadListEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) { b, e := p.ReadByte(ctx) if e != nil { err = NewTProtocolException(e) return } elemType = TType(b) size32, e := p.ReadI32(ctx) if e != nil { err = NewTProtocolException(e) return } if size32 < 0 { err = invalidDataLength return } size = int(size32) return elemType, size, nil } func (p *TBinaryProtocol) ReadSetEnd(ctx context.Context) error { return nil } func (p *TBinaryProtocol) ReadBool(ctx context.Context) (bool, error) { b, e := p.ReadByte(ctx) v := true if b != 1 { v = false } return v, e } func (p *TBinaryProtocol) ReadByte(ctx context.Context) (int8, error) { v, err := p.trans.ReadByte() return int8(v), err } func (p *TBinaryProtocol) ReadI16(ctx context.Context) (value int16, err error) { buf := p.buffer[0:2] err = p.readAll(ctx, buf) value = int16(binary.BigEndian.Uint16(buf)) return value, err } func (p *TBinaryProtocol) ReadI32(ctx context.Context) (value int32, err error) { buf := p.buffer[0:4] err = p.readAll(ctx, buf) value = int32(binary.BigEndian.Uint32(buf)) return value, err } func (p *TBinaryProtocol) ReadI64(ctx context.Context) (value int64, err error) { buf := p.buffer[0:8] err = p.readAll(ctx, buf) value = int64(binary.BigEndian.Uint64(buf)) return value, err } func (p *TBinaryProtocol) ReadDouble(ctx context.Context) (value float64, err error) { buf := p.buffer[0:8] err = p.readAll(ctx, buf) value = math.Float64frombits(binary.BigEndian.Uint64(buf)) return value, err } func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err error) { size, e := p.ReadI32(ctx) if e != nil { return "", e } err = checkSizeForProtocol(size, p.cfg) if err != nil { return } if size < 0 { err = invalidDataLength return } if size == 0 { return "", nil } if size < int32(len(p.buffer)) { // Avoid allocation on small reads buf := p.buffer[:size] read, e := io.ReadFull(p.trans, buf) return string(buf[:read]), NewTProtocolException(e) } return p.readStringBody(size) } func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) { size, e := p.ReadI32(ctx) if e != nil { return nil, e } if err := checkSizeForProtocol(size, p.cfg); err != nil { return nil, err } buf, err := safeReadBytes(size, p.trans) return buf, NewTProtocolException(err) } func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) { return NewTProtocolException(p.trans.Flush(ctx)) } func (p *TBinaryProtocol) Skip(ctx context.Context, fieldType TType) (err error) { return SkipDefaultDepth(ctx, p, fieldType) } func (p *TBinaryProtocol) Transport() TTransport { return p.origTransport } func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) { var read int _, deadlineSet := ctx.Deadline() for { read, err = io.ReadFull(p.trans, buf) if deadlineSet && read == 0 && isTimeoutError(err) && ctx.Err() == nil { // This is I/O timeout without anything read, // and we still have time left, keep retrying. continue } // For anything else, don't retry break } return NewTProtocolException(err) } func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) { buf, err := safeReadBytes(size, p.trans) return string(buf), NewTProtocolException(err) } func (p *TBinaryProtocol) SetTConfiguration(conf *TConfiguration) { PropagateTConfiguration(p.trans, conf) PropagateTConfiguration(p.origTransport, conf) p.cfg = conf } var ( _ TConfigurationSetter = (*TBinaryProtocolFactory)(nil) _ TConfigurationSetter = (*TBinaryProtocol)(nil) ) // This function is shared between TBinaryProtocol and TCompactProtocol. // // It tries to read size bytes from trans, in a way that prevents large // allocations when size is insanely large (mostly caused by malformed message). func safeReadBytes(size int32, trans io.Reader) ([]byte, error) { if size < 0 { return nil, nil } buf := new(bytes.Buffer) _, err := io.CopyN(buf, trans, int64(size)) return buf.Bytes(), err }