1/*
2 * Copyright 2016, Google Inc.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met:
8 *
9 *     * Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 *     * Redistributions in binary form must reproduce the above
12 * copyright notice, this list of conditions and the following disclaimer
13 * in the documentation and/or other materials provided with the
14 * distribution.
15 *     * Neither the name of Google Inc. nor the names of its
16 * contributors may be used to endorse or promote products derived from
17 * this software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 *
31 */
32
33// This file is the implementation of a gRPC server using HTTP/2 which
34// uses the standard Go http2 Server implementation (via the
35// http.Handler interface), rather than speaking low-level HTTP/2
36// frames itself. It is the implementation of *grpc.Server.ServeHTTP.
37
38package transport
39
40import (
41	"errors"
42	"fmt"
43	"io"
44	"net"
45	"net/http"
46	"strings"
47	"sync"
48	"time"
49
50	"golang.org/x/net/context"
51	"golang.org/x/net/http2"
52	"google.golang.org/grpc/codes"
53	"google.golang.org/grpc/credentials"
54	"google.golang.org/grpc/metadata"
55	"google.golang.org/grpc/peer"
56)
57
58// NewServerHandlerTransport returns a ServerTransport handling gRPC
59// from inside an http.Handler. It requires that the http Server
60// supports HTTP/2.
61func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
62	if r.ProtoMajor != 2 {
63		return nil, errors.New("gRPC requires HTTP/2")
64	}
65	if r.Method != "POST" {
66		return nil, errors.New("invalid gRPC request method")
67	}
68	if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
69		return nil, errors.New("invalid gRPC request content-type")
70	}
71	if _, ok := w.(http.Flusher); !ok {
72		return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
73	}
74	if _, ok := w.(http.CloseNotifier); !ok {
75		return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier")
76	}
77
78	st := &serverHandlerTransport{
79		rw:       w,
80		req:      r,
81		closedCh: make(chan struct{}),
82		writes:   make(chan func()),
83	}
84
85	if v := r.Header.Get("grpc-timeout"); v != "" {
86		to, err := timeoutDecode(v)
87		if err != nil {
88			return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
89		}
90		st.timeoutSet = true
91		st.timeout = to
92	}
93
94	var metakv []string
95	for k, vv := range r.Header {
96		k = strings.ToLower(k)
97		if isReservedHeader(k) {
98			continue
99		}
100		for _, v := range vv {
101			if k == "user-agent" {
102				// user-agent is special. Copying logic of http_util.go.
103				if i := strings.LastIndex(v, " "); i == -1 {
104					// There is no application user agent string being set
105					continue
106				} else {
107					v = v[:i]
108				}
109			}
110			metakv = append(metakv, k, v)
111
112		}
113	}
114	st.headerMD = metadata.Pairs(metakv...)
115
116	return st, nil
117}
118
119// serverHandlerTransport is an implementation of ServerTransport
120// which replies to exactly one gRPC request (exactly one HTTP request),
121// using the net/http.Handler interface. This http.Handler is guaranteed
122// at this point to be speaking over HTTP/2, so it's able to speak valid
123// gRPC.
124type serverHandlerTransport struct {
125	rw               http.ResponseWriter
126	req              *http.Request
127	timeoutSet       bool
128	timeout          time.Duration
129	didCommonHeaders bool
130
131	headerMD metadata.MD
132
133	closeOnce sync.Once
134	closedCh  chan struct{} // closed on Close
135
136	// writes is a channel of code to run serialized in the
137	// ServeHTTP (HandleStreams) goroutine. The channel is closed
138	// when WriteStatus is called.
139	writes chan func()
140}
141
142func (ht *serverHandlerTransport) Close() error {
143	ht.closeOnce.Do(ht.closeCloseChanOnce)
144	return nil
145}
146
147func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) }
148
149func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }
150
151// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
152// the empty string if unknown.
153type strAddr string
154
155func (a strAddr) Network() string {
156	if a != "" {
157		// Per the documentation on net/http.Request.RemoteAddr, if this is
158		// set, it's set to the IP:port of the peer (hence, TCP):
159		// https://golang.org/pkg/net/http/#Request
160		//
161		// If we want to support Unix sockets later, we can
162		// add our own grpc-specific convention within the
163		// grpc codebase to set RemoteAddr to a different
164		// format, or probably better: we can attach it to the
165		// context and use that from serverHandlerTransport.RemoteAddr.
166		return "tcp"
167	}
168	return ""
169}
170
171func (a strAddr) String() string { return string(a) }
172
173// do runs fn in the ServeHTTP goroutine.
174func (ht *serverHandlerTransport) do(fn func()) error {
175	select {
176	case ht.writes <- fn:
177		return nil
178	case <-ht.closedCh:
179		return ErrConnClosing
180	}
181}
182
183func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
184	err := ht.do(func() {
185		ht.writeCommonHeaders(s)
186
187		// And flush, in case no header or body has been sent yet.
188		// This forces a separation of headers and trailers if this is the
189		// first call (for example, in end2end tests's TestNoService).
190		ht.rw.(http.Flusher).Flush()
191
192		h := ht.rw.Header()
193		h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
194		if statusDesc != "" {
195			h.Set("Grpc-Message", statusDesc)
196		}
197		if md := s.Trailer(); len(md) > 0 {
198			for k, vv := range md {
199				for _, v := range vv {
200					// http2 ResponseWriter mechanism to
201					// send undeclared Trailers after the
202					// headers have possibly been written.
203					h.Add(http2.TrailerPrefix+k, v)
204				}
205			}
206		}
207	})
208	close(ht.writes)
209	return err
210}
211
212// writeCommonHeaders sets common headers on the first write
213// call (Write, WriteHeader, or WriteStatus).
214func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
215	if ht.didCommonHeaders {
216		return
217	}
218	ht.didCommonHeaders = true
219
220	h := ht.rw.Header()
221	h["Date"] = nil // suppress Date to make tests happy; TODO: restore
222	h.Set("Content-Type", "application/grpc")
223
224	// Predeclare trailers we'll set later in WriteStatus (after the body).
225	// This is a SHOULD in the HTTP RFC, and the way you add (known)
226	// Trailers per the net/http.ResponseWriter contract.
227	// See https://golang.org/pkg/net/http/#ResponseWriter
228	// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
229	h.Add("Trailer", "Grpc-Status")
230	h.Add("Trailer", "Grpc-Message")
231
232	if s.sendCompress != "" {
233		h.Set("Grpc-Encoding", s.sendCompress)
234	}
235}
236
237func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
238	return ht.do(func() {
239		ht.writeCommonHeaders(s)
240		ht.rw.Write(data)
241		if !opts.Delay {
242			ht.rw.(http.Flusher).Flush()
243		}
244	})
245}
246
247func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
248	return ht.do(func() {
249		ht.writeCommonHeaders(s)
250		h := ht.rw.Header()
251		for k, vv := range md {
252			for _, v := range vv {
253				h.Add(k, v)
254			}
255		}
256		ht.rw.WriteHeader(200)
257		ht.rw.(http.Flusher).Flush()
258	})
259}
260
261func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
262	// With this transport type there will be exactly 1 stream: this HTTP request.
263
264	var ctx context.Context
265	var cancel context.CancelFunc
266	if ht.timeoutSet {
267		ctx, cancel = context.WithTimeout(context.Background(), ht.timeout)
268	} else {
269		ctx, cancel = context.WithCancel(context.Background())
270	}
271
272	// requestOver is closed when either the request's context is done
273	// or the status has been written via WriteStatus.
274	requestOver := make(chan struct{})
275
276	// clientGone receives a single value if peer is gone, either
277	// because the underlying connection is dead or because the
278	// peer sends an http2 RST_STREAM.
279	clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
280	go func() {
281		select {
282		case <-requestOver:
283			return
284		case <-ht.closedCh:
285		case <-clientGone:
286		}
287		cancel()
288	}()
289
290	req := ht.req
291
292	s := &Stream{
293		id:            0,            // irrelevant
294		windowHandler: func(int) {}, // nothing
295		cancel:        cancel,
296		buf:           newRecvBuffer(),
297		st:            ht,
298		method:        req.URL.Path,
299		recvCompress:  req.Header.Get("grpc-encoding"),
300	}
301	pr := &peer.Peer{
302		Addr: ht.RemoteAddr(),
303	}
304	if req.TLS != nil {
305		pr.AuthInfo = credentials.TLSInfo{*req.TLS}
306	}
307	ctx = metadata.NewContext(ctx, ht.headerMD)
308	ctx = peer.NewContext(ctx, pr)
309	s.ctx = newContextWithStream(ctx, s)
310	s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
311
312	// readerDone is closed when the Body.Read-ing goroutine exits.
313	readerDone := make(chan struct{})
314	go func() {
315		defer close(readerDone)
316		for {
317			buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership
318			n, err := req.Body.Read(buf)
319			if n > 0 {
320				s.buf.put(&recvMsg{data: buf[:n]})
321			}
322			if err != nil {
323				s.buf.put(&recvMsg{err: mapRecvMsgError(err)})
324				return
325			}
326		}
327	}()
328
329	// startStream is provided by the *grpc.Server's serveStreams.
330	// It starts a goroutine serving s and exits immediately.
331	// The goroutine that is started is the one that then calls
332	// into ht, calling WriteHeader, Write, WriteStatus, Close, etc.
333	startStream(s)
334
335	ht.runStream()
336	close(requestOver)
337
338	// Wait for reading goroutine to finish.
339	req.Body.Close()
340	<-readerDone
341}
342
343func (ht *serverHandlerTransport) runStream() {
344	for {
345		select {
346		case fn, ok := <-ht.writes:
347			if !ok {
348				return
349			}
350			fn()
351		case <-ht.closedCh:
352			return
353		}
354	}
355}
356
357// mapRecvMsgError returns the non-nil err into the appropriate
358// error value as expected by callers of *grpc.parser.recvMsg.
359// In particular, in can only be:
360//   * io.EOF
361//   * io.ErrUnexpectedEOF
362//   * of type transport.ConnectionError
363//   * of type transport.StreamError
364func mapRecvMsgError(err error) error {
365	if err == io.EOF || err == io.ErrUnexpectedEOF {
366		return err
367	}
368	if se, ok := err.(http2.StreamError); ok {
369		if code, ok := http2ErrConvTab[se.Code]; ok {
370			return StreamError{
371				Code: code,
372				Desc: se.Error(),
373			}
374		}
375	}
376	return ConnectionError{Desc: err.Error()}
377}
378