1// Copyright 2018, OpenCensus Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package tracecontext
16
17import (
18	"fmt"
19	"net/http"
20	"reflect"
21	"strings"
22	"testing"
23
24	"go.opencensus.io/trace"
25	"go.opencensus.io/trace/tracestate"
26)
27
28var (
29	tpHeader        = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
30	traceID         = trace.TraceID{75, 249, 47, 53, 119, 179, 77, 166, 163, 206, 146, 157, 14, 14, 71, 54}
31	spanID          = trace.SpanID{0, 240, 103, 170, 11, 169, 2, 183}
32	traceOpt        = trace.TraceOptions(1)
33	oversizeValue   = strings.Repeat("a", maxTracestateLen/2)
34	oversizeEntry1  = tracestate.Entry{Key: "foo", Value: oversizeValue}
35	oversizeEntry2  = tracestate.Entry{Key: "hello", Value: oversizeValue}
36	entry1          = tracestate.Entry{Key: "foo", Value: "bar"}
37	entry2          = tracestate.Entry{Key: "hello", Value: "world   example"}
38	oversizeTs, _   = tracestate.New(nil, oversizeEntry1, oversizeEntry2)
39	defaultTs, _    = tracestate.New(nil, nil...)
40	nonDefaultTs, _ = tracestate.New(nil, entry1, entry2)
41)
42
43func TestHTTPFormat_FromRequest(t *testing.T) {
44	tests := []struct {
45		name   string
46		header string
47		wantSc trace.SpanContext
48		wantOk bool
49	}{
50		{
51			name:   "future version",
52			header: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
53			wantSc: trace.SpanContext{
54				TraceID:      trace.TraceID{75, 249, 47, 53, 119, 179, 77, 166, 163, 206, 146, 157, 14, 14, 71, 54},
55				SpanID:       trace.SpanID{0, 240, 103, 170, 11, 169, 2, 183},
56				TraceOptions: trace.TraceOptions(1),
57			},
58			wantOk: true,
59		},
60		{
61			name:   "zero trace ID and span ID",
62			header: "00-00000000000000000000000000000000-0000000000000000-01",
63			wantSc: trace.SpanContext{},
64			wantOk: false,
65		},
66		{
67			name:   "valid header",
68			header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
69			wantSc: trace.SpanContext{
70				TraceID:      trace.TraceID{75, 249, 47, 53, 119, 179, 77, 166, 163, 206, 146, 157, 14, 14, 71, 54},
71				SpanID:       trace.SpanID{0, 240, 103, 170, 11, 169, 2, 183},
72				TraceOptions: trace.TraceOptions(1),
73			},
74			wantOk: true,
75		},
76		{
77			name:   "missing options",
78			header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7",
79			wantSc: trace.SpanContext{},
80			wantOk: false,
81		},
82		{
83			name:   "empty options",
84			header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-",
85			wantSc: trace.SpanContext{},
86			wantOk: false,
87		},
88	}
89
90	f := &HTTPFormat{}
91	for _, tt := range tests {
92		t.Run(tt.name, func(t *testing.T) {
93			req, _ := http.NewRequest("GET", "http://example.com", nil)
94			req.Header.Set("traceparent", tt.header)
95
96			gotSc, gotOk := f.SpanContextFromRequest(req)
97			if !reflect.DeepEqual(gotSc, tt.wantSc) {
98				t.Errorf("HTTPFormat.FromRequest() gotSc = %v, want %v", gotSc, tt.wantSc)
99			}
100			if gotOk != tt.wantOk {
101				t.Errorf("HTTPFormat.FromRequest() gotOk = %v, want %v", gotOk, tt.wantOk)
102			}
103
104			gotSc, gotOk = f.SpanContextFromHeaders(tt.header, "")
105			if !reflect.DeepEqual(gotSc, tt.wantSc) {
106				t.Errorf("HTTPFormat.SpanContextFromHeaders() gotTs = %v, want %v", gotSc.Tracestate, tt.wantSc.Tracestate)
107			}
108			if gotOk != tt.wantOk {
109				t.Errorf("HTTPFormat.SpanContextFromHeaders() gotOk = %v, want %v", gotOk, tt.wantOk)
110			}
111		})
112	}
113}
114
115func TestHTTPFormat_ToRequest(t *testing.T) {
116	tests := []struct {
117		sc         trace.SpanContext
118		wantHeader string
119	}{
120		{
121			sc: trace.SpanContext{
122				TraceID:      trace.TraceID{75, 249, 47, 53, 119, 179, 77, 166, 163, 206, 146, 157, 14, 14, 71, 54},
123				SpanID:       trace.SpanID{0, 240, 103, 170, 11, 169, 2, 183},
124				TraceOptions: trace.TraceOptions(1),
125			},
126			wantHeader: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
127		},
128	}
129	for _, tt := range tests {
130		t.Run(tt.wantHeader, func(t *testing.T) {
131			f := &HTTPFormat{}
132			req, _ := http.NewRequest("GET", "http://example.com", nil)
133			f.SpanContextToRequest(tt.sc, req)
134
135			h := req.Header.Get("traceparent")
136			if got, want := h, tt.wantHeader; got != want {
137				t.Errorf("HTTPFormat.ToRequest() header = %v, want %v", got, want)
138			}
139
140			gotTp, _ := f.SpanContextToHeaders(tt.sc)
141			if gotTp != tt.wantHeader {
142				t.Errorf("HTTPFormat.SpanContextToHeaders() tracestate header = %v, want %v", gotTp, tt.wantHeader)
143			}
144		})
145	}
146}
147
148func TestHTTPFormatTracestate_FromRequest(t *testing.T) {
149	scWithNonDefaultTracestate := trace.SpanContext{
150		TraceID:      traceID,
151		SpanID:       spanID,
152		TraceOptions: traceOpt,
153		Tracestate:   nonDefaultTs,
154	}
155
156	scWithDefaultTracestate := trace.SpanContext{
157		TraceID:      traceID,
158		SpanID:       spanID,
159		TraceOptions: traceOpt,
160		Tracestate:   defaultTs,
161	}
162
163	tests := []struct {
164		name     string
165		tpHeader string
166		tsHeader string
167		wantSc   trace.SpanContext
168		wantOk   bool
169	}{
170		{
171			name:     "tracestate invalid entries delimiter",
172			tpHeader: tpHeader,
173			tsHeader: "foo=bar;hello=world",
174			wantSc:   scWithDefaultTracestate,
175			wantOk:   true,
176		},
177		{
178			name:     "tracestate invalid key-value delimiter",
179			tpHeader: tpHeader,
180			tsHeader: "foo=bar,hello-world",
181			wantSc:   scWithDefaultTracestate,
182			wantOk:   true,
183		},
184		{
185			name:     "tracestate invalid value character",
186			tpHeader: tpHeader,
187			tsHeader: "foo=bar,hello=world   example   \u00a0  ",
188			wantSc:   scWithDefaultTracestate,
189			wantOk:   true,
190		},
191		{
192			name:     "tracestate blank key-value",
193			tpHeader: tpHeader,
194			tsHeader: "foo=bar,    ",
195			wantSc:   scWithDefaultTracestate,
196			wantOk:   true,
197		},
198		{
199			name:     "tracestate oversize header",
200			tpHeader: tpHeader,
201			tsHeader: fmt.Sprintf("foo=%s,hello=%s", oversizeValue, oversizeValue),
202			wantSc:   scWithDefaultTracestate,
203			wantOk:   true,
204		},
205		{
206			name:     "tracestate valid",
207			tpHeader: tpHeader,
208			tsHeader: "foo=bar   ,   hello=world   example",
209			wantSc:   scWithNonDefaultTracestate,
210			wantOk:   true,
211		},
212	}
213
214	f := &HTTPFormat{}
215	for _, tt := range tests {
216		t.Run(tt.name, func(t *testing.T) {
217			req, _ := http.NewRequest("GET", "http://example.com", nil)
218			req.Header.Set("traceparent", tt.tpHeader)
219			req.Header.Set("tracestate", tt.tsHeader)
220
221			gotSc, gotOk := f.SpanContextFromRequest(req)
222			if !reflect.DeepEqual(gotSc, tt.wantSc) {
223				t.Errorf("HTTPFormat.FromRequest() gotTs = %v, want %v", gotSc.Tracestate, tt.wantSc.Tracestate)
224			}
225			if gotOk != tt.wantOk {
226				t.Errorf("HTTPFormat.FromRequest() gotOk = %v, want %v", gotOk, tt.wantOk)
227			}
228
229			gotSc, gotOk = f.SpanContextFromHeaders(tt.tpHeader, tt.tsHeader)
230			if !reflect.DeepEqual(gotSc, tt.wantSc) {
231				t.Errorf("HTTPFormat.SpanContextFromHeaders() gotTs = %v, want %v", gotSc.Tracestate, tt.wantSc.Tracestate)
232			}
233			if gotOk != tt.wantOk {
234				t.Errorf("HTTPFormat.SpanContextFromHeaders() gotOk = %v, want %v", gotOk, tt.wantOk)
235			}
236		})
237	}
238}
239
240func TestHTTPFormatTracestate_ToRequest(t *testing.T) {
241	tests := []struct {
242		name       string
243		sc         trace.SpanContext
244		wantHeader string
245	}{
246		{
247			name: "valid span context with default tracestate",
248			sc: trace.SpanContext{
249				TraceID:      traceID,
250				SpanID:       spanID,
251				TraceOptions: traceOpt,
252			},
253			wantHeader: "",
254		},
255		{
256			name: "valid span context with non default tracestate",
257			sc: trace.SpanContext{
258				TraceID:      traceID,
259				SpanID:       spanID,
260				TraceOptions: traceOpt,
261				Tracestate:   nonDefaultTs,
262			},
263			wantHeader: "foo=bar,hello=world   example",
264		},
265		{
266			name: "valid span context with oversize tracestate",
267			sc: trace.SpanContext{
268				TraceID:      traceID,
269				SpanID:       spanID,
270				TraceOptions: traceOpt,
271				Tracestate:   oversizeTs,
272			},
273			wantHeader: "",
274		},
275	}
276	for _, tt := range tests {
277		t.Run(tt.name, func(t *testing.T) {
278			f := &HTTPFormat{}
279			req, _ := http.NewRequest("GET", "http://example.com", nil)
280			f.SpanContextToRequest(tt.sc, req)
281
282			h := req.Header.Get("tracestate")
283			if got, want := h, tt.wantHeader; got != want {
284				t.Errorf("HTTPFormat.ToRequest() tracestate header = %v, want %v", got, want)
285			}
286
287			_, gotTs := f.SpanContextToHeaders(tt.sc)
288			if gotTs != tt.wantHeader {
289				t.Errorf("HTTPFormat.SpanContextToHeaders() tracestate header = %v, want %v", gotTs, tt.wantHeader)
290			}
291		})
292	}
293}
294