1// Copyright (c) 2012-2015 Ugorji Nwoke. All rights reserved.
2// Use of this source code is governed by a MIT license found in the LICENSE file.
3
4package codec
5
6import (
7	"bufio"
8	"io"
9	"net/rpc"
10	"sync"
11)
12
13// rpcEncodeTerminator allows a handler specify a []byte terminator to send after each Encode.
14//
15// Some codecs like json need to put a space after each encoded value, to serve as a
16// delimiter for things like numbers (else json codec will continue reading till EOF).
17type rpcEncodeTerminator interface {
18	rpcEncodeTerminate() []byte
19}
20
21// Rpc provides a rpc Server or Client Codec for rpc communication.
22type Rpc interface {
23	ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec
24	ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec
25}
26
27// RpcCodecBuffered allows access to the underlying bufio.Reader/Writer
28// used by the rpc connection. It accomodates use-cases where the connection
29// should be used by rpc and non-rpc functions, e.g. streaming a file after
30// sending an rpc response.
31type RpcCodecBuffered interface {
32	BufferedReader() *bufio.Reader
33	BufferedWriter() *bufio.Writer
34}
35
36// -------------------------------------
37
38// rpcCodec defines the struct members and common methods.
39type rpcCodec struct {
40	rwc io.ReadWriteCloser
41	dec *Decoder
42	enc *Encoder
43	bw  *bufio.Writer
44	br  *bufio.Reader
45	mu  sync.Mutex
46	h   Handle
47
48	cls   bool
49	clsmu sync.RWMutex
50}
51
52func newRPCCodec(conn io.ReadWriteCloser, h Handle) rpcCodec {
53	bw := bufio.NewWriter(conn)
54	br := bufio.NewReader(conn)
55	return rpcCodec{
56		rwc: conn,
57		bw:  bw,
58		br:  br,
59		enc: NewEncoder(bw, h),
60		dec: NewDecoder(br, h),
61		h:   h,
62	}
63}
64
65func (c *rpcCodec) BufferedReader() *bufio.Reader {
66	return c.br
67}
68
69func (c *rpcCodec) BufferedWriter() *bufio.Writer {
70	return c.bw
71}
72
73func (c *rpcCodec) write(obj1, obj2 interface{}, writeObj2, doFlush bool) (err error) {
74	if c.isClosed() {
75		return io.EOF
76	}
77	if err = c.enc.Encode(obj1); err != nil {
78		return
79	}
80	t, tOk := c.h.(rpcEncodeTerminator)
81	if tOk {
82		c.bw.Write(t.rpcEncodeTerminate())
83	}
84	if writeObj2 {
85		if err = c.enc.Encode(obj2); err != nil {
86			return
87		}
88		if tOk {
89			c.bw.Write(t.rpcEncodeTerminate())
90		}
91	}
92	if doFlush {
93		return c.bw.Flush()
94	}
95	return
96}
97
98func (c *rpcCodec) read(obj interface{}) (err error) {
99	if c.isClosed() {
100		return io.EOF
101	}
102	//If nil is passed in, we should still attempt to read content to nowhere.
103	if obj == nil {
104		var obj2 interface{}
105		return c.dec.Decode(&obj2)
106	}
107	return c.dec.Decode(obj)
108}
109
110func (c *rpcCodec) isClosed() bool {
111	c.clsmu.RLock()
112	x := c.cls
113	c.clsmu.RUnlock()
114	return x
115}
116
117func (c *rpcCodec) Close() error {
118	if c.isClosed() {
119		return io.EOF
120	}
121	c.clsmu.Lock()
122	c.cls = true
123	c.clsmu.Unlock()
124	return c.rwc.Close()
125}
126
127func (c *rpcCodec) ReadResponseBody(body interface{}) error {
128	return c.read(body)
129}
130
131// -------------------------------------
132
133type goRpcCodec struct {
134	rpcCodec
135}
136
137func (c *goRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
138	// Must protect for concurrent access as per API
139	c.mu.Lock()
140	defer c.mu.Unlock()
141	return c.write(r, body, true, true)
142}
143
144func (c *goRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error {
145	c.mu.Lock()
146	defer c.mu.Unlock()
147	return c.write(r, body, true, true)
148}
149
150func (c *goRpcCodec) ReadResponseHeader(r *rpc.Response) error {
151	return c.read(r)
152}
153
154func (c *goRpcCodec) ReadRequestHeader(r *rpc.Request) error {
155	return c.read(r)
156}
157
158func (c *goRpcCodec) ReadRequestBody(body interface{}) error {
159	return c.read(body)
160}
161
162// -------------------------------------
163
164// goRpc is the implementation of Rpc that uses the communication protocol
165// as defined in net/rpc package.
166type goRpc struct{}
167
168// GoRpc implements Rpc using the communication protocol defined in net/rpc package.
169// Its methods (ServerCodec and ClientCodec) return values that implement RpcCodecBuffered.
170var GoRpc goRpc
171
172func (x goRpc) ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec {
173	return &goRpcCodec{newRPCCodec(conn, h)}
174}
175
176func (x goRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec {
177	return &goRpcCodec{newRPCCodec(conn, h)}
178}
179
180var _ RpcCodecBuffered = (*rpcCodec)(nil) // ensure *rpcCodec implements RpcCodecBuffered
181