1// Copyright (c) 2017 Uber Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package jaeger
16
17import (
18	"errors"
19	"fmt"
20	"testing"
21	"time"
22
23	"github.com/stretchr/testify/assert"
24	"github.com/stretchr/testify/require"
25	"github.com/uber/jaeger-lib/metrics"
26	mTestutils "github.com/uber/jaeger-lib/metrics/metricstest"
27
28	"github.com/uber/jaeger-client-go/log"
29	"github.com/uber/jaeger-client-go/testutils"
30	"github.com/uber/jaeger-client-go/thrift-gen/sampling"
31)
32
33func TestRemotelyControlledSampler_updateRace(t *testing.T) {
34	m := &Metrics{
35		SamplerRetrieved: metrics.NullCounter,
36		SamplerUpdated:   metrics.NullCounter,
37	}
38	initSampler, _ := NewProbabilisticSampler(0.123)
39	logger := log.NullLogger
40	fetcher := &testSamplingStrategyFetcher{response: []byte("probabilistic")}
41	parser := new(testSamplingStrategyParser)
42	updaters := []SamplerUpdater{new(ProbabilisticSamplerUpdater)}
43	sampler := NewRemotelyControlledSampler(
44		"test",
45		SamplerOptions.Metrics(m),
46		SamplerOptions.MaxOperations(42),
47		SamplerOptions.OperationNameLateBinding(true),
48		SamplerOptions.InitialSampler(initSampler),
49		SamplerOptions.Logger(logger),
50		SamplerOptions.SamplingServerURL("my url"),
51		SamplerOptions.SamplingRefreshInterval(time.Millisecond),
52		SamplerOptions.SamplingStrategyFetcher(fetcher),
53		SamplerOptions.SamplingStrategyParser(parser),
54		SamplerOptions.Updaters(updaters...),
55	)
56
57	s := makeSpan(1, "test")
58	end := make(chan struct{})
59
60	accessor := func(f func()) {
61		for {
62			select {
63			case <-end:
64				return
65			default:
66				f()
67			}
68		}
69	}
70
71	go accessor(func() {
72		sampler.UpdateSampler()
73	})
74
75	go accessor(func() {
76		sampler.IsSampled(TraceID{Low: 1}, "test")
77	})
78
79	go accessor(func() {
80		sampler.OnCreateSpan(s)
81	})
82
83	go accessor(func() {
84		sampler.OnSetTag(s, "test", 1)
85	})
86
87	go accessor(func() {
88		sampler.OnFinishSpan(s)
89	})
90
91	go accessor(func() {
92		sampler.OnSetOperationName(s, "test")
93	})
94
95	time.Sleep(100 * time.Millisecond)
96	close(end)
97	sampler.Close()
98}
99
100type testSamplingStrategyFetcher struct {
101	response []byte
102}
103
104func (c *testSamplingStrategyFetcher) Fetch(serviceName string) ([]byte, error) {
105	return []byte(c.response), nil
106}
107
108type testSamplingStrategyParser struct {
109}
110
111func (p *testSamplingStrategyParser) Parse(response []byte) (interface{}, error) {
112	strategy := new(sampling.SamplingStrategyResponse)
113
114	switch string(response) {
115	case "probabilistic":
116		strategy.StrategyType = sampling.SamplingStrategyType_PROBABILISTIC
117		strategy.ProbabilisticSampling = &sampling.ProbabilisticSamplingStrategy{
118			SamplingRate: 0.85,
119		}
120		return strategy, nil
121	}
122
123	return nil, errors.New("unknown strategy test request")
124}
125
126func TestRemoteSamplerOptions(t *testing.T) {
127	m := new(Metrics)
128	initSampler, _ := NewProbabilisticSampler(0.123)
129	logger := log.NullLogger
130	fetcher := new(fakeSamplingFetcher)
131	parser := new(samplingStrategyParser)
132	updaters := []SamplerUpdater{new(ProbabilisticSamplerUpdater)}
133	sampler := NewRemotelyControlledSampler(
134		"test",
135		SamplerOptions.Metrics(m),
136		SamplerOptions.MaxOperations(42),
137		SamplerOptions.OperationNameLateBinding(true),
138		SamplerOptions.InitialSampler(initSampler),
139		SamplerOptions.Logger(logger),
140		SamplerOptions.SamplingServerURL("my url"),
141		SamplerOptions.SamplingRefreshInterval(42*time.Second),
142		SamplerOptions.SamplingStrategyFetcher(fetcher),
143		SamplerOptions.SamplingStrategyParser(parser),
144		SamplerOptions.Updaters(updaters...),
145	)
146	assert.Same(t, m, sampler.metrics)
147	assert.Equal(t, 42, sampler.posParams.MaxOperations)
148	assert.True(t, sampler.posParams.OperationNameLateBinding)
149	assert.Same(t, initSampler, sampler.Sampler())
150	assert.Same(t, logger, sampler.logger)
151	assert.Equal(t, "my url", sampler.samplingServerURL)
152	assert.Equal(t, 42*time.Second, sampler.samplingRefreshInterval)
153	assert.Same(t, fetcher, sampler.samplingFetcher)
154	assert.Same(t, parser, sampler.samplingParser)
155	assert.Same(t, updaters[0], sampler.updaters[0])
156}
157
158func TestRemoteSamplerOptionsDefaults(t *testing.T) {
159	options := new(samplerOptions).applyOptionsAndDefaults()
160	sampler, ok := options.sampler.(*ProbabilisticSampler)
161	assert.True(t, ok)
162	assert.Equal(t, 0.001, sampler.samplingRate)
163
164	assert.NotNil(t, options.logger)
165	assert.NotEmpty(t, options.samplingServerURL)
166	assert.NotNil(t, options.metrics)
167	assert.NotZero(t, options.samplingRefreshInterval)
168}
169
170func initAgent(t *testing.T) (*testutils.MockAgent, *RemotelyControlledSampler, *mTestutils.Factory) {
171	agent, err := testutils.StartMockAgent()
172	require.NoError(t, err)
173
174	metricsFactory := mTestutils.NewFactory(0)
175	metrics := NewMetrics(metricsFactory, nil)
176
177	initialSampler, _ := NewProbabilisticSampler(0.001)
178	sampler := NewRemotelyControlledSampler(
179		"client app",
180		SamplerOptions.Metrics(metrics),
181		SamplerOptions.SamplingServerURL("http://"+agent.SamplingServerAddr()),
182		SamplerOptions.MaxOperations(testDefaultMaxOperations),
183		SamplerOptions.InitialSampler(initialSampler),
184		SamplerOptions.Logger(log.NullLogger),
185		SamplerOptions.SamplingRefreshInterval(time.Minute),
186	)
187	sampler.Close() // stop timer-based updates, we want to call them manually
188
189	return agent, sampler, metricsFactory
190}
191
192func makeSpan(id uint64, operationName string) *Span {
193	return &Span{
194		context: SpanContext{
195			traceID:       TraceID{Low: id},
196			samplingState: new(samplingState),
197		},
198		operationName: operationName,
199	}
200}
201
202func TestRemotelyControlledSampler(t *testing.T) {
203	agent, remoteSampler, metricsFactory := initAgent(t)
204	defer agent.Close()
205
206	defaultSampler := newProbabilisticSampler(0.001)
207	remoteSampler.setSampler(defaultSampler)
208
209	agent.AddSamplingStrategy("client app",
210		getSamplingStrategyResponse(sampling.SamplingStrategyType_PROBABILISTIC, testDefaultSamplingProbability))
211	remoteSampler.UpdateSampler()
212	metricsFactory.AssertCounterMetrics(t, []mTestutils.ExpectedMetric{
213		{Name: "jaeger.tracer.sampler_queries", Tags: map[string]string{"result": "ok"}, Value: 1},
214		{Name: "jaeger.tracer.sampler_updates", Tags: map[string]string{"result": "ok"}, Value: 1},
215	}...)
216	s1, ok := remoteSampler.Sampler().(*ProbabilisticSampler)
217	assert.True(t, ok)
218	assert.EqualValues(t, testDefaultSamplingProbability, s1.samplingRate, "Sampler should have been updated")
219
220	decision := remoteSampler.OnCreateSpan(makeSpan(testMaxID+10, testOperationName))
221	assert.False(t, decision.Sample)
222	assert.Equal(t, testProbabilisticExpectedTags, decision.Tags)
223	decision = remoteSampler.OnCreateSpan(makeSpan(testMaxID-10, testOperationName))
224	assert.True(t, decision.Sample)
225	assert.Equal(t, testProbabilisticExpectedTags, decision.Tags)
226
227	remoteSampler.setSampler(defaultSampler)
228
229	c := make(chan time.Time)
230	ticker := &time.Ticker{C: c}
231	go remoteSampler.pollControllerWithTicker(ticker)
232
233	c <- time.Now() // force update based on timer
234	time.Sleep(10 * time.Millisecond)
235	remoteSampler.Close()
236
237	s2, ok := remoteSampler.Sampler().(*ProbabilisticSampler)
238	assert.True(t, ok)
239	assert.EqualValues(t, testDefaultSamplingProbability, s2.samplingRate, "Sampler should have been updated from timer")
240
241	assert.False(t, remoteSampler.Equal(remoteSampler)) // for code coverage only
242}
243
244func makeSamplerTags(key string, value interface{}) []Tag {
245	return []Tag{
246		{"sampler.type", key},
247		{"sampler.param", value},
248	}
249}
250
251func TestRemotelyControlledSampler_updateSampler(t *testing.T) {
252	tests := []struct {
253		probabilities              map[string]float64
254		defaultProbability         float64
255		expectedDefaultProbability float64
256		expectedTags               []Tag
257	}{
258		{
259			probabilities:              map[string]float64{testOperationName: 1.1},
260			defaultProbability:         testDefaultSamplingProbability,
261			expectedDefaultProbability: testDefaultSamplingProbability,
262			expectedTags:               makeSamplerTags("probabilistic", 1.0),
263		},
264		{
265			probabilities:              map[string]float64{testOperationName: testDefaultSamplingProbability},
266			defaultProbability:         testDefaultSamplingProbability,
267			expectedDefaultProbability: testDefaultSamplingProbability,
268			expectedTags:               testProbabilisticExpectedTags,
269		},
270		{
271			probabilities: map[string]float64{
272				testOperationName:          testDefaultSamplingProbability,
273				testFirstTimeOperationName: testDefaultSamplingProbability,
274			},
275			defaultProbability:         testDefaultSamplingProbability,
276			expectedDefaultProbability: testDefaultSamplingProbability,
277			expectedTags:               testProbabilisticExpectedTags,
278		},
279		{
280			probabilities:              map[string]float64{"new op": 1.1},
281			defaultProbability:         testDefaultSamplingProbability,
282			expectedDefaultProbability: testDefaultSamplingProbability,
283			expectedTags:               testProbabilisticExpectedTags,
284		},
285		{
286			probabilities:              map[string]float64{"new op": 1.1},
287			defaultProbability:         1.1,
288			expectedDefaultProbability: 1.0,
289			expectedTags:               makeSamplerTags("probabilistic", 1.0),
290		},
291	}
292
293	for i, test := range tests {
294		t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) {
295			agent, sampler, metricsFactory := initAgent(t)
296			defer agent.Close()
297
298			initSampler, ok := sampler.Sampler().(*ProbabilisticSampler)
299			assert.True(t, ok)
300
301			res := &sampling.SamplingStrategyResponse{
302				StrategyType: sampling.SamplingStrategyType_PROBABILISTIC,
303				OperationSampling: &sampling.PerOperationSamplingStrategies{
304					DefaultSamplingProbability:       test.defaultProbability,
305					DefaultLowerBoundTracesPerSecond: 0.001,
306				},
307			}
308			for opName, prob := range test.probabilities {
309				res.OperationSampling.PerOperationStrategies = append(res.OperationSampling.PerOperationStrategies,
310					&sampling.OperationSamplingStrategy{
311						Operation: opName,
312						ProbabilisticSampling: &sampling.ProbabilisticSamplingStrategy{
313							SamplingRate: prob,
314						},
315					},
316				)
317			}
318
319			agent.AddSamplingStrategy("client app", res)
320			sampler.UpdateSampler()
321
322			metricsFactory.AssertCounterMetrics(t,
323				mTestutils.ExpectedMetric{
324					Name: "jaeger.tracer.sampler_updates", Tags: map[string]string{"result": "ok"}, Value: 1,
325				},
326			)
327
328			s, ok := sampler.Sampler().(*PerOperationSampler)
329			assert.True(t, ok)
330			assert.NotEqual(t, initSampler, sampler.Sampler(), "Sampler should have been updated")
331			assert.Equal(t, test.expectedDefaultProbability, s.defaultSampler.SamplingRate())
332
333			// First call is always sampled
334			decision := sampler.OnCreateSpan(makeSpan(testMaxID+10, testOperationName))
335			assert.True(t, decision.Sample)
336
337			decision = sampler.OnCreateSpan(makeSpan(testMaxID-10, testOperationName))
338			assert.True(t, decision.Sample)
339			assert.Equal(t, test.expectedTags, decision.Tags)
340		})
341	}
342}
343
344func TestRemotelyControlledSampler_updateDefaultRate(t *testing.T) {
345	agent, sampler, _ := initAgent(t)
346	defer agent.Close()
347
348	res := &sampling.SamplingStrategyResponse{
349		StrategyType: sampling.SamplingStrategyType_PROBABILISTIC,
350		OperationSampling: &sampling.PerOperationSamplingStrategies{
351			DefaultSamplingProbability: 0.5,
352		},
353	}
354	agent.AddSamplingStrategy("client app", res)
355	sampler.UpdateSampler()
356
357	// Check what rate we get for a specific operation
358	decision := sampler.OnCreateSpan(makeSpan(0, testOperationName))
359	assert.True(t, decision.Sample)
360	assert.Equal(t, makeSamplerTags("probabilistic", 0.5), decision.Tags)
361
362	// Change the default and update
363	res.OperationSampling.DefaultSamplingProbability = 0.1
364	sampler.UpdateSampler()
365
366	// Check sampling rate has changed
367	decision = sampler.OnCreateSpan(makeSpan(0, testOperationName))
368	assert.True(t, decision.Sample)
369	assert.Equal(t, makeSamplerTags("probabilistic", 0.1), decision.Tags)
370
371	// Add an operation-specific rate
372	res.OperationSampling.PerOperationStrategies = []*sampling.OperationSamplingStrategy{{
373		Operation: testOperationName,
374		ProbabilisticSampling: &sampling.ProbabilisticSamplingStrategy{
375			SamplingRate: 0.2,
376		},
377	}}
378	sampler.UpdateSampler()
379
380	// Check we get the requested rate
381	decision = sampler.OnCreateSpan(makeSpan(0, testOperationName))
382	assert.True(t, decision.Sample)
383	assert.Equal(t, makeSamplerTags("probabilistic", 0.2), decision.Tags)
384
385	// Now remove the operation-specific rate
386	res.OperationSampling.PerOperationStrategies = nil
387	sampler.UpdateSampler()
388
389	// Check we get the default rate
390	assert.True(t, decision.Sample)
391	decision = sampler.OnCreateSpan(makeSpan(0, testOperationName))
392	assert.True(t, decision.Sample)
393	assert.Equal(t, makeSamplerTags("probabilistic", 0.1), decision.Tags)
394}
395
396func TestSamplerQueryError(t *testing.T) {
397	agent, sampler, metricsFactory := initAgent(t)
398	defer agent.Close()
399
400	// override the actual handler
401	sampler.samplingFetcher = &fakeSamplingFetcher{}
402
403	initSampler, ok := sampler.Sampler().(*ProbabilisticSampler)
404	assert.True(t, ok)
405
406	sampler.Close() // stop timer-based updates, we want to call them manually
407
408	sampler.UpdateSampler()
409	assert.Equal(t, initSampler, sampler.Sampler(), "Sampler should not have been updated due to query error")
410
411	metricsFactory.AssertCounterMetrics(t,
412		mTestutils.ExpectedMetric{Name: "jaeger.tracer.sampler_queries", Tags: map[string]string{"result": "err"}, Value: 1},
413	)
414}
415
416type fakeSamplingFetcher struct{}
417
418func (c *fakeSamplingFetcher) Fetch(serviceName string) ([]byte, error) {
419	return nil, errors.New("query error")
420}
421
422func TestRemotelyControlledSampler_updateSamplerFromAdaptiveSampler(t *testing.T) {
423	agent, remoteSampler, metricsFactory := initAgent(t)
424	defer agent.Close()
425	remoteSampler.Close() // close the second time (initAgent already called Close)
426
427	strategies := &sampling.PerOperationSamplingStrategies{
428		DefaultSamplingProbability:       testDefaultSamplingProbability,
429		DefaultLowerBoundTracesPerSecond: 1.0,
430	}
431	adaptiveSampler := NewPerOperationSampler(PerOperationSamplerParams{
432		MaxOperations: testDefaultMaxOperations,
433		Strategies:    strategies,
434	})
435
436	// Overwrite the sampler with an adaptive sampler
437	remoteSampler.setSampler(adaptiveSampler)
438
439	agent.AddSamplingStrategy("client app",
440		getSamplingStrategyResponse(sampling.SamplingStrategyType_PROBABILISTIC, 0.5))
441	remoteSampler.UpdateSampler()
442
443	// Sampler should have been updated to probabilistic
444	_, ok := remoteSampler.Sampler().(*ProbabilisticSampler)
445	require.True(t, ok)
446
447	// Overwrite the sampler with an adaptive sampler
448	remoteSampler.setSampler(adaptiveSampler)
449
450	agent.AddSamplingStrategy("client app",
451		getSamplingStrategyResponse(sampling.SamplingStrategyType_RATE_LIMITING, 1))
452	remoteSampler.UpdateSampler()
453
454	// Sampler should have been updated to ratelimiting
455	_, ok = remoteSampler.Sampler().(*RateLimitingSampler)
456	require.True(t, ok)
457
458	// Overwrite the sampler with an adaptive sampler
459	remoteSampler.setSampler(adaptiveSampler)
460
461	// Update existing adaptive sampler
462	agent.AddSamplingStrategy("client app", &sampling.SamplingStrategyResponse{OperationSampling: strategies})
463	remoteSampler.UpdateSampler()
464
465	metricsFactory.AssertCounterMetrics(t,
466		mTestutils.ExpectedMetric{Name: "jaeger.tracer.sampler_queries", Tags: map[string]string{"result": "ok"}, Value: 3},
467		mTestutils.ExpectedMetric{Name: "jaeger.tracer.sampler_updates", Tags: map[string]string{"result": "ok"}, Value: 3},
468	)
469}
470
471func TestRemotelyControlledSampler_updateRateLimitingOrProbabilisticSampler(t *testing.T) {
472	probabilisticSampler, err := NewProbabilisticSampler(0.002)
473	require.NoError(t, err)
474	otherProbabilisticSampler, err := NewProbabilisticSampler(0.003)
475	require.NoError(t, err)
476	maxProbabilisticSampler, err := NewProbabilisticSampler(1.0)
477	require.NoError(t, err)
478
479	rateLimitingSampler := NewRateLimitingSampler(2)
480	otherRateLimitingSampler := NewRateLimitingSampler(3)
481
482	testCases := []struct {
483		res                  *sampling.SamplingStrategyResponse
484		initSampler          SamplerV2
485		expectedSampler      Sampler
486		shouldErr            bool
487		referenceEquivalence bool
488		caption              string
489	}{
490		{
491			res:                  getSamplingStrategyResponse(sampling.SamplingStrategyType_PROBABILISTIC, 1.5),
492			initSampler:          probabilisticSampler,
493			expectedSampler:      maxProbabilisticSampler,
494			shouldErr:            true,
495			referenceEquivalence: false,
496			caption:              "invalid probabilistic strategy",
497		},
498		{
499			res:                  getSamplingStrategyResponse(sampling.SamplingStrategyType_PROBABILISTIC, 0.002),
500			initSampler:          probabilisticSampler,
501			expectedSampler:      probabilisticSampler,
502			shouldErr:            false,
503			referenceEquivalence: true,
504			caption:              "unchanged probabilistic strategy",
505		},
506		{
507			res:                  getSamplingStrategyResponse(sampling.SamplingStrategyType_PROBABILISTIC, 0.003),
508			initSampler:          probabilisticSampler,
509			expectedSampler:      otherProbabilisticSampler,
510			shouldErr:            false,
511			referenceEquivalence: false,
512			caption:              "valid probabilistic strategy",
513		},
514		{
515			res:                  getSamplingStrategyResponse(sampling.SamplingStrategyType_RATE_LIMITING, 2),
516			initSampler:          rateLimitingSampler,
517			expectedSampler:      rateLimitingSampler,
518			shouldErr:            false,
519			referenceEquivalence: true,
520			caption:              "unchanged rate limiting strategy",
521		},
522		{
523			res:                  getSamplingStrategyResponse(sampling.SamplingStrategyType_RATE_LIMITING, 3),
524			initSampler:          rateLimitingSampler,
525			expectedSampler:      otherRateLimitingSampler,
526			shouldErr:            false,
527			referenceEquivalence: false,
528			caption:              "valid rate limiting strategy",
529		},
530		{
531			res:                  &sampling.SamplingStrategyResponse{},
532			initSampler:          rateLimitingSampler,
533			expectedSampler:      rateLimitingSampler,
534			shouldErr:            true,
535			referenceEquivalence: true,
536			caption:              "invalid strategy",
537		},
538	}
539
540	for _, tc := range testCases {
541		testCase := tc // capture loop var
542		t.Run(testCase.caption, func(t *testing.T) {
543			remoteSampler := NewRemotelyControlledSampler(
544				"test",
545				SamplerOptions.InitialSampler(testCase.initSampler.(Sampler)),
546				SamplerOptions.Updaters(
547					new(ProbabilisticSamplerUpdater),
548					new(RateLimitingSamplerUpdater),
549				),
550			)
551			err := remoteSampler.updateSamplerViaUpdaters(testCase.res)
552			if testCase.shouldErr {
553				require.Error(t, err)
554				return
555			}
556			if testCase.referenceEquivalence {
557				assert.Equal(t, testCase.expectedSampler, remoteSampler.Sampler())
558			} else {
559				type comparable interface {
560					Equal(other Sampler) bool
561				}
562				es, esOk := testCase.expectedSampler.(comparable)
563				require.True(t, esOk, "expected sampler %+v must implement Equal()", testCase.expectedSampler)
564				assert.True(t, es.Equal(remoteSampler.Sampler().(Sampler)),
565					"sampler.Equal: want=%+v, have=%+v", testCase.expectedSampler, remoteSampler.Sampler())
566			}
567		})
568	}
569}
570
571func getSamplingStrategyResponse(strategyType sampling.SamplingStrategyType, value float64) *sampling.SamplingStrategyResponse {
572	if strategyType == sampling.SamplingStrategyType_PROBABILISTIC {
573		return &sampling.SamplingStrategyResponse{
574			StrategyType: sampling.SamplingStrategyType_PROBABILISTIC,
575			ProbabilisticSampling: &sampling.ProbabilisticSamplingStrategy{
576				SamplingRate: value,
577			},
578		}
579	}
580	if strategyType == sampling.SamplingStrategyType_RATE_LIMITING {
581		return &sampling.SamplingStrategyResponse{
582			StrategyType: sampling.SamplingStrategyType_RATE_LIMITING,
583			RateLimitingSampling: &sampling.RateLimitingSamplingStrategy{
584				MaxTracesPerSecond: int16(value),
585			},
586		}
587	}
588	return nil
589}
590
591func TestRemotelyControlledSampler_printErrorForBrokenUpstream(t *testing.T) {
592	logger := &log.BytesBufferLogger{}
593	sampler := NewRemotelyControlledSampler(
594		"client app",
595		SamplerOptions.Logger(logger),
596		SamplerOptions.SamplingServerURL("invalid address"),
597	)
598	sampler.Close() // stop timer-based updates, we want to call them manually
599	sampler.UpdateSampler()
600	assert.Contains(t, logger.String(), "failed to fetch sampling strategy:")
601}
602