1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23	"context"
24)
25
26// THeaderProtocol is a thrift protocol that implements THeader:
27// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md
28//
29// It supports either binary or compact protocol as the wrapped protocol.
30//
31// Most of the THeader handlings are happening inside THeaderTransport.
32type THeaderProtocol struct {
33	transport *THeaderTransport
34
35	// Will be initialized on first read/write.
36	protocol TProtocol
37}
38
39// NewTHeaderProtocol creates a new THeaderProtocol from the underlying
40// transport. The passed in transport will be wrapped with THeaderTransport.
41//
42// Note that THeaderTransport handles frame and zlib by itself,
43// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
44// instead of rich transports like TZlibTransport or TFramedTransport.
45func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
46	t := NewTHeaderTransport(trans)
47	p, _ := THeaderProtocolDefault.GetProtocol(t)
48	return &THeaderProtocol{
49		transport: t,
50		protocol:  p,
51	}
52}
53
54type tHeaderProtocolFactory struct{}
55
56func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
57	return NewTHeaderProtocol(trans)
58}
59
60// NewTHeaderProtocolFactory creates a factory for THeader.
61//
62// It's a wrapper for NewTHeaderProtocol
63func NewTHeaderProtocolFactory() TProtocolFactory {
64	return tHeaderProtocolFactory{}
65}
66
67// Transport returns the underlying transport.
68//
69// It's guaranteed to be of type *THeaderTransport.
70func (p *THeaderProtocol) Transport() TTransport {
71	return p.transport
72}
73
74// GetReadHeaders returns the THeaderMap read from transport.
75func (p *THeaderProtocol) GetReadHeaders() THeaderMap {
76	return p.transport.GetReadHeaders()
77}
78
79// SetWriteHeader sets a header for write.
80func (p *THeaderProtocol) SetWriteHeader(key, value string) {
81	p.transport.SetWriteHeader(key, value)
82}
83
84// ClearWriteHeaders clears all write headers previously set.
85func (p *THeaderProtocol) ClearWriteHeaders() {
86	p.transport.ClearWriteHeaders()
87}
88
89// AddTransform add a transform for writing.
90func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error {
91	return p.transport.AddTransform(transform)
92}
93
94func (p *THeaderProtocol) Flush(ctx context.Context) error {
95	return p.transport.Flush(ctx)
96}
97
98func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error {
99	newProto, err := p.transport.Protocol().GetProtocol(p.transport)
100	if err != nil {
101		return err
102	}
103	p.protocol = newProto
104	p.transport.SequenceID = seqID
105	return p.protocol.WriteMessageBegin(name, typeID, seqID)
106}
107
108func (p *THeaderProtocol) WriteMessageEnd() error {
109	if err := p.protocol.WriteMessageEnd(); err != nil {
110		return err
111	}
112	return p.transport.Flush(context.Background())
113}
114
115func (p *THeaderProtocol) WriteStructBegin(name string) error {
116	return p.protocol.WriteStructBegin(name)
117}
118
119func (p *THeaderProtocol) WriteStructEnd() error {
120	return p.protocol.WriteStructEnd()
121}
122
123func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) error {
124	return p.protocol.WriteFieldBegin(name, typeID, id)
125}
126
127func (p *THeaderProtocol) WriteFieldEnd() error {
128	return p.protocol.WriteFieldEnd()
129}
130
131func (p *THeaderProtocol) WriteFieldStop() error {
132	return p.protocol.WriteFieldStop()
133}
134
135func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
136	return p.protocol.WriteMapBegin(keyType, valueType, size)
137}
138
139func (p *THeaderProtocol) WriteMapEnd() error {
140	return p.protocol.WriteMapEnd()
141}
142
143func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error {
144	return p.protocol.WriteListBegin(elemType, size)
145}
146
147func (p *THeaderProtocol) WriteListEnd() error {
148	return p.protocol.WriteListEnd()
149}
150
151func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error {
152	return p.protocol.WriteSetBegin(elemType, size)
153}
154
155func (p *THeaderProtocol) WriteSetEnd() error {
156	return p.protocol.WriteSetEnd()
157}
158
159func (p *THeaderProtocol) WriteBool(value bool) error {
160	return p.protocol.WriteBool(value)
161}
162
163func (p *THeaderProtocol) WriteByte(value int8) error {
164	return p.protocol.WriteByte(value)
165}
166
167func (p *THeaderProtocol) WriteI16(value int16) error {
168	return p.protocol.WriteI16(value)
169}
170
171func (p *THeaderProtocol) WriteI32(value int32) error {
172	return p.protocol.WriteI32(value)
173}
174
175func (p *THeaderProtocol) WriteI64(value int64) error {
176	return p.protocol.WriteI64(value)
177}
178
179func (p *THeaderProtocol) WriteDouble(value float64) error {
180	return p.protocol.WriteDouble(value)
181}
182
183func (p *THeaderProtocol) WriteString(value string) error {
184	return p.protocol.WriteString(value)
185}
186
187func (p *THeaderProtocol) WriteBinary(value []byte) error {
188	return p.protocol.WriteBinary(value)
189}
190
191// ReadFrame calls underlying THeaderTransport's ReadFrame function.
192func (p *THeaderProtocol) ReadFrame() error {
193	return p.transport.ReadFrame()
194}
195
196func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) {
197	if err = p.transport.ReadFrame(); err != nil {
198		return
199	}
200
201	var newProto TProtocol
202	newProto, err = p.transport.Protocol().GetProtocol(p.transport)
203	if err != nil {
204		tAppExc, ok := err.(TApplicationException)
205		if !ok {
206			return
207		}
208		if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e != nil {
209			return
210		}
211		if e := tAppExc.Write(p.protocol); e != nil {
212			return
213		}
214		if e := p.protocol.WriteMessageEnd(); e != nil {
215			return
216		}
217		if e := p.transport.Flush(context.Background()); e != nil {
218			return
219		}
220		return
221	}
222	p.protocol = newProto
223
224	return p.protocol.ReadMessageBegin()
225}
226
227func (p *THeaderProtocol) ReadMessageEnd() error {
228	return p.protocol.ReadMessageEnd()
229}
230
231func (p *THeaderProtocol) ReadStructBegin() (name string, err error) {
232	return p.protocol.ReadStructBegin()
233}
234
235func (p *THeaderProtocol) ReadStructEnd() error {
236	return p.protocol.ReadStructEnd()
237}
238
239func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) {
240	return p.protocol.ReadFieldBegin()
241}
242
243func (p *THeaderProtocol) ReadFieldEnd() error {
244	return p.protocol.ReadFieldEnd()
245}
246
247func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
248	return p.protocol.ReadMapBegin()
249}
250
251func (p *THeaderProtocol) ReadMapEnd() error {
252	return p.protocol.ReadMapEnd()
253}
254
255func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err error) {
256	return p.protocol.ReadListBegin()
257}
258
259func (p *THeaderProtocol) ReadListEnd() error {
260	return p.protocol.ReadListEnd()
261}
262
263func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) {
264	return p.protocol.ReadSetBegin()
265}
266
267func (p *THeaderProtocol) ReadSetEnd() error {
268	return p.protocol.ReadSetEnd()
269}
270
271func (p *THeaderProtocol) ReadBool() (value bool, err error) {
272	return p.protocol.ReadBool()
273}
274
275func (p *THeaderProtocol) ReadByte() (value int8, err error) {
276	return p.protocol.ReadByte()
277}
278
279func (p *THeaderProtocol) ReadI16() (value int16, err error) {
280	return p.protocol.ReadI16()
281}
282
283func (p *THeaderProtocol) ReadI32() (value int32, err error) {
284	return p.protocol.ReadI32()
285}
286
287func (p *THeaderProtocol) ReadI64() (value int64, err error) {
288	return p.protocol.ReadI64()
289}
290
291func (p *THeaderProtocol) ReadDouble() (value float64, err error) {
292	return p.protocol.ReadDouble()
293}
294
295func (p *THeaderProtocol) ReadString() (value string, err error) {
296	return p.protocol.ReadString()
297}
298
299func (p *THeaderProtocol) ReadBinary() (value []byte, err error) {
300	return p.protocol.ReadBinary()
301}
302
303func (p *THeaderProtocol) Skip(fieldType TType) error {
304	return p.protocol.Skip(fieldType)
305}
306