1// Unless explicitly stated otherwise all files in this repository are licensed
2// under the Apache License Version 2.0.
3// This product includes software developed at Datadog (https://www.datadoghq.com/).
4// Copyright 2016 Datadog, Inc.
5
6package twirp
7
8import (
9	"context"
10	"fmt"
11	"net"
12	"net/http"
13	"testing"
14
15	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
16	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
17	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
18	"gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig"
19
20	"github.com/stretchr/testify/assert"
21	"github.com/twitchtv/twirp"
22	"github.com/twitchtv/twirp/ctxsetters"
23	"github.com/twitchtv/twirp/example"
24)
25
26type mockClient struct {
27	code int
28	err  error
29}
30
31func (mc *mockClient) Do(req *http.Request) (*http.Response, error) {
32	if mc.err != nil {
33		return nil, mc.err
34	}
35	// the request body in a response should be nil based on the documentation of http.Response
36	req.Body = nil
37	res := &http.Response{
38		Status:     fmt.Sprintf("%d %s", mc.code, http.StatusText(mc.code)),
39		StatusCode: mc.code,
40		Proto:      req.Proto,
41		ProtoMajor: req.ProtoMajor,
42		ProtoMinor: req.ProtoMinor,
43		Request:    req,
44	}
45	return res, nil
46}
47
48func TestClient(t *testing.T) {
49	mt := mocktracer.Start()
50	defer mt.Stop()
51
52	ctx := context.Background()
53	ctx = ctxsetters.WithPackageName(ctx, "twirp.test")
54	ctx = ctxsetters.WithServiceName(ctx, "Example")
55	ctx = ctxsetters.WithMethodName(ctx, "Method")
56
57	url := "http://localhost/twirp/twirp.test/Example/Method"
58
59	t.Run("success", func(t *testing.T) {
60		defer mt.Reset()
61		assert := assert.New(t)
62
63		mc := &mockClient{code: 200}
64		wc := WrapClient(mc)
65
66		req, err := http.NewRequest("POST", url, nil)
67		assert.NoError(err)
68		req = req.WithContext(ctx)
69
70		_, err = wc.Do(req)
71		assert.NoError(err)
72
73		spans := mt.FinishedSpans()
74		assert.Len(spans, 1)
75		span := spans[0]
76		assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
77		assert.Equal("twirp.request", span.OperationName())
78		assert.Equal("twirp.request", span.Tag(ext.ResourceName))
79		assert.Equal("twirp.test", span.Tag("twirp.package"))
80		assert.Equal("Example", span.Tag("twirp.service"))
81		assert.Equal("Method", span.Tag("twirp.method"))
82		assert.Equal("200", span.Tag(ext.HTTPCode))
83	})
84
85	t.Run("server-error", func(t *testing.T) {
86		defer mt.Reset()
87		assert := assert.New(t)
88
89		mc := &mockClient{code: 500}
90		wc := WrapClient(mc)
91
92		req, err := http.NewRequest("POST", url, nil)
93		assert.NoError(err)
94		req = req.WithContext(ctx)
95
96		_, err = wc.Do(req)
97		assert.NoError(err)
98
99		spans := mt.FinishedSpans()
100		assert.Len(spans, 1)
101		span := spans[0]
102		assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
103		assert.Equal("twirp.request", span.OperationName())
104		assert.Equal("twirp.request", span.Tag(ext.ResourceName))
105		assert.Equal("twirp.test", span.Tag("twirp.package"))
106		assert.Equal("Example", span.Tag("twirp.service"))
107		assert.Equal("Method", span.Tag("twirp.method"))
108		assert.Equal("500", span.Tag(ext.HTTPCode))
109		assert.Equal(true, span.Tag(ext.Error).(bool))
110	})
111
112	t.Run("timeout", func(t *testing.T) {
113		defer mt.Reset()
114		assert := assert.New(t)
115
116		mc := &mockClient{err: context.DeadlineExceeded}
117		wc := WrapClient(mc)
118
119		req, err := http.NewRequest("POST", url, nil)
120		assert.NoError(err)
121		req = req.WithContext(ctx)
122
123		_, err = wc.Do(req)
124		assert.Equal(context.DeadlineExceeded, err)
125
126		spans := mt.FinishedSpans()
127		assert.Len(spans, 1)
128		span := spans[0]
129		assert.Equal(ext.SpanTypeHTTP, span.Tag(ext.SpanType))
130		assert.Equal("twirp.request", span.OperationName())
131		assert.Equal("twirp.request", span.Tag(ext.ResourceName))
132		assert.Equal("twirp.test", span.Tag("twirp.package"))
133		assert.Equal("Example", span.Tag("twirp.service"))
134		assert.Equal("Method", span.Tag("twirp.method"))
135		assert.Equal(context.DeadlineExceeded, span.Tag(ext.Error))
136	})
137}
138
139func mockServer(hooks *twirp.ServerHooks, assert *assert.Assertions, twerr twirp.Error) {
140	ctx := context.Background()
141	ctx = ctxsetters.WithPackageName(ctx, "twirp.test")
142	ctx = ctxsetters.WithServiceName(ctx, "Example")
143	ctx, err := hooks.RequestReceived(ctx)
144	assert.NoError(err)
145
146	ctx = ctxsetters.WithMethodName(ctx, "Method")
147	ctx, err = hooks.RequestRouted(ctx)
148	assert.NoError(err)
149
150	if twerr != nil {
151		ctx = ctxsetters.WithStatusCode(ctx, twirp.ServerHTTPStatusFromErrorCode(twerr.Code()))
152		ctx = hooks.Error(ctx, twerr)
153	} else {
154		ctx = hooks.ResponsePrepared(ctx)
155		ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK)
156	}
157
158	hooks.ResponseSent(ctx)
159}
160
161func TestServerHooks(t *testing.T) {
162	mt := mocktracer.Start()
163	defer mt.Stop()
164	hooks := NewServerHooks(WithServiceName("twirp-test"), WithAnalytics(true))
165
166	t.Run("success", func(t *testing.T) {
167		defer mt.Reset()
168		assert := assert.New(t)
169
170		mockServer(hooks, assert, nil)
171
172		spans := mt.FinishedSpans()
173		assert.Len(spans, 1)
174		span := spans[0]
175		assert.Equal(ext.SpanTypeWeb, span.Tag(ext.SpanType))
176		assert.Equal("twirp-test", span.Tag(ext.ServiceName))
177		assert.Equal("twirp.Example", span.OperationName())
178		assert.Equal("twirp.test", span.Tag("twirp.package"))
179		assert.Equal("Example", span.Tag("twirp.service"))
180		assert.Equal("Method", span.Tag("twirp.method"))
181		assert.Equal("200", span.Tag(ext.HTTPCode))
182	})
183
184	t.Run("error", func(t *testing.T) {
185		defer mt.Reset()
186		assert := assert.New(t)
187
188		mockServer(hooks, assert, twirp.InternalError("something bad or unexpected happened"))
189
190		spans := mt.FinishedSpans()
191		assert.Len(spans, 1)
192		span := spans[0]
193		assert.Equal(ext.SpanTypeWeb, span.Tag(ext.SpanType))
194		assert.Equal("twirp-test", span.Tag(ext.ServiceName))
195		assert.Equal("twirp.Example", span.OperationName())
196		assert.Equal("twirp.test", span.Tag("twirp.package"))
197		assert.Equal("Example", span.Tag("twirp.service"))
198		assert.Equal("Method", span.Tag("twirp.method"))
199		assert.Equal("500", span.Tag(ext.HTTPCode))
200		assert.Equal("twirp error internal: something bad or unexpected happened", span.Tag(ext.Error).(error).Error())
201	})
202
203	t.Run("chained", func(t *testing.T) {
204		defer mt.Reset()
205		assert := assert.New(t)
206
207		otherHooks := &twirp.ServerHooks{
208			RequestReceived: func(ctx context.Context) (context.Context, error) {
209				_, ctx = tracer.StartSpanFromContext(ctx, "other.span.name")
210				return ctx, nil
211			},
212			ResponseSent: func(ctx context.Context) {
213				span, ok := tracer.SpanFromContext(ctx)
214				if !ok {
215					return
216				}
217				span.Finish()
218			},
219		}
220		mockServer(twirp.ChainHooks(hooks, otherHooks), assert, twirp.InternalError("something bad or unexpected happened"))
221
222		spans := mt.FinishedSpans()
223		assert.Len(spans, 2)
224		span := spans[0]
225		assert.Equal(ext.SpanTypeWeb, span.Tag(ext.SpanType))
226		assert.Equal("twirp-test", span.Tag(ext.ServiceName))
227		assert.Equal("twirp.Example", span.OperationName())
228		assert.Equal("twirp.test", span.Tag("twirp.package"))
229		assert.Equal("Example", span.Tag("twirp.service"))
230		assert.Equal("Method", span.Tag("twirp.method"))
231		assert.Equal("500", span.Tag(ext.HTTPCode))
232		assert.Equal("twirp error internal: something bad or unexpected happened", span.Tag(ext.Error).(error).Error())
233
234		span = spans[1]
235		assert.Equal("other.span.name", span.OperationName())
236	})
237}
238
239func TestAnalyticsSettings(t *testing.T) {
240	assertRate := func(t *testing.T, mt mocktracer.Tracer, rate interface{}, opts ...Option) {
241		hooks := NewServerHooks(opts...)
242		assert := assert.New(t)
243		mockServer(hooks, assert, nil)
244
245		spans := mt.FinishedSpans()
246		assert.Len(spans, 1)
247		s := spans[0]
248		assert.Equal(rate, s.Tag(ext.EventSampleRate))
249	}
250
251	t.Run("defaults", func(t *testing.T) {
252		mt := mocktracer.Start()
253		defer mt.Stop()
254
255		assertRate(t, mt, nil)
256	})
257
258	t.Run("global", func(t *testing.T) {
259		mt := mocktracer.Start()
260		defer mt.Stop()
261
262		rate := globalconfig.AnalyticsRate()
263		defer globalconfig.SetAnalyticsRate(rate)
264		globalconfig.SetAnalyticsRate(0.4)
265
266		assertRate(t, mt, 0.4)
267	})
268
269	t.Run("enabled", func(t *testing.T) {
270		mt := mocktracer.Start()
271		defer mt.Stop()
272
273		assertRate(t, mt, 1.0, WithAnalytics(true))
274	})
275
276	t.Run("disabled", func(t *testing.T) {
277		mt := mocktracer.Start()
278		defer mt.Stop()
279
280		assertRate(t, mt, nil, WithAnalytics(false))
281	})
282
283	t.Run("override", func(t *testing.T) {
284		mt := mocktracer.Start()
285		defer mt.Stop()
286
287		rate := globalconfig.AnalyticsRate()
288		defer globalconfig.SetAnalyticsRate(rate)
289		globalconfig.SetAnalyticsRate(0.4)
290
291		assertRate(t, mt, 0.23, WithAnalyticsRate(0.23))
292	})
293}
294
295type notifyListener struct {
296	net.Listener
297	ch chan<- struct{}
298}
299
300func (n *notifyListener) Accept() (c net.Conn, err error) {
301	if n.ch != nil {
302		close(n.ch)
303		n.ch = nil
304	}
305	return n.Listener.Accept()
306}
307
308type haberdasher int32
309
310func (h haberdasher) MakeHat(ctx context.Context, size *example.Size) (*example.Hat, error) {
311	if size.Inches != int32(h) {
312		return nil, twirp.InvalidArgumentError("Inches", "Only size of %d is allowed")
313	}
314	hat := &example.Hat{
315		Size:  size.Inches,
316		Color: "purple",
317		Name:  "doggie beanie",
318	}
319	return hat, nil
320}
321
322func TestHaberdash(t *testing.T) {
323	mt := mocktracer.Start()
324	defer mt.Stop()
325	assert := assert.New(t)
326
327	l, err := net.Listen("tcp4", "127.0.0.1:0")
328	assert.NoError(err)
329	defer l.Close()
330
331	readyCh := make(chan struct{})
332	nl := &notifyListener{Listener: l, ch: readyCh}
333
334	server := WrapServer(example.NewHaberdasherServer(haberdasher(6), NewServerHooks()))
335	errCh := make(chan error)
336	go func() {
337		err := http.Serve(nl, server)
338		if err != nil && err != http.ErrServerClosed {
339			errCh <- err
340		}
341		close(errCh)
342	}()
343
344	select {
345	case <-readyCh:
346		break
347	case err := <-errCh:
348		assert.FailNow("server not started", err)
349	}
350
351	client := example.NewHaberdasherJSONClient("http://"+nl.Addr().String(), WrapClient(&http.Client{}))
352	hat, err := client.MakeHat(context.Background(), &example.Size{Inches: 6})
353	assert.NoError(err)
354	assert.Equal("purple", hat.Color)
355
356	spans := mt.FinishedSpans()
357	assert.Len(spans, 3)
358	assert.Equal(ext.SpanTypeWeb, spans[0].Tag(ext.SpanType))
359	assert.Equal(ext.SpanTypeWeb, spans[1].Tag(ext.SpanType))
360	assert.Equal(ext.SpanTypeHTTP, spans[2].Tag(ext.SpanType))
361}
362