1// Copyright 2018, OpenCensus Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package tracecontext contains HTTP propagator for TraceContext standard.
16// See https://github.com/w3c/distributed-tracing for more information.
17package tracecontext // import "go.opencensus.io/plugin/ochttp/propagation/tracecontext"
18
19import (
20	"encoding/hex"
21	"fmt"
22	"net/http"
23	"net/textproto"
24	"regexp"
25	"strings"
26
27	"go.opencensus.io/trace"
28	"go.opencensus.io/trace/propagation"
29	"go.opencensus.io/trace/tracestate"
30)
31
32const (
33	supportedVersion  = 0
34	maxVersion        = 254
35	maxTracestateLen  = 512
36	traceparentHeader = "traceparent"
37	tracestateHeader  = "tracestate"
38	trimOWSRegexFmt   = `^[\x09\x20]*(.*[^\x20\x09])[\x09\x20]*$`
39)
40
41var trimOWSRegExp = regexp.MustCompile(trimOWSRegexFmt)
42
43var _ propagation.HTTPFormat = (*HTTPFormat)(nil)
44
45// HTTPFormat implements the TraceContext trace propagation format.
46type HTTPFormat struct{}
47
48// SpanContextFromRequest extracts a span context from incoming requests.
49func (f *HTTPFormat) SpanContextFromRequest(req *http.Request) (sc trace.SpanContext, ok bool) {
50	h, ok := getRequestHeader(req, traceparentHeader, false)
51	if !ok {
52		return trace.SpanContext{}, false
53	}
54	sections := strings.Split(h, "-")
55	if len(sections) < 4 {
56		return trace.SpanContext{}, false
57	}
58
59	if len(sections[0]) != 2 {
60		return trace.SpanContext{}, false
61	}
62	ver, err := hex.DecodeString(sections[0])
63	if err != nil {
64		return trace.SpanContext{}, false
65	}
66	version := int(ver[0])
67	if version > maxVersion {
68		return trace.SpanContext{}, false
69	}
70
71	if version == 0 && len(sections) != 4 {
72		return trace.SpanContext{}, false
73	}
74
75	if len(sections[1]) != 32 {
76		return trace.SpanContext{}, false
77	}
78	tid, err := hex.DecodeString(sections[1])
79	if err != nil {
80		return trace.SpanContext{}, false
81	}
82	copy(sc.TraceID[:], tid)
83
84	if len(sections[2]) != 16 {
85		return trace.SpanContext{}, false
86	}
87	sid, err := hex.DecodeString(sections[2])
88	if err != nil {
89		return trace.SpanContext{}, false
90	}
91	copy(sc.SpanID[:], sid)
92
93	opts, err := hex.DecodeString(sections[3])
94	if err != nil || len(opts) < 1 {
95		return trace.SpanContext{}, false
96	}
97	sc.TraceOptions = trace.TraceOptions(opts[0])
98
99	// Don't allow all zero trace or span ID.
100	if sc.TraceID == [16]byte{} || sc.SpanID == [8]byte{} {
101		return trace.SpanContext{}, false
102	}
103
104	sc.Tracestate = tracestateFromRequest(req)
105	return sc, true
106}
107
108// getRequestHeader returns a combined header field according to RFC7230 section 3.2.2.
109// If commaSeparated is true, multiple header fields with the same field name using be
110// combined using ",".
111// If no header was found using the given name, "ok" would be false.
112// If more than one headers was found using the given name, while commaSeparated is false,
113// "ok" would be false.
114func getRequestHeader(req *http.Request, name string, commaSeparated bool) (hdr string, ok bool) {
115	v := req.Header[textproto.CanonicalMIMEHeaderKey(name)]
116	switch len(v) {
117	case 0:
118		return "", false
119	case 1:
120		return v[0], true
121	default:
122		return strings.Join(v, ","), commaSeparated
123	}
124}
125
126// TODO(rghetia): return an empty Tracestate when parsing tracestate header encounters an error.
127// Revisit to return additional boolean value to indicate parsing error when following issues
128// are resolved.
129// https://github.com/w3c/distributed-tracing/issues/172
130// https://github.com/w3c/distributed-tracing/issues/175
131func tracestateFromRequest(req *http.Request) *tracestate.Tracestate {
132	h, _ := getRequestHeader(req, tracestateHeader, true)
133	if h == "" {
134		return nil
135	}
136
137	var entries []tracestate.Entry
138	pairs := strings.Split(h, ",")
139	hdrLenWithoutOWS := len(pairs) - 1 // Number of commas
140	for _, pair := range pairs {
141		matches := trimOWSRegExp.FindStringSubmatch(pair)
142		if matches == nil {
143			return nil
144		}
145		pair = matches[1]
146		hdrLenWithoutOWS += len(pair)
147		if hdrLenWithoutOWS > maxTracestateLen {
148			return nil
149		}
150		kv := strings.Split(pair, "=")
151		if len(kv) != 2 {
152			return nil
153		}
154		entries = append(entries, tracestate.Entry{Key: kv[0], Value: kv[1]})
155	}
156	ts, err := tracestate.New(nil, entries...)
157	if err != nil {
158		return nil
159	}
160
161	return ts
162}
163
164func tracestateToRequest(sc trace.SpanContext, req *http.Request) {
165	var pairs = make([]string, 0, len(sc.Tracestate.Entries()))
166	if sc.Tracestate != nil {
167		for _, entry := range sc.Tracestate.Entries() {
168			pairs = append(pairs, strings.Join([]string{entry.Key, entry.Value}, "="))
169		}
170		h := strings.Join(pairs, ",")
171
172		if h != "" && len(h) <= maxTracestateLen {
173			req.Header.Set(tracestateHeader, h)
174		}
175	}
176}
177
178// SpanContextToRequest modifies the given request to include traceparent and tracestate headers.
179func (f *HTTPFormat) SpanContextToRequest(sc trace.SpanContext, req *http.Request) {
180	h := fmt.Sprintf("%x-%x-%x-%x",
181		[]byte{supportedVersion},
182		sc.TraceID[:],
183		sc.SpanID[:],
184		[]byte{byte(sc.TraceOptions)})
185	req.Header.Set(traceparentHeader, h)
186	tracestateToRequest(sc, req)
187}
188