1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23	"bytes"
24	"encoding/binary"
25	"errors"
26	"fmt"
27	"io"
28	"math"
29)
30
31type TBinaryProtocol struct {
32	trans         TRichTransport
33	origTransport TTransport
34	reader        io.Reader
35	writer        io.Writer
36	strictRead    bool
37	strictWrite   bool
38	buffer        [64]byte
39}
40
41type TBinaryProtocolFactory struct {
42	strictRead  bool
43	strictWrite bool
44}
45
46func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
47	return NewTBinaryProtocol(t, false, true)
48}
49
50func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
51	p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
52	if et, ok := t.(TRichTransport); ok {
53		p.trans = et
54	} else {
55		p.trans = NewTRichTransport(t)
56	}
57	p.reader = p.trans
58	p.writer = p.trans
59	return p
60}
61
62func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
63	return NewTBinaryProtocolFactory(false, true)
64}
65
66func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
67	return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
68}
69
70func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
71	return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
72}
73
74/**
75 * Writing Methods
76 */
77
78func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
79	if p.strictWrite {
80		version := uint32(VERSION_1) | uint32(typeId)
81		e := p.WriteI32(int32(version))
82		if e != nil {
83			return e
84		}
85		e = p.WriteString(name)
86		if e != nil {
87			return e
88		}
89		e = p.WriteI32(seqId)
90		return e
91	} else {
92		e := p.WriteString(name)
93		if e != nil {
94			return e
95		}
96		e = p.WriteByte(int8(typeId))
97		if e != nil {
98			return e
99		}
100		e = p.WriteI32(seqId)
101		return e
102	}
103	return nil
104}
105
106func (p *TBinaryProtocol) WriteMessageEnd() error {
107	return nil
108}
109
110func (p *TBinaryProtocol) WriteStructBegin(name string) error {
111	return nil
112}
113
114func (p *TBinaryProtocol) WriteStructEnd() error {
115	return nil
116}
117
118func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
119	e := p.WriteByte(int8(typeId))
120	if e != nil {
121		return e
122	}
123	e = p.WriteI16(id)
124	return e
125}
126
127func (p *TBinaryProtocol) WriteFieldEnd() error {
128	return nil
129}
130
131func (p *TBinaryProtocol) WriteFieldStop() error {
132	e := p.WriteByte(STOP)
133	return e
134}
135
136func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
137	e := p.WriteByte(int8(keyType))
138	if e != nil {
139		return e
140	}
141	e = p.WriteByte(int8(valueType))
142	if e != nil {
143		return e
144	}
145	e = p.WriteI32(int32(size))
146	return e
147}
148
149func (p *TBinaryProtocol) WriteMapEnd() error {
150	return nil
151}
152
153func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error {
154	e := p.WriteByte(int8(elemType))
155	if e != nil {
156		return e
157	}
158	e = p.WriteI32(int32(size))
159	return e
160}
161
162func (p *TBinaryProtocol) WriteListEnd() error {
163	return nil
164}
165
166func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error {
167	e := p.WriteByte(int8(elemType))
168	if e != nil {
169		return e
170	}
171	e = p.WriteI32(int32(size))
172	return e
173}
174
175func (p *TBinaryProtocol) WriteSetEnd() error {
176	return nil
177}
178
179func (p *TBinaryProtocol) WriteBool(value bool) error {
180	if value {
181		return p.WriteByte(1)
182	}
183	return p.WriteByte(0)
184}
185
186func (p *TBinaryProtocol) WriteByte(value int8) error {
187	e := p.trans.WriteByte(byte(value))
188	return NewTProtocolException(e)
189}
190
191func (p *TBinaryProtocol) WriteI16(value int16) error {
192	v := p.buffer[0:2]
193	binary.BigEndian.PutUint16(v, uint16(value))
194	_, e := p.writer.Write(v)
195	return NewTProtocolException(e)
196}
197
198func (p *TBinaryProtocol) WriteI32(value int32) error {
199	v := p.buffer[0:4]
200	binary.BigEndian.PutUint32(v, uint32(value))
201	_, e := p.writer.Write(v)
202	return NewTProtocolException(e)
203}
204
205func (p *TBinaryProtocol) WriteI64(value int64) error {
206	v := p.buffer[0:8]
207	binary.BigEndian.PutUint64(v, uint64(value))
208	_, err := p.writer.Write(v)
209	return NewTProtocolException(err)
210}
211
212func (p *TBinaryProtocol) WriteDouble(value float64) error {
213	return p.WriteI64(int64(math.Float64bits(value)))
214}
215
216func (p *TBinaryProtocol) WriteString(value string) error {
217	e := p.WriteI32(int32(len(value)))
218	if e != nil {
219		return e
220	}
221	_, err := p.trans.WriteString(value)
222	return NewTProtocolException(err)
223}
224
225func (p *TBinaryProtocol) WriteBinary(value []byte) error {
226	e := p.WriteI32(int32(len(value)))
227	if e != nil {
228		return e
229	}
230	_, err := p.writer.Write(value)
231	return NewTProtocolException(err)
232}
233
234/**
235 * Reading methods
236 */
237
238func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
239	size, e := p.ReadI32()
240	if e != nil {
241		return "", typeId, 0, NewTProtocolException(e)
242	}
243	if size < 0 {
244		typeId = TMessageType(size & 0x0ff)
245		version := int64(int64(size) & VERSION_MASK)
246		if version != VERSION_1 {
247			return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
248		}
249		name, e = p.ReadString()
250		if e != nil {
251			return name, typeId, seqId, NewTProtocolException(e)
252		}
253		seqId, e = p.ReadI32()
254		if e != nil {
255			return name, typeId, seqId, NewTProtocolException(e)
256		}
257		return name, typeId, seqId, nil
258	}
259	if p.strictRead {
260		return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
261	}
262	name, e2 := p.readStringBody(size)
263	if e2 != nil {
264		return name, typeId, seqId, e2
265	}
266	b, e3 := p.ReadByte()
267	if e3 != nil {
268		return name, typeId, seqId, e3
269	}
270	typeId = TMessageType(b)
271	seqId, e4 := p.ReadI32()
272	if e4 != nil {
273		return name, typeId, seqId, e4
274	}
275	return name, typeId, seqId, nil
276}
277
278func (p *TBinaryProtocol) ReadMessageEnd() error {
279	return nil
280}
281
282func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) {
283	return
284}
285
286func (p *TBinaryProtocol) ReadStructEnd() error {
287	return nil
288}
289
290func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) {
291	t, err := p.ReadByte()
292	typeId = TType(t)
293	if err != nil {
294		return name, typeId, seqId, err
295	}
296	if t != STOP {
297		seqId, err = p.ReadI16()
298	}
299	return name, typeId, seqId, err
300}
301
302func (p *TBinaryProtocol) ReadFieldEnd() error {
303	return nil
304}
305
306var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
307
308func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) {
309	k, e := p.ReadByte()
310	if e != nil {
311		err = NewTProtocolException(e)
312		return
313	}
314	kType = TType(k)
315	v, e := p.ReadByte()
316	if e != nil {
317		err = NewTProtocolException(e)
318		return
319	}
320	vType = TType(v)
321	size32, e := p.ReadI32()
322	if e != nil {
323		err = NewTProtocolException(e)
324		return
325	}
326	if size32 < 0 {
327		err = invalidDataLength
328		return
329	}
330	size = int(size32)
331	return kType, vType, size, nil
332}
333
334func (p *TBinaryProtocol) ReadMapEnd() error {
335	return nil
336}
337
338func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) {
339	b, e := p.ReadByte()
340	if e != nil {
341		err = NewTProtocolException(e)
342		return
343	}
344	elemType = TType(b)
345	size32, e := p.ReadI32()
346	if e != nil {
347		err = NewTProtocolException(e)
348		return
349	}
350	if size32 < 0 {
351		err = invalidDataLength
352		return
353	}
354	size = int(size32)
355
356	return
357}
358
359func (p *TBinaryProtocol) ReadListEnd() error {
360	return nil
361}
362
363func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
364	b, e := p.ReadByte()
365	if e != nil {
366		err = NewTProtocolException(e)
367		return
368	}
369	elemType = TType(b)
370	size32, e := p.ReadI32()
371	if e != nil {
372		err = NewTProtocolException(e)
373		return
374	}
375	if size32 < 0 {
376		err = invalidDataLength
377		return
378	}
379	size = int(size32)
380	return elemType, size, nil
381}
382
383func (p *TBinaryProtocol) ReadSetEnd() error {
384	return nil
385}
386
387func (p *TBinaryProtocol) ReadBool() (bool, error) {
388	b, e := p.ReadByte()
389	v := true
390	if b != 1 {
391		v = false
392	}
393	return v, e
394}
395
396func (p *TBinaryProtocol) ReadByte() (int8, error) {
397	v, err := p.trans.ReadByte()
398	return int8(v), err
399}
400
401func (p *TBinaryProtocol) ReadI16() (value int16, err error) {
402	buf := p.buffer[0:2]
403	err = p.readAll(buf)
404	value = int16(binary.BigEndian.Uint16(buf))
405	return value, err
406}
407
408func (p *TBinaryProtocol) ReadI32() (value int32, err error) {
409	buf := p.buffer[0:4]
410	err = p.readAll(buf)
411	value = int32(binary.BigEndian.Uint32(buf))
412	return value, err
413}
414
415func (p *TBinaryProtocol) ReadI64() (value int64, err error) {
416	buf := p.buffer[0:8]
417	err = p.readAll(buf)
418	value = int64(binary.BigEndian.Uint64(buf))
419	return value, err
420}
421
422func (p *TBinaryProtocol) ReadDouble() (value float64, err error) {
423	buf := p.buffer[0:8]
424	err = p.readAll(buf)
425	value = math.Float64frombits(binary.BigEndian.Uint64(buf))
426	return value, err
427}
428
429func (p *TBinaryProtocol) ReadString() (value string, err error) {
430	size, e := p.ReadI32()
431	if e != nil {
432		return "", e
433	}
434	if size < 0 {
435		err = invalidDataLength
436		return
437	}
438
439	return p.readStringBody(size)
440}
441
442func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
443	size, e := p.ReadI32()
444	if e != nil {
445		return nil, e
446	}
447	if size < 0 {
448		return nil, invalidDataLength
449	}
450	if uint64(size) > p.trans.RemainingBytes() {
451		return nil, invalidDataLength
452	}
453
454	isize := int(size)
455	buf := make([]byte, isize)
456	_, err := io.ReadFull(p.trans, buf)
457	return buf, NewTProtocolException(err)
458}
459
460func (p *TBinaryProtocol) Flush() (err error) {
461	return NewTProtocolException(p.trans.Flush())
462}
463
464func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
465	return SkipDefaultDepth(p, fieldType)
466}
467
468func (p *TBinaryProtocol) Transport() TTransport {
469	return p.origTransport
470}
471
472func (p *TBinaryProtocol) readAll(buf []byte) error {
473	_, err := io.ReadFull(p.reader, buf)
474	return NewTProtocolException(err)
475}
476
477const readLimit = 32768
478
479func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
480	if size < 0 {
481		return "", nil
482	}
483	if uint64(size) > p.trans.RemainingBytes() {
484		return "", invalidDataLength
485	}
486
487	var (
488		buf bytes.Buffer
489		e   error
490		b   []byte
491	)
492
493	switch {
494	case int(size) <= len(p.buffer):
495		b = p.buffer[:size] // avoids allocation for small reads
496	case int(size) < readLimit:
497		b = make([]byte, size)
498	default:
499		b = make([]byte, readLimit)
500	}
501
502	for size > 0 {
503		_, e = io.ReadFull(p.trans, b)
504		buf.Write(b)
505		if e != nil {
506			break
507		}
508		size -= readLimit
509		if size < readLimit && size > 0 {
510			b = b[:size]
511		}
512	}
513	return buf.String(), NewTProtocolException(e)
514}
515