1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package rpc
6
7import (
8	"bufio"
9	"encoding/gob"
10	"errors"
11	"io"
12	"log"
13	"net"
14	"net/http"
15	"sync"
16)
17
18// ServerError represents an error that has been returned from
19// the remote side of the RPC connection.
20type ServerError string
21
22func (e ServerError) Error() string {
23	return string(e)
24}
25
26var ErrShutdown = errors.New("connection is shut down")
27
28// Call represents an active RPC.
29type Call struct {
30	ServiceMethod string      // The name of the service and method to call.
31	Args          interface{} // The argument to the function (*struct).
32	Reply         interface{} // The reply from the function (*struct).
33	Error         error       // After completion, the error status.
34	Done          chan *Call  // Strobes when call is complete.
35}
36
37// Client represents an RPC Client.
38// There may be multiple outstanding Calls associated
39// with a single Client, and a Client may be used by
40// multiple goroutines simultaneously.
41type Client struct {
42	mutex    sync.Mutex // protects pending, seq, request
43	sending  sync.Mutex
44	request  Request
45	seq      uint64
46	codec    ClientCodec
47	pending  map[uint64]*Call
48	closing  bool
49	shutdown bool
50}
51
52// A ClientCodec implements writing of RPC requests and
53// reading of RPC responses for the client side of an RPC session.
54// The client calls WriteRequest to write a request to the connection
55// and calls ReadResponseHeader and ReadResponseBody in pairs
56// to read responses.  The client calls Close when finished with the
57// connection. ReadResponseBody may be called with a nil
58// argument to force the body of the response to be read and then
59// discarded.
60type ClientCodec interface {
61	WriteRequest(*Request, interface{}) error
62	ReadResponseHeader(*Response) error
63	ReadResponseBody(interface{}) error
64
65	Close() error
66}
67
68func (client *Client) send(call *Call) {
69	client.sending.Lock()
70	defer client.sending.Unlock()
71
72	// Register this call.
73	client.mutex.Lock()
74	if client.shutdown || client.closing {
75		call.Error = ErrShutdown
76		client.mutex.Unlock()
77		call.done()
78		return
79	}
80	seq := client.seq
81	client.seq++
82	client.pending[seq] = call
83	client.mutex.Unlock()
84
85	// Encode and send the request.
86	client.request.Seq = seq
87	client.request.ServiceMethod = call.ServiceMethod
88	err := client.codec.WriteRequest(&client.request, call.Args)
89	if err != nil {
90		client.mutex.Lock()
91		call = client.pending[seq]
92		delete(client.pending, seq)
93		client.mutex.Unlock()
94		if call != nil {
95			call.Error = err
96			call.done()
97		}
98	}
99}
100
101func (client *Client) input() {
102	var err error
103	var response Response
104	for err == nil {
105		response = Response{}
106		err = client.codec.ReadResponseHeader(&response)
107		if err != nil {
108			break
109		}
110		seq := response.Seq
111		client.mutex.Lock()
112		call := client.pending[seq]
113		delete(client.pending, seq)
114		client.mutex.Unlock()
115
116		switch {
117		case call == nil:
118			// We've got no pending call. That usually means that
119			// WriteRequest partially failed, and call was already
120			// removed; response is a server telling us about an
121			// error reading request body. We should still attempt
122			// to read error body, but there's no one to give it to.
123			err = client.codec.ReadResponseBody(nil)
124			if err != nil {
125				err = errors.New("reading error body: " + err.Error())
126			}
127		case response.Error != "":
128			// We've got an error response. Give this to the request;
129			// any subsequent requests will get the ReadResponseBody
130			// error if there is one.
131			call.Error = ServerError(response.Error)
132			err = client.codec.ReadResponseBody(nil)
133			if err != nil {
134				err = errors.New("reading error body: " + err.Error())
135			}
136			call.done()
137		default:
138			err = client.codec.ReadResponseBody(call.Reply)
139			if err != nil {
140				call.Error = errors.New("reading body " + err.Error())
141			}
142			call.done()
143		}
144	}
145	// Terminate pending calls.
146	client.sending.Lock()
147	client.mutex.Lock()
148	client.shutdown = true
149	closing := client.closing
150	if err == io.EOF {
151		if closing {
152			err = ErrShutdown
153		} else {
154			err = io.ErrUnexpectedEOF
155		}
156	}
157	for _, call := range client.pending {
158		call.Error = err
159		call.done()
160	}
161	client.mutex.Unlock()
162	client.sending.Unlock()
163	if err != io.EOF && !closing {
164		log.Println("rpc: client protocol error:", err)
165	}
166}
167
168func (call *Call) done() {
169	select {
170	case call.Done <- call:
171		// ok
172	default:
173		// We don't want to block here.  It is the caller's responsibility to make
174		// sure the channel has enough buffer space. See comment in Go().
175		log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
176	}
177}
178
179// NewClient returns a new Client to handle requests to the
180// set of services at the other end of the connection.
181// It adds a buffer to the write side of the connection so
182// the header and payload are sent as a unit.
183func NewClient(conn io.ReadWriteCloser) *Client {
184	encBuf := bufio.NewWriter(conn)
185	client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
186	return NewClientWithCodec(client)
187}
188
189// NewClientWithCodec is like NewClient but uses the specified
190// codec to encode requests and decode responses.
191func NewClientWithCodec(codec ClientCodec) *Client {
192	client := &Client{
193		codec:   codec,
194		pending: make(map[uint64]*Call),
195	}
196	go client.input()
197	return client
198}
199
200type gobClientCodec struct {
201	rwc    io.ReadWriteCloser
202	dec    *gob.Decoder
203	enc    *gob.Encoder
204	encBuf *bufio.Writer
205}
206
207func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) {
208	if err = c.enc.Encode(r); err != nil {
209		return
210	}
211	if err = c.enc.Encode(body); err != nil {
212		return
213	}
214	return c.encBuf.Flush()
215}
216
217func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
218	return c.dec.Decode(r)
219}
220
221func (c *gobClientCodec) ReadResponseBody(body interface{}) error {
222	return c.dec.Decode(body)
223}
224
225func (c *gobClientCodec) Close() error {
226	return c.rwc.Close()
227}
228
229// DialHTTP connects to an HTTP RPC server at the specified network address
230// listening on the default HTTP RPC path.
231func DialHTTP(network, address string) (*Client, error) {
232	return DialHTTPPath(network, address, DefaultRPCPath)
233}
234
235// DialHTTPPath connects to an HTTP RPC server
236// at the specified network address and path.
237func DialHTTPPath(network, address, path string) (*Client, error) {
238	var err error
239	conn, err := net.Dial(network, address)
240	if err != nil {
241		return nil, err
242	}
243	io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
244
245	// Require successful HTTP response
246	// before switching to RPC protocol.
247	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
248	if err == nil && resp.Status == connected {
249		return NewClient(conn), nil
250	}
251	if err == nil {
252		err = errors.New("unexpected HTTP response: " + resp.Status)
253	}
254	conn.Close()
255	return nil, &net.OpError{
256		Op:   "dial-http",
257		Net:  network + " " + address,
258		Addr: nil,
259		Err:  err,
260	}
261}
262
263// Dial connects to an RPC server at the specified network address.
264func Dial(network, address string) (*Client, error) {
265	conn, err := net.Dial(network, address)
266	if err != nil {
267		return nil, err
268	}
269	return NewClient(conn), nil
270}
271
272func (client *Client) Close() error {
273	client.mutex.Lock()
274	if client.shutdown || client.closing {
275		client.mutex.Unlock()
276		return ErrShutdown
277	}
278	client.closing = true
279	client.mutex.Unlock()
280	return client.codec.Close()
281}
282
283// Go invokes the function asynchronously.  It returns the Call structure representing
284// the invocation.  The done channel will signal when the call is complete by returning
285// the same Call object.  If done is nil, Go will allocate a new channel.
286// If non-nil, done must be buffered or Go will deliberately crash.
287func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
288	call := new(Call)
289	call.ServiceMethod = serviceMethod
290	call.Args = args
291	call.Reply = reply
292	if done == nil {
293		done = make(chan *Call, 10) // buffered.
294	} else {
295		// If caller passes done != nil, it must arrange that
296		// done has enough buffer for the number of simultaneous
297		// RPCs that will be using that channel.  If the channel
298		// is totally unbuffered, it's best not to run at all.
299		if cap(done) == 0 {
300			log.Panic("rpc: done channel is unbuffered")
301		}
302	}
303	call.Done = done
304	client.send(call)
305	return call
306}
307
308// Call invokes the named function, waits for it to complete, and returns its error status.
309func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
310	call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
311	return call.Error
312}
313