1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package rpcreplay
16
17import (
18	"bufio"
19	"context"
20	"encoding/binary"
21	"errors"
22	"fmt"
23	"io"
24	"log"
25	"net"
26	"os"
27	"sync"
28
29	pb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
30	"github.com/golang/protobuf/proto"
31	"github.com/golang/protobuf/ptypes"
32	"github.com/golang/protobuf/ptypes/any"
33	spb "google.golang.org/genproto/googleapis/rpc/status"
34	"google.golang.org/grpc"
35	"google.golang.org/grpc/metadata"
36	"google.golang.org/grpc/status"
37)
38
39// A Recorder records RPCs for later playback.
40type Recorder struct {
41	mu   sync.Mutex
42	w    *bufio.Writer
43	f    *os.File
44	next int
45	err  error
46	// BeforeFunc defines a function that can inspect and modify requests and responses
47	// written to the replay file. It does not modify messages sent to the service.
48	// It is run once before a request is written to the replay file, and once before a response
49	// is written to the replay file.
50	// The function is called with the method name and the message that triggered the callback.
51	// If the function returns an error, the error will be returned to the client.
52	// This is only executed for unary RPCs; streaming RPCs are not supported.
53	BeforeFunc func(string, proto.Message) error
54}
55
56// NewRecorder creates a recorder that writes to filename. The file will
57// also store the initial bytes for retrieval during replay.
58//
59// You must call Close on the Recorder to ensure that all data is written.
60func NewRecorder(filename string, initial []byte) (*Recorder, error) {
61	f, err := os.Create(filename)
62	if err != nil {
63		return nil, err
64	}
65	rec, err := NewRecorderWriter(f, initial)
66	if err != nil {
67		_ = f.Close()
68		return nil, err
69	}
70	rec.f = f
71	return rec, nil
72}
73
74// NewRecorderWriter creates a recorder that writes to w. The initial
75// bytes will also be written to w for retrieval during replay.
76//
77// You must call Close on the Recorder to ensure that all data is written.
78func NewRecorderWriter(w io.Writer, initial []byte) (*Recorder, error) {
79	bw := bufio.NewWriter(w)
80	if err := writeHeader(bw, initial); err != nil {
81		return nil, err
82	}
83	return &Recorder{w: bw, next: 1}, nil
84}
85
86// DialOptions returns the options that must be passed to grpc.Dial
87// to enable recording.
88func (r *Recorder) DialOptions() []grpc.DialOption {
89	return []grpc.DialOption{
90		grpc.WithUnaryInterceptor(r.interceptUnary),
91		grpc.WithStreamInterceptor(r.interceptStream),
92	}
93}
94
95// Close saves any unwritten information.
96func (r *Recorder) Close() error {
97	r.mu.Lock()
98	defer r.mu.Unlock()
99	if r.err != nil {
100		return r.err
101	}
102	err := r.w.Flush()
103	if r.f != nil {
104		if err2 := r.f.Close(); err == nil {
105			err = err2
106		}
107	}
108	return err
109}
110
111// Intercepts all unary (non-stream) RPCs.
112func (r *Recorder) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
113	ereq := &entry{
114		kind:   pb.Entry_REQUEST,
115		method: method,
116		msg:    message{msg: proto.Clone(req.(proto.Message))},
117	}
118
119	if r.BeforeFunc != nil {
120		if err := r.BeforeFunc(method, ereq.msg.msg); err != nil {
121			return err
122		}
123	}
124	refIndex, err := r.writeEntry(ereq)
125	if err != nil {
126		return err
127	}
128	ierr := invoker(ctx, method, req, res, cc, opts...)
129	eres := &entry{
130		kind:     pb.Entry_RESPONSE,
131		refIndex: refIndex,
132	}
133	// If the error is not a gRPC status, then something more
134	// serious is wrong. More significantly, we have no way
135	// of serializing an arbitrary error. So just return it
136	// without recording the response.
137	if _, ok := status.FromError(ierr); !ok {
138		r.mu.Lock()
139		r.err = fmt.Errorf("saw non-status error in %s response: %v (%T)", method, ierr, ierr)
140		r.mu.Unlock()
141		return ierr
142	}
143	eres.msg.set(proto.Clone(res.(proto.Message)), ierr)
144	if r.BeforeFunc != nil {
145		if err := r.BeforeFunc(method, eres.msg.msg); err != nil {
146			return err
147		}
148	}
149	if _, err := r.writeEntry(eres); err != nil {
150		return err
151	}
152	return ierr
153}
154
155func (r *Recorder) writeEntry(e *entry) (int, error) {
156	r.mu.Lock()
157	defer r.mu.Unlock()
158	if r.err != nil {
159		return 0, r.err
160	}
161	err := writeEntry(r.w, e)
162	if err != nil {
163		r.err = err
164		return 0, err
165	}
166	n := r.next
167	r.next++
168	return n, nil
169}
170
171func (r *Recorder) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
172	cstream, serr := streamer(ctx, desc, cc, method, opts...)
173	e := &entry{
174		kind:   pb.Entry_CREATE_STREAM,
175		method: method,
176	}
177	e.msg.set(nil, serr)
178	refIndex, err := r.writeEntry(e)
179	if err != nil {
180		return nil, err
181	}
182	return &recClientStream{
183		ctx:      ctx,
184		rec:      r,
185		cstream:  cstream,
186		refIndex: refIndex,
187	}, serr
188}
189
190// A recClientStream implements the gprc.ClientStream interface.
191// It behaves exactly like the default ClientStream, but also
192// records all messages sent and received.
193type recClientStream struct {
194	ctx      context.Context
195	rec      *Recorder
196	cstream  grpc.ClientStream
197	refIndex int
198}
199
200func (rcs *recClientStream) Context() context.Context { return rcs.ctx }
201
202func (rcs *recClientStream) SendMsg(m interface{}) error {
203	serr := rcs.cstream.SendMsg(m)
204	e := &entry{
205		kind:     pb.Entry_SEND,
206		refIndex: rcs.refIndex,
207	}
208	e.msg.set(m, serr)
209	if _, err := rcs.rec.writeEntry(e); err != nil {
210		return err
211	}
212	return serr
213}
214
215func (rcs *recClientStream) RecvMsg(m interface{}) error {
216	serr := rcs.cstream.RecvMsg(m)
217	e := &entry{
218		kind:     pb.Entry_RECV,
219		refIndex: rcs.refIndex,
220	}
221	e.msg.set(m, serr)
222	if _, err := rcs.rec.writeEntry(e); err != nil {
223		return err
224	}
225	return serr
226}
227
228func (rcs *recClientStream) Header() (metadata.MD, error) {
229	// TODO(jba): record.
230	return rcs.cstream.Header()
231}
232
233func (rcs *recClientStream) Trailer() metadata.MD {
234	// TODO(jba): record.
235	return rcs.cstream.Trailer()
236}
237
238func (rcs *recClientStream) CloseSend() error {
239	// TODO(jba): record.
240	return rcs.cstream.CloseSend()
241}
242
243// A Replayer replays a set of RPCs saved by a Recorder.
244type Replayer struct {
245	initial []byte                                // initial state
246	log     func(format string, v ...interface{}) // for debugging
247
248	mu      sync.Mutex
249	calls   []*call
250	streams []*stream
251	// BeforeFunc defines a function that can inspect and modify requests before they
252	// are matched for responses from the replay file.
253	// The function is called with the method name and the message that triggered the callback.
254	// If the function returns an error, the error will be returned to the client.
255	// This is only executed for unary RPCs; streaming RPCs are not supported.
256	BeforeFunc func(string, proto.Message) error
257}
258
259// A call represents a unary RPC, with a request and response (or error).
260type call struct {
261	method   string
262	request  proto.Message
263	response message
264}
265
266// A stream represents a gRPC stream, with an initial create-stream call, followed by
267// zero or more sends and/or receives.
268type stream struct {
269	method      string
270	createIndex int
271	createErr   error // error from create call
272	sends       []message
273	recvs       []message
274}
275
276// NewReplayer creates a Replayer that reads from filename.
277func NewReplayer(filename string) (*Replayer, error) {
278	f, err := os.Open(filename)
279	if err != nil {
280		return nil, err
281	}
282	defer f.Close()
283	return NewReplayerReader(f)
284}
285
286// NewReplayerReader creates a Replayer that reads from r.
287func NewReplayerReader(r io.Reader) (*Replayer, error) {
288	rep := &Replayer{
289		log: func(string, ...interface{}) {},
290	}
291	if err := rep.read(r); err != nil {
292		return nil, err
293	}
294	return rep, nil
295}
296
297// read reads the stream of recorded entries.
298// It matches requests with responses, with each pair grouped
299// into a call struct.
300func (rep *Replayer) read(r io.Reader) error {
301	r = bufio.NewReader(r)
302	bytes, err := readHeader(r)
303	if err != nil {
304		return err
305	}
306	rep.initial = bytes
307
308	callsByIndex := map[int]*call{}
309	streamsByIndex := map[int]*stream{}
310	for i := 1; ; i++ {
311		e, err := readEntry(r)
312		if err != nil {
313			return err
314		}
315		if e == nil {
316			break
317		}
318		switch e.kind {
319		case pb.Entry_REQUEST:
320			callsByIndex[i] = &call{
321				method:  e.method,
322				request: e.msg.msg,
323			}
324
325		case pb.Entry_RESPONSE:
326			call := callsByIndex[e.refIndex]
327			if call == nil {
328				return fmt.Errorf("replayer: no request for response #%d", i)
329			}
330			delete(callsByIndex, e.refIndex)
331			call.response = e.msg
332			rep.calls = append(rep.calls, call)
333
334		case pb.Entry_CREATE_STREAM:
335			s := &stream{method: e.method, createIndex: i}
336			s.createErr = e.msg.err
337			streamsByIndex[i] = s
338			rep.streams = append(rep.streams, s)
339
340		case pb.Entry_SEND:
341			s := streamsByIndex[e.refIndex]
342			if s == nil {
343				return fmt.Errorf("replayer: no stream for send #%d", i)
344			}
345			s.sends = append(s.sends, e.msg)
346
347		case pb.Entry_RECV:
348			s := streamsByIndex[e.refIndex]
349			if s == nil {
350				return fmt.Errorf("replayer: no stream for recv #%d", i)
351			}
352			s.recvs = append(s.recvs, e.msg)
353
354		default:
355			return fmt.Errorf("replayer: unknown kind %s", e.kind)
356		}
357	}
358	if len(callsByIndex) > 0 {
359		return fmt.Errorf("replayer: %d unmatched requests", len(callsByIndex))
360	}
361	return nil
362}
363
364// DialOptions returns the options that must be passed to grpc.Dial
365// to enable replaying.
366func (rep *Replayer) DialOptions() []grpc.DialOption {
367	return []grpc.DialOption{
368		// On replay, we make no RPCs, which means the connection may be closed
369		// before the normally async Dial completes. Making the Dial synchronous
370		// fixes that.
371		grpc.WithBlock(),
372		grpc.WithUnaryInterceptor(rep.interceptUnary),
373		grpc.WithStreamInterceptor(rep.interceptStream),
374	}
375}
376
377// Connection returns a fake gRPC connection suitable for replaying.
378func (rep *Replayer) Connection() (*grpc.ClientConn, error) {
379	// We don't need an actual connection, not even a loopback one.
380	// But we do need something to attach gRPC interceptors to.
381	// So we start a local server and connect to it, then close it down.
382	srv := grpc.NewServer()
383	l, err := net.Listen("tcp", "localhost:0")
384	if err != nil {
385		return nil, err
386	}
387	go func() {
388		if err := srv.Serve(l); err != nil {
389			panic(err) // we should never get an error because we just connect and stop
390		}
391	}()
392	conn, err := grpc.Dial(l.Addr().String(),
393		append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...)
394	if err != nil {
395		return nil, err
396	}
397	conn.Close()
398	srv.Stop()
399	return conn, nil
400}
401
402// Initial returns the initial state saved by the Recorder.
403func (rep *Replayer) Initial() []byte { return rep.initial }
404
405// SetLogFunc sets a function to be used for debug logging. The function
406// should be safe to be called from multiple goroutines.
407func (rep *Replayer) SetLogFunc(f func(format string, v ...interface{})) {
408	rep.log = f
409}
410
411// Close closes the Replayer.
412func (rep *Replayer) Close() error {
413	return nil
414}
415
416func (rep *Replayer) interceptUnary(_ context.Context, method string, req, res interface{}, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error {
417	mreq := req.(proto.Message)
418	if rep.BeforeFunc != nil {
419		if err := rep.BeforeFunc(method, mreq); err != nil {
420			return err
421		}
422	}
423	rep.log("request %s (%s)", method, req)
424	call := rep.extractCall(method, mreq)
425	if call == nil {
426		return fmt.Errorf("replayer: request not found: %s", mreq)
427	}
428	rep.log("returning %v", call.response)
429	if call.response.err != nil {
430		return call.response.err
431	}
432	proto.Merge(res.(proto.Message), call.response.msg) // copy msg into res
433	return nil
434}
435
436func (rep *Replayer) interceptStream(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, method string, _ grpc.Streamer, _ ...grpc.CallOption) (grpc.ClientStream, error) {
437	rep.log("create-stream %s", method)
438	return &repClientStream{ctx: ctx, rep: rep, method: method}, nil
439}
440
441type repClientStream struct {
442	ctx    context.Context
443	rep    *Replayer
444	method string
445	str    *stream
446}
447
448func (rcs *repClientStream) Context() context.Context { return rcs.ctx }
449
450func (rcs *repClientStream) SendMsg(req interface{}) error {
451	if rcs.str == nil {
452		if err := rcs.setStream(rcs.method, req.(proto.Message)); err != nil {
453			return err
454		}
455	}
456	if len(rcs.str.sends) == 0 {
457		return fmt.Errorf("replayer: no more sends for stream %s, created at index %d",
458			rcs.str.method, rcs.str.createIndex)
459	}
460	// TODO(jba): Do not assume that the sends happen in the same order on replay.
461	msg := rcs.str.sends[0]
462	rcs.str.sends = rcs.str.sends[1:]
463	return msg.err
464}
465
466func (rcs *repClientStream) setStream(method string, req proto.Message) error {
467	str := rcs.rep.extractStream(method, req)
468	if str == nil {
469		return fmt.Errorf("replayer: stream not found for method %s and request %v", method, req)
470	}
471	if str.createErr != nil {
472		return str.createErr
473	}
474	rcs.str = str
475	return nil
476}
477
478func (rcs *repClientStream) RecvMsg(m interface{}) error {
479	if rcs.str == nil {
480		// Receive before send; fall back to matching stream by method only.
481		if err := rcs.setStream(rcs.method, nil); err != nil {
482			return err
483		}
484	}
485	if len(rcs.str.recvs) == 0 {
486		return fmt.Errorf("replayer: no more receives for stream %s, created at index %d",
487			rcs.str.method, rcs.str.createIndex)
488	}
489	msg := rcs.str.recvs[0]
490	rcs.str.recvs = rcs.str.recvs[1:]
491	if msg.err != nil {
492		return msg.err
493	}
494	proto.Merge(m.(proto.Message), msg.msg) // copy msg into m
495	return nil
496}
497
498func (rcs *repClientStream) Header() (metadata.MD, error) {
499	log.Printf("replay: stream metadata not supported")
500	return nil, nil
501}
502
503func (rcs *repClientStream) Trailer() metadata.MD {
504	log.Printf("replay: stream metadata not supported")
505	return nil
506}
507
508func (rcs *repClientStream) CloseSend() error {
509	return nil
510}
511
512// extractCall finds the first call in the list with the same method
513// and request. It returns nil if it can't find such a call.
514func (rep *Replayer) extractCall(method string, req proto.Message) *call {
515	rep.mu.Lock()
516	defer rep.mu.Unlock()
517	for i, call := range rep.calls {
518		if call == nil {
519			continue
520		}
521		if method == call.method && proto.Equal(req, call.request) {
522			rep.calls[i] = nil // nil out this call so we don't reuse it
523			return call
524		}
525	}
526	return nil
527}
528
529// extractStream find the first stream in the list with the same method and the same
530// first request sent. If req is nil, that means a receive occurred before a send, so
531// it matches only on method.
532func (rep *Replayer) extractStream(method string, req proto.Message) *stream {
533	rep.mu.Lock()
534	defer rep.mu.Unlock()
535	for i, stream := range rep.streams {
536		// Skip stream if it is nil (already extracted) or its method doesn't match.
537		if stream == nil || stream.method != method {
538			continue
539		}
540		// If there is a first request, skip stream if it has no requests or its first
541		// request doesn't match.
542		if req != nil && len(stream.sends) > 0 && !proto.Equal(req, stream.sends[0].msg) {
543			continue
544		}
545		rep.streams[i] = nil // nil out this stream so we don't reuse it
546		return stream
547	}
548	return nil
549}
550
551// Fprint reads the entries from filename and writes them to w in human-readable form.
552// It is intended for debugging.
553func Fprint(w io.Writer, filename string) error {
554	f, err := os.Open(filename)
555	if err != nil {
556		return err
557	}
558	defer f.Close()
559	return FprintReader(w, f)
560}
561
562// FprintReader reads the entries from r and writes them to w in human-readable form.
563// It is intended for debugging.
564func FprintReader(w io.Writer, r io.Reader) error {
565	initial, err := readHeader(r)
566	if err != nil {
567		return err
568	}
569	fmt.Fprintf(w, "initial state: %q\n", string(initial))
570	for i := 1; ; i++ {
571		e, err := readEntry(r)
572		if err != nil {
573			return err
574		}
575		if e == nil {
576			return nil
577		}
578
579		fmt.Fprintf(w, "#%d: kind: %s, method: %s, ref index: %d", i, e.kind, e.method, e.refIndex)
580		switch {
581		case e.msg.msg != nil:
582			fmt.Fprintf(w, ", message:\n")
583			if err := proto.MarshalText(w, e.msg.msg); err != nil {
584				return err
585			}
586		case e.msg.err != nil:
587			fmt.Fprintf(w, ", error: %v\n", e.msg.err)
588		default:
589			fmt.Fprintln(w)
590		}
591	}
592}
593
594// An entry holds one gRPC action (request, response, etc.).
595type entry struct {
596	kind     pb.Entry_Kind
597	method   string
598	msg      message
599	refIndex int // index of corresponding request or create-stream
600}
601
602func (e1 *entry) equal(e2 *entry) bool {
603	if e1 == nil && e2 == nil {
604		return true
605	}
606	if e1 == nil || e2 == nil {
607		return false
608	}
609	return e1.kind == e2.kind &&
610		e1.method == e2.method &&
611		proto.Equal(e1.msg.msg, e2.msg.msg) &&
612		errEqual(e1.msg.err, e2.msg.err) &&
613		e1.refIndex == e2.refIndex
614}
615
616func errEqual(e1, e2 error) bool {
617	if e1 == e2 {
618		return true
619	}
620	s1, ok1 := status.FromError(e1)
621	s2, ok2 := status.FromError(e2)
622	if !ok1 || !ok2 {
623		return false
624	}
625	return proto.Equal(s1.Proto(), s2.Proto())
626}
627
628// message holds either a single proto.Message or an error.
629type message struct {
630	msg proto.Message
631	err error
632}
633
634func (m *message) set(msg interface{}, err error) {
635	m.err = err
636	if err != io.EOF && msg != nil {
637		m.msg = msg.(proto.Message)
638	}
639}
640
641// File format:
642//   header
643//   sequence of Entry protos
644//
645// Header format:
646//   magic string
647//   a record containing the bytes of the initial state
648
649const magic = "RPCReplay"
650
651func writeHeader(w io.Writer, initial []byte) error {
652	if _, err := io.WriteString(w, magic); err != nil {
653		return err
654	}
655	return writeRecord(w, initial)
656}
657
658func readHeader(r io.Reader) ([]byte, error) {
659	var buf [len(magic)]byte
660	if _, err := io.ReadFull(r, buf[:]); err != nil {
661		if err == io.EOF {
662			err = errors.New("rpcreplay: empty replay file")
663		}
664		return nil, err
665	}
666	if string(buf[:]) != magic {
667		return nil, errors.New("rpcreplay: not a replay file (does not begin with magic string)")
668	}
669	bytes, err := readRecord(r)
670	if err == io.EOF {
671		err = errors.New("rpcreplay: missing initial state")
672	}
673	return bytes, err
674}
675
676func writeEntry(w io.Writer, e *entry) error {
677	var m proto.Message
678	if e.msg.err != nil && e.msg.err != io.EOF {
679		s, ok := status.FromError(e.msg.err)
680		if !ok {
681			return fmt.Errorf("rpcreplay: error %v is not a Status", e.msg.err)
682		}
683		m = s.Proto()
684	} else {
685		m = e.msg.msg
686	}
687	var a *any.Any
688	var err error
689	if m != nil {
690		a, err = ptypes.MarshalAny(m)
691		if err != nil {
692			return err
693		}
694	}
695	pe := &pb.Entry{
696		Kind:     e.kind,
697		Method:   e.method,
698		Message:  a,
699		IsError:  e.msg.err != nil,
700		RefIndex: int32(e.refIndex),
701	}
702	bytes, err := proto.Marshal(pe)
703	if err != nil {
704		return err
705	}
706	return writeRecord(w, bytes)
707}
708
709func readEntry(r io.Reader) (*entry, error) {
710	buf, err := readRecord(r)
711	if err == io.EOF {
712		return nil, nil
713	}
714	if err != nil {
715		return nil, err
716	}
717	var pe pb.Entry
718	if err := proto.Unmarshal(buf, &pe); err != nil {
719		return nil, err
720	}
721	var msg message
722	if pe.Message != nil {
723		var any ptypes.DynamicAny
724		if err := ptypes.UnmarshalAny(pe.Message, &any); err != nil {
725			return nil, err
726		}
727		if pe.IsError {
728			msg.err = status.ErrorProto(any.Message.(*spb.Status))
729		} else {
730			msg.msg = any.Message
731		}
732	} else if pe.IsError {
733		msg.err = io.EOF
734	} else if pe.Kind != pb.Entry_CREATE_STREAM {
735		return nil, errors.New("rpcreplay: entry with nil message and false is_error")
736	}
737	return &entry{
738		kind:     pe.Kind,
739		method:   pe.Method,
740		msg:      msg,
741		refIndex: int(pe.RefIndex),
742	}, nil
743}
744
745// A record consists of an unsigned 32-bit little-endian length L followed by L
746// bytes.
747
748func writeRecord(w io.Writer, data []byte) error {
749	if err := binary.Write(w, binary.LittleEndian, uint32(len(data))); err != nil {
750		return err
751	}
752	_, err := w.Write(data)
753	return err
754}
755
756func readRecord(r io.Reader) ([]byte, error) {
757	var size uint32
758	if err := binary.Read(r, binary.LittleEndian, &size); err != nil {
759		return nil, err
760	}
761	buf := make([]byte, size)
762	if _, err := io.ReadFull(r, buf); err != nil {
763		return nil, err
764	}
765	return buf, nil
766}
767