1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package rpc
6
7import (
8	"errors"
9	"fmt"
10	"io"
11	"log"
12	"net"
13	"net/http/httptest"
14	"runtime"
15	"strings"
16	"sync"
17	"sync/atomic"
18	"testing"
19	"time"
20)
21
22var (
23	newServer                 *Server
24	serverAddr, newServerAddr string
25	httpServerAddr            string
26	once, newOnce, httpOnce   sync.Once
27)
28
29const (
30	newHttpPath = "/foo"
31)
32
33type Args struct {
34	A, B int
35}
36
37type Reply struct {
38	C int
39}
40
41type Arith int
42
43// Some of Arith's methods have value args, some have pointer args. That's deliberate.
44
45func (t *Arith) Add(args Args, reply *Reply) error {
46	reply.C = args.A + args.B
47	return nil
48}
49
50func (t *Arith) Mul(args *Args, reply *Reply) error {
51	reply.C = args.A * args.B
52	return nil
53}
54
55func (t *Arith) Div(args Args, reply *Reply) error {
56	if args.B == 0 {
57		return errors.New("divide by zero")
58	}
59	reply.C = args.A / args.B
60	return nil
61}
62
63func (t *Arith) String(args *Args, reply *string) error {
64	*reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
65	return nil
66}
67
68func (t *Arith) Scan(args string, reply *Reply) (err error) {
69	_, err = fmt.Sscan(args, &reply.C)
70	return
71}
72
73func (t *Arith) Error(args *Args, reply *Reply) error {
74	panic("ERROR")
75}
76
77func listenTCP() (net.Listener, string) {
78	l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
79	if e != nil {
80		log.Fatalf("net.Listen tcp :0: %v", e)
81	}
82	return l, l.Addr().String()
83}
84
85func startServer() {
86	Register(new(Arith))
87	RegisterName("net.rpc.Arith", new(Arith))
88
89	var l net.Listener
90	l, serverAddr = listenTCP()
91	log.Println("Test RPC server listening on", serverAddr)
92	go Accept(l)
93
94	HandleHTTP()
95	httpOnce.Do(startHttpServer)
96}
97
98func startNewServer() {
99	newServer = NewServer()
100	newServer.Register(new(Arith))
101	newServer.RegisterName("net.rpc.Arith", new(Arith))
102	newServer.RegisterName("newServer.Arith", new(Arith))
103
104	var l net.Listener
105	l, newServerAddr = listenTCP()
106	log.Println("NewServer test RPC server listening on", newServerAddr)
107	go newServer.Accept(l)
108
109	newServer.HandleHTTP(newHttpPath, "/bar")
110	httpOnce.Do(startHttpServer)
111}
112
113func startHttpServer() {
114	server := httptest.NewServer(nil)
115	httpServerAddr = server.Listener.Addr().String()
116	log.Println("Test HTTP RPC server listening on", httpServerAddr)
117}
118
119func TestRPC(t *testing.T) {
120	once.Do(startServer)
121	testRPC(t, serverAddr)
122	newOnce.Do(startNewServer)
123	testRPC(t, newServerAddr)
124	testNewServerRPC(t, newServerAddr)
125}
126
127func testRPC(t *testing.T, addr string) {
128	client, err := Dial("tcp", addr)
129	if err != nil {
130		t.Fatal("dialing", err)
131	}
132	defer client.Close()
133
134	// Synchronous calls
135	args := &Args{7, 8}
136	reply := new(Reply)
137	err = client.Call("Arith.Add", args, reply)
138	if err != nil {
139		t.Errorf("Add: expected no error but got string %q", err.Error())
140	}
141	if reply.C != args.A+args.B {
142		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
143	}
144
145	// Nonexistent method
146	args = &Args{7, 0}
147	reply = new(Reply)
148	err = client.Call("Arith.BadOperation", args, reply)
149	// expect an error
150	if err == nil {
151		t.Error("BadOperation: expected error")
152	} else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
153		t.Errorf("BadOperation: expected can't find method error; got %q", err)
154	}
155
156	// Unknown service
157	args = &Args{7, 8}
158	reply = new(Reply)
159	err = client.Call("Arith.Unknown", args, reply)
160	if err == nil {
161		t.Error("expected error calling unknown service")
162	} else if strings.Index(err.Error(), "method") < 0 {
163		t.Error("expected error about method; got", err)
164	}
165
166	// Out of order.
167	args = &Args{7, 8}
168	mulReply := new(Reply)
169	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
170	addReply := new(Reply)
171	addCall := client.Go("Arith.Add", args, addReply, nil)
172
173	addCall = <-addCall.Done
174	if addCall.Error != nil {
175		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
176	}
177	if addReply.C != args.A+args.B {
178		t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
179	}
180
181	mulCall = <-mulCall.Done
182	if mulCall.Error != nil {
183		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
184	}
185	if mulReply.C != args.A*args.B {
186		t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
187	}
188
189	// Error test
190	args = &Args{7, 0}
191	reply = new(Reply)
192	err = client.Call("Arith.Div", args, reply)
193	// expect an error: zero divide
194	if err == nil {
195		t.Error("Div: expected error")
196	} else if err.Error() != "divide by zero" {
197		t.Error("Div: expected divide by zero error; got", err)
198	}
199
200	// Bad type.
201	reply = new(Reply)
202	err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
203	if err == nil {
204		t.Error("expected error calling Arith.Add with wrong arg type")
205	} else if strings.Index(err.Error(), "type") < 0 {
206		t.Error("expected error about type; got", err)
207	}
208
209	// Non-struct argument
210	const Val = 12345
211	str := fmt.Sprint(Val)
212	reply = new(Reply)
213	err = client.Call("Arith.Scan", &str, reply)
214	if err != nil {
215		t.Errorf("Scan: expected no error but got string %q", err.Error())
216	} else if reply.C != Val {
217		t.Errorf("Scan: expected %d got %d", Val, reply.C)
218	}
219
220	// Non-struct reply
221	args = &Args{27, 35}
222	str = ""
223	err = client.Call("Arith.String", args, &str)
224	if err != nil {
225		t.Errorf("String: expected no error but got string %q", err.Error())
226	}
227	expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
228	if str != expect {
229		t.Errorf("String: expected %s got %s", expect, str)
230	}
231
232	args = &Args{7, 8}
233	reply = new(Reply)
234	err = client.Call("Arith.Mul", args, reply)
235	if err != nil {
236		t.Errorf("Mul: expected no error but got string %q", err.Error())
237	}
238	if reply.C != args.A*args.B {
239		t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
240	}
241
242	// ServiceName contain "." character
243	args = &Args{7, 8}
244	reply = new(Reply)
245	err = client.Call("net.rpc.Arith.Add", args, reply)
246	if err != nil {
247		t.Errorf("Add: expected no error but got string %q", err.Error())
248	}
249	if reply.C != args.A+args.B {
250		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
251	}
252}
253
254func testNewServerRPC(t *testing.T, addr string) {
255	client, err := Dial("tcp", addr)
256	if err != nil {
257		t.Fatal("dialing", err)
258	}
259	defer client.Close()
260
261	// Synchronous calls
262	args := &Args{7, 8}
263	reply := new(Reply)
264	err = client.Call("newServer.Arith.Add", args, reply)
265	if err != nil {
266		t.Errorf("Add: expected no error but got string %q", err.Error())
267	}
268	if reply.C != args.A+args.B {
269		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
270	}
271}
272
273func TestHTTP(t *testing.T) {
274	once.Do(startServer)
275	testHTTPRPC(t, "")
276	newOnce.Do(startNewServer)
277	testHTTPRPC(t, newHttpPath)
278}
279
280func testHTTPRPC(t *testing.T, path string) {
281	var client *Client
282	var err error
283	if path == "" {
284		client, err = DialHTTP("tcp", httpServerAddr)
285	} else {
286		client, err = DialHTTPPath("tcp", httpServerAddr, path)
287	}
288	if err != nil {
289		t.Fatal("dialing", err)
290	}
291	defer client.Close()
292
293	// Synchronous calls
294	args := &Args{7, 8}
295	reply := new(Reply)
296	err = client.Call("Arith.Add", args, reply)
297	if err != nil {
298		t.Errorf("Add: expected no error but got string %q", err.Error())
299	}
300	if reply.C != args.A+args.B {
301		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
302	}
303}
304
305// CodecEmulator provides a client-like api and a ServerCodec interface.
306// Can be used to test ServeRequest.
307type CodecEmulator struct {
308	server        *Server
309	serviceMethod string
310	args          *Args
311	reply         *Reply
312	err           error
313}
314
315func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
316	codec.serviceMethod = serviceMethod
317	codec.args = args
318	codec.reply = reply
319	codec.err = nil
320	var serverError error
321	if codec.server == nil {
322		serverError = ServeRequest(codec)
323	} else {
324		serverError = codec.server.ServeRequest(codec)
325	}
326	if codec.err == nil && serverError != nil {
327		codec.err = serverError
328	}
329	return codec.err
330}
331
332func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
333	req.ServiceMethod = codec.serviceMethod
334	req.Seq = 0
335	return nil
336}
337
338func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
339	if codec.args == nil {
340		return io.ErrUnexpectedEOF
341	}
342	*(argv.(*Args)) = *codec.args
343	return nil
344}
345
346func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
347	if resp.Error != "" {
348		codec.err = errors.New(resp.Error)
349	} else {
350		*codec.reply = *(reply.(*Reply))
351	}
352	return nil
353}
354
355func (codec *CodecEmulator) Close() error {
356	return nil
357}
358
359func TestServeRequest(t *testing.T) {
360	once.Do(startServer)
361	testServeRequest(t, nil)
362	newOnce.Do(startNewServer)
363	testServeRequest(t, newServer)
364}
365
366func testServeRequest(t *testing.T, server *Server) {
367	client := CodecEmulator{server: server}
368	defer client.Close()
369
370	args := &Args{7, 8}
371	reply := new(Reply)
372	err := client.Call("Arith.Add", args, reply)
373	if err != nil {
374		t.Errorf("Add: expected no error but got string %q", err.Error())
375	}
376	if reply.C != args.A+args.B {
377		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
378	}
379
380	err = client.Call("Arith.Add", nil, reply)
381	if err == nil {
382		t.Errorf("expected error calling Arith.Add with nil arg")
383	}
384}
385
386type ReplyNotPointer int
387type ArgNotPublic int
388type ReplyNotPublic int
389type NeedsPtrType int
390type local struct{}
391
392func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
393	return nil
394}
395
396func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
397	return nil
398}
399
400func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
401	return nil
402}
403
404func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
405	return nil
406}
407
408// Check that registration handles lots of bad methods and a type with no suitable methods.
409func TestRegistrationError(t *testing.T) {
410	err := Register(new(ReplyNotPointer))
411	if err == nil {
412		t.Error("expected error registering ReplyNotPointer")
413	}
414	err = Register(new(ArgNotPublic))
415	if err == nil {
416		t.Error("expected error registering ArgNotPublic")
417	}
418	err = Register(new(ReplyNotPublic))
419	if err == nil {
420		t.Error("expected error registering ReplyNotPublic")
421	}
422	err = Register(NeedsPtrType(0))
423	if err == nil {
424		t.Error("expected error registering NeedsPtrType")
425	} else if !strings.Contains(err.Error(), "pointer") {
426		t.Error("expected hint when registering NeedsPtrType")
427	}
428}
429
430type WriteFailCodec int
431
432func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
433	// the panic caused by this error used to not unlock a lock.
434	return errors.New("fail")
435}
436
437func (WriteFailCodec) ReadResponseHeader(*Response) error {
438	select {}
439}
440
441func (WriteFailCodec) ReadResponseBody(interface{}) error {
442	select {}
443}
444
445func (WriteFailCodec) Close() error {
446	return nil
447}
448
449func TestSendDeadlock(t *testing.T) {
450	client := NewClientWithCodec(WriteFailCodec(0))
451	defer client.Close()
452
453	done := make(chan bool)
454	go func() {
455		testSendDeadlock(client)
456		testSendDeadlock(client)
457		done <- true
458	}()
459	select {
460	case <-done:
461		return
462	case <-time.After(5 * time.Second):
463		t.Fatal("deadlock")
464	}
465}
466
467func testSendDeadlock(client *Client) {
468	defer func() {
469		recover()
470	}()
471	args := &Args{7, 8}
472	reply := new(Reply)
473	client.Call("Arith.Add", args, reply)
474}
475
476func dialDirect() (*Client, error) {
477	return Dial("tcp", serverAddr)
478}
479
480func dialHTTP() (*Client, error) {
481	return DialHTTP("tcp", httpServerAddr)
482}
483
484func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
485	once.Do(startServer)
486	client, err := dial()
487	if err != nil {
488		t.Fatal("error dialing", err)
489	}
490	defer client.Close()
491
492	args := &Args{7, 8}
493	reply := new(Reply)
494	return testing.AllocsPerRun(100, func() {
495		err := client.Call("Arith.Add", args, reply)
496		if err != nil {
497			t.Errorf("Add: expected no error but got string %q", err.Error())
498		}
499		if reply.C != args.A+args.B {
500			t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
501		}
502	})
503}
504
505func TestCountMallocs(t *testing.T) {
506	if testing.Short() {
507		t.Skip("skipping malloc count in short mode")
508	}
509	if runtime.GOMAXPROCS(0) > 1 {
510		t.Skip("skipping; GOMAXPROCS>1")
511	}
512	fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
513}
514
515func TestCountMallocsOverHTTP(t *testing.T) {
516	if testing.Short() {
517		t.Skip("skipping malloc count in short mode")
518	}
519	if runtime.GOMAXPROCS(0) > 1 {
520		t.Skip("skipping; GOMAXPROCS>1")
521	}
522	fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
523}
524
525type writeCrasher struct {
526	done chan bool
527}
528
529func (writeCrasher) Close() error {
530	return nil
531}
532
533func (w *writeCrasher) Read(p []byte) (int, error) {
534	<-w.done
535	return 0, io.EOF
536}
537
538func (writeCrasher) Write(p []byte) (int, error) {
539	return 0, errors.New("fake write failure")
540}
541
542func TestClientWriteError(t *testing.T) {
543	w := &writeCrasher{done: make(chan bool)}
544	c := NewClient(w)
545	defer c.Close()
546
547	res := false
548	err := c.Call("foo", 1, &res)
549	if err == nil {
550		t.Fatal("expected error")
551	}
552	if err.Error() != "fake write failure" {
553		t.Error("unexpected value of error:", err)
554	}
555	w.done <- true
556}
557
558func TestTCPClose(t *testing.T) {
559	once.Do(startServer)
560
561	client, err := dialHTTP()
562	if err != nil {
563		t.Fatalf("dialing: %v", err)
564	}
565	defer client.Close()
566
567	args := Args{17, 8}
568	var reply Reply
569	err = client.Call("Arith.Mul", args, &reply)
570	if err != nil {
571		t.Fatal("arith error:", err)
572	}
573	t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
574	if reply.C != args.A*args.B {
575		t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
576	}
577}
578
579func TestErrorAfterClientClose(t *testing.T) {
580	once.Do(startServer)
581
582	client, err := dialHTTP()
583	if err != nil {
584		t.Fatalf("dialing: %v", err)
585	}
586	err = client.Close()
587	if err != nil {
588		t.Fatal("close error:", err)
589	}
590	err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
591	if err != ErrShutdown {
592		t.Errorf("Forever: expected ErrShutdown got %v", err)
593	}
594}
595
596func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
597	b.StopTimer()
598	once.Do(startServer)
599	client, err := dial()
600	if err != nil {
601		b.Fatal("error dialing:", err)
602	}
603	defer client.Close()
604
605	// Synchronous calls
606	args := &Args{7, 8}
607	procs := runtime.GOMAXPROCS(-1)
608	N := int32(b.N)
609	var wg sync.WaitGroup
610	wg.Add(procs)
611	b.StartTimer()
612
613	for p := 0; p < procs; p++ {
614		go func() {
615			reply := new(Reply)
616			for atomic.AddInt32(&N, -1) >= 0 {
617				err := client.Call("Arith.Add", args, reply)
618				if err != nil {
619					b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
620				}
621				if reply.C != args.A+args.B {
622					b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
623				}
624			}
625			wg.Done()
626		}()
627	}
628	wg.Wait()
629}
630
631func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
632	const MaxConcurrentCalls = 100
633	b.StopTimer()
634	once.Do(startServer)
635	client, err := dial()
636	if err != nil {
637		b.Fatal("error dialing:", err)
638	}
639	defer client.Close()
640
641	// Asynchronous calls
642	args := &Args{7, 8}
643	procs := 4 * runtime.GOMAXPROCS(-1)
644	send := int32(b.N)
645	recv := int32(b.N)
646	var wg sync.WaitGroup
647	wg.Add(procs)
648	gate := make(chan bool, MaxConcurrentCalls)
649	res := make(chan *Call, MaxConcurrentCalls)
650	b.StartTimer()
651
652	for p := 0; p < procs; p++ {
653		go func() {
654			for atomic.AddInt32(&send, -1) >= 0 {
655				gate <- true
656				reply := new(Reply)
657				client.Go("Arith.Add", args, reply, res)
658			}
659		}()
660		go func() {
661			for call := range res {
662				A := call.Args.(*Args).A
663				B := call.Args.(*Args).B
664				C := call.Reply.(*Reply).C
665				if A+B != C {
666					b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C)
667				}
668				<-gate
669				if atomic.AddInt32(&recv, -1) == 0 {
670					close(res)
671				}
672			}
673			wg.Done()
674		}()
675	}
676	wg.Wait()
677}
678
679func BenchmarkEndToEnd(b *testing.B) {
680	benchmarkEndToEnd(dialDirect, b)
681}
682
683func BenchmarkEndToEndHTTP(b *testing.B) {
684	benchmarkEndToEnd(dialHTTP, b)
685}
686
687func BenchmarkEndToEndAsync(b *testing.B) {
688	benchmarkEndToEndAsync(dialDirect, b)
689}
690
691func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
692	benchmarkEndToEndAsync(dialHTTP, b)
693}
694