1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23	"bytes"
24	"context"
25	"encoding/binary"
26	"errors"
27	"fmt"
28	"io"
29	"math"
30)
31
32type TBinaryProtocol struct {
33	trans         TRichTransport
34	origTransport TTransport
35	cfg           *TConfiguration
36	buffer        [64]byte
37}
38
39type TBinaryProtocolFactory struct {
40	cfg *TConfiguration
41}
42
43// Deprecated: Use NewTBinaryProtocolConf instead.
44func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
45	return NewTBinaryProtocolConf(t, &TConfiguration{
46		noPropagation: true,
47	})
48}
49
50// Deprecated: Use NewTBinaryProtocolConf instead.
51func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
52	return NewTBinaryProtocolConf(t, &TConfiguration{
53		TBinaryStrictRead:  &strictRead,
54		TBinaryStrictWrite: &strictWrite,
55
56		noPropagation: true,
57	})
58}
59
60func NewTBinaryProtocolConf(t TTransport, conf *TConfiguration) *TBinaryProtocol {
61	PropagateTConfiguration(t, conf)
62	p := &TBinaryProtocol{
63		origTransport: t,
64		cfg:           conf,
65	}
66	if et, ok := t.(TRichTransport); ok {
67		p.trans = et
68	} else {
69		p.trans = NewTRichTransport(t)
70	}
71	return p
72}
73
74// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
75func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
76	return NewTBinaryProtocolFactoryConf(&TConfiguration{
77		noPropagation: true,
78	})
79}
80
81// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
82func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
83	return NewTBinaryProtocolFactoryConf(&TConfiguration{
84		TBinaryStrictRead:  &strictRead,
85		TBinaryStrictWrite: &strictWrite,
86
87		noPropagation: true,
88	})
89}
90
91func NewTBinaryProtocolFactoryConf(conf *TConfiguration) *TBinaryProtocolFactory {
92	return &TBinaryProtocolFactory{
93		cfg: conf,
94	}
95}
96
97func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
98	return NewTBinaryProtocolConf(t, p.cfg)
99}
100
101func (p *TBinaryProtocolFactory) SetTConfiguration(conf *TConfiguration) {
102	p.cfg = conf
103}
104
105/**
106 * Writing Methods
107 */
108
109func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
110	if p.cfg.GetTBinaryStrictWrite() {
111		version := uint32(VERSION_1) | uint32(typeId)
112		e := p.WriteI32(ctx, int32(version))
113		if e != nil {
114			return e
115		}
116		e = p.WriteString(ctx, name)
117		if e != nil {
118			return e
119		}
120		e = p.WriteI32(ctx, seqId)
121		return e
122	} else {
123		e := p.WriteString(ctx, name)
124		if e != nil {
125			return e
126		}
127		e = p.WriteByte(ctx, int8(typeId))
128		if e != nil {
129			return e
130		}
131		e = p.WriteI32(ctx, seqId)
132		return e
133	}
134	return nil
135}
136
137func (p *TBinaryProtocol) WriteMessageEnd(ctx context.Context) error {
138	return nil
139}
140
141func (p *TBinaryProtocol) WriteStructBegin(ctx context.Context, name string) error {
142	return nil
143}
144
145func (p *TBinaryProtocol) WriteStructEnd(ctx context.Context) error {
146	return nil
147}
148
149func (p *TBinaryProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
150	e := p.WriteByte(ctx, int8(typeId))
151	if e != nil {
152		return e
153	}
154	e = p.WriteI16(ctx, id)
155	return e
156}
157
158func (p *TBinaryProtocol) WriteFieldEnd(ctx context.Context) error {
159	return nil
160}
161
162func (p *TBinaryProtocol) WriteFieldStop(ctx context.Context) error {
163	e := p.WriteByte(ctx, STOP)
164	return e
165}
166
167func (p *TBinaryProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
168	e := p.WriteByte(ctx, int8(keyType))
169	if e != nil {
170		return e
171	}
172	e = p.WriteByte(ctx, int8(valueType))
173	if e != nil {
174		return e
175	}
176	e = p.WriteI32(ctx, int32(size))
177	return e
178}
179
180func (p *TBinaryProtocol) WriteMapEnd(ctx context.Context) error {
181	return nil
182}
183
184func (p *TBinaryProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
185	e := p.WriteByte(ctx, int8(elemType))
186	if e != nil {
187		return e
188	}
189	e = p.WriteI32(ctx, int32(size))
190	return e
191}
192
193func (p *TBinaryProtocol) WriteListEnd(ctx context.Context) error {
194	return nil
195}
196
197func (p *TBinaryProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
198	e := p.WriteByte(ctx, int8(elemType))
199	if e != nil {
200		return e
201	}
202	e = p.WriteI32(ctx, int32(size))
203	return e
204}
205
206func (p *TBinaryProtocol) WriteSetEnd(ctx context.Context) error {
207	return nil
208}
209
210func (p *TBinaryProtocol) WriteBool(ctx context.Context, value bool) error {
211	if value {
212		return p.WriteByte(ctx, 1)
213	}
214	return p.WriteByte(ctx, 0)
215}
216
217func (p *TBinaryProtocol) WriteByte(ctx context.Context, value int8) error {
218	e := p.trans.WriteByte(byte(value))
219	return NewTProtocolException(e)
220}
221
222func (p *TBinaryProtocol) WriteI16(ctx context.Context, value int16) error {
223	v := p.buffer[0:2]
224	binary.BigEndian.PutUint16(v, uint16(value))
225	_, e := p.trans.Write(v)
226	return NewTProtocolException(e)
227}
228
229func (p *TBinaryProtocol) WriteI32(ctx context.Context, value int32) error {
230	v := p.buffer[0:4]
231	binary.BigEndian.PutUint32(v, uint32(value))
232	_, e := p.trans.Write(v)
233	return NewTProtocolException(e)
234}
235
236func (p *TBinaryProtocol) WriteI64(ctx context.Context, value int64) error {
237	v := p.buffer[0:8]
238	binary.BigEndian.PutUint64(v, uint64(value))
239	_, err := p.trans.Write(v)
240	return NewTProtocolException(err)
241}
242
243func (p *TBinaryProtocol) WriteDouble(ctx context.Context, value float64) error {
244	return p.WriteI64(ctx, int64(math.Float64bits(value)))
245}
246
247func (p *TBinaryProtocol) WriteString(ctx context.Context, value string) error {
248	e := p.WriteI32(ctx, int32(len(value)))
249	if e != nil {
250		return e
251	}
252	_, err := p.trans.WriteString(value)
253	return NewTProtocolException(err)
254}
255
256func (p *TBinaryProtocol) WriteBinary(ctx context.Context, value []byte) error {
257	e := p.WriteI32(ctx, int32(len(value)))
258	if e != nil {
259		return e
260	}
261	_, err := p.trans.Write(value)
262	return NewTProtocolException(err)
263}
264
265/**
266 * Reading methods
267 */
268
269func (p *TBinaryProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
270	size, e := p.ReadI32(ctx)
271	if e != nil {
272		return "", typeId, 0, NewTProtocolException(e)
273	}
274	if size < 0 {
275		typeId = TMessageType(size & 0x0ff)
276		version := int64(int64(size) & VERSION_MASK)
277		if version != VERSION_1 {
278			return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
279		}
280		name, e = p.ReadString(ctx)
281		if e != nil {
282			return name, typeId, seqId, NewTProtocolException(e)
283		}
284		seqId, e = p.ReadI32(ctx)
285		if e != nil {
286			return name, typeId, seqId, NewTProtocolException(e)
287		}
288		return name, typeId, seqId, nil
289	}
290	if p.cfg.GetTBinaryStrictRead() {
291		return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
292	}
293	name, e2 := p.readStringBody(size)
294	if e2 != nil {
295		return name, typeId, seqId, e2
296	}
297	b, e3 := p.ReadByte(ctx)
298	if e3 != nil {
299		return name, typeId, seqId, e3
300	}
301	typeId = TMessageType(b)
302	seqId, e4 := p.ReadI32(ctx)
303	if e4 != nil {
304		return name, typeId, seqId, e4
305	}
306	return name, typeId, seqId, nil
307}
308
309func (p *TBinaryProtocol) ReadMessageEnd(ctx context.Context) error {
310	return nil
311}
312
313func (p *TBinaryProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
314	return
315}
316
317func (p *TBinaryProtocol) ReadStructEnd(ctx context.Context) error {
318	return nil
319}
320
321func (p *TBinaryProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, seqId int16, err error) {
322	t, err := p.ReadByte(ctx)
323	typeId = TType(t)
324	if err != nil {
325		return name, typeId, seqId, err
326	}
327	if t != STOP {
328		seqId, err = p.ReadI16(ctx)
329	}
330	return name, typeId, seqId, err
331}
332
333func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error {
334	return nil
335}
336
337var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
338
339func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) {
340	k, e := p.ReadByte(ctx)
341	if e != nil {
342		err = NewTProtocolException(e)
343		return
344	}
345	kType = TType(k)
346	v, e := p.ReadByte(ctx)
347	if e != nil {
348		err = NewTProtocolException(e)
349		return
350	}
351	vType = TType(v)
352	size32, e := p.ReadI32(ctx)
353	if e != nil {
354		err = NewTProtocolException(e)
355		return
356	}
357	if size32 < 0 {
358		err = invalidDataLength
359		return
360	}
361	size = int(size32)
362	return kType, vType, size, nil
363}
364
365func (p *TBinaryProtocol) ReadMapEnd(ctx context.Context) error {
366	return nil
367}
368
369func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
370	b, e := p.ReadByte(ctx)
371	if e != nil {
372		err = NewTProtocolException(e)
373		return
374	}
375	elemType = TType(b)
376	size32, e := p.ReadI32(ctx)
377	if e != nil {
378		err = NewTProtocolException(e)
379		return
380	}
381	if size32 < 0 {
382		err = invalidDataLength
383		return
384	}
385	size = int(size32)
386
387	return
388}
389
390func (p *TBinaryProtocol) ReadListEnd(ctx context.Context) error {
391	return nil
392}
393
394func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
395	b, e := p.ReadByte(ctx)
396	if e != nil {
397		err = NewTProtocolException(e)
398		return
399	}
400	elemType = TType(b)
401	size32, e := p.ReadI32(ctx)
402	if e != nil {
403		err = NewTProtocolException(e)
404		return
405	}
406	if size32 < 0 {
407		err = invalidDataLength
408		return
409	}
410	size = int(size32)
411	return elemType, size, nil
412}
413
414func (p *TBinaryProtocol) ReadSetEnd(ctx context.Context) error {
415	return nil
416}
417
418func (p *TBinaryProtocol) ReadBool(ctx context.Context) (bool, error) {
419	b, e := p.ReadByte(ctx)
420	v := true
421	if b != 1 {
422		v = false
423	}
424	return v, e
425}
426
427func (p *TBinaryProtocol) ReadByte(ctx context.Context) (int8, error) {
428	v, err := p.trans.ReadByte()
429	return int8(v), err
430}
431
432func (p *TBinaryProtocol) ReadI16(ctx context.Context) (value int16, err error) {
433	buf := p.buffer[0:2]
434	err = p.readAll(ctx, buf)
435	value = int16(binary.BigEndian.Uint16(buf))
436	return value, err
437}
438
439func (p *TBinaryProtocol) ReadI32(ctx context.Context) (value int32, err error) {
440	buf := p.buffer[0:4]
441	err = p.readAll(ctx, buf)
442	value = int32(binary.BigEndian.Uint32(buf))
443	return value, err
444}
445
446func (p *TBinaryProtocol) ReadI64(ctx context.Context) (value int64, err error) {
447	buf := p.buffer[0:8]
448	err = p.readAll(ctx, buf)
449	value = int64(binary.BigEndian.Uint64(buf))
450	return value, err
451}
452
453func (p *TBinaryProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
454	buf := p.buffer[0:8]
455	err = p.readAll(ctx, buf)
456	value = math.Float64frombits(binary.BigEndian.Uint64(buf))
457	return value, err
458}
459
460func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err error) {
461	size, e := p.ReadI32(ctx)
462	if e != nil {
463		return "", e
464	}
465	err = checkSizeForProtocol(size, p.cfg)
466	if err != nil {
467		return
468	}
469	if size < 0 {
470		err = invalidDataLength
471		return
472	}
473	if size == 0 {
474		return "", nil
475	}
476	if size < int32(len(p.buffer)) {
477		// Avoid allocation on small reads
478		buf := p.buffer[:size]
479		read, e := io.ReadFull(p.trans, buf)
480		return string(buf[:read]), NewTProtocolException(e)
481	}
482
483	return p.readStringBody(size)
484}
485
486func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
487	size, e := p.ReadI32(ctx)
488	if e != nil {
489		return nil, e
490	}
491	if err := checkSizeForProtocol(size, p.cfg); err != nil {
492		return nil, err
493	}
494
495	buf, err := safeReadBytes(size, p.trans)
496	return buf, NewTProtocolException(err)
497}
498
499func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
500	return NewTProtocolException(p.trans.Flush(ctx))
501}
502
503func (p *TBinaryProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
504	return SkipDefaultDepth(ctx, p, fieldType)
505}
506
507func (p *TBinaryProtocol) Transport() TTransport {
508	return p.origTransport
509}
510
511func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) {
512	var read int
513	_, deadlineSet := ctx.Deadline()
514	for {
515		read, err = io.ReadFull(p.trans, buf)
516		if deadlineSet && read == 0 && isTimeoutError(err) && ctx.Err() == nil {
517			// This is I/O timeout without anything read,
518			// and we still have time left, keep retrying.
519			continue
520		}
521		// For anything else, don't retry
522		break
523	}
524	return NewTProtocolException(err)
525}
526
527func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
528	buf, err := safeReadBytes(size, p.trans)
529	return string(buf), NewTProtocolException(err)
530}
531
532func (p *TBinaryProtocol) SetTConfiguration(conf *TConfiguration) {
533	PropagateTConfiguration(p.trans, conf)
534	PropagateTConfiguration(p.origTransport, conf)
535	p.cfg = conf
536}
537
538var (
539	_ TConfigurationSetter = (*TBinaryProtocolFactory)(nil)
540	_ TConfigurationSetter = (*TBinaryProtocol)(nil)
541)
542
543// This function is shared between TBinaryProtocol and TCompactProtocol.
544//
545// It tries to read size bytes from trans, in a way that prevents large
546// allocations when size is insanely large (mostly caused by malformed message).
547func safeReadBytes(size int32, trans io.Reader) ([]byte, error) {
548	if size < 0 {
549		return nil, nil
550	}
551
552	buf := new(bytes.Buffer)
553	_, err := io.CopyN(buf, trans, int64(size))
554	return buf.Bytes(), err
555}
556