1/*
2 *
3 * Copyright 2014, Google Inc.
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are
8 * met:
9 *
10 *     * Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 *     * Redistributions in binary form must reproduce the above
13 * copyright notice, this list of conditions and the following disclaimer
14 * in the documentation and/or other materials provided with the
15 * distribution.
16 *     * Neither the name of Google Inc. nor the names of its
17 * contributors may be used to endorse or promote products derived from
18 * this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 *
32 */
33
34package grpc
35
36import (
37	"bytes"
38	"errors"
39	"io"
40	"sync"
41	"time"
42
43	"golang.org/x/net/context"
44	"golang.org/x/net/trace"
45	"google.golang.org/grpc/codes"
46	"google.golang.org/grpc/metadata"
47	"google.golang.org/grpc/transport"
48)
49
50type streamHandler func(srv interface{}, stream ServerStream) error
51
52// StreamDesc represents a streaming RPC service's method specification.
53type StreamDesc struct {
54	StreamName string
55	Handler    streamHandler
56
57	// At least one of these is true.
58	ServerStreams bool
59	ClientStreams bool
60}
61
62// Stream defines the common interface a client or server stream has to satisfy.
63type Stream interface {
64	// Context returns the context for this stream.
65	Context() context.Context
66	// SendMsg blocks until it sends m, the stream is done or the stream
67	// breaks.
68	// On error, it aborts the stream and returns an RPC status on client
69	// side. On server side, it simply returns the error to the caller.
70	// SendMsg is called by generated code.
71	SendMsg(m interface{}) error
72	// RecvMsg blocks until it receives a message or the stream is
73	// done. On client side, it returns io.EOF when the stream is done. On
74	// any other error, it aborts the stream and returns an RPC status. On
75	// server side, it simply returns the error to the caller.
76	RecvMsg(m interface{}) error
77}
78
79// ClientStream defines the interface a client stream has to satify.
80type ClientStream interface {
81	// Header returns the header metedata received from the server if there
82	// is any. It blocks if the metadata is not ready to read.
83	Header() (metadata.MD, error)
84	// Trailer returns the trailer metadata from the server. It must be called
85	// after stream.Recv() returns non-nil error (including io.EOF) for
86	// bi-directional streaming and server streaming or stream.CloseAndRecv()
87	// returns for client streaming in order to receive trailer metadata if
88	// present. Otherwise, it could returns an empty MD even though trailer
89	// is present.
90	Trailer() metadata.MD
91	// CloseSend closes the send direction of the stream. It closes the stream
92	// when non-nil error is met.
93	CloseSend() error
94	Stream
95}
96
97// NewClientStream creates a new Stream for the client side. This is called
98// by generated code.
99func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
100	var (
101		t   transport.ClientTransport
102		err error
103	)
104	t, err = cc.dopts.picker.Pick(ctx)
105	if err != nil {
106		return nil, toRPCErr(err)
107	}
108	// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
109	callHdr := &transport.CallHdr{
110		Host:   cc.authority,
111		Method: method,
112		Flush:  desc.ServerStreams && desc.ClientStreams,
113	}
114	if cc.dopts.cp != nil {
115		callHdr.SendCompress = cc.dopts.cp.Type()
116	}
117	cs := &clientStream{
118		desc:    desc,
119		codec:   cc.dopts.codec,
120		cp:      cc.dopts.cp,
121		dc:      cc.dopts.dc,
122		tracing: EnableTracing,
123	}
124	if cc.dopts.cp != nil {
125		callHdr.SendCompress = cc.dopts.cp.Type()
126		cs.cbuf = new(bytes.Buffer)
127	}
128	if cs.tracing {
129		cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
130		cs.trInfo.firstLine.client = true
131		if deadline, ok := ctx.Deadline(); ok {
132			cs.trInfo.firstLine.deadline = deadline.Sub(time.Now())
133		}
134		cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false)
135		ctx = trace.NewContext(ctx, cs.trInfo.tr)
136	}
137	s, err := t.NewStream(ctx, callHdr)
138	if err != nil {
139		cs.finish(err)
140		return nil, toRPCErr(err)
141	}
142	cs.t = t
143	cs.s = s
144	cs.p = &parser{r: s}
145	// Listen on ctx.Done() to detect cancellation when there is no pending
146	// I/O operations on this stream.
147	go func() {
148		select {
149		case <-t.Error():
150			// Incur transport error, simply exit.
151		case <-s.Context().Done():
152			err := s.Context().Err()
153			cs.finish(err)
154			cs.closeTransportStream(transport.ContextErr(err))
155		}
156	}()
157	return cs, nil
158}
159
160// clientStream implements a client side Stream.
161type clientStream struct {
162	t     transport.ClientTransport
163	s     *transport.Stream
164	p     *parser
165	desc  *StreamDesc
166	codec Codec
167	cp    Compressor
168	cbuf  *bytes.Buffer
169	dc    Decompressor
170
171	tracing bool // set to EnableTracing when the clientStream is created.
172
173	mu     sync.Mutex
174	closed bool
175	// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
176	// and is set to nil when the clientStream's finish method is called.
177	trInfo traceInfo
178}
179
180func (cs *clientStream) Context() context.Context {
181	return cs.s.Context()
182}
183
184func (cs *clientStream) Header() (metadata.MD, error) {
185	m, err := cs.s.Header()
186	if err != nil {
187		if _, ok := err.(transport.ConnectionError); !ok {
188			cs.closeTransportStream(err)
189		}
190	}
191	return m, err
192}
193
194func (cs *clientStream) Trailer() metadata.MD {
195	return cs.s.Trailer()
196}
197
198func (cs *clientStream) SendMsg(m interface{}) (err error) {
199	if cs.tracing {
200		cs.mu.Lock()
201		if cs.trInfo.tr != nil {
202			cs.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
203		}
204		cs.mu.Unlock()
205	}
206	defer func() {
207		if err != nil {
208			cs.finish(err)
209		}
210		if err == nil || err == io.EOF {
211			return
212		}
213		if _, ok := err.(transport.ConnectionError); !ok {
214			cs.closeTransportStream(err)
215		}
216		err = toRPCErr(err)
217	}()
218	out, err := encode(cs.codec, m, cs.cp, cs.cbuf)
219	defer func() {
220		if cs.cbuf != nil {
221			cs.cbuf.Reset()
222		}
223	}()
224	if err != nil {
225		return transport.StreamErrorf(codes.Internal, "grpc: %v", err)
226	}
227	return cs.t.Write(cs.s, out, &transport.Options{Last: false})
228}
229
230func (cs *clientStream) RecvMsg(m interface{}) (err error) {
231	err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
232	defer func() {
233		// err != nil indicates the termination of the stream.
234		if err != nil {
235			cs.finish(err)
236		}
237	}()
238	if err == nil {
239		if cs.tracing {
240			cs.mu.Lock()
241			if cs.trInfo.tr != nil {
242				cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
243			}
244			cs.mu.Unlock()
245		}
246		if !cs.desc.ClientStreams || cs.desc.ServerStreams {
247			return
248		}
249		// Special handling for client streaming rpc.
250		err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
251		cs.closeTransportStream(err)
252		if err == nil {
253			return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
254		}
255		if err == io.EOF {
256			if cs.s.StatusCode() == codes.OK {
257				cs.finish(err)
258				return nil
259			}
260			return Errorf(cs.s.StatusCode(), cs.s.StatusDesc())
261		}
262		return toRPCErr(err)
263	}
264	if _, ok := err.(transport.ConnectionError); !ok {
265		cs.closeTransportStream(err)
266	}
267	if err == io.EOF {
268		if cs.s.StatusCode() == codes.OK {
269			// Returns io.EOF to indicate the end of the stream.
270			return
271		}
272		return Errorf(cs.s.StatusCode(), cs.s.StatusDesc())
273	}
274	return toRPCErr(err)
275}
276
277func (cs *clientStream) CloseSend() (err error) {
278	err = cs.t.Write(cs.s, nil, &transport.Options{Last: true})
279	defer func() {
280		if err != nil {
281			cs.finish(err)
282		}
283	}()
284	if err == nil || err == io.EOF {
285		return
286	}
287	if _, ok := err.(transport.ConnectionError); !ok {
288		cs.closeTransportStream(err)
289	}
290	err = toRPCErr(err)
291	return
292}
293
294func (cs *clientStream) closeTransportStream(err error) {
295	cs.mu.Lock()
296	if cs.closed {
297		cs.mu.Unlock()
298		return
299	}
300	cs.closed = true
301	cs.mu.Unlock()
302	cs.t.CloseStream(cs.s, err)
303}
304
305func (cs *clientStream) finish(err error) {
306	if !cs.tracing {
307		return
308	}
309	cs.mu.Lock()
310	defer cs.mu.Unlock()
311	if cs.trInfo.tr != nil {
312		if err == nil || err == io.EOF {
313			cs.trInfo.tr.LazyPrintf("RPC: [OK]")
314		} else {
315			cs.trInfo.tr.LazyPrintf("RPC: [%v]", err)
316			cs.trInfo.tr.SetError()
317		}
318		cs.trInfo.tr.Finish()
319		cs.trInfo.tr = nil
320	}
321}
322
323// ServerStream defines the interface a server stream has to satisfy.
324type ServerStream interface {
325	// SendHeader sends the header metadata. It should not be called
326	// after SendProto. It fails if called multiple times or if
327	// called after SendProto.
328	SendHeader(metadata.MD) error
329	// SetTrailer sets the trailer metadata which will be sent with the
330	// RPC status.
331	SetTrailer(metadata.MD)
332	Stream
333}
334
335// serverStream implements a server side Stream.
336type serverStream struct {
337	t          transport.ServerTransport
338	s          *transport.Stream
339	p          *parser
340	codec      Codec
341	cp         Compressor
342	dc         Decompressor
343	cbuf       *bytes.Buffer
344	statusCode codes.Code
345	statusDesc string
346	trInfo     *traceInfo
347
348	mu sync.Mutex // protects trInfo.tr after the service handler runs.
349}
350
351func (ss *serverStream) Context() context.Context {
352	return ss.s.Context()
353}
354
355func (ss *serverStream) SendHeader(md metadata.MD) error {
356	return ss.t.WriteHeader(ss.s, md)
357}
358
359func (ss *serverStream) SetTrailer(md metadata.MD) {
360	if md.Len() == 0 {
361		return
362	}
363	ss.s.SetTrailer(md)
364	return
365}
366
367func (ss *serverStream) SendMsg(m interface{}) (err error) {
368	defer func() {
369		if ss.trInfo != nil {
370			ss.mu.Lock()
371			if ss.trInfo.tr != nil {
372				if err == nil {
373					ss.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
374				} else {
375					ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
376					ss.trInfo.tr.SetError()
377				}
378			}
379			ss.mu.Unlock()
380		}
381	}()
382	out, err := encode(ss.codec, m, ss.cp, ss.cbuf)
383	defer func() {
384		if ss.cbuf != nil {
385			ss.cbuf.Reset()
386		}
387	}()
388	if err != nil {
389		err = transport.StreamErrorf(codes.Internal, "grpc: %v", err)
390		return err
391	}
392	return ss.t.Write(ss.s, out, &transport.Options{Last: false})
393}
394
395func (ss *serverStream) RecvMsg(m interface{}) (err error) {
396	defer func() {
397		if ss.trInfo != nil {
398			ss.mu.Lock()
399			if ss.trInfo.tr != nil {
400				if err == nil {
401					ss.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
402				} else if err != io.EOF {
403					ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
404					ss.trInfo.tr.SetError()
405				}
406			}
407			ss.mu.Unlock()
408		}
409	}()
410	return recv(ss.p, ss.codec, ss.s, ss.dc, m)
411}
412