1/*
2 *
3 * Copyright 2014, Google Inc.
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are
8 * met:
9 *
10 *     * Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 *     * Redistributions in binary form must reproduce the above
13 * copyright notice, this list of conditions and the following disclaimer
14 * in the documentation and/or other materials provided with the
15 * distribution.
16 *     * Neither the name of Google Inc. nor the names of its
17 * contributors may be used to endorse or promote products derived from
18 * this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 *
32 */
33
34package grpc
35
36import (
37	"bytes"
38	"compress/gzip"
39	"encoding/binary"
40	"fmt"
41	"io"
42	"io/ioutil"
43	"math"
44	"math/rand"
45	"os"
46	"time"
47
48	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/golang/protobuf/proto"
49	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
50	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc/codes"
51	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc/metadata"
52	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc/transport"
53)
54
55// Codec defines the interface gRPC uses to encode and decode messages.
56type Codec interface {
57	// Marshal returns the wire format of v.
58	Marshal(v interface{}) ([]byte, error)
59	// Unmarshal parses the wire format into v.
60	Unmarshal(data []byte, v interface{}) error
61	// String returns the name of the Codec implementation. The returned
62	// string will be used as part of content type in transmission.
63	String() string
64}
65
66// protoCodec is a Codec implemetation with protobuf. It is the default codec for gRPC.
67type protoCodec struct{}
68
69func (protoCodec) Marshal(v interface{}) ([]byte, error) {
70	return proto.Marshal(v.(proto.Message))
71}
72
73func (protoCodec) Unmarshal(data []byte, v interface{}) error {
74	return proto.Unmarshal(data, v.(proto.Message))
75}
76
77func (protoCodec) String() string {
78	return "proto"
79}
80
81// Compressor defines the interface gRPC uses to compress a message.
82type Compressor interface {
83	// Do compresses p into w.
84	Do(w io.Writer, p []byte) error
85	// Type returns the compression algorithm the Compressor uses.
86	Type() string
87}
88
89// NewGZIPCompressor creates a Compressor based on GZIP.
90func NewGZIPCompressor() Compressor {
91	return &gzipCompressor{}
92}
93
94type gzipCompressor struct {
95}
96
97func (c *gzipCompressor) Do(w io.Writer, p []byte) error {
98	z := gzip.NewWriter(w)
99	if _, err := z.Write(p); err != nil {
100		return err
101	}
102	return z.Close()
103}
104
105func (c *gzipCompressor) Type() string {
106	return "gzip"
107}
108
109// Decompressor defines the interface gRPC uses to decompress a message.
110type Decompressor interface {
111	// Do reads the data from r and uncompress them.
112	Do(r io.Reader) ([]byte, error)
113	// Type returns the compression algorithm the Decompressor uses.
114	Type() string
115}
116
117type gzipDecompressor struct {
118}
119
120// NewGZIPDecompressor creates a Decompressor based on GZIP.
121func NewGZIPDecompressor() Decompressor {
122	return &gzipDecompressor{}
123}
124
125func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) {
126	z, err := gzip.NewReader(r)
127	if err != nil {
128		return nil, err
129	}
130	defer z.Close()
131	return ioutil.ReadAll(z)
132}
133
134func (d *gzipDecompressor) Type() string {
135	return "gzip"
136}
137
138// callInfo contains all related configuration and information about an RPC.
139type callInfo struct {
140	failFast  bool
141	headerMD  metadata.MD
142	trailerMD metadata.MD
143	traceInfo traceInfo // in trace.go
144}
145
146// CallOption configures a Call before it starts or extracts information from
147// a Call after it completes.
148type CallOption interface {
149	// before is called before the call is sent to any server.  If before
150	// returns a non-nil error, the RPC fails with that error.
151	before(*callInfo) error
152
153	// after is called after the call has completed.  after cannot return an
154	// error, so any failures should be reported via output parameters.
155	after(*callInfo)
156}
157
158type beforeCall func(c *callInfo) error
159
160func (o beforeCall) before(c *callInfo) error { return o(c) }
161func (o beforeCall) after(c *callInfo)        {}
162
163type afterCall func(c *callInfo)
164
165func (o afterCall) before(c *callInfo) error { return nil }
166func (o afterCall) after(c *callInfo)        { o(c) }
167
168// Header returns a CallOptions that retrieves the header metadata
169// for a unary RPC.
170func Header(md *metadata.MD) CallOption {
171	return afterCall(func(c *callInfo) {
172		*md = c.headerMD
173	})
174}
175
176// Trailer returns a CallOptions that retrieves the trailer metadata
177// for a unary RPC.
178func Trailer(md *metadata.MD) CallOption {
179	return afterCall(func(c *callInfo) {
180		*md = c.trailerMD
181	})
182}
183
184// The format of the payload: compressed or not?
185type payloadFormat uint8
186
187const (
188	compressionNone payloadFormat = iota // no compression
189	compressionMade
190)
191
192// parser reads complelete gRPC messages from the underlying reader.
193type parser struct {
194	// r is the underlying reader.
195	// See the comment on recvMsg for the permissible
196	// error types.
197	r io.Reader
198
199	// The header of a gRPC message. Find more detail
200	// at http://www.grpc.io/docs/guides/wire.html.
201	header [5]byte
202}
203
204// recvMsg reads a complete gRPC message from the stream.
205//
206// It returns the message and its payload (compression/encoding)
207// format. The caller owns the returned msg memory.
208//
209// If there is an error, possible values are:
210//   * io.EOF, when no messages remain
211//   * io.ErrUnexpectedEOF
212//   * of type transport.ConnectionError
213//   * of type transport.StreamError
214// No other error values or types must be returned, which also means
215// that the underlying io.Reader must not return an incompatible
216// error.
217func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
218	if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
219		return 0, nil, err
220	}
221
222	pf = payloadFormat(p.header[0])
223	length := binary.BigEndian.Uint32(p.header[1:])
224
225	if length == 0 {
226		return pf, nil, nil
227	}
228	// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
229	// of making it for each message:
230	msg = make([]byte, int(length))
231	if _, err := io.ReadFull(p.r, msg); err != nil {
232		if err == io.EOF {
233			err = io.ErrUnexpectedEOF
234		}
235		return 0, nil, err
236	}
237	return pf, msg, nil
238}
239
240// encode serializes msg and prepends the message header. If msg is nil, it
241// generates the message header of 0 message length.
242func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte, error) {
243	var b []byte
244	var length uint
245	if msg != nil {
246		var err error
247		// TODO(zhaoq): optimize to reduce memory alloc and copying.
248		b, err = c.Marshal(msg)
249		if err != nil {
250			return nil, err
251		}
252		if cp != nil {
253			if err := cp.Do(cbuf, b); err != nil {
254				return nil, err
255			}
256			b = cbuf.Bytes()
257		}
258		length = uint(len(b))
259	}
260	if length > math.MaxUint32 {
261		return nil, Errorf(codes.InvalidArgument, "grpc: message too large (%d bytes)", length)
262	}
263
264	const (
265		payloadLen = 1
266		sizeLen    = 4
267	)
268
269	var buf = make([]byte, payloadLen+sizeLen+len(b))
270
271	// Write payload format
272	if cp == nil {
273		buf[0] = byte(compressionNone)
274	} else {
275		buf[0] = byte(compressionMade)
276	}
277	// Write length of b into buf
278	binary.BigEndian.PutUint32(buf[1:], uint32(length))
279	// Copy encoded msg to buf
280	copy(buf[5:], b)
281
282	return buf, nil
283}
284
285func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error {
286	switch pf {
287	case compressionNone:
288	case compressionMade:
289		if recvCompress == "" {
290			return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress)
291		}
292		if dc == nil || recvCompress != dc.Type() {
293			return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
294		}
295	default:
296		return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
297	}
298	return nil
299}
300
301func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
302	pf, d, err := p.recvMsg()
303	if err != nil {
304		return err
305	}
306	if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
307		return err
308	}
309	if pf == compressionMade {
310		d, err = dc.Do(bytes.NewReader(d))
311		if err != nil {
312			return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
313		}
314	}
315	if err := c.Unmarshal(d, m); err != nil {
316		return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
317	}
318	return nil
319}
320
321// rpcError defines the status from an RPC.
322type rpcError struct {
323	code codes.Code
324	desc string
325}
326
327func (e rpcError) Error() string {
328	return fmt.Sprintf("rpc error: code = %d desc = %q", e.code, e.desc)
329}
330
331// Code returns the error code for err if it was produced by the rpc system.
332// Otherwise, it returns codes.Unknown.
333func Code(err error) codes.Code {
334	if err == nil {
335		return codes.OK
336	}
337	if e, ok := err.(rpcError); ok {
338		return e.code
339	}
340	return codes.Unknown
341}
342
343// ErrorDesc returns the error description of err if it was produced by the rpc system.
344// Otherwise, it returns err.Error() or empty string when err is nil.
345func ErrorDesc(err error) string {
346	if err == nil {
347		return ""
348	}
349	if e, ok := err.(rpcError); ok {
350		return e.desc
351	}
352	return err.Error()
353}
354
355// Errorf returns an error containing an error code and a description;
356// Errorf returns nil if c is OK.
357func Errorf(c codes.Code, format string, a ...interface{}) error {
358	if c == codes.OK {
359		return nil
360	}
361	return rpcError{
362		code: c,
363		desc: fmt.Sprintf(format, a...),
364	}
365}
366
367// toRPCErr converts an error into a rpcError.
368func toRPCErr(err error) error {
369	switch e := err.(type) {
370	case rpcError:
371		return err
372	case transport.StreamError:
373		return rpcError{
374			code: e.Code,
375			desc: e.Desc,
376		}
377	case transport.ConnectionError:
378		return rpcError{
379			code: codes.Internal,
380			desc: e.Desc,
381		}
382	}
383	return Errorf(codes.Unknown, "%v", err)
384}
385
386// convertCode converts a standard Go error into its canonical code. Note that
387// this is only used to translate the error returned by the server applications.
388func convertCode(err error) codes.Code {
389	switch err {
390	case nil:
391		return codes.OK
392	case io.EOF:
393		return codes.OutOfRange
394	case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF:
395		return codes.FailedPrecondition
396	case os.ErrInvalid:
397		return codes.InvalidArgument
398	case context.Canceled:
399		return codes.Canceled
400	case context.DeadlineExceeded:
401		return codes.DeadlineExceeded
402	}
403	switch {
404	case os.IsExist(err):
405		return codes.AlreadyExists
406	case os.IsNotExist(err):
407		return codes.NotFound
408	case os.IsPermission(err):
409		return codes.PermissionDenied
410	}
411	return codes.Unknown
412}
413
414const (
415	// how long to wait after the first failure before retrying
416	baseDelay = 1.0 * time.Second
417	// upper bound of backoff delay
418	maxDelay = 120 * time.Second
419	// backoff increases by this factor on each retry
420	backoffFactor = 1.6
421	// backoff is randomized downwards by this factor
422	backoffJitter = 0.2
423)
424
425func backoff(retries int) (t time.Duration) {
426	if retries == 0 {
427		return baseDelay
428	}
429	backoff, max := float64(baseDelay), float64(maxDelay)
430	for backoff < max && retries > 0 {
431		backoff *= backoffFactor
432		retries--
433	}
434	if backoff > max {
435		backoff = max
436	}
437	// Randomize backoff delays so that if a cluster of requests start at
438	// the same time, they won't operate in lockstep.
439	backoff *= 1 + backoffJitter*(rand.Float64()*2-1)
440	if backoff < 0 {
441		return 0
442	}
443	return time.Duration(backoff)
444}
445