1// Copyright 2018, OpenCensus Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package ochttp_test
16
17import (
18	"fmt"
19	"io/ioutil"
20	"net/http"
21	"net/http/httptest"
22	"strings"
23	"sync"
24	"testing"
25
26	"go.opencensus.io/plugin/ochttp"
27	"go.opencensus.io/stats/view"
28	"go.opencensus.io/trace"
29)
30
31const reqCount = 5
32
33func TestClientNew(t *testing.T) {
34	server := httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
35		resp.Write([]byte("Hello, world!"))
36	}))
37	defer server.Close()
38
39	if err := view.Register(
40		ochttp.ClientSentBytesDistribution,
41		ochttp.ClientReceivedBytesDistribution,
42		ochttp.ClientRoundtripLatencyDistribution,
43		ochttp.ClientCompletedCount,
44	); err != nil {
45		t.Fatalf("Failed to register ochttp.DefaultClientViews error: %v", err)
46	}
47
48	views := []string{
49		"opencensus.io/http/client/sent_bytes",
50		"opencensus.io/http/client/received_bytes",
51		"opencensus.io/http/client/roundtrip_latency",
52		"opencensus.io/http/client/completed_count",
53	}
54	for _, name := range views {
55		v := view.Find(name)
56		if v == nil {
57			t.Errorf("view not found %q", name)
58			continue
59		}
60	}
61
62	var wg sync.WaitGroup
63	var tr ochttp.Transport
64	errs := make(chan error, reqCount)
65	wg.Add(reqCount)
66
67	for i := 0; i < reqCount; i++ {
68		go func() {
69			defer wg.Done()
70			req, err := http.NewRequest("POST", server.URL, strings.NewReader("req-body"))
71			if err != nil {
72				errs <- fmt.Errorf("error creating request: %v", err)
73			}
74			resp, err := tr.RoundTrip(req)
75			if err != nil {
76				errs <- fmt.Errorf("response error: %v", err)
77			}
78			if err := resp.Body.Close(); err != nil {
79				errs <- fmt.Errorf("error closing response body: %v", err)
80			}
81			if got, want := resp.StatusCode, 200; got != want {
82				errs <- fmt.Errorf("resp.StatusCode=%d; wantCount %d", got, want)
83			}
84		}()
85	}
86
87	go func() {
88		wg.Wait()
89		close(errs)
90	}()
91
92	for err := range errs {
93		if err != nil {
94			t.Fatal(err)
95		}
96	}
97
98	for _, viewName := range views {
99		v := view.Find(viewName)
100		if v == nil {
101			t.Errorf("view not found %q", viewName)
102			continue
103		}
104		rows, err := view.RetrieveData(v.Name)
105		if err != nil {
106			t.Error(err)
107			continue
108		}
109		if got, want := len(rows), 1; got != want {
110			t.Errorf("len(%q) = %d; want %d", viewName, got, want)
111			continue
112		}
113		data := rows[0].Data
114		var count int64
115		switch data := data.(type) {
116		case *view.CountData:
117			count = data.Value
118		case *view.DistributionData:
119			count = data.Count
120		default:
121			t.Errorf("Unknown data type: %v", data)
122			continue
123		}
124		if got := count; got != reqCount {
125			t.Fatalf("%s = %d; want %d", viewName, got, reqCount)
126		}
127	}
128}
129
130func TestClientOld(t *testing.T) {
131	server := httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
132		resp.Write([]byte("Hello, world!"))
133	}))
134	defer server.Close()
135
136	if err := view.Register(ochttp.DefaultClientViews...); err != nil {
137		t.Fatalf("Failed to register ochttp.DefaultClientViews error: %v", err)
138	}
139
140	views := []string{
141		"opencensus.io/http/client/request_count",
142		"opencensus.io/http/client/latency",
143		"opencensus.io/http/client/request_bytes",
144		"opencensus.io/http/client/response_bytes",
145	}
146	for _, name := range views {
147		v := view.Find(name)
148		if v == nil {
149			t.Errorf("view not found %q", name)
150			continue
151		}
152	}
153
154	var wg sync.WaitGroup
155	var tr ochttp.Transport
156	errs := make(chan error, reqCount)
157	wg.Add(reqCount)
158
159	for i := 0; i < reqCount; i++ {
160		go func() {
161			defer wg.Done()
162			req, err := http.NewRequest("POST", server.URL, strings.NewReader("req-body"))
163			if err != nil {
164				errs <- fmt.Errorf("error creating request: %v", err)
165			}
166			resp, err := tr.RoundTrip(req)
167			if err != nil {
168				errs <- fmt.Errorf("response error: %v", err)
169			}
170			if err := resp.Body.Close(); err != nil {
171				errs <- fmt.Errorf("error closing response body: %v", err)
172			}
173			if got, want := resp.StatusCode, 200; got != want {
174				errs <- fmt.Errorf("resp.StatusCode=%d; wantCount %d", got, want)
175			}
176		}()
177	}
178
179	go func() {
180		wg.Wait()
181		close(errs)
182	}()
183
184	for err := range errs {
185		if err != nil {
186			t.Fatal(err)
187		}
188	}
189
190	for _, viewName := range views {
191		v := view.Find(viewName)
192		if v == nil {
193			t.Errorf("view not found %q", viewName)
194			continue
195		}
196		rows, err := view.RetrieveData(v.Name)
197		if err != nil {
198			t.Error(err)
199			continue
200		}
201		if got, want := len(rows), 1; got != want {
202			t.Errorf("len(%q) = %d; want %d", viewName, got, want)
203			continue
204		}
205		data := rows[0].Data
206		var count int64
207		switch data := data.(type) {
208		case *view.CountData:
209			count = data.Value
210		case *view.DistributionData:
211			count = data.Count
212		default:
213			t.Errorf("Unknown data type: %v", data)
214			continue
215		}
216		if got := count; got != reqCount {
217			t.Fatalf("%s = %d; want %d", viewName, got, reqCount)
218		}
219	}
220}
221
222var noTrace = trace.StartOptions{Sampler: trace.NeverSample()}
223
224func BenchmarkTransportNoTrace(b *testing.B) {
225	benchmarkClientServer(b, &ochttp.Transport{StartOptions: noTrace})
226}
227
228func BenchmarkTransport(b *testing.B) {
229	benchmarkClientServer(b, &ochttp.Transport{})
230}
231
232func benchmarkClientServer(b *testing.B, transport *ochttp.Transport) {
233	b.ReportAllocs()
234	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
235		fmt.Fprintf(rw, "Hello world.\n")
236	}))
237	defer ts.Close()
238	transport.StartOptions.Sampler = trace.AlwaysSample()
239	var client http.Client
240	client.Transport = transport
241	b.ResetTimer()
242
243	for i := 0; i < b.N; i++ {
244		res, err := client.Get(ts.URL)
245		if err != nil {
246			b.Fatalf("Get: %v", err)
247		}
248		all, err := ioutil.ReadAll(res.Body)
249		res.Body.Close()
250		if err != nil {
251			b.Fatal("ReadAll:", err)
252		}
253		body := string(all)
254		if body != "Hello world.\n" {
255			b.Fatal("Got body:", body)
256		}
257	}
258}
259
260func BenchmarkTransportParallel64NoTrace(b *testing.B) {
261	benchmarkClientServerParallel(b, 64, &ochttp.Transport{StartOptions: noTrace})
262}
263
264func BenchmarkTransportParallel64(b *testing.B) {
265	benchmarkClientServerParallel(b, 64, &ochttp.Transport{})
266}
267
268func benchmarkClientServerParallel(b *testing.B, parallelism int, transport *ochttp.Transport) {
269	b.ReportAllocs()
270	ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
271		fmt.Fprintf(rw, "Hello world.\n")
272	}))
273	defer ts.Close()
274
275	var c http.Client
276	transport.Base = &http.Transport{
277		MaxIdleConns:        parallelism,
278		MaxIdleConnsPerHost: parallelism,
279	}
280	transport.StartOptions.Sampler = trace.AlwaysSample()
281	c.Transport = transport
282
283	b.ResetTimer()
284
285	// TODO(ramonza): replace with b.RunParallel (it didn't work when I tried)
286
287	var wg sync.WaitGroup
288	wg.Add(parallelism)
289	for i := 0; i < parallelism; i++ {
290		iterations := b.N / parallelism
291		if i == 0 {
292			iterations += b.N % parallelism
293		}
294		go func() {
295			defer wg.Done()
296			for j := 0; j < iterations; j++ {
297				res, err := c.Get(ts.URL)
298				if err != nil {
299					b.Logf("Get: %v", err)
300					return
301				}
302				all, err := ioutil.ReadAll(res.Body)
303				res.Body.Close()
304				if err != nil {
305					b.Logf("ReadAll: %v", err)
306					return
307				}
308				body := string(all)
309				if body != "Hello world.\n" {
310					panic("Got body: " + body)
311				}
312			}
313		}()
314	}
315	wg.Wait()
316}
317