1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package drpcconn
5
6import (
7	"context"
8	"sync"
9
10	"github.com/zeebo/errs"
11
12	"storj.io/drpc"
13	"storj.io/drpc/drpcenc"
14	"storj.io/drpc/drpcmanager"
15	"storj.io/drpc/drpcmetadata"
16	"storj.io/drpc/drpcstream"
17	"storj.io/drpc/drpcwire"
18)
19
20// Options controls configuration settings for a conn.
21type Options struct {
22	// Manager controls the options we pass to the manager of this conn.
23	Manager drpcmanager.Options
24}
25
26// Conn is a drpc client connection.
27type Conn struct {
28	tr   drpc.Transport
29	man  *drpcmanager.Manager
30	mu   sync.Mutex
31	wbuf []byte
32}
33
34var _ drpc.Conn = (*Conn)(nil)
35
36// New returns a conn that uses the transport for reads and writes.
37func New(tr drpc.Transport) *Conn {
38	return NewWithOptions(tr, Options{})
39}
40
41// NewWithOptions returns a conn that uses the transport for reads and writes.
42// The Options control details of how the conn operates.
43func NewWithOptions(tr drpc.Transport, opts Options) *Conn {
44	return &Conn{
45		tr:  tr,
46		man: drpcmanager.NewWithOptions(tr, opts.Manager),
47	}
48}
49
50// Transport returns the transport the conn is using.
51func (c *Conn) Transport() drpc.Transport {
52	return c.tr
53}
54
55// Closed returns a channel that is closed once the connection is closed.
56func (c *Conn) Closed() <-chan struct{} {
57	return c.man.Closed()
58}
59
60// Close closes the connection.
61func (c *Conn) Close() (err error) {
62	return c.man.Close()
63}
64
65// Invoke issues the rpc on the transport serializing in, waits for a response, and
66// deserializes it into out. Only one Invoke or Stream may be open at a time.
67func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) (err error) {
68	var metadata []byte
69	if md, ok := drpcmetadata.Get(ctx); ok {
70		metadata, err = drpcmetadata.Encode(metadata, md)
71		if err != nil {
72			return err
73		}
74	}
75
76	stream, err := c.man.NewClientStream(ctx)
77	if err != nil {
78		return err
79	}
80	defer func() { err = errs.Combine(err, stream.Close()) }()
81
82	// we have to protect c.wbuf here even though the manager only allows one
83	// stream at a time because the stream may async close allowing another
84	// concurrent call to Invoke to proceed.
85	c.mu.Lock()
86	defer c.mu.Unlock()
87
88	c.wbuf, err = drpcenc.MarshalAppend(in, enc, c.wbuf[:0])
89	if err != nil {
90		return err
91	}
92
93	if err := c.doInvoke(stream, enc, rpc, c.wbuf, metadata, out); err != nil {
94		return err
95	}
96	return nil
97}
98
99func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string, data []byte, metadata []byte, out drpc.Message) (err error) {
100	if len(metadata) > 0 {
101		if err := stream.RawWrite(drpcwire.KindInvokeMetadata, metadata); err != nil {
102			return err
103		}
104	}
105	if err := stream.RawWrite(drpcwire.KindInvoke, []byte(rpc)); err != nil {
106		return err
107	}
108	if err := stream.RawWrite(drpcwire.KindMessage, data); err != nil {
109		return err
110	}
111	if err := stream.CloseSend(); err != nil {
112		return err
113	}
114	if err := stream.MsgRecv(out, enc); err != nil {
115		return err
116	}
117	return nil
118}
119
120// NewStream begins a streaming rpc on the connection. Only one Invoke or Stream may
121// be open at a time.
122func (c *Conn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (_ drpc.Stream, err error) {
123	var metadata []byte
124	if md, ok := drpcmetadata.Get(ctx); ok {
125		metadata, err = drpcmetadata.Encode(metadata, md)
126		if err != nil {
127			return nil, err
128		}
129	}
130
131	stream, err := c.man.NewClientStream(ctx)
132	if err != nil {
133		return nil, err
134	}
135
136	if err := c.doNewStream(stream, rpc, metadata); err != nil {
137		return nil, errs.Combine(err, stream.Close())
138	}
139
140	return stream, nil
141}
142
143func (c *Conn) doNewStream(stream *drpcstream.Stream, rpc string, metadata []byte) error {
144	if len(metadata) > 0 {
145		if err := stream.RawWrite(drpcwire.KindInvokeMetadata, metadata); err != nil {
146			return err
147		}
148	}
149	if err := stream.RawWrite(drpcwire.KindInvoke, []byte(rpc)); err != nil {
150		return err
151	}
152	if err := stream.RawFlush(); err != nil {
153		return err
154	}
155	return nil
156}
157