1package rpc
2
3import (
4	"fmt"
5	"io"
6	"net"
7	"sync"
8)
9
10type WrapErrorFunc func(error) interface{}
11
12type Transporter interface {
13	// IsConnected returns false when incoming packets have
14	// finished processing.
15	//
16	// TODO: Use a better name.
17	IsConnected() bool
18
19	registerProtocol(p Protocol) error
20
21	getDispatcher() (dispatcher, error)
22	getReceiver() (receiver, error)
23
24	// receiveFrames starts processing incoming frames in a
25	// background goroutine, if it's not already happening.
26	// Returns the result of done(), for convenience.
27	receiveFrames() <-chan struct{}
28
29	// Returns a channel that's closed when incoming frames have
30	// finished processing, either due to an error or the
31	// underlying connection being closed. Successive calls to
32	// done() return the same value.
33	done() <-chan struct{}
34
35	// err returns a non-nil error value after done() is closed.
36	// After done() is closed, successive calls to err return the
37	// same value.
38	err() error
39
40	// Close closes the transport and releases resources.
41	Close()
42}
43
44var _ Transporter = (*transport)(nil)
45
46type transport struct {
47	c          net.Conn
48	enc        *framedMsgpackEncoder
49	dispatcher dispatcher
50	receiver   receiver
51	packetizer *packetizer
52	protocols  *protocolHandler
53	calls      *callContainer
54	log        LogInterface
55	closeOnce  sync.Once
56	startOnce  sync.Once
57	stopCh     chan struct{}
58
59	// Filled in right before stopCh is closed.
60	stopErr error
61}
62
63// DefaultMaxFrameLength (100 MiB) is a reasonable default value for
64// the maxFrameLength parameter in NewTransporter.
65const DefaultMaxFrameLength = 100 * 1024 * 1024
66
67// NewTransport creates a new Transporter from the given connection
68// and parameters. Both sides of a connection should use the same
69// number for maxFrameLength.
70func NewTransport(c net.Conn, l LogFactory, instrumenterStorage NetworkInstrumenterStorage, wef WrapErrorFunc, maxFrameLength int32) Transporter {
71	if maxFrameLength <= 0 {
72		panic(fmt.Sprintf("maxFrameLength must be positive: got %d", maxFrameLength))
73	}
74
75	if l == nil {
76		l = NewSimpleLogFactory(nil, nil)
77	}
78	log := l.NewLog(c.RemoteAddr())
79	if instrumenterStorage == nil {
80		instrumenterStorage = NewDummyInstrumentationStorage()
81	}
82
83	ret := &transport{
84		c:         c,
85		log:       log,
86		stopCh:    make(chan struct{}),
87		protocols: newProtocolHandler(wef),
88		calls:     newCallContainer(),
89	}
90	enc := newFramedMsgpackEncoder(maxFrameLength, c)
91	ret.enc = enc
92	ret.dispatcher = newDispatch(enc, ret.calls, log, instrumenterStorage)
93	ret.receiver = newReceiveHandler(enc, ret.protocols, log)
94	ret.packetizer = newPacketizer(maxFrameLength, c, ret.protocols, ret.calls, log, instrumenterStorage)
95	return ret
96}
97
98func (t *transport) Close() {
99	t.closeOnce.Do(func() {
100		// Since the receiver might require the transport, we have to
101		// close it before terminating our loops
102		close(t.stopCh)
103		t.dispatcher.Close()
104		<-t.receiver.Close()
105
106		// First inform the encoder that it should close
107		encoderClosed := t.enc.Close()
108		// Unblock any remaining writes
109		t.c.Close()
110		// Wait for the encoder to finish handling the now unblocked writes
111		<-encoderClosed
112	})
113}
114
115func (t *transport) IsConnected() bool {
116	select {
117	case <-t.stopCh:
118		return false
119	default:
120		return true
121	}
122}
123
124func (t *transport) receiveFrames() <-chan struct{} {
125	t.startOnce.Do(func() {
126		go t.receiveFramesLoop()
127	})
128	return t.stopCh
129}
130
131func (t *transport) done() <-chan struct{} {
132	return t.stopCh
133}
134
135func (t *transport) err() error {
136	select {
137	case <-t.stopCh:
138		return t.stopErr
139	default:
140		return nil
141	}
142}
143
144func (t *transport) receiveFramesLoop() {
145	// Packetize: do work
146	var err error
147	for shouldContinue(err) {
148		var rpc rpcMessage
149		if rpc, err = t.packetizer.NextFrame(); shouldReceive(rpc) {
150			if rerr := t.receiver.Receive(rpc); rerr != nil {
151				t.log.Info("error on Receive: %v", rerr)
152			}
153		}
154	}
155
156	// Log packetizer completion
157	t.log.TransportError(err)
158
159	// This must happen before stopCh is closed to have a correct
160	// ordering.
161	t.stopErr = err
162
163	t.Close()
164}
165
166func (t *transport) getDispatcher() (dispatcher, error) {
167	if !t.IsConnected() {
168		return nil, io.EOF
169	}
170	return t.dispatcher, nil
171}
172
173func (t *transport) getReceiver() (receiver, error) {
174	if !t.IsConnected() {
175		return nil, io.EOF
176	}
177	return t.receiver, nil
178}
179
180func (t *transport) registerProtocol(p Protocol) error {
181	return t.protocols.registerProtocol(p)
182}
183
184func shouldContinue(err error) bool {
185	err = unboxRPCError(err)
186	switch err.(type) {
187	case nil:
188		return true
189	case CallNotFoundError:
190		return true
191	case MethodNotFoundError:
192		return true
193	case ProtocolNotFoundError:
194		return true
195	default:
196		return false
197	}
198}
199
200func shouldReceive(rpc rpcMessage) bool {
201	if rpc == nil {
202		return false
203	}
204	switch rpc.Err().(type) {
205	case nil:
206		return true
207	case MethodNotFoundError:
208		return true
209	case ProtocolNotFoundError:
210		return true
211	default:
212		return false
213	}
214}
215