1// Copyright 2016 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package trace
16
17import (
18	crand "crypto/rand"
19	"encoding/binary"
20	"fmt"
21	"math/rand"
22	"sync"
23	"time"
24
25	"golang.org/x/time/rate"
26)
27
28type SamplingPolicy interface {
29	// Sample returns a Decision.
30	// If Trace is false in the returned Decision, then the Decision should be
31	// the zero value.
32	Sample(p Parameters) Decision
33}
34
35// Parameters contains the values passed to a SamplingPolicy's Sample method.
36type Parameters struct {
37	HasTraceHeader bool // whether the incoming request has a valid X-Cloud-Trace-Context header.
38}
39
40// Decision is the value returned by a call to a SamplingPolicy's Sample method.
41type Decision struct {
42	Trace  bool    // Whether to trace the request.
43	Sample bool    // Whether the trace is included in the random sample.
44	Policy string  // Name of the sampling policy.
45	Weight float64 // Sample weight to be used in statistical calculations.
46}
47
48type sampler struct {
49	fraction float64
50	skipped  float64
51	*rate.Limiter
52	*rand.Rand
53	sync.Mutex
54}
55
56func (s *sampler) Sample(p Parameters) Decision {
57	s.Lock()
58	x := s.Float64()
59	d := s.sample(p, time.Now(), x)
60	s.Unlock()
61	return d
62}
63
64// sample contains the a deterministic, time-independent logic of Sample.
65func (s *sampler) sample(p Parameters, now time.Time, x float64) (d Decision) {
66	d.Sample = x < s.fraction
67	d.Trace = p.HasTraceHeader || d.Sample
68	if !d.Trace {
69		// We have no reason to trace this request.
70		return Decision{}
71	}
72	// We test separately that the rate limit is not tiny before calling AllowN,
73	// because of overflow problems in x/time/rate.
74	if s.Limit() < 1e-9 || !s.AllowN(now, 1) {
75		// Rejected by the rate limit.
76		if d.Sample {
77			s.skipped++
78		}
79		return Decision{}
80	}
81	if d.Sample {
82		d.Policy, d.Weight = "default", (1.0+s.skipped)/s.fraction
83		s.skipped = 0.0
84	}
85	return
86}
87
88// NewLimitedSampler returns a sampling policy that randomly samples a given
89// fraction of requests.  It also enforces a limit on the number of traces per
90// second.  It tries to trace every request with a trace header, but will not
91// exceed the qps limit to do it.
92func NewLimitedSampler(fraction, maxqps float64) (SamplingPolicy, error) {
93	if !(fraction >= 0) {
94		return nil, fmt.Errorf("invalid fraction %f", fraction)
95	}
96	if !(maxqps >= 0) {
97		return nil, fmt.Errorf("invalid maxqps %f", maxqps)
98	}
99	// Set a limit on the number of accumulated "tokens", to limit bursts of
100	// traced requests.  Use one more than a second's worth of tokens, or 100,
101	// whichever is smaller.
102	// See https://godoc.org/golang.org/x/time/rate#NewLimiter.
103	maxTokens := 100
104	if maxqps < 99.0 {
105		maxTokens = 1 + int(maxqps)
106	}
107	var seed int64
108	if err := binary.Read(crand.Reader, binary.LittleEndian, &seed); err != nil {
109		seed = time.Now().UnixNano()
110	}
111	s := sampler{
112		fraction: fraction,
113		Limiter:  rate.NewLimiter(rate.Limit(maxqps), maxTokens),
114		Rand:     rand.New(rand.NewSource(seed)),
115	}
116	return &s, nil
117}
118