1// Copyright 2009 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package rpc 6 7import ( 8 "errors" 9 "fmt" 10 "io" 11 "log" 12 "net" 13 "net/http/httptest" 14 "runtime" 15 "strings" 16 "sync" 17 "sync/atomic" 18 "testing" 19 "time" 20) 21 22var ( 23 newServer *Server 24 serverAddr, newServerAddr string 25 httpServerAddr string 26 once, newOnce, httpOnce sync.Once 27) 28 29const ( 30 newHttpPath = "/foo" 31) 32 33type Args struct { 34 A, B int 35} 36 37type Reply struct { 38 C int 39} 40 41type Arith int 42 43// Some of Arith's methods have value args, some have pointer args. That's deliberate. 44 45func (t *Arith) Add(args Args, reply *Reply) error { 46 reply.C = args.A + args.B 47 return nil 48} 49 50func (t *Arith) Mul(args *Args, reply *Reply) error { 51 reply.C = args.A * args.B 52 return nil 53} 54 55func (t *Arith) Div(args Args, reply *Reply) error { 56 if args.B == 0 { 57 return errors.New("divide by zero") 58 } 59 reply.C = args.A / args.B 60 return nil 61} 62 63func (t *Arith) String(args *Args, reply *string) error { 64 *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 65 return nil 66} 67 68func (t *Arith) Scan(args string, reply *Reply) (err error) { 69 _, err = fmt.Sscan(args, &reply.C) 70 return 71} 72 73func (t *Arith) Error(args *Args, reply *Reply) error { 74 panic("ERROR") 75} 76 77func listenTCP() (net.Listener, string) { 78 l, e := net.Listen("tcp", "127.0.0.1:0") // any available address 79 if e != nil { 80 log.Fatalf("net.Listen tcp :0: %v", e) 81 } 82 return l, l.Addr().String() 83} 84 85func startServer() { 86 Register(new(Arith)) 87 88 var l net.Listener 89 l, serverAddr = listenTCP() 90 log.Println("Test RPC server listening on", serverAddr) 91 go Accept(l) 92 93 HandleHTTP() 94 httpOnce.Do(startHttpServer) 95} 96 97func startNewServer() { 98 newServer = NewServer() 99 newServer.Register(new(Arith)) 100 101 var l net.Listener 102 l, newServerAddr = listenTCP() 103 log.Println("NewServer test RPC server listening on", newServerAddr) 104 go Accept(l) 105 106 newServer.HandleHTTP(newHttpPath, "/bar") 107 httpOnce.Do(startHttpServer) 108} 109 110func startHttpServer() { 111 server := httptest.NewServer(nil) 112 httpServerAddr = server.Listener.Addr().String() 113 log.Println("Test HTTP RPC server listening on", httpServerAddr) 114} 115 116func TestRPC(t *testing.T) { 117 once.Do(startServer) 118 testRPC(t, serverAddr) 119 newOnce.Do(startNewServer) 120 testRPC(t, newServerAddr) 121} 122 123func testRPC(t *testing.T, addr string) { 124 client, err := Dial("tcp", addr) 125 if err != nil { 126 t.Fatal("dialing", err) 127 } 128 129 // Synchronous calls 130 args := &Args{7, 8} 131 reply := new(Reply) 132 err = client.Call("Arith.Add", args, reply) 133 if err != nil { 134 t.Errorf("Add: expected no error but got string %q", err.Error()) 135 } 136 if reply.C != args.A+args.B { 137 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 138 } 139 140 // Nonexistent method 141 args = &Args{7, 0} 142 reply = new(Reply) 143 err = client.Call("Arith.BadOperation", args, reply) 144 // expect an error 145 if err == nil { 146 t.Error("BadOperation: expected error") 147 } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") { 148 t.Errorf("BadOperation: expected can't find method error; got %q", err) 149 } 150 151 // Unknown service 152 args = &Args{7, 8} 153 reply = new(Reply) 154 err = client.Call("Arith.Unknown", args, reply) 155 if err == nil { 156 t.Error("expected error calling unknown service") 157 } else if strings.Index(err.Error(), "method") < 0 { 158 t.Error("expected error about method; got", err) 159 } 160 161 // Out of order. 162 args = &Args{7, 8} 163 mulReply := new(Reply) 164 mulCall := client.Go("Arith.Mul", args, mulReply, nil) 165 addReply := new(Reply) 166 addCall := client.Go("Arith.Add", args, addReply, nil) 167 168 addCall = <-addCall.Done 169 if addCall.Error != nil { 170 t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) 171 } 172 if addReply.C != args.A+args.B { 173 t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) 174 } 175 176 mulCall = <-mulCall.Done 177 if mulCall.Error != nil { 178 t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) 179 } 180 if mulReply.C != args.A*args.B { 181 t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) 182 } 183 184 // Error test 185 args = &Args{7, 0} 186 reply = new(Reply) 187 err = client.Call("Arith.Div", args, reply) 188 // expect an error: zero divide 189 if err == nil { 190 t.Error("Div: expected error") 191 } else if err.Error() != "divide by zero" { 192 t.Error("Div: expected divide by zero error; got", err) 193 } 194 195 // Bad type. 196 reply = new(Reply) 197 err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use 198 if err == nil { 199 t.Error("expected error calling Arith.Add with wrong arg type") 200 } else if strings.Index(err.Error(), "type") < 0 { 201 t.Error("expected error about type; got", err) 202 } 203 204 // Non-struct argument 205 const Val = 12345 206 str := fmt.Sprint(Val) 207 reply = new(Reply) 208 err = client.Call("Arith.Scan", &str, reply) 209 if err != nil { 210 t.Errorf("Scan: expected no error but got string %q", err.Error()) 211 } else if reply.C != Val { 212 t.Errorf("Scan: expected %d got %d", Val, reply.C) 213 } 214 215 // Non-struct reply 216 args = &Args{27, 35} 217 str = "" 218 err = client.Call("Arith.String", args, &str) 219 if err != nil { 220 t.Errorf("String: expected no error but got string %q", err.Error()) 221 } 222 expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 223 if str != expect { 224 t.Errorf("String: expected %s got %s", expect, str) 225 } 226 227 args = &Args{7, 8} 228 reply = new(Reply) 229 err = client.Call("Arith.Mul", args, reply) 230 if err != nil { 231 t.Errorf("Mul: expected no error but got string %q", err.Error()) 232 } 233 if reply.C != args.A*args.B { 234 t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) 235 } 236} 237 238func TestHTTP(t *testing.T) { 239 once.Do(startServer) 240 testHTTPRPC(t, "") 241 newOnce.Do(startNewServer) 242 testHTTPRPC(t, newHttpPath) 243} 244 245func testHTTPRPC(t *testing.T, path string) { 246 var client *Client 247 var err error 248 if path == "" { 249 client, err = DialHTTP("tcp", httpServerAddr) 250 } else { 251 client, err = DialHTTPPath("tcp", httpServerAddr, path) 252 } 253 if err != nil { 254 t.Fatal("dialing", err) 255 } 256 257 // Synchronous calls 258 args := &Args{7, 8} 259 reply := new(Reply) 260 err = client.Call("Arith.Add", args, reply) 261 if err != nil { 262 t.Errorf("Add: expected no error but got string %q", err.Error()) 263 } 264 if reply.C != args.A+args.B { 265 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 266 } 267} 268 269// CodecEmulator provides a client-like api and a ServerCodec interface. 270// Can be used to test ServeRequest. 271type CodecEmulator struct { 272 server *Server 273 serviceMethod string 274 args *Args 275 reply *Reply 276 err error 277} 278 279func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error { 280 codec.serviceMethod = serviceMethod 281 codec.args = args 282 codec.reply = reply 283 codec.err = nil 284 var serverError error 285 if codec.server == nil { 286 serverError = ServeRequest(codec) 287 } else { 288 serverError = codec.server.ServeRequest(codec) 289 } 290 if codec.err == nil && serverError != nil { 291 codec.err = serverError 292 } 293 return codec.err 294} 295 296func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { 297 req.ServiceMethod = codec.serviceMethod 298 req.Seq = 0 299 return nil 300} 301 302func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { 303 if codec.args == nil { 304 return io.ErrUnexpectedEOF 305 } 306 *(argv.(*Args)) = *codec.args 307 return nil 308} 309 310func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error { 311 if resp.Error != "" { 312 codec.err = errors.New(resp.Error) 313 } else { 314 *codec.reply = *(reply.(*Reply)) 315 } 316 return nil 317} 318 319func (codec *CodecEmulator) Close() error { 320 return nil 321} 322 323func TestServeRequest(t *testing.T) { 324 once.Do(startServer) 325 testServeRequest(t, nil) 326 newOnce.Do(startNewServer) 327 testServeRequest(t, newServer) 328} 329 330func testServeRequest(t *testing.T, server *Server) { 331 client := CodecEmulator{server: server} 332 333 args := &Args{7, 8} 334 reply := new(Reply) 335 err := client.Call("Arith.Add", args, reply) 336 if err != nil { 337 t.Errorf("Add: expected no error but got string %q", err.Error()) 338 } 339 if reply.C != args.A+args.B { 340 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 341 } 342 343 err = client.Call("Arith.Add", nil, reply) 344 if err == nil { 345 t.Errorf("expected error calling Arith.Add with nil arg") 346 } 347} 348 349type ReplyNotPointer int 350type ArgNotPublic int 351type ReplyNotPublic int 352type NeedsPtrType int 353type local struct{} 354 355func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { 356 return nil 357} 358 359func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error { 360 return nil 361} 362 363func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { 364 return nil 365} 366 367func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error { 368 return nil 369} 370 371// Check that registration handles lots of bad methods and a type with no suitable methods. 372func TestRegistrationError(t *testing.T) { 373 err := Register(new(ReplyNotPointer)) 374 if err == nil { 375 t.Error("expected error registering ReplyNotPointer") 376 } 377 err = Register(new(ArgNotPublic)) 378 if err == nil { 379 t.Error("expected error registering ArgNotPublic") 380 } 381 err = Register(new(ReplyNotPublic)) 382 if err == nil { 383 t.Error("expected error registering ReplyNotPublic") 384 } 385 err = Register(NeedsPtrType(0)) 386 if err == nil { 387 t.Error("expected error registering NeedsPtrType") 388 } else if !strings.Contains(err.Error(), "pointer") { 389 t.Error("expected hint when registering NeedsPtrType") 390 } 391} 392 393type WriteFailCodec int 394 395func (WriteFailCodec) WriteRequest(*Request, interface{}) error { 396 // the panic caused by this error used to not unlock a lock. 397 return errors.New("fail") 398} 399 400func (WriteFailCodec) ReadResponseHeader(*Response) error { 401 select {} 402 panic("unreachable") 403} 404 405func (WriteFailCodec) ReadResponseBody(interface{}) error { 406 select {} 407 panic("unreachable") 408} 409 410func (WriteFailCodec) Close() error { 411 return nil 412} 413 414func TestSendDeadlock(t *testing.T) { 415 client := NewClientWithCodec(WriteFailCodec(0)) 416 417 done := make(chan bool) 418 go func() { 419 testSendDeadlock(client) 420 testSendDeadlock(client) 421 done <- true 422 }() 423 select { 424 case <-done: 425 return 426 case <-time.After(5 * time.Second): 427 t.Fatal("deadlock") 428 } 429} 430 431func testSendDeadlock(client *Client) { 432 defer func() { 433 recover() 434 }() 435 args := &Args{7, 8} 436 reply := new(Reply) 437 client.Call("Arith.Add", args, reply) 438} 439 440func dialDirect() (*Client, error) { 441 return Dial("tcp", serverAddr) 442} 443 444func dialHTTP() (*Client, error) { 445 return DialHTTP("tcp", httpServerAddr) 446} 447 448func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { 449 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) 450 once.Do(startServer) 451 client, err := dial() 452 if err != nil { 453 t.Fatal("error dialing", err) 454 } 455 args := &Args{7, 8} 456 reply := new(Reply) 457 memstats := new(runtime.MemStats) 458 runtime.ReadMemStats(memstats) 459 mallocs := 0 - memstats.Mallocs 460 const count = 100 461 for i := 0; i < count; i++ { 462 err := client.Call("Arith.Add", args, reply) 463 if err != nil { 464 t.Errorf("Add: expected no error but got string %q", err.Error()) 465 } 466 if reply.C != args.A+args.B { 467 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 468 } 469 } 470 runtime.ReadMemStats(memstats) 471 mallocs += memstats.Mallocs 472 return mallocs / count 473} 474 475func TestCountMallocs(t *testing.T) { 476 fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t)) 477} 478 479func TestCountMallocsOverHTTP(t *testing.T) { 480 fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) 481} 482 483type writeCrasher struct { 484 done chan bool 485} 486 487func (writeCrasher) Close() error { 488 return nil 489} 490 491func (w *writeCrasher) Read(p []byte) (int, error) { 492 <-w.done 493 return 0, io.EOF 494} 495 496func (writeCrasher) Write(p []byte) (int, error) { 497 return 0, errors.New("fake write failure") 498} 499 500func TestClientWriteError(t *testing.T) { 501 w := &writeCrasher{done: make(chan bool)} 502 c := NewClient(w) 503 res := false 504 err := c.Call("foo", 1, &res) 505 if err == nil { 506 t.Fatal("expected error") 507 } 508 if err.Error() != "fake write failure" { 509 t.Error("unexpected value of error:", err) 510 } 511 w.done <- true 512} 513 514func TestTCPClose(t *testing.T) { 515 once.Do(startServer) 516 517 client, err := dialHTTP() 518 if err != nil { 519 t.Fatalf("dialing: %v", err) 520 } 521 defer client.Close() 522 523 args := Args{17, 8} 524 var reply Reply 525 err = client.Call("Arith.Mul", args, &reply) 526 if err != nil { 527 t.Fatal("arith error:", err) 528 } 529 t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) 530 if reply.C != args.A*args.B { 531 t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) 532 } 533} 534 535func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { 536 b.StopTimer() 537 once.Do(startServer) 538 client, err := dial() 539 if err != nil { 540 b.Fatal("error dialing:", err) 541 } 542 543 // Synchronous calls 544 args := &Args{7, 8} 545 procs := runtime.GOMAXPROCS(-1) 546 N := int32(b.N) 547 var wg sync.WaitGroup 548 wg.Add(procs) 549 b.StartTimer() 550 551 for p := 0; p < procs; p++ { 552 go func() { 553 reply := new(Reply) 554 for atomic.AddInt32(&N, -1) >= 0 { 555 err := client.Call("Arith.Add", args, reply) 556 if err != nil { 557 b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) 558 } 559 if reply.C != args.A+args.B { 560 b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) 561 } 562 } 563 wg.Done() 564 }() 565 } 566 wg.Wait() 567} 568 569func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { 570 const MaxConcurrentCalls = 100 571 b.StopTimer() 572 once.Do(startServer) 573 client, err := dial() 574 if err != nil { 575 b.Fatal("error dialing:", err) 576 } 577 578 // Asynchronous calls 579 args := &Args{7, 8} 580 procs := 4 * runtime.GOMAXPROCS(-1) 581 send := int32(b.N) 582 recv := int32(b.N) 583 var wg sync.WaitGroup 584 wg.Add(procs) 585 gate := make(chan bool, MaxConcurrentCalls) 586 res := make(chan *Call, MaxConcurrentCalls) 587 b.StartTimer() 588 589 for p := 0; p < procs; p++ { 590 go func() { 591 for atomic.AddInt32(&send, -1) >= 0 { 592 gate <- true 593 reply := new(Reply) 594 client.Go("Arith.Add", args, reply, res) 595 } 596 }() 597 go func() { 598 for call := range res { 599 A := call.Args.(*Args).A 600 B := call.Args.(*Args).B 601 C := call.Reply.(*Reply).C 602 if A+B != C { 603 b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C) 604 } 605 <-gate 606 if atomic.AddInt32(&recv, -1) == 0 { 607 close(res) 608 } 609 } 610 wg.Done() 611 }() 612 } 613 wg.Wait() 614} 615 616func BenchmarkEndToEnd(b *testing.B) { 617 benchmarkEndToEnd(dialDirect, b) 618} 619 620func BenchmarkEndToEndHTTP(b *testing.B) { 621 benchmarkEndToEnd(dialHTTP, b) 622} 623 624func BenchmarkEndToEndAsync(b *testing.B) { 625 benchmarkEndToEndAsync(dialDirect, b) 626} 627 628func BenchmarkEndToEndAsyncHTTP(b *testing.B) { 629 benchmarkEndToEndAsync(dialHTTP, b) 630} 631