1// Copyright (c) 2017 Uber Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package remote
16
17import (
18	"encoding/json"
19	"io"
20	"net/http"
21	"net/http/httptest"
22	"net/url"
23	"testing"
24	"time"
25
26	"github.com/stretchr/testify/assert"
27	"github.com/stretchr/testify/require"
28	"github.com/uber/jaeger-lib/metrics"
29	"github.com/uber/jaeger-lib/metrics/metricstest"
30	"go.uber.org/atomic"
31
32	"github.com/uber/jaeger-client-go"
33	"github.com/uber/jaeger-client-go/internal/baggage"
34	thrift "github.com/uber/jaeger-client-go/thrift-gen/baggage"
35)
36
37const (
38	service      = "svc"
39	expectedKey  = "key"
40	expectedSize = 10
41)
42
43var (
44	testRestrictions = []*thrift.BaggageRestriction{
45		{BaggageKey: expectedKey, MaxValueLength: int32(expectedSize)},
46	}
47)
48
49var _ io.Closer = new(RestrictionManager) // API check
50
51type baggageHandler struct {
52	returnError  *atomic.Bool
53	restrictions []*thrift.BaggageRestriction
54}
55
56func (h *baggageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
57	if h.returnError.Load() {
58		w.WriteHeader(http.StatusInternalServerError)
59	} else {
60		bytes, _ := json.Marshal(h.restrictions)
61		w.Header().Add("Content-Type", "application/json")
62		w.Write(bytes)
63	}
64}
65
66func (h *baggageHandler) setReturnError(b bool) {
67	h.returnError.Store(b)
68}
69
70func withHTTPServer(
71	restrictions []*thrift.BaggageRestriction,
72	f func(
73		metrics *jaeger.Metrics,
74		factory *metricstest.Factory,
75		handler *baggageHandler,
76		server *httptest.Server,
77	),
78) {
79	factory := metricstest.NewFactory(0)
80	m := jaeger.NewMetrics(factory, nil)
81
82	handler := &baggageHandler{returnError: atomic.NewBool(true), restrictions: restrictions}
83	server := httptest.NewServer(handler)
84	defer server.Close()
85
86	f(m, factory, handler, server)
87}
88
89func TestNewRemoteRestrictionManager(t *testing.T) {
90	withHTTPServer(
91		testRestrictions,
92		func(
93			metrics *jaeger.Metrics,
94			factory *metricstest.Factory,
95			handler *baggageHandler,
96			server *httptest.Server,
97		) {
98			handler.setReturnError(false)
99			mgr := NewRestrictionManager(
100				service,
101				Options.HostPort(getHostPort(t, server.URL)),
102				Options.Metrics(metrics),
103				Options.Logger(jaeger.NullLogger),
104			)
105			defer mgr.Close()
106
107			for i := 0; i < 100; i++ {
108				if mgr.isReady() {
109					break
110				}
111				time.Sleep(time.Millisecond)
112			}
113			require.True(t, mgr.isReady())
114
115			restriction := mgr.GetRestriction(service, expectedKey)
116			assert.EqualValues(t, baggage.NewRestriction(true, expectedSize), restriction)
117
118			badKey := "bad-key"
119			restriction = mgr.GetRestriction(service, badKey)
120			assert.EqualValues(t, baggage.NewRestriction(false, 0), restriction)
121
122			factory.AssertCounterMetrics(t,
123				metricstest.ExpectedMetric{
124					Name:  "jaeger.tracer.baggage_restrictions_updates",
125					Tags:  map[string]string{"result": "ok"},
126					Value: 1,
127				},
128			)
129		})
130}
131
132func TestDenyBaggageOnInitializationError(t *testing.T) {
133	withHTTPServer(
134		testRestrictions,
135		func(
136			m *jaeger.Metrics,
137			factory *metricstest.Factory,
138			handler *baggageHandler,
139			server *httptest.Server,
140		) {
141			mgr := NewRestrictionManager(
142				service,
143				Options.DenyBaggageOnInitializationFailure(true),
144				Options.HostPort(getHostPort(t, server.URL)),
145				Options.Metrics(m),
146				Options.Logger(jaeger.NullLogger),
147			)
148			require.False(t, mgr.isReady())
149
150			metricName := "jaeger.tracer.baggage_restrictions_updates"
151			metricTags := map[string]string{"result": "err"}
152			key := metrics.GetKey(metricName, metricTags, "|", "=")
153			for i := 0; i < 100; i++ {
154				// wait until the async initialization call is complete
155				counters, _ := factory.Snapshot()
156				if _, ok := counters[key]; ok {
157					break
158				}
159				time.Sleep(time.Millisecond)
160			}
161
162			factory.AssertCounterMetrics(t,
163				metricstest.ExpectedMetric{
164					Name:  metricName,
165					Tags:  metricTags,
166					Value: 1,
167				},
168			)
169
170			// DenyBaggageOnInitializationFailure should not allow any key to be written
171			restriction := mgr.GetRestriction(service, expectedKey)
172			assert.EqualValues(t, baggage.NewRestriction(false, 0), restriction)
173
174			// have the http server return restrictions
175			handler.setReturnError(false)
176			mgr.updateRestrictions()
177
178			// Wait until manager retrieves baggage restrictions
179			for i := 0; i < 100; i++ {
180				if mgr.isReady() {
181					break
182				}
183				time.Sleep(time.Millisecond)
184			}
185			require.True(t, mgr.isReady())
186
187			restriction = mgr.GetRestriction(service, expectedKey)
188			assert.EqualValues(t, baggage.NewRestriction(true, expectedSize), restriction)
189		})
190}
191
192func TestAllowBaggageOnInitializationFailure(t *testing.T) {
193	withHTTPServer(
194		testRestrictions,
195		func(
196			metrics *jaeger.Metrics,
197			factory *metricstest.Factory,
198			handler *baggageHandler,
199			server *httptest.Server,
200		) {
201			mgr := NewRestrictionManager(
202				service,
203				Options.RefreshInterval(time.Millisecond),
204				Options.HostPort(getHostPort(t, server.URL)),
205				Options.Metrics(metrics),
206				Options.Logger(jaeger.NullLogger),
207			)
208			require.False(t, mgr.isReady())
209
210			// AllowBaggageOnInitializationFailure should allow any key to be written
211			restriction := mgr.GetRestriction(service, expectedKey)
212			assert.EqualValues(t, baggage.NewRestriction(true, 2048), restriction)
213		})
214}
215
216func getHostPort(t *testing.T, s string) string {
217	u, err := url.Parse(s)
218	require.NoError(t, err, "Failed to parse url")
219	return u.Host
220}
221