1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package rpcreplay 16 17import ( 18 "bufio" 19 "context" 20 "encoding/binary" 21 "errors" 22 "fmt" 23 "io" 24 "log" 25 "net" 26 "os" 27 "sync" 28 29 pb "cloud.google.com/go/rpcreplay/proto/rpcreplay" 30 "github.com/golang/protobuf/proto" 31 "github.com/golang/protobuf/ptypes" 32 "github.com/golang/protobuf/ptypes/any" 33 spb "google.golang.org/genproto/googleapis/rpc/status" 34 "google.golang.org/grpc" 35 "google.golang.org/grpc/metadata" 36 "google.golang.org/grpc/status" 37) 38 39// A Recorder records RPCs for later playback. 40type Recorder struct { 41 mu sync.Mutex 42 w *bufio.Writer 43 f *os.File 44 next int 45 err error 46 // BeforeFunc defines a function that can inspect and modify requests and responses 47 // written to the replay file. It does not modify messages sent to the service. 48 // It is run once before a request is written to the replay file, and once before a response 49 // is written to the replay file. 50 // The function is called with the method name and the message that triggered the callback. 51 // If the function returns an error, the error will be returned to the client. 52 // This is only executed for unary RPCs; streaming RPCs are not supported. 53 BeforeFunc func(string, proto.Message) error 54} 55 56// NewRecorder creates a recorder that writes to filename. The file will 57// also store the initial bytes for retrieval during replay. 58// 59// You must call Close on the Recorder to ensure that all data is written. 60func NewRecorder(filename string, initial []byte) (*Recorder, error) { 61 f, err := os.Create(filename) 62 if err != nil { 63 return nil, err 64 } 65 rec, err := NewRecorderWriter(f, initial) 66 if err != nil { 67 _ = f.Close() 68 return nil, err 69 } 70 rec.f = f 71 return rec, nil 72} 73 74// NewRecorderWriter creates a recorder that writes to w. The initial 75// bytes will also be written to w for retrieval during replay. 76// 77// You must call Close on the Recorder to ensure that all data is written. 78func NewRecorderWriter(w io.Writer, initial []byte) (*Recorder, error) { 79 bw := bufio.NewWriter(w) 80 if err := writeHeader(bw, initial); err != nil { 81 return nil, err 82 } 83 return &Recorder{w: bw, next: 1}, nil 84} 85 86// DialOptions returns the options that must be passed to grpc.Dial 87// to enable recording. 88func (r *Recorder) DialOptions() []grpc.DialOption { 89 return []grpc.DialOption{ 90 grpc.WithUnaryInterceptor(r.interceptUnary), 91 grpc.WithStreamInterceptor(r.interceptStream), 92 } 93} 94 95// Close saves any unwritten information. 96func (r *Recorder) Close() error { 97 r.mu.Lock() 98 defer r.mu.Unlock() 99 if r.err != nil { 100 return r.err 101 } 102 err := r.w.Flush() 103 if r.f != nil { 104 if err2 := r.f.Close(); err == nil { 105 err = err2 106 } 107 } 108 return err 109} 110 111// Intercepts all unary (non-stream) RPCs. 112func (r *Recorder) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 113 ereq := &entry{ 114 kind: pb.Entry_REQUEST, 115 method: method, 116 msg: message{msg: proto.Clone(req.(proto.Message))}, 117 } 118 119 if r.BeforeFunc != nil { 120 if err := r.BeforeFunc(method, ereq.msg.msg); err != nil { 121 return err 122 } 123 } 124 refIndex, err := r.writeEntry(ereq) 125 if err != nil { 126 return err 127 } 128 ierr := invoker(ctx, method, req, res, cc, opts...) 129 eres := &entry{ 130 kind: pb.Entry_RESPONSE, 131 refIndex: refIndex, 132 } 133 // If the error is not a gRPC status, then something more 134 // serious is wrong. More significantly, we have no way 135 // of serializing an arbitrary error. So just return it 136 // without recording the response. 137 if _, ok := status.FromError(ierr); !ok { 138 r.mu.Lock() 139 r.err = fmt.Errorf("saw non-status error in %s response: %v (%T)", method, ierr, ierr) 140 r.mu.Unlock() 141 return ierr 142 } 143 eres.msg.set(proto.Clone(res.(proto.Message)), ierr) 144 if r.BeforeFunc != nil { 145 if err := r.BeforeFunc(method, eres.msg.msg); err != nil { 146 return err 147 } 148 } 149 if _, err := r.writeEntry(eres); err != nil { 150 return err 151 } 152 return ierr 153} 154 155func (r *Recorder) writeEntry(e *entry) (int, error) { 156 r.mu.Lock() 157 defer r.mu.Unlock() 158 if r.err != nil { 159 return 0, r.err 160 } 161 err := writeEntry(r.w, e) 162 if err != nil { 163 r.err = err 164 return 0, err 165 } 166 n := r.next 167 r.next++ 168 return n, nil 169} 170 171func (r *Recorder) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 172 cstream, serr := streamer(ctx, desc, cc, method, opts...) 173 e := &entry{ 174 kind: pb.Entry_CREATE_STREAM, 175 method: method, 176 } 177 e.msg.set(nil, serr) 178 refIndex, err := r.writeEntry(e) 179 if err != nil { 180 return nil, err 181 } 182 return &recClientStream{ 183 ctx: ctx, 184 rec: r, 185 cstream: cstream, 186 refIndex: refIndex, 187 }, serr 188} 189 190// A recClientStream implements the gprc.ClientStream interface. 191// It behaves exactly like the default ClientStream, but also 192// records all messages sent and received. 193type recClientStream struct { 194 ctx context.Context 195 rec *Recorder 196 cstream grpc.ClientStream 197 refIndex int 198} 199 200func (rcs *recClientStream) Context() context.Context { return rcs.ctx } 201 202func (rcs *recClientStream) SendMsg(m interface{}) error { 203 serr := rcs.cstream.SendMsg(m) 204 e := &entry{ 205 kind: pb.Entry_SEND, 206 refIndex: rcs.refIndex, 207 } 208 e.msg.set(m, serr) 209 if _, err := rcs.rec.writeEntry(e); err != nil { 210 return err 211 } 212 return serr 213} 214 215func (rcs *recClientStream) RecvMsg(m interface{}) error { 216 serr := rcs.cstream.RecvMsg(m) 217 e := &entry{ 218 kind: pb.Entry_RECV, 219 refIndex: rcs.refIndex, 220 } 221 e.msg.set(m, serr) 222 if _, err := rcs.rec.writeEntry(e); err != nil { 223 return err 224 } 225 return serr 226} 227 228func (rcs *recClientStream) Header() (metadata.MD, error) { 229 // TODO(jba): record. 230 return rcs.cstream.Header() 231} 232 233func (rcs *recClientStream) Trailer() metadata.MD { 234 // TODO(jba): record. 235 return rcs.cstream.Trailer() 236} 237 238func (rcs *recClientStream) CloseSend() error { 239 // TODO(jba): record. 240 return rcs.cstream.CloseSend() 241} 242 243// A Replayer replays a set of RPCs saved by a Recorder. 244type Replayer struct { 245 initial []byte // initial state 246 log func(format string, v ...interface{}) // for debugging 247 248 mu sync.Mutex 249 calls []*call 250 streams []*stream 251 // BeforeFunc defines a function that can inspect and modify requests before they 252 // are matched for responses from the replay file. 253 // The function is called with the method name and the message that triggered the callback. 254 // If the function returns an error, the error will be returned to the client. 255 // This is only executed for unary RPCs; streaming RPCs are not supported. 256 BeforeFunc func(string, proto.Message) error 257} 258 259// A call represents a unary RPC, with a request and response (or error). 260type call struct { 261 method string 262 request proto.Message 263 response message 264} 265 266// A stream represents a gRPC stream, with an initial create-stream call, followed by 267// zero or more sends and/or receives. 268type stream struct { 269 method string 270 createIndex int 271 createErr error // error from create call 272 sends []message 273 recvs []message 274} 275 276// NewReplayer creates a Replayer that reads from filename. 277func NewReplayer(filename string) (*Replayer, error) { 278 f, err := os.Open(filename) 279 if err != nil { 280 return nil, err 281 } 282 defer f.Close() 283 return NewReplayerReader(f) 284} 285 286// NewReplayerReader creates a Replayer that reads from r. 287func NewReplayerReader(r io.Reader) (*Replayer, error) { 288 rep := &Replayer{ 289 log: func(string, ...interface{}) {}, 290 } 291 if err := rep.read(r); err != nil { 292 return nil, err 293 } 294 return rep, nil 295} 296 297// read reads the stream of recorded entries. 298// It matches requests with responses, with each pair grouped 299// into a call struct. 300func (rep *Replayer) read(r io.Reader) error { 301 r = bufio.NewReader(r) 302 bytes, err := readHeader(r) 303 if err != nil { 304 return err 305 } 306 rep.initial = bytes 307 308 callsByIndex := map[int]*call{} 309 streamsByIndex := map[int]*stream{} 310 for i := 1; ; i++ { 311 e, err := readEntry(r) 312 if err != nil { 313 return err 314 } 315 if e == nil { 316 break 317 } 318 switch e.kind { 319 case pb.Entry_REQUEST: 320 callsByIndex[i] = &call{ 321 method: e.method, 322 request: e.msg.msg, 323 } 324 325 case pb.Entry_RESPONSE: 326 call := callsByIndex[e.refIndex] 327 if call == nil { 328 return fmt.Errorf("replayer: no request for response #%d", i) 329 } 330 delete(callsByIndex, e.refIndex) 331 call.response = e.msg 332 rep.calls = append(rep.calls, call) 333 334 case pb.Entry_CREATE_STREAM: 335 s := &stream{method: e.method, createIndex: i} 336 s.createErr = e.msg.err 337 streamsByIndex[i] = s 338 rep.streams = append(rep.streams, s) 339 340 case pb.Entry_SEND: 341 s := streamsByIndex[e.refIndex] 342 if s == nil { 343 return fmt.Errorf("replayer: no stream for send #%d", i) 344 } 345 s.sends = append(s.sends, e.msg) 346 347 case pb.Entry_RECV: 348 s := streamsByIndex[e.refIndex] 349 if s == nil { 350 return fmt.Errorf("replayer: no stream for recv #%d", i) 351 } 352 s.recvs = append(s.recvs, e.msg) 353 354 default: 355 return fmt.Errorf("replayer: unknown kind %s", e.kind) 356 } 357 } 358 if len(callsByIndex) > 0 { 359 return fmt.Errorf("replayer: %d unmatched requests", len(callsByIndex)) 360 } 361 return nil 362} 363 364// DialOptions returns the options that must be passed to grpc.Dial 365// to enable replaying. 366func (rep *Replayer) DialOptions() []grpc.DialOption { 367 return []grpc.DialOption{ 368 // On replay, we make no RPCs, which means the connection may be closed 369 // before the normally async Dial completes. Making the Dial synchronous 370 // fixes that. 371 grpc.WithBlock(), 372 grpc.WithUnaryInterceptor(rep.interceptUnary), 373 grpc.WithStreamInterceptor(rep.interceptStream), 374 } 375} 376 377// Connection returns a fake gRPC connection suitable for replaying. 378func (rep *Replayer) Connection() (*grpc.ClientConn, error) { 379 // We don't need an actual connection, not even a loopback one. 380 // But we do need something to attach gRPC interceptors to. 381 // So we start a local server and connect to it, then close it down. 382 srv := grpc.NewServer() 383 l, err := net.Listen("tcp", "localhost:0") 384 if err != nil { 385 return nil, err 386 } 387 go func() { 388 if err := srv.Serve(l); err != nil { 389 panic(err) // we should never get an error because we just connect and stop 390 } 391 }() 392 conn, err := grpc.Dial(l.Addr().String(), 393 append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...) 394 if err != nil { 395 return nil, err 396 } 397 conn.Close() 398 srv.Stop() 399 return conn, nil 400} 401 402// Initial returns the initial state saved by the Recorder. 403func (rep *Replayer) Initial() []byte { return rep.initial } 404 405// SetLogFunc sets a function to be used for debug logging. The function 406// should be safe to be called from multiple goroutines. 407func (rep *Replayer) SetLogFunc(f func(format string, v ...interface{})) { 408 rep.log = f 409} 410 411// Close closes the Replayer. 412func (rep *Replayer) Close() error { 413 return nil 414} 415 416func (rep *Replayer) interceptUnary(_ context.Context, method string, req, res interface{}, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error { 417 mreq := req.(proto.Message) 418 if rep.BeforeFunc != nil { 419 if err := rep.BeforeFunc(method, mreq); err != nil { 420 return err 421 } 422 } 423 rep.log("request %s (%s)", method, req) 424 call := rep.extractCall(method, mreq) 425 if call == nil { 426 return fmt.Errorf("replayer: request not found: %s", mreq) 427 } 428 rep.log("returning %v", call.response) 429 if call.response.err != nil { 430 return call.response.err 431 } 432 proto.Merge(res.(proto.Message), call.response.msg) // copy msg into res 433 return nil 434} 435 436func (rep *Replayer) interceptStream(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, method string, _ grpc.Streamer, _ ...grpc.CallOption) (grpc.ClientStream, error) { 437 rep.log("create-stream %s", method) 438 return &repClientStream{ctx: ctx, rep: rep, method: method}, nil 439} 440 441type repClientStream struct { 442 ctx context.Context 443 rep *Replayer 444 method string 445 str *stream 446} 447 448func (rcs *repClientStream) Context() context.Context { return rcs.ctx } 449 450func (rcs *repClientStream) SendMsg(req interface{}) error { 451 if rcs.str == nil { 452 if err := rcs.setStream(rcs.method, req.(proto.Message)); err != nil { 453 return err 454 } 455 } 456 if len(rcs.str.sends) == 0 { 457 return fmt.Errorf("replayer: no more sends for stream %s, created at index %d", 458 rcs.str.method, rcs.str.createIndex) 459 } 460 // TODO(jba): Do not assume that the sends happen in the same order on replay. 461 msg := rcs.str.sends[0] 462 rcs.str.sends = rcs.str.sends[1:] 463 return msg.err 464} 465 466func (rcs *repClientStream) setStream(method string, req proto.Message) error { 467 str := rcs.rep.extractStream(method, req) 468 if str == nil { 469 return fmt.Errorf("replayer: stream not found for method %s and request %v", method, req) 470 } 471 if str.createErr != nil { 472 return str.createErr 473 } 474 rcs.str = str 475 return nil 476} 477 478func (rcs *repClientStream) RecvMsg(m interface{}) error { 479 if rcs.str == nil { 480 // Receive before send; fall back to matching stream by method only. 481 if err := rcs.setStream(rcs.method, nil); err != nil { 482 return err 483 } 484 } 485 if len(rcs.str.recvs) == 0 { 486 return fmt.Errorf("replayer: no more receives for stream %s, created at index %d", 487 rcs.str.method, rcs.str.createIndex) 488 } 489 msg := rcs.str.recvs[0] 490 rcs.str.recvs = rcs.str.recvs[1:] 491 if msg.err != nil { 492 return msg.err 493 } 494 proto.Merge(m.(proto.Message), msg.msg) // copy msg into m 495 return nil 496} 497 498func (rcs *repClientStream) Header() (metadata.MD, error) { 499 log.Printf("replay: stream metadata not supported") 500 return nil, nil 501} 502 503func (rcs *repClientStream) Trailer() metadata.MD { 504 log.Printf("replay: stream metadata not supported") 505 return nil 506} 507 508func (rcs *repClientStream) CloseSend() error { 509 return nil 510} 511 512// extractCall finds the first call in the list with the same method 513// and request. It returns nil if it can't find such a call. 514func (rep *Replayer) extractCall(method string, req proto.Message) *call { 515 rep.mu.Lock() 516 defer rep.mu.Unlock() 517 for i, call := range rep.calls { 518 if call == nil { 519 continue 520 } 521 if method == call.method && proto.Equal(req, call.request) { 522 rep.calls[i] = nil // nil out this call so we don't reuse it 523 return call 524 } 525 } 526 return nil 527} 528 529// extractStream find the first stream in the list with the same method and the same 530// first request sent. If req is nil, that means a receive occurred before a send, so 531// it matches only on method. 532func (rep *Replayer) extractStream(method string, req proto.Message) *stream { 533 rep.mu.Lock() 534 defer rep.mu.Unlock() 535 for i, stream := range rep.streams { 536 // Skip stream if it is nil (already extracted) or its method doesn't match. 537 if stream == nil || stream.method != method { 538 continue 539 } 540 // If there is a first request, skip stream if it has no requests or its first 541 // request doesn't match. 542 if req != nil && len(stream.sends) > 0 && !proto.Equal(req, stream.sends[0].msg) { 543 continue 544 } 545 rep.streams[i] = nil // nil out this stream so we don't reuse it 546 return stream 547 } 548 return nil 549} 550 551// Fprint reads the entries from filename and writes them to w in human-readable form. 552// It is intended for debugging. 553func Fprint(w io.Writer, filename string) error { 554 f, err := os.Open(filename) 555 if err != nil { 556 return err 557 } 558 defer f.Close() 559 return FprintReader(w, f) 560} 561 562// FprintReader reads the entries from r and writes them to w in human-readable form. 563// It is intended for debugging. 564func FprintReader(w io.Writer, r io.Reader) error { 565 initial, err := readHeader(r) 566 if err != nil { 567 return err 568 } 569 fmt.Fprintf(w, "initial state: %q\n", string(initial)) 570 for i := 1; ; i++ { 571 e, err := readEntry(r) 572 if err != nil { 573 return err 574 } 575 if e == nil { 576 return nil 577 } 578 579 fmt.Fprintf(w, "#%d: kind: %s, method: %s, ref index: %d", i, e.kind, e.method, e.refIndex) 580 switch { 581 case e.msg.msg != nil: 582 fmt.Fprintf(w, ", message:\n") 583 if err := proto.MarshalText(w, e.msg.msg); err != nil { 584 return err 585 } 586 case e.msg.err != nil: 587 fmt.Fprintf(w, ", error: %v\n", e.msg.err) 588 default: 589 fmt.Fprintln(w) 590 } 591 } 592} 593 594// An entry holds one gRPC action (request, response, etc.). 595type entry struct { 596 kind pb.Entry_Kind 597 method string 598 msg message 599 refIndex int // index of corresponding request or create-stream 600} 601 602func (e1 *entry) equal(e2 *entry) bool { 603 if e1 == nil && e2 == nil { 604 return true 605 } 606 if e1 == nil || e2 == nil { 607 return false 608 } 609 return e1.kind == e2.kind && 610 e1.method == e2.method && 611 proto.Equal(e1.msg.msg, e2.msg.msg) && 612 errEqual(e1.msg.err, e2.msg.err) && 613 e1.refIndex == e2.refIndex 614} 615 616func errEqual(e1, e2 error) bool { 617 if e1 == e2 { 618 return true 619 } 620 s1, ok1 := status.FromError(e1) 621 s2, ok2 := status.FromError(e2) 622 if !ok1 || !ok2 { 623 return false 624 } 625 return proto.Equal(s1.Proto(), s2.Proto()) 626} 627 628// message holds either a single proto.Message or an error. 629type message struct { 630 msg proto.Message 631 err error 632} 633 634func (m *message) set(msg interface{}, err error) { 635 m.err = err 636 if err != io.EOF && msg != nil { 637 m.msg = msg.(proto.Message) 638 } 639} 640 641// File format: 642// header 643// sequence of Entry protos 644// 645// Header format: 646// magic string 647// a record containing the bytes of the initial state 648 649const magic = "RPCReplay" 650 651func writeHeader(w io.Writer, initial []byte) error { 652 if _, err := io.WriteString(w, magic); err != nil { 653 return err 654 } 655 return writeRecord(w, initial) 656} 657 658func readHeader(r io.Reader) ([]byte, error) { 659 var buf [len(magic)]byte 660 if _, err := io.ReadFull(r, buf[:]); err != nil { 661 if err == io.EOF { 662 err = errors.New("rpcreplay: empty replay file") 663 } 664 return nil, err 665 } 666 if string(buf[:]) != magic { 667 return nil, errors.New("rpcreplay: not a replay file (does not begin with magic string)") 668 } 669 bytes, err := readRecord(r) 670 if err == io.EOF { 671 err = errors.New("rpcreplay: missing initial state") 672 } 673 return bytes, err 674} 675 676func writeEntry(w io.Writer, e *entry) error { 677 var m proto.Message 678 if e.msg.err != nil && e.msg.err != io.EOF { 679 s, ok := status.FromError(e.msg.err) 680 if !ok { 681 return fmt.Errorf("rpcreplay: error %v is not a Status", e.msg.err) 682 } 683 m = s.Proto() 684 } else { 685 m = e.msg.msg 686 } 687 var a *any.Any 688 var err error 689 if m != nil { 690 a, err = ptypes.MarshalAny(m) 691 if err != nil { 692 return err 693 } 694 } 695 pe := &pb.Entry{ 696 Kind: e.kind, 697 Method: e.method, 698 Message: a, 699 IsError: e.msg.err != nil, 700 RefIndex: int32(e.refIndex), 701 } 702 bytes, err := proto.Marshal(pe) 703 if err != nil { 704 return err 705 } 706 return writeRecord(w, bytes) 707} 708 709func readEntry(r io.Reader) (*entry, error) { 710 buf, err := readRecord(r) 711 if err == io.EOF { 712 return nil, nil 713 } 714 if err != nil { 715 return nil, err 716 } 717 var pe pb.Entry 718 if err := proto.Unmarshal(buf, &pe); err != nil { 719 return nil, err 720 } 721 var msg message 722 if pe.Message != nil { 723 var any ptypes.DynamicAny 724 if err := ptypes.UnmarshalAny(pe.Message, &any); err != nil { 725 return nil, err 726 } 727 if pe.IsError { 728 msg.err = status.ErrorProto(any.Message.(*spb.Status)) 729 } else { 730 msg.msg = any.Message 731 } 732 } else if pe.IsError { 733 msg.err = io.EOF 734 } else if pe.Kind != pb.Entry_CREATE_STREAM { 735 return nil, errors.New("rpcreplay: entry with nil message and false is_error") 736 } 737 return &entry{ 738 kind: pe.Kind, 739 method: pe.Method, 740 msg: msg, 741 refIndex: int(pe.RefIndex), 742 }, nil 743} 744 745// A record consists of an unsigned 32-bit little-endian length L followed by L 746// bytes. 747 748func writeRecord(w io.Writer, data []byte) error { 749 if err := binary.Write(w, binary.LittleEndian, uint32(len(data))); err != nil { 750 return err 751 } 752 _, err := w.Write(data) 753 return err 754} 755 756func readRecord(r io.Reader) ([]byte, error) { 757 var size uint32 758 if err := binary.Read(r, binary.LittleEndian, &size); err != nil { 759 return nil, err 760 } 761 buf := make([]byte, size) 762 if _, err := io.ReadFull(r, buf); err != nil { 763 return nil, err 764 } 765 return buf, nil 766} 767