1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package rpc
6
7import (
8	"bufio"
9	"encoding/gob"
10	"errors"
11	"io"
12	"log"
13	"net"
14	"net/http"
15	"sync"
16)
17
18// ServerError represents an error that has been returned from
19// the remote side of the RPC connection.
20type ServerError string
21
22func (e ServerError) Error() string {
23	return string(e)
24}
25
26var ErrShutdown = errors.New("connection is shut down")
27
28// Call represents an active RPC.
29type Call struct {
30	ServiceMethod string      // The name of the service and method to call.
31	Args          interface{} // The argument to the function (*struct).
32	Reply         interface{} // The reply from the function (*struct).
33	Error         error       // After completion, the error status.
34	Done          chan *Call  // Receives *Call when Go is complete.
35}
36
37// Client represents an RPC Client.
38// There may be multiple outstanding Calls associated
39// with a single Client, and a Client may be used by
40// multiple goroutines simultaneously.
41type Client struct {
42	codec ClientCodec
43
44	reqMutex sync.Mutex // protects following
45	request  Request
46
47	mutex    sync.Mutex // protects following
48	seq      uint64
49	pending  map[uint64]*Call
50	closing  bool // user has called Close
51	shutdown bool // server has told us to stop
52}
53
54// A ClientCodec implements writing of RPC requests and
55// reading of RPC responses for the client side of an RPC session.
56// The client calls WriteRequest to write a request to the connection
57// and calls ReadResponseHeader and ReadResponseBody in pairs
58// to read responses. The client calls Close when finished with the
59// connection. ReadResponseBody may be called with a nil
60// argument to force the body of the response to be read and then
61// discarded.
62// See NewClient's comment for information about concurrent access.
63type ClientCodec interface {
64	WriteRequest(*Request, interface{}) error
65	ReadResponseHeader(*Response) error
66	ReadResponseBody(interface{}) error
67
68	Close() error
69}
70
71func (client *Client) send(call *Call) {
72	client.reqMutex.Lock()
73	defer client.reqMutex.Unlock()
74
75	// Register this call.
76	client.mutex.Lock()
77	if client.shutdown || client.closing {
78		client.mutex.Unlock()
79		call.Error = ErrShutdown
80		call.done()
81		return
82	}
83	seq := client.seq
84	client.seq++
85	client.pending[seq] = call
86	client.mutex.Unlock()
87
88	// Encode and send the request.
89	client.request.Seq = seq
90	client.request.ServiceMethod = call.ServiceMethod
91	err := client.codec.WriteRequest(&client.request, call.Args)
92	if err != nil {
93		client.mutex.Lock()
94		call = client.pending[seq]
95		delete(client.pending, seq)
96		client.mutex.Unlock()
97		if call != nil {
98			call.Error = err
99			call.done()
100		}
101	}
102}
103
104func (client *Client) input() {
105	var err error
106	var response Response
107	for err == nil {
108		response = Response{}
109		err = client.codec.ReadResponseHeader(&response)
110		if err != nil {
111			break
112		}
113		seq := response.Seq
114		client.mutex.Lock()
115		call := client.pending[seq]
116		delete(client.pending, seq)
117		client.mutex.Unlock()
118
119		switch {
120		case call == nil:
121			// We've got no pending call. That usually means that
122			// WriteRequest partially failed, and call was already
123			// removed; response is a server telling us about an
124			// error reading request body. We should still attempt
125			// to read error body, but there's no one to give it to.
126			err = client.codec.ReadResponseBody(nil)
127			if err != nil {
128				err = errors.New("reading error body: " + err.Error())
129			}
130		case response.Error != "":
131			// We've got an error response. Give this to the request;
132			// any subsequent requests will get the ReadResponseBody
133			// error if there is one.
134			call.Error = ServerError(response.Error)
135			err = client.codec.ReadResponseBody(nil)
136			if err != nil {
137				err = errors.New("reading error body: " + err.Error())
138			}
139			call.done()
140		default:
141			err = client.codec.ReadResponseBody(call.Reply)
142			if err != nil {
143				call.Error = errors.New("reading body " + err.Error())
144			}
145			call.done()
146		}
147	}
148	// Terminate pending calls.
149	client.reqMutex.Lock()
150	client.mutex.Lock()
151	client.shutdown = true
152	closing := client.closing
153	if err == io.EOF {
154		if closing {
155			err = ErrShutdown
156		} else {
157			err = io.ErrUnexpectedEOF
158		}
159	}
160	for _, call := range client.pending {
161		call.Error = err
162		call.done()
163	}
164	client.mutex.Unlock()
165	client.reqMutex.Unlock()
166	if debugLog && err != io.EOF && !closing {
167		log.Println("rpc: client protocol error:", err)
168	}
169}
170
171func (call *Call) done() {
172	select {
173	case call.Done <- call:
174		// ok
175	default:
176		// We don't want to block here. It is the caller's responsibility to make
177		// sure the channel has enough buffer space. See comment in Go().
178		if debugLog {
179			log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
180		}
181	}
182}
183
184// NewClient returns a new Client to handle requests to the
185// set of services at the other end of the connection.
186// It adds a buffer to the write side of the connection so
187// the header and payload are sent as a unit.
188//
189// The read and write halves of the connection are serialized independently,
190// so no interlocking is required. However each half may be accessed
191// concurrently so the implementation of conn should protect against
192// concurrent reads or concurrent writes.
193func NewClient(conn io.ReadWriteCloser) *Client {
194	encBuf := bufio.NewWriter(conn)
195	client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
196	return NewClientWithCodec(client)
197}
198
199// NewClientWithCodec is like NewClient but uses the specified
200// codec to encode requests and decode responses.
201func NewClientWithCodec(codec ClientCodec) *Client {
202	client := &Client{
203		codec:   codec,
204		pending: make(map[uint64]*Call),
205	}
206	go client.input()
207	return client
208}
209
210type gobClientCodec struct {
211	rwc    io.ReadWriteCloser
212	dec    *gob.Decoder
213	enc    *gob.Encoder
214	encBuf *bufio.Writer
215}
216
217func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) {
218	if err = c.enc.Encode(r); err != nil {
219		return
220	}
221	if err = c.enc.Encode(body); err != nil {
222		return
223	}
224	return c.encBuf.Flush()
225}
226
227func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
228	return c.dec.Decode(r)
229}
230
231func (c *gobClientCodec) ReadResponseBody(body interface{}) error {
232	return c.dec.Decode(body)
233}
234
235func (c *gobClientCodec) Close() error {
236	return c.rwc.Close()
237}
238
239// DialHTTP connects to an HTTP RPC server at the specified network address
240// listening on the default HTTP RPC path.
241func DialHTTP(network, address string) (*Client, error) {
242	return DialHTTPPath(network, address, DefaultRPCPath)
243}
244
245// DialHTTPPath connects to an HTTP RPC server
246// at the specified network address and path.
247func DialHTTPPath(network, address, path string) (*Client, error) {
248	conn, err := net.Dial(network, address)
249	if err != nil {
250		return nil, err
251	}
252	io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
253
254	// Require successful HTTP response
255	// before switching to RPC protocol.
256	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
257	if err == nil && resp.Status == connected {
258		return NewClient(conn), nil
259	}
260	if err == nil {
261		err = errors.New("unexpected HTTP response: " + resp.Status)
262	}
263	conn.Close()
264	return nil, &net.OpError{
265		Op:   "dial-http",
266		Net:  network + " " + address,
267		Addr: nil,
268		Err:  err,
269	}
270}
271
272// Dial connects to an RPC server at the specified network address.
273func Dial(network, address string) (*Client, error) {
274	conn, err := net.Dial(network, address)
275	if err != nil {
276		return nil, err
277	}
278	return NewClient(conn), nil
279}
280
281// Close calls the underlying codec's Close method. If the connection is already
282// shutting down, ErrShutdown is returned.
283func (client *Client) Close() error {
284	client.mutex.Lock()
285	if client.closing {
286		client.mutex.Unlock()
287		return ErrShutdown
288	}
289	client.closing = true
290	client.mutex.Unlock()
291	return client.codec.Close()
292}
293
294// Go invokes the function asynchronously. It returns the Call structure representing
295// the invocation. The done channel will signal when the call is complete by returning
296// the same Call object. If done is nil, Go will allocate a new channel.
297// If non-nil, done must be buffered or Go will deliberately crash.
298func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
299	call := new(Call)
300	call.ServiceMethod = serviceMethod
301	call.Args = args
302	call.Reply = reply
303	if done == nil {
304		done = make(chan *Call, 10) // buffered.
305	} else {
306		// If caller passes done != nil, it must arrange that
307		// done has enough buffer for the number of simultaneous
308		// RPCs that will be using that channel. If the channel
309		// is totally unbuffered, it's best not to run at all.
310		if cap(done) == 0 {
311			log.Panic("rpc: done channel is unbuffered")
312		}
313	}
314	call.Done = done
315	client.send(call)
316	return call
317}
318
319// Call invokes the named function, waits for it to complete, and returns its error status.
320func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
321	call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
322	return call.Error
323}
324