1package lightstep
2
3import (
4	"encoding/base64"
5	"io"
6	"io/ioutil"
7
8	"github.com/gogo/protobuf/proto"
9	lightstep "github.com/lightstep/lightstep-tracer-common/golang/gogo/lightsteppb"
10	"github.com/opentracing/opentracing-go"
11)
12
13// BinaryCarrier is used as the format parameter in inject/extract for lighstep binary propagation.
14const BinaryCarrier = opentracing.Binary
15
16var theBinaryPropagator binaryPropagator
17
18type binaryPropagator struct{}
19
20func (binaryPropagator) Inject(
21	spanContext opentracing.SpanContext,
22	opaqueCarrier interface{},
23) error {
24	sc, ok := spanContext.(SpanContext)
25	if !ok {
26		return opentracing.ErrInvalidSpanContext
27	}
28	data, err := proto.Marshal(&lightstep.BinaryCarrier{
29		BasicCtx: &lightstep.BasicTracerCarrier{
30			TraceId:      sc.TraceID,
31			SpanId:       sc.SpanID,
32			Sampled:      true,
33			BaggageItems: sc.Baggage,
34		},
35	})
36	if err != nil {
37		return err
38	}
39
40	switch carrier := opaqueCarrier.(type) {
41	case io.Writer:
42		buf := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
43		base64.StdEncoding.Encode(buf, data)
44		_, err = carrier.Write(buf)
45		return err
46	case *string:
47		*carrier = base64.StdEncoding.EncodeToString(data)
48	case *[]byte:
49		*carrier = make([]byte, base64.StdEncoding.EncodedLen(len(data)))
50		base64.StdEncoding.Encode(*carrier, data)
51	default:
52		return opentracing.ErrInvalidCarrier
53	}
54	return nil
55}
56
57func (binaryPropagator) Extract(
58	opaqueCarrier interface{},
59) (opentracing.SpanContext, error) {
60	var data []byte
61	var err error
62
63	// Decode from string, *string, *[]byte, or []byte
64	switch carrier := opaqueCarrier.(type) {
65	case io.Reader:
66		buf, err := ioutil.ReadAll(carrier)
67		if err != nil {
68			return nil, err
69		}
70		data, err = decodeBase64Bytes(buf)
71	case *string:
72		if carrier != nil {
73			data, err = base64.StdEncoding.DecodeString(*carrier)
74		}
75	case string:
76		data, err = base64.StdEncoding.DecodeString(carrier)
77	case *[]byte:
78		if carrier != nil {
79			data, err = decodeBase64Bytes(*carrier)
80		}
81	case []byte:
82		data, err = decodeBase64Bytes(carrier)
83	default:
84		return nil, opentracing.ErrInvalidCarrier
85	}
86	if err != nil {
87		return nil, err
88	}
89	pb := &lightstep.BinaryCarrier{}
90	if err := proto.Unmarshal(data, pb); err != nil {
91		return nil, err
92	}
93	if pb.BasicCtx == nil {
94		return nil, opentracing.ErrInvalidCarrier
95	}
96
97	return SpanContext{
98		TraceID: pb.BasicCtx.TraceId,
99		SpanID:  pb.BasicCtx.SpanId,
100		Baggage: pb.BasicCtx.BaggageItems,
101	}, nil
102}
103
104func decodeBase64Bytes(in []byte) ([]byte, error) {
105	data := make([]byte, base64.StdEncoding.DecodedLen(len(in)))
106	n, err := base64.StdEncoding.Decode(data, in)
107	if err != nil {
108		return nil, err
109	}
110	return data[:n], nil
111}
112