1/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package transport
20
21import (
22	"context"
23	"errors"
24	"fmt"
25	"io"
26	"net/http"
27	"net/http/httptest"
28	"net/url"
29	"reflect"
30	"sync"
31	"testing"
32	"time"
33
34	"github.com/golang/protobuf/proto"
35	dpb "github.com/golang/protobuf/ptypes/duration"
36	epb "google.golang.org/genproto/googleapis/rpc/errdetails"
37	"google.golang.org/grpc/codes"
38	"google.golang.org/grpc/metadata"
39	"google.golang.org/grpc/status"
40)
41
42func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
43	type testCase struct {
44		name    string
45		req     *http.Request
46		wantErr string
47		modrw   func(http.ResponseWriter) http.ResponseWriter
48		check   func(*serverHandlerTransport, *testCase) error
49	}
50	tests := []testCase{
51		{
52			name: "http/1.1",
53			req: &http.Request{
54				ProtoMajor: 1,
55				ProtoMinor: 1,
56			},
57			wantErr: "gRPC requires HTTP/2",
58		},
59		{
60			name: "bad method",
61			req: &http.Request{
62				ProtoMajor: 2,
63				Method:     "GET",
64				Header:     http.Header{},
65				RequestURI: "/",
66			},
67			wantErr: "invalid gRPC request method",
68		},
69		{
70			name: "bad content type",
71			req: &http.Request{
72				ProtoMajor: 2,
73				Method:     "POST",
74				Header: http.Header{
75					"Content-Type": {"application/foo"},
76				},
77				RequestURI: "/service/foo.bar",
78			},
79			wantErr: "invalid gRPC request content-type",
80		},
81		{
82			name: "not flusher",
83			req: &http.Request{
84				ProtoMajor: 2,
85				Method:     "POST",
86				Header: http.Header{
87					"Content-Type": {"application/grpc"},
88				},
89				RequestURI: "/service/foo.bar",
90			},
91			modrw: func(w http.ResponseWriter) http.ResponseWriter {
92				// Return w without its Flush method
93				type onlyCloseNotifier interface {
94					http.ResponseWriter
95					http.CloseNotifier
96				}
97				return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
98			},
99			wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
100		},
101		{
102			name: "valid",
103			req: &http.Request{
104				ProtoMajor: 2,
105				Method:     "POST",
106				Header: http.Header{
107					"Content-Type": {"application/grpc"},
108				},
109				URL: &url.URL{
110					Path: "/service/foo.bar",
111				},
112				RequestURI: "/service/foo.bar",
113			},
114			check: func(t *serverHandlerTransport, tt *testCase) error {
115				if t.req != tt.req {
116					return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
117				}
118				if t.rw == nil {
119					return errors.New("t.rw = nil; want non-nil")
120				}
121				return nil
122			},
123		},
124		{
125			name: "with timeout",
126			req: &http.Request{
127				ProtoMajor: 2,
128				Method:     "POST",
129				Header: http.Header{
130					"Content-Type": []string{"application/grpc"},
131					"Grpc-Timeout": {"200m"},
132				},
133				URL: &url.URL{
134					Path: "/service/foo.bar",
135				},
136				RequestURI: "/service/foo.bar",
137			},
138			check: func(t *serverHandlerTransport, tt *testCase) error {
139				if !t.timeoutSet {
140					return errors.New("timeout not set")
141				}
142				if want := 200 * time.Millisecond; t.timeout != want {
143					return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
144				}
145				return nil
146			},
147		},
148		{
149			name: "with bad timeout",
150			req: &http.Request{
151				ProtoMajor: 2,
152				Method:     "POST",
153				Header: http.Header{
154					"Content-Type": []string{"application/grpc"},
155					"Grpc-Timeout": {"tomorrow"},
156				},
157				URL: &url.URL{
158					Path: "/service/foo.bar",
159				},
160				RequestURI: "/service/foo.bar",
161			},
162			wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`,
163		},
164		{
165			name: "with metadata",
166			req: &http.Request{
167				ProtoMajor: 2,
168				Method:     "POST",
169				Header: http.Header{
170					"Content-Type": []string{"application/grpc"},
171					"meta-foo":     {"foo-val"},
172					"meta-bar":     {"bar-val1", "bar-val2"},
173					"user-agent":   {"x/y a/b"},
174				},
175				URL: &url.URL{
176					Path: "/service/foo.bar",
177				},
178				RequestURI: "/service/foo.bar",
179			},
180			check: func(ht *serverHandlerTransport, tt *testCase) error {
181				want := metadata.MD{
182					"meta-bar":     {"bar-val1", "bar-val2"},
183					"user-agent":   {"x/y a/b"},
184					"meta-foo":     {"foo-val"},
185					"content-type": {"application/grpc"},
186				}
187
188				if !reflect.DeepEqual(ht.headerMD, want) {
189					return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want)
190				}
191				return nil
192			},
193		},
194	}
195
196	for _, tt := range tests {
197		rw := newTestHandlerResponseWriter()
198		if tt.modrw != nil {
199			rw = tt.modrw(rw)
200		}
201		got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
202		if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
203			t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
204			continue
205		}
206		if gotErr != nil {
207			continue
208		}
209		if tt.check != nil {
210			if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
211				t.Errorf("%s: %v", tt.name, err)
212			}
213		}
214	}
215}
216
217type testHandlerResponseWriter struct {
218	*httptest.ResponseRecorder
219	closeNotify chan bool
220}
221
222func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify }
223func (w testHandlerResponseWriter) Flush()                   {}
224
225func newTestHandlerResponseWriter() http.ResponseWriter {
226	return testHandlerResponseWriter{
227		ResponseRecorder: httptest.NewRecorder(),
228		closeNotify:      make(chan bool, 1),
229	}
230}
231
232type handleStreamTest struct {
233	t     *testing.T
234	bodyw *io.PipeWriter
235	rw    testHandlerResponseWriter
236	ht    *serverHandlerTransport
237}
238
239func newHandleStreamTest(t *testing.T) *handleStreamTest {
240	bodyr, bodyw := io.Pipe()
241	req := &http.Request{
242		ProtoMajor: 2,
243		Method:     "POST",
244		Header: http.Header{
245			"Content-Type": {"application/grpc"},
246		},
247		URL: &url.URL{
248			Path: "/service/foo.bar",
249		},
250		RequestURI: "/service/foo.bar",
251		Body:       bodyr,
252	}
253	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
254	ht, err := NewServerHandlerTransport(rw, req, nil)
255	if err != nil {
256		t.Fatal(err)
257	}
258	return &handleStreamTest{
259		t:     t,
260		bodyw: bodyw,
261		ht:    ht.(*serverHandlerTransport),
262		rw:    rw,
263	}
264}
265
266func TestHandlerTransport_HandleStreams(t *testing.T) {
267	st := newHandleStreamTest(t)
268	handleStream := func(s *Stream) {
269		if want := "/service/foo.bar"; s.method != want {
270			t.Errorf("stream method = %q; want %q", s.method, want)
271		}
272		st.bodyw.Close() // no body
273		st.ht.WriteStatus(s, status.New(codes.OK, ""))
274	}
275	st.ht.HandleStreams(
276		func(s *Stream) { go handleStream(s) },
277		func(ctx context.Context, method string) context.Context { return ctx },
278	)
279	wantHeader := http.Header{
280		"Date":         nil,
281		"Content-Type": {"application/grpc"},
282		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
283		"Grpc-Status":  {"0"},
284	}
285	if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
286		t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader)
287	}
288}
289
290// Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
291func TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
292	handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
293}
294
295// Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
296func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
297	handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
298}
299
300func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
301	st := newHandleStreamTest(t)
302
303	handleStream := func(s *Stream) {
304		st.ht.WriteStatus(s, status.New(statusCode, msg))
305	}
306	st.ht.HandleStreams(
307		func(s *Stream) { go handleStream(s) },
308		func(ctx context.Context, method string) context.Context { return ctx },
309	)
310	wantHeader := http.Header{
311		"Date":         nil,
312		"Content-Type": {"application/grpc"},
313		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
314		"Grpc-Status":  {fmt.Sprint(uint32(statusCode))},
315		"Grpc-Message": {encodeGrpcMessage(msg)},
316	}
317
318	if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
319		t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
320	}
321}
322
323func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
324	bodyr, bodyw := io.Pipe()
325	req := &http.Request{
326		ProtoMajor: 2,
327		Method:     "POST",
328		Header: http.Header{
329			"Content-Type": {"application/grpc"},
330			"Grpc-Timeout": {"200m"},
331		},
332		URL: &url.URL{
333			Path: "/service/foo.bar",
334		},
335		RequestURI: "/service/foo.bar",
336		Body:       bodyr,
337	}
338	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
339	ht, err := NewServerHandlerTransport(rw, req, nil)
340	if err != nil {
341		t.Fatal(err)
342	}
343	runStream := func(s *Stream) {
344		defer bodyw.Close()
345		select {
346		case <-s.ctx.Done():
347		case <-time.After(5 * time.Second):
348			t.Errorf("timeout waiting for ctx.Done")
349			return
350		}
351		err := s.ctx.Err()
352		if err != context.DeadlineExceeded {
353			t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
354			return
355		}
356		ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
357	}
358	ht.HandleStreams(
359		func(s *Stream) { go runStream(s) },
360		func(ctx context.Context, method string) context.Context { return ctx },
361	)
362	wantHeader := http.Header{
363		"Date":         nil,
364		"Content-Type": {"application/grpc"},
365		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
366		"Grpc-Status":  {"4"},
367		"Grpc-Message": {encodeGrpcMessage("too slow")},
368	}
369	if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
370		t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
371	}
372}
373
374// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
375// concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
376func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
377	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
378		if want := "/service/foo.bar"; s.method != want {
379			t.Errorf("stream method = %q; want %q", s.method, want)
380		}
381		st.bodyw.Close() // no body
382
383		var wg sync.WaitGroup
384		wg.Add(5)
385		for i := 0; i < 5; i++ {
386			go func() {
387				defer wg.Done()
388				st.ht.WriteStatus(s, status.New(codes.OK, ""))
389			}()
390		}
391		wg.Wait()
392	})
393}
394
395// TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
396// following "WriteStatus" does not panic writing to closed "writes" channel.
397func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
398	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
399		if want := "/service/foo.bar"; s.method != want {
400			t.Errorf("stream method = %q; want %q", s.method, want)
401		}
402		st.bodyw.Close() // no body
403
404		st.ht.WriteStatus(s, status.New(codes.OK, ""))
405		st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
406	})
407}
408
409func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
410	st := newHandleStreamTest(t)
411	st.ht.HandleStreams(
412		func(s *Stream) { go handleStream(st, s) },
413		func(ctx context.Context, method string) context.Context { return ctx },
414	)
415}
416
417func TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
418	errDetails := []proto.Message{
419		&epb.RetryInfo{
420			RetryDelay: &dpb.Duration{Seconds: 60},
421		},
422		&epb.ResourceInfo{
423			ResourceType: "foo bar",
424			ResourceName: "service.foo.bar",
425			Owner:        "User",
426		},
427	}
428
429	statusCode := codes.ResourceExhausted
430	msg := "you are being throttled"
431	st, err := status.New(statusCode, msg).WithDetails(errDetails...)
432	if err != nil {
433		t.Fatal(err)
434	}
435
436	stBytes, err := proto.Marshal(st.Proto())
437	if err != nil {
438		t.Fatal(err)
439	}
440
441	hst := newHandleStreamTest(t)
442	handleStream := func(s *Stream) {
443		hst.ht.WriteStatus(s, st)
444	}
445	hst.ht.HandleStreams(
446		func(s *Stream) { go handleStream(s) },
447		func(ctx context.Context, method string) context.Context { return ctx },
448	)
449	wantHeader := http.Header{
450		"Date":                    nil,
451		"Content-Type":            {"application/grpc"},
452		"Trailer":                 {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
453		"Grpc-Status":             {fmt.Sprint(uint32(statusCode))},
454		"Grpc-Message":            {encodeGrpcMessage(msg)},
455		"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
456	}
457
458	if !reflect.DeepEqual(hst.rw.HeaderMap, wantHeader) {
459		t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", hst.rw.HeaderMap, wantHeader)
460	}
461}
462