1// This file and its contents are licensed under the Apache License 2.0.
2// Please see the included NOTICE for copyright information and
3// LICENSE for a copy of the license.
4
5package runner
6
7import (
8	"io/ioutil"
9	"os"
10	"reflect"
11	"testing"
12)
13
14func TestParseFlags(t *testing.T) {
15	defaultConfig, err := ParseFlags(&Config{}, []string{})
16	if err != nil {
17		t.Fatal("error occured on default config with no arguments")
18	}
19
20	testCases := []struct {
21		name        string
22		args        []string
23		env         map[string]string
24		result      func(Config) Config
25		shouldError bool
26	}{
27		{
28			name:   "Default config",
29			args:   []string{},
30			result: func(c Config) Config { return c },
31		},
32		{
33			name:        "Invalid flag error",
34			args:        []string{"-foo", "bar"},
35			shouldError: true,
36		},
37		{
38			name:        "Invalid config file",
39			args:        []string{"-config", "flags_test.go"},
40			shouldError: true,
41		},
42		{
43			name:        "CORS Origin regex error",
44			args:        []string{"-web-cors-origin", "["},
45			shouldError: true,
46		},
47		{
48			name: "Don't migrate",
49			args: []string{"-migrate", "false"},
50			result: func(c Config) Config {
51				c.Migrate = false
52				return c
53			},
54		},
55		{
56			name: "enable disabled-features",
57			args: []string{"-promql-enable-feature", "promql-at-modifier"},
58			result: func(c Config) Config {
59				c.APICfg.EnableFeatures = "promql-at-modifier"
60				c.APICfg.EnabledFeaturesList = []string{"promql-at-modifier"}
61				return c
62			},
63		},
64		{
65			name: "Only migrate",
66			args: []string{"-migrate", "only"},
67			result: func(c Config) Config {
68				c.Migrate = true
69				c.StopAfterMigrate = true
70				return c
71			},
72		},
73		{
74			name:        "Invalid migrate option",
75			args:        []string{"-migrate", "invalid"},
76			shouldError: true,
77		},
78		{
79			name: "Read-only mode",
80			args: []string{"-read-only"},
81			result: func(c Config) Config {
82				c.APICfg.ReadOnly = true
83				c.Migrate = false
84				c.StopAfterMigrate = false
85				c.UseVersionLease = false
86				c.InstallExtensions = false
87				c.UpgradeExtensions = false
88				return c
89			},
90		},
91		{
92			name:        "Invalid migrate option",
93			args:        []string{"-migrate", "invalid"},
94			shouldError: true,
95		},
96		{
97			name: "Running HA and read-only error",
98			args: []string{
99				"-leader-election-pg-advisory-lock-id", "1",
100				"-read-only",
101			},
102			shouldError: true,
103		},
104		{
105			name: "Running migrate and read-only error",
106			args: []string{
107				"-migrate", "true",
108				"-read-only",
109			},
110			shouldError: true,
111		},
112		{
113			name: "Running install TimescaleDB and read-only error",
114			args: []string{
115				"-install-extensions",
116				"-read-only",
117			},
118			shouldError: true,
119		},
120		{
121			name: "invalid TLS setup, missing key file",
122			args: []string{
123				"-tls-cert-file", "foo",
124			},
125			shouldError: true,
126		},
127		{
128			name: "invalid TLS setup, missing cert file",
129			args: []string{
130				"-tls-key-file", "foo",
131			},
132			shouldError: true,
133		},
134		{
135			name: "invalid auth setup",
136			args: []string{
137				"-auth-username", "foo",
138			},
139			shouldError: true,
140		},
141		{
142			name: "invalid env variable type causing parse error, PROMSCALE prefix",
143			env: map[string]string{
144				"PROMSCALE_INSTALL_EXTENSIONS": "foobar",
145			},
146			shouldError: true,
147		},
148		{
149			name: "invalid env variable type causing parse error, TS_PROM prefix",
150			env: map[string]string{
151				"TS_PROM_INSTALL_EXTENSIONS": "foobar",
152			},
153			shouldError: true,
154		},
155	}
156
157	for _, c := range testCases {
158		t.Run(c.name, func(t *testing.T) {
159			// Clearing environment variables so they don't interfere with the test.
160			os.Clearenv()
161			for name, value := range c.env {
162				if err := os.Setenv(name, value); err != nil {
163					t.Fatalf("unexpected error when setting env variable: name %s, value %s, error %s", name, value, err)
164				}
165			}
166			config, err := ParseFlags(&Config{}, c.args)
167			if c.shouldError {
168				if err == nil {
169					t.Fatal("Unexpected error result, should not be nil")
170				}
171				return
172			} else if err != nil {
173				t.Fatalf("Unexpected returned error: %s", err.Error())
174			}
175
176			expected := c.result(*defaultConfig)
177			if !reflect.DeepEqual(*config, expected) {
178				t.Fatalf("Unexpected config returned\nwanted:\n%+v\ngot:\n%+v\n", expected, *config)
179			}
180		})
181	}
182}
183
184func TestParseFlagsConfigPrecedence(t *testing.T) {
185	// Clearing environment variables so they don't interfere with the test.
186	os.Clearenv()
187	defaultConfig, err := ParseFlags(&Config{}, []string{})
188
189	if err != nil {
190		t.Fatalf("error occured on default config with no arguments: %s", err)
191	}
192
193	testCases := []struct {
194		name               string
195		args               []string
196		env                map[string]string
197		configFileContents string
198		result             func(Config) Config
199	}{
200		{
201			name:   "Default config",
202			result: func(c Config) Config { return c },
203		},
204		{
205			name:               "Config file only",
206			configFileContents: "web-listen-address: localhost:9201",
207			result: func(c Config) Config {
208				c.ListenAddr = "localhost:9201"
209				return c
210			},
211		},
212		{
213			name: "Env variable only, TS_PROM prefix",
214			env: map[string]string{
215				"TS_PROM_WEB_LISTEN_ADDRESS": "localhost:9201",
216			},
217			result: func(c Config) Config {
218				c.ListenAddr = "localhost:9201"
219				return c
220			},
221		},
222		{
223			name: "Env variable only, PROMSCALE prefix",
224			env: map[string]string{
225				"PROMSCALE_WEB_LISTEN_ADDRESS": "localhost:9201",
226			},
227			result: func(c Config) Config {
228				c.ListenAddr = "localhost:9201"
229				return c
230			},
231		},
232		{
233			// In this case, we expect that PROMSCALE prefix gets precedence.
234			name: "Env variable only, both prefixes",
235			env: map[string]string{
236				"PROMSCALE_WEB_LISTEN_ADDRESS": "localhost:9201",
237				"TS_PROM_WEB_LISTEN_ADDRESS":   "127.0.0.1:9201",
238			},
239			result: func(c Config) Config {
240				c.ListenAddr = "localhost:9201"
241				return c
242			},
243		},
244		{
245			name: "Env variable takes precedence over config file setting",
246			env: map[string]string{
247				"PROMSCALE_WEB_LISTEN_ADDRESS": "localhost:9201",
248			},
249			configFileContents: "web-listen-address: 127.0.0.1:9201",
250			result: func(c Config) Config {
251				c.ListenAddr = "localhost:9201"
252				return c
253			},
254		},
255		{
256			name: "CLI arg takes precedence over env variable",
257			args: []string{
258				"-web-listen-address", "localhost:9201",
259			},
260			env: map[string]string{
261				"PROMSCALE_WEB_LISTEN_ADDRESS": "127.0.0.1:9201",
262			},
263			result: func(c Config) Config {
264				c.ListenAddr = "localhost:9201"
265				return c
266			},
267		},
268	}
269
270	for _, c := range testCases {
271		t.Run(c.name, func(t *testing.T) {
272			// Clearing environment variables so they don't interfere with the test.
273			os.Clearenv()
274			for name, value := range c.env {
275				if err := os.Setenv(name, value); err != nil {
276					t.Fatalf("unexpected error when setting env variable: name %s, value %s, error %s", name, value, err)
277				}
278			}
279
280			var configFilePath string
281			if c.configFileContents != "" {
282				f, err := ioutil.TempFile("", "promscale.yml")
283				if err != nil {
284					t.Fatalf("unexpected error when creating config file: %s", err)
285				}
286
287				configFilePath = f.Name()
288
289				defer os.Remove(f.Name())
290
291				if _, err := f.Write([]byte(c.configFileContents)); err != nil {
292					t.Fatalf("unexpected error while writing configuration file: %s", err)
293				}
294				if err := f.Close(); err != nil {
295					t.Fatalf("unexpected error while closing configuration file: %s", err)
296				}
297
298				// Add config file path to args.
299				c.args = append(c.args, "-config="+f.Name())
300			}
301
302			config, err := ParseFlags(&Config{}, c.args)
303			if err != nil {
304				t.Fatalf("unexpected error, all test cases should pass without error: %s", err)
305			}
306
307			expected := c.result(*defaultConfig)
308
309			// Need to account for change in config file path, which
310			// would otherwise be set to default `config.yml`.
311			if configFilePath != "" {
312				expected.ConfigFile = configFilePath
313			}
314
315			if !reflect.DeepEqual(*config, expected) {
316				t.Fatalf("Unexpected config returned\nwanted:\n%+v\ngot:\n%+v\n", expected, *config)
317			}
318		})
319	}
320}
321