1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build !plan9
6// +build !plan9
7
8package ctxhttp
9
10import (
11	"context"
12	"io"
13	"io/ioutil"
14	"net/http"
15	"net/http/httptest"
16	"testing"
17	"time"
18)
19
20func TestGo17Context(t *testing.T) {
21	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22		io.WriteString(w, "ok")
23	}))
24	defer ts.Close()
25	ctx := context.Background()
26	resp, err := Get(ctx, http.DefaultClient, ts.URL)
27	if resp == nil || err != nil {
28		t.Fatalf("error received from client: %v %v", err, resp)
29	}
30	resp.Body.Close()
31}
32
33const (
34	requestDuration = 100 * time.Millisecond
35	requestBody     = "ok"
36)
37
38func okHandler(w http.ResponseWriter, r *http.Request) {
39	time.Sleep(requestDuration)
40	io.WriteString(w, requestBody)
41}
42
43func TestNoTimeout(t *testing.T) {
44	ts := httptest.NewServer(http.HandlerFunc(okHandler))
45	defer ts.Close()
46
47	ctx := context.Background()
48	res, err := Get(ctx, nil, ts.URL)
49	if err != nil {
50		t.Fatal(err)
51	}
52	defer res.Body.Close()
53	slurp, err := ioutil.ReadAll(res.Body)
54	if err != nil {
55		t.Fatal(err)
56	}
57	if string(slurp) != requestBody {
58		t.Errorf("body = %q; want %q", slurp, requestBody)
59	}
60}
61
62func TestCancelBeforeHeaders(t *testing.T) {
63	ctx, cancel := context.WithCancel(context.Background())
64
65	blockServer := make(chan struct{})
66	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
67		cancel()
68		<-blockServer
69		io.WriteString(w, requestBody)
70	}))
71	defer ts.Close()
72	defer close(blockServer)
73
74	res, err := Get(ctx, nil, ts.URL)
75	if err == nil {
76		res.Body.Close()
77		t.Fatal("Get returned unexpected nil error")
78	}
79	if err != context.Canceled {
80		t.Errorf("err = %v; want %v", err, context.Canceled)
81	}
82}
83
84func TestCancelAfterHangingRequest(t *testing.T) {
85	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
86		w.WriteHeader(http.StatusOK)
87		w.(http.Flusher).Flush()
88		<-w.(http.CloseNotifier).CloseNotify()
89	}))
90	defer ts.Close()
91
92	ctx, cancel := context.WithCancel(context.Background())
93	resp, err := Get(ctx, nil, ts.URL)
94	if err != nil {
95		t.Fatalf("unexpected error in Get: %v", err)
96	}
97
98	// Cancel befer reading the body.
99	// Reading Request.Body should fail, since the request was
100	// canceled before anything was written.
101	cancel()
102
103	done := make(chan struct{})
104
105	go func() {
106		b, err := ioutil.ReadAll(resp.Body)
107		if len(b) != 0 || err == nil {
108			t.Errorf(`Read got (%q, %v); want ("", error)`, b, err)
109		}
110		close(done)
111	}()
112
113	select {
114	case <-time.After(1 * time.Second):
115		t.Errorf("Test timed out")
116	case <-done:
117	}
118}
119