1package thrift
2
3import (
4	"context"
5	"fmt"
6)
7
8// ResponseMeta represents the metadata attached to the response.
9type ResponseMeta struct {
10	// The headers in the response, if any.
11	// If the underlying transport/protocol is not THeader, this will always be nil.
12	Headers THeaderMap
13}
14
15type TClient interface {
16	Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error)
17}
18
19type TStandardClient struct {
20	seqId        int32
21	iprot, oprot TProtocol
22}
23
24// TStandardClient implements TClient, and uses the standard message format for Thrift.
25// It is not safe for concurrent use.
26func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
27	return &TStandardClient{
28		iprot: inputProtocol,
29		oprot: outputProtocol,
30	}
31}
32
33func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
34	// Set headers from context object on THeaderProtocol
35	if headerProt, ok := oprot.(*THeaderProtocol); ok {
36		headerProt.ClearWriteHeaders()
37		for _, key := range GetWriteHeaderList(ctx) {
38			if value, ok := GetHeader(ctx, key); ok {
39				headerProt.SetWriteHeader(key, value)
40			}
41		}
42	}
43
44	if err := oprot.WriteMessageBegin(ctx, method, CALL, seqId); err != nil {
45		return err
46	}
47	if err := args.Write(ctx, oprot); err != nil {
48		return err
49	}
50	if err := oprot.WriteMessageEnd(ctx); err != nil {
51		return err
52	}
53	return oprot.Flush(ctx)
54}
55
56func (p *TStandardClient) Recv(ctx context.Context, iprot TProtocol, seqId int32, method string, result TStruct) error {
57	rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin(ctx)
58	if err != nil {
59		return err
60	}
61
62	if method != rMethod {
63		return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
64	} else if seqId != rSeqId {
65		return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
66	} else if rTypeId == EXCEPTION {
67		var exception tApplicationException
68		if err := exception.Read(ctx, iprot); err != nil {
69			return err
70		}
71
72		if err := iprot.ReadMessageEnd(ctx); err != nil {
73			return err
74		}
75
76		return &exception
77	} else if rTypeId != REPLY {
78		return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
79	}
80
81	if err := result.Read(ctx, iprot); err != nil {
82		return err
83	}
84
85	return iprot.ReadMessageEnd(ctx)
86}
87
88func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error) {
89	p.seqId++
90	seqId := p.seqId
91
92	if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
93		return ResponseMeta{}, err
94	}
95
96	// method is oneway
97	if result == nil {
98		return ResponseMeta{}, nil
99	}
100
101	err := p.Recv(ctx, p.iprot, seqId, method, result)
102	var headers THeaderMap
103	if hp, ok := p.iprot.(*THeaderProtocol); ok {
104		headers = hp.transport.readHeaders
105	}
106	return ResponseMeta{
107		Headers: headers,
108	}, err
109}
110