1package promclient_test
2
3import (
4	"context"
5	"errors"
6	"sort"
7	"strings"
8	"testing"
9
10	"github.com/grafana/grafana/pkg/tsdb/prometheus/promclient"
11
12	apiv1 "github.com/prometheus/client_golang/api/prometheus/v1"
13
14	"github.com/stretchr/testify/require"
15)
16
17func TestCache_GetClient(t *testing.T) {
18	t.Run("it caches the client for a set of auth headers", func(t *testing.T) {
19		tc := setupCacheContext()
20
21		c, err := tc.providerCache.GetClient(headers)
22		require.Nil(t, err)
23
24		c2, err := tc.providerCache.GetClient(headers)
25		require.Nil(t, err)
26
27		require.Equal(t, c, c2)
28		require.Equal(t, 1, tc.clientProvider.numCalls)
29	})
30
31	t.Run("it returns different clients when the headers differ", func(t *testing.T) {
32		tc := setupCacheContext()
33		h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
34		h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"}
35
36		c, err := tc.providerCache.GetClient(h1)
37		require.Nil(t, err)
38
39		c2, err := tc.providerCache.GetClient(h2)
40		require.Nil(t, err)
41
42		require.NotEqual(t, c, c2)
43		require.Equal(t, 2, tc.clientProvider.numCalls)
44	})
45
46	t.Run("it returns from the cache when headers are the same", func(t *testing.T) {
47		tc := setupCacheContext()
48		h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
49		h2 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"}
50
51		c, err := tc.providerCache.GetClient(h1)
52		require.Nil(t, err)
53
54		c2, err := tc.providerCache.GetClient(h2)
55		require.Nil(t, err)
56
57		require.Equal(t, c, c2)
58		require.Equal(t, 1, tc.clientProvider.numCalls)
59	})
60
61	t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) {
62		tc := setupCacheContext()
63		tc.clientProvider.errors <- errors.New("something bad")
64
65		_, err := tc.providerCache.GetClient(headers)
66		require.EqualError(t, err, "something bad")
67
68		c, err := tc.providerCache.GetClient(headers)
69		require.Nil(t, err)
70
71		require.NotNil(t, c)
72		require.Equal(t, 2, tc.clientProvider.numCalls)
73	})
74}
75
76type cacheTestContext struct {
77	providerCache  *promclient.ProviderCache
78	clientProvider *fakePromClientProvider
79}
80
81func setupCacheContext() *cacheTestContext {
82	fp := newFakePromClientProvider()
83	p, err := promclient.NewProviderCache(fp)
84	if err != nil {
85		panic(err)
86	}
87
88	return &cacheTestContext{
89		providerCache:  p,
90		clientProvider: fp,
91	}
92}
93
94func newFakePromClientProvider() *fakePromClientProvider {
95	return &fakePromClientProvider{
96		errors: make(chan error, 1),
97	}
98}
99
100type fakePromClientProvider struct {
101	headers  map[string]string
102	numCalls int
103	errors   chan error
104}
105
106func (p *fakePromClientProvider) GetClient(h map[string]string) (apiv1.API, error) {
107	p.headers = h
108	p.numCalls++
109
110	var err error
111	select {
112	case err = <-p.errors:
113	default:
114	}
115
116	var config []string
117	for _, v := range h {
118		config = append(config, v)
119	}
120	sort.Strings(config) //because map
121	return &fakePromClient{config: strings.Join(config, "")}, err
122}
123
124type fakePromClient struct {
125	apiv1.API
126	config string
127}
128
129func (c *fakePromClient) Config(ctx context.Context) (apiv1.ConfigResult, error) {
130	return apiv1.ConfigResult{YAML: c.config}, nil
131}
132