1/* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20package thrift 21 22import ( 23 "context" 24) 25 26// THeaderProtocol is a thrift protocol that implements THeader: 27// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md 28// 29// It supports either binary or compact protocol as the wrapped protocol. 30// 31// Most of the THeader handlings are happening inside THeaderTransport. 32type THeaderProtocol struct { 33 transport *THeaderTransport 34 35 // Will be initialized on first read/write. 36 protocol TProtocol 37} 38 39// NewTHeaderProtocol creates a new THeaderProtocol from the underlying 40// transport. The passed in transport will be wrapped with THeaderTransport. 41// 42// Note that THeaderTransport handles frame and zlib by itself, 43// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket), 44// instead of rich transports like TZlibTransport or TFramedTransport. 45func NewTHeaderProtocol(trans TTransport) *THeaderProtocol { 46 t := NewTHeaderTransport(trans) 47 p, _ := THeaderProtocolDefault.GetProtocol(t) 48 return &THeaderProtocol{ 49 transport: t, 50 protocol: p, 51 } 52} 53 54type tHeaderProtocolFactory struct{} 55 56func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol { 57 return NewTHeaderProtocol(trans) 58} 59 60// NewTHeaderProtocolFactory creates a factory for THeader. 61// 62// It's a wrapper for NewTHeaderProtocol 63func NewTHeaderProtocolFactory() TProtocolFactory { 64 return tHeaderProtocolFactory{} 65} 66 67// Transport returns the underlying transport. 68// 69// It's guaranteed to be of type *THeaderTransport. 70func (p *THeaderProtocol) Transport() TTransport { 71 return p.transport 72} 73 74// GetReadHeaders returns the THeaderMap read from transport. 75func (p *THeaderProtocol) GetReadHeaders() THeaderMap { 76 return p.transport.GetReadHeaders() 77} 78 79// SetWriteHeader sets a header for write. 80func (p *THeaderProtocol) SetWriteHeader(key, value string) { 81 p.transport.SetWriteHeader(key, value) 82} 83 84// ClearWriteHeaders clears all write headers previously set. 85func (p *THeaderProtocol) ClearWriteHeaders() { 86 p.transport.ClearWriteHeaders() 87} 88 89// AddTransform add a transform for writing. 90func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error { 91 return p.transport.AddTransform(transform) 92} 93 94func (p *THeaderProtocol) Flush(ctx context.Context) error { 95 return p.transport.Flush(ctx) 96} 97 98func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error { 99 newProto, err := p.transport.Protocol().GetProtocol(p.transport) 100 if err != nil { 101 return err 102 } 103 p.protocol = newProto 104 p.transport.SequenceID = seqID 105 return p.protocol.WriteMessageBegin(name, typeID, seqID) 106} 107 108func (p *THeaderProtocol) WriteMessageEnd() error { 109 if err := p.protocol.WriteMessageEnd(); err != nil { 110 return err 111 } 112 return p.transport.Flush(context.Background()) 113} 114 115func (p *THeaderProtocol) WriteStructBegin(name string) error { 116 return p.protocol.WriteStructBegin(name) 117} 118 119func (p *THeaderProtocol) WriteStructEnd() error { 120 return p.protocol.WriteStructEnd() 121} 122 123func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) error { 124 return p.protocol.WriteFieldBegin(name, typeID, id) 125} 126 127func (p *THeaderProtocol) WriteFieldEnd() error { 128 return p.protocol.WriteFieldEnd() 129} 130 131func (p *THeaderProtocol) WriteFieldStop() error { 132 return p.protocol.WriteFieldStop() 133} 134 135func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { 136 return p.protocol.WriteMapBegin(keyType, valueType, size) 137} 138 139func (p *THeaderProtocol) WriteMapEnd() error { 140 return p.protocol.WriteMapEnd() 141} 142 143func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error { 144 return p.protocol.WriteListBegin(elemType, size) 145} 146 147func (p *THeaderProtocol) WriteListEnd() error { 148 return p.protocol.WriteListEnd() 149} 150 151func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error { 152 return p.protocol.WriteSetBegin(elemType, size) 153} 154 155func (p *THeaderProtocol) WriteSetEnd() error { 156 return p.protocol.WriteSetEnd() 157} 158 159func (p *THeaderProtocol) WriteBool(value bool) error { 160 return p.protocol.WriteBool(value) 161} 162 163func (p *THeaderProtocol) WriteByte(value int8) error { 164 return p.protocol.WriteByte(value) 165} 166 167func (p *THeaderProtocol) WriteI16(value int16) error { 168 return p.protocol.WriteI16(value) 169} 170 171func (p *THeaderProtocol) WriteI32(value int32) error { 172 return p.protocol.WriteI32(value) 173} 174 175func (p *THeaderProtocol) WriteI64(value int64) error { 176 return p.protocol.WriteI64(value) 177} 178 179func (p *THeaderProtocol) WriteDouble(value float64) error { 180 return p.protocol.WriteDouble(value) 181} 182 183func (p *THeaderProtocol) WriteString(value string) error { 184 return p.protocol.WriteString(value) 185} 186 187func (p *THeaderProtocol) WriteBinary(value []byte) error { 188 return p.protocol.WriteBinary(value) 189} 190 191// ReadFrame calls underlying THeaderTransport's ReadFrame function. 192func (p *THeaderProtocol) ReadFrame() error { 193 return p.transport.ReadFrame() 194} 195 196func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) { 197 if err = p.transport.ReadFrame(); err != nil { 198 return 199 } 200 201 var newProto TProtocol 202 newProto, err = p.transport.Protocol().GetProtocol(p.transport) 203 if err != nil { 204 tAppExc, ok := err.(TApplicationException) 205 if !ok { 206 return 207 } 208 if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e != nil { 209 return 210 } 211 if e := tAppExc.Write(p.protocol); e != nil { 212 return 213 } 214 if e := p.protocol.WriteMessageEnd(); e != nil { 215 return 216 } 217 if e := p.transport.Flush(context.Background()); e != nil { 218 return 219 } 220 return 221 } 222 p.protocol = newProto 223 224 return p.protocol.ReadMessageBegin() 225} 226 227func (p *THeaderProtocol) ReadMessageEnd() error { 228 return p.protocol.ReadMessageEnd() 229} 230 231func (p *THeaderProtocol) ReadStructBegin() (name string, err error) { 232 return p.protocol.ReadStructBegin() 233} 234 235func (p *THeaderProtocol) ReadStructEnd() error { 236 return p.protocol.ReadStructEnd() 237} 238 239func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) { 240 return p.protocol.ReadFieldBegin() 241} 242 243func (p *THeaderProtocol) ReadFieldEnd() error { 244 return p.protocol.ReadFieldEnd() 245} 246 247func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { 248 return p.protocol.ReadMapBegin() 249} 250 251func (p *THeaderProtocol) ReadMapEnd() error { 252 return p.protocol.ReadMapEnd() 253} 254 255func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err error) { 256 return p.protocol.ReadListBegin() 257} 258 259func (p *THeaderProtocol) ReadListEnd() error { 260 return p.protocol.ReadListEnd() 261} 262 263func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) { 264 return p.protocol.ReadSetBegin() 265} 266 267func (p *THeaderProtocol) ReadSetEnd() error { 268 return p.protocol.ReadSetEnd() 269} 270 271func (p *THeaderProtocol) ReadBool() (value bool, err error) { 272 return p.protocol.ReadBool() 273} 274 275func (p *THeaderProtocol) ReadByte() (value int8, err error) { 276 return p.protocol.ReadByte() 277} 278 279func (p *THeaderProtocol) ReadI16() (value int16, err error) { 280 return p.protocol.ReadI16() 281} 282 283func (p *THeaderProtocol) ReadI32() (value int32, err error) { 284 return p.protocol.ReadI32() 285} 286 287func (p *THeaderProtocol) ReadI64() (value int64, err error) { 288 return p.protocol.ReadI64() 289} 290 291func (p *THeaderProtocol) ReadDouble() (value float64, err error) { 292 return p.protocol.ReadDouble() 293} 294 295func (p *THeaderProtocol) ReadString() (value string, err error) { 296 return p.protocol.ReadString() 297} 298 299func (p *THeaderProtocol) ReadBinary() (value []byte, err error) { 300 return p.protocol.ReadBinary() 301} 302 303func (p *THeaderProtocol) Skip(fieldType TType) error { 304 return p.protocol.Skip(fieldType) 305} 306