1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build !plan9
6
7package ctxhttp
8
9import (
10	"context"
11	"io"
12	"io/ioutil"
13	"net/http"
14	"net/http/httptest"
15	"testing"
16	"time"
17)
18
19func TestGo17Context(t *testing.T) {
20	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21		io.WriteString(w, "ok")
22	}))
23	defer ts.Close()
24	ctx := context.Background()
25	resp, err := Get(ctx, http.DefaultClient, ts.URL)
26	if resp == nil || err != nil {
27		t.Fatalf("error received from client: %v %v", err, resp)
28	}
29	resp.Body.Close()
30}
31
32const (
33	requestDuration = 100 * time.Millisecond
34	requestBody     = "ok"
35)
36
37func okHandler(w http.ResponseWriter, r *http.Request) {
38	time.Sleep(requestDuration)
39	io.WriteString(w, requestBody)
40}
41
42func TestNoTimeout(t *testing.T) {
43	ts := httptest.NewServer(http.HandlerFunc(okHandler))
44	defer ts.Close()
45
46	ctx := context.Background()
47	res, err := Get(ctx, nil, ts.URL)
48	if err != nil {
49		t.Fatal(err)
50	}
51	defer res.Body.Close()
52	slurp, err := ioutil.ReadAll(res.Body)
53	if err != nil {
54		t.Fatal(err)
55	}
56	if string(slurp) != requestBody {
57		t.Errorf("body = %q; want %q", slurp, requestBody)
58	}
59}
60
61func TestCancelBeforeHeaders(t *testing.T) {
62	ctx, cancel := context.WithCancel(context.Background())
63
64	blockServer := make(chan struct{})
65	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
66		cancel()
67		<-blockServer
68		io.WriteString(w, requestBody)
69	}))
70	defer ts.Close()
71	defer close(blockServer)
72
73	res, err := Get(ctx, nil, ts.URL)
74	if err == nil {
75		res.Body.Close()
76		t.Fatal("Get returned unexpected nil error")
77	}
78	if err != context.Canceled {
79		t.Errorf("err = %v; want %v", err, context.Canceled)
80	}
81}
82
83func TestCancelAfterHangingRequest(t *testing.T) {
84	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
85		w.WriteHeader(http.StatusOK)
86		w.(http.Flusher).Flush()
87		<-w.(http.CloseNotifier).CloseNotify()
88	}))
89	defer ts.Close()
90
91	ctx, cancel := context.WithCancel(context.Background())
92	resp, err := Get(ctx, nil, ts.URL)
93	if err != nil {
94		t.Fatalf("unexpected error in Get: %v", err)
95	}
96
97	// Cancel befer reading the body.
98	// Reading Request.Body should fail, since the request was
99	// canceled before anything was written.
100	cancel()
101
102	done := make(chan struct{})
103
104	go func() {
105		b, err := ioutil.ReadAll(resp.Body)
106		if len(b) != 0 || err == nil {
107			t.Errorf(`Read got (%q, %v); want ("", error)`, b, err)
108		}
109		close(done)
110	}()
111
112	select {
113	case <-time.After(1 * time.Second):
114		t.Errorf("Test timed out")
115	case <-done:
116	}
117}
118