1package httpd_test
2
3import (
4	"bytes"
5	"context"
6	"encoding/json"
7	"errors"
8	"fmt"
9	"io"
10	"log"
11	"math"
12	"mime/multipart"
13	"net/http"
14	"net/http/httptest"
15	"net/url"
16	"os"
17	"reflect"
18	"strings"
19	"sync/atomic"
20	"testing"
21	"time"
22
23	jwt "github.com/dgrijalva/jwt-go"
24	"github.com/gogo/protobuf/proto"
25	"github.com/golang/snappy"
26	"github.com/google/go-cmp/cmp"
27	"github.com/influxdata/flux"
28	"github.com/influxdata/flux/lang"
29	"github.com/influxdata/influxdb/flux/client"
30	"github.com/influxdata/influxdb/internal"
31	"github.com/influxdata/influxdb/logger"
32	"github.com/influxdata/influxdb/models"
33	"github.com/influxdata/influxdb/prometheus/remote"
34	"github.com/influxdata/influxdb/query"
35	"github.com/influxdata/influxdb/services/httpd"
36	"github.com/influxdata/influxdb/services/meta"
37	"github.com/influxdata/influxdb/storage/reads"
38	"github.com/influxdata/influxdb/storage/reads/datatypes"
39	"github.com/influxdata/influxdb/tsdb"
40	"github.com/influxdata/influxql"
41)
42
43// Ensure the handler returns results from a query (including nil results).
44func TestHandler_Query(t *testing.T) {
45	h := NewHandler(false)
46	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
47		if stmt.String() != `SELECT * FROM bar` {
48			t.Fatalf("unexpected query: %s", stmt.String())
49		} else if ctx.Database != `foo` {
50			t.Fatalf("unexpected db: %s", ctx.Database)
51		}
52		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series0"}})}
53		ctx.Results <- &query.Result{StatementID: 2, Series: models.Rows([]*models.Row{{Name: "series1"}})}
54		return nil
55	}
56
57	w := httptest.NewRecorder()
58	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil))
59	if w.Code != http.StatusOK {
60		t.Fatalf("unexpected status: %d", w.Code)
61	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":1,"series":[{"name":"series0"}]},{"statement_id":2,"series":[{"name":"series1"}]}]}` {
62		t.Fatalf("unexpected body: %s", body)
63	}
64}
65
66// Ensure the handler returns results from a query passed as a file.
67func TestHandler_Query_File(t *testing.T) {
68	h := NewHandler(false)
69	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
70		if stmt.String() != `SELECT * FROM bar` {
71			t.Fatalf("unexpected query: %s", stmt.String())
72		} else if ctx.Database != `foo` {
73			t.Fatalf("unexpected db: %s", ctx.Database)
74		}
75		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series0"}})}
76		ctx.Results <- &query.Result{StatementID: 2, Series: models.Rows([]*models.Row{{Name: "series1"}})}
77		return nil
78	}
79
80	var body bytes.Buffer
81	writer := multipart.NewWriter(&body)
82	part, err := writer.CreateFormFile("q", "")
83	if err != nil {
84		t.Fatal(err)
85	}
86	io.WriteString(part, "SELECT * FROM bar")
87
88	if err := writer.Close(); err != nil {
89		t.Fatal(err)
90	}
91
92	r := MustNewJSONRequest("POST", "/query?db=foo", &body)
93	r.Header.Set("Content-Type", writer.FormDataContentType())
94
95	w := httptest.NewRecorder()
96	h.ServeHTTP(w, r)
97	if w.Code != http.StatusOK {
98		t.Fatalf("unexpected status: %d", w.Code)
99	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":1,"series":[{"name":"series0"}]},{"statement_id":2,"series":[{"name":"series1"}]}]}` {
100		t.Fatalf("unexpected body: %s", body)
101	}
102}
103
104// Test query with user authentication.
105func TestHandler_Query_Auth(t *testing.T) {
106	// Create the handler to be tested.
107	h := NewHandler(true)
108
109	// Set mock meta client functions for the handler to use.
110	h.MetaClient.AdminUserExistsFn = func() bool { return true }
111
112	h.MetaClient.UserFn = func(username string) (meta.User, error) {
113		if username != "user1" {
114			return nil, meta.ErrUserNotFound
115		}
116		return &meta.UserInfo{
117			Name:  "user1",
118			Hash:  "abcd",
119			Admin: true,
120		}, nil
121	}
122
123	h.MetaClient.AuthenticateFn = func(u, p string) (meta.User, error) {
124		if u != "user1" {
125			return nil, fmt.Errorf("unexpected user: exp: user1, got: %s", u)
126		} else if p != "abcd" {
127			return nil, fmt.Errorf("unexpected password: exp: abcd, got: %s", p)
128		}
129		return h.MetaClient.User(u)
130	}
131
132	// Set mock query authorizer for handler to use.
133	h.QueryAuthorizer.AuthorizeQueryFn = func(u meta.User, query *influxql.Query, database string) error {
134		return nil
135	}
136
137	// Set mock statement executor for handler to use.
138	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
139		if stmt.String() != `SELECT * FROM bar` {
140			t.Fatalf("unexpected query: %s", stmt.String())
141		} else if ctx.Database != `foo` {
142			t.Fatalf("unexpected db: %s", ctx.Database)
143		}
144		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series0"}})}
145		ctx.Results <- &query.Result{StatementID: 2, Series: models.Rows([]*models.Row{{Name: "series1"}})}
146		return nil
147	}
148
149	// Test the handler with valid user and password in the URL parameters.
150	w := httptest.NewRecorder()
151	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?u=user1&p=abcd&db=foo&q=SELECT+*+FROM+bar", nil))
152	if w.Code != http.StatusOK {
153		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
154	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":1,"series":[{"name":"series0"}]},{"statement_id":2,"series":[{"name":"series1"}]}]}` {
155		t.Fatalf("unexpected body: %s", body)
156	}
157
158	// Test the handler with valid user and password using basic auth.
159	w = httptest.NewRecorder()
160	r := MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil)
161	r.SetBasicAuth("user1", "abcd")
162	h.ServeHTTP(w, r)
163	if w.Code != http.StatusOK {
164		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
165	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":1,"series":[{"name":"series0"}]},{"statement_id":2,"series":[{"name":"series1"}]}]}` {
166		t.Fatalf("unexpected body: %s", body)
167	}
168
169	// Test the handler with valid JWT bearer token.
170	req := MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil)
171	// Create a signed JWT token string and add it to the request header.
172	_, signedToken := MustJWTToken("user1", h.Config.SharedSecret, false)
173	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signedToken))
174
175	w = httptest.NewRecorder()
176	h.ServeHTTP(w, req)
177	if w.Code != http.StatusOK {
178		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
179	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":1,"series":[{"name":"series0"}]},{"statement_id":2,"series":[{"name":"series1"}]}]}` {
180		t.Fatalf("unexpected body: %s", body)
181	}
182
183	// Test the handler with JWT token signed with invalid key.
184	req = MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil)
185	// Create a signed JWT token string and add it to the request header.
186	_, signedToken = MustJWTToken("user1", "invalid key", false)
187	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signedToken))
188
189	w = httptest.NewRecorder()
190	h.ServeHTTP(w, req)
191	if w.Code != http.StatusUnauthorized {
192		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
193	} else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"signature is invalid"}` {
194		t.Fatalf("unexpected body: %s", body)
195	}
196
197	// Test handler with valid JWT token carrying non-existent user.
198	_, signedToken = MustJWTToken("bad_user", h.Config.SharedSecret, false)
199	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signedToken))
200
201	w = httptest.NewRecorder()
202	h.ServeHTTP(w, req)
203	if w.Code != http.StatusUnauthorized {
204		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
205	} else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"user not found"}` {
206		t.Fatalf("unexpected body: %s", body)
207	}
208
209	// Test handler with expired JWT token.
210	_, signedToken = MustJWTToken("user1", h.Config.SharedSecret, true)
211	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signedToken))
212
213	w = httptest.NewRecorder()
214	h.ServeHTTP(w, req)
215	if w.Code != http.StatusUnauthorized {
216		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
217	} else if !strings.Contains(w.Body.String(), `{"error":"Token is expired`) {
218		t.Fatalf("unexpected body: %s", w.Body.String())
219	}
220
221	// Test handler with JWT token that has no expiration set.
222	token, _ := MustJWTToken("user1", h.Config.SharedSecret, false)
223	delete(token.Claims.(jwt.MapClaims), "exp")
224	signedToken, err := token.SignedString([]byte(h.Config.SharedSecret))
225	if err != nil {
226		t.Fatal(err)
227	}
228	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signedToken))
229	w = httptest.NewRecorder()
230	h.ServeHTTP(w, req)
231	if w.Code != http.StatusUnauthorized {
232		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
233	} else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"token expiration required"}` {
234		t.Fatalf("unexpected body: %s", body)
235	}
236
237	// Test that auth fails if shared secret is blank.
238	origSecret := h.Config.SharedSecret
239	h.Config.SharedSecret = ""
240	token, _ = MustJWTToken("user1", h.Config.SharedSecret, false)
241	signedToken, err = token.SignedString([]byte(h.Config.SharedSecret))
242	if err != nil {
243		t.Fatal(err)
244	}
245	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signedToken))
246	w = httptest.NewRecorder()
247	h.ServeHTTP(w, req)
248	if w.Code != http.StatusUnauthorized {
249		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
250	} else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"bearer auth disabled"}` {
251		t.Fatalf("unexpected body: %s", body)
252	}
253	h.Config.SharedSecret = origSecret
254
255	// Test the handler with valid user and password in the url and invalid in
256	// basic auth (prioritize url).
257	w = httptest.NewRecorder()
258	r = MustNewJSONRequest("GET", "/query?u=user1&p=abcd&db=foo&q=SELECT+*+FROM+bar", nil)
259	r.SetBasicAuth("user1", "efgh")
260	h.ServeHTTP(w, r)
261	if w.Code != http.StatusOK {
262		t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String())
263	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":1,"series":[{"name":"series0"}]},{"statement_id":2,"series":[{"name":"series1"}]}]}` {
264		t.Fatalf("unexpected body: %s", body)
265	}
266}
267
268// Ensure the handler returns results from a query (including nil results).
269func TestHandler_QueryRegex(t *testing.T) {
270	h := NewHandler(false)
271	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
272		if stmt.String() != `SELECT * FROM test WHERE url =~ /http\:\/\/www.akamai\.com/` {
273			t.Fatalf("unexpected query: %s", stmt.String())
274		} else if ctx.Database != `test` {
275			t.Fatalf("unexpected db: %s", ctx.Database)
276		}
277		ctx.Results <- nil
278		return nil
279	}
280
281	w := httptest.NewRecorder()
282	h.ServeHTTP(w, MustNewRequest("GET", "/query?db=test&q=SELECT%20%2A%20FROM%20test%20WHERE%20url%20%3D~%20%2Fhttp%5C%3A%5C%2F%5C%2Fwww.akamai%5C.com%2F", nil))
283}
284
285// Ensure the handler merges results from the same statement.
286func TestHandler_Query_MergeResults(t *testing.T) {
287	h := NewHandler(false)
288	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
289		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series0"}})}
290		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series1"}})}
291		return nil
292	}
293
294	w := httptest.NewRecorder()
295	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil))
296	if w.Code != http.StatusOK {
297		t.Fatalf("unexpected status: %d", w.Code)
298	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":1,"series":[{"name":"series0"},{"name":"series1"}]}]}` {
299		t.Fatalf("unexpected body: %s", body)
300	}
301}
302
303// Ensure the handler merges results from the same statement.
304func TestHandler_Query_MergeEmptyResults(t *testing.T) {
305	h := NewHandler(false)
306	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
307		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows{}}
308		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series1"}})}
309		return nil
310	}
311
312	w := httptest.NewRecorder()
313	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil))
314	if w.Code != http.StatusOK {
315		t.Fatalf("unexpected status: %d", w.Code)
316	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":1,"series":[{"name":"series1"}]}]}` {
317		t.Fatalf("unexpected body: %s", body)
318	}
319}
320
321// Ensure the handler can parse chunked and chunk size query parameters.
322func TestHandler_Query_Chunked(t *testing.T) {
323	h := NewHandler(false)
324	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
325		if ctx.ChunkSize != 2 {
326			t.Fatalf("unexpected chunk size: %d", ctx.ChunkSize)
327		}
328		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series0"}})}
329		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series1"}})}
330		return nil
331	}
332
333	w := httptest.NewRecorder()
334	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar&chunked=true&chunk_size=2", nil))
335	if w.Code != http.StatusOK {
336		t.Fatalf("unexpected status: %d", w.Code)
337	} else if w.Body.String() != `{"results":[{"statement_id":1,"series":[{"name":"series0"}]}]}
338{"results":[{"statement_id":1,"series":[{"name":"series1"}]}]}
339` {
340		t.Fatalf("unexpected body: %s", w.Body.String())
341	}
342}
343
344// Ensure the handler can accept an async query.
345func TestHandler_Query_Async(t *testing.T) {
346	done := make(chan struct{})
347	h := NewHandler(false)
348	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
349		if stmt.String() != `SELECT * FROM bar` {
350			t.Fatalf("unexpected query: %s", stmt.String())
351		} else if ctx.Database != `foo` {
352			t.Fatalf("unexpected db: %s", ctx.Database)
353		}
354		ctx.Results <- &query.Result{StatementID: 1, Series: models.Rows([]*models.Row{{Name: "series0"}})}
355		ctx.Results <- &query.Result{StatementID: 2, Series: models.Rows([]*models.Row{{Name: "series1"}})}
356		close(done)
357		return nil
358	}
359
360	w := httptest.NewRecorder()
361	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar&async=true", nil))
362	if w.Code != http.StatusNoContent {
363		t.Fatalf("unexpected status: %d", w.Code)
364	} else if body := strings.TrimSpace(w.Body.String()); body != `` {
365		t.Fatalf("unexpected body: %s", body)
366	}
367
368	// Wait to make sure the async query runs and completes.
369	timer := time.NewTimer(100 * time.Millisecond)
370	defer timer.Stop()
371
372	select {
373	case <-timer.C:
374		t.Fatal("timeout while waiting for async query to complete")
375	case <-done:
376	}
377}
378
379// Ensure the handler returns a status 400 if the query is not passed in.
380func TestHandler_Query_ErrQueryRequired(t *testing.T) {
381	h := NewHandler(false)
382	w := httptest.NewRecorder()
383	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query", nil))
384	if w.Code != http.StatusBadRequest {
385		t.Fatalf("unexpected status: %d", w.Code)
386	} else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"missing required parameter \"q\""}` {
387		t.Fatalf("unexpected body: %s", body)
388	}
389}
390
391// Ensure the handler returns a status 400 if the query cannot be parsed.
392func TestHandler_Query_ErrInvalidQuery(t *testing.T) {
393	h := NewHandler(false)
394	w := httptest.NewRecorder()
395	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?q=SELECT", nil))
396	if w.Code != http.StatusBadRequest {
397		t.Fatalf("unexpected status: %d", w.Code)
398	} else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"error parsing query: found EOF, expected identifier, string, number, bool at line 1, char 8"}` {
399		t.Fatalf("unexpected body: %s", body)
400	}
401}
402
403// Ensure the handler returns an appropriate 401 or 403 status when authentication or authorization fails.
404func TestHandler_Query_ErrAuthorize(t *testing.T) {
405	h := NewHandler(true)
406	h.QueryAuthorizer.AuthorizeQueryFn = func(u meta.User, q *influxql.Query, db string) error {
407		return errors.New("marker")
408	}
409	h.MetaClient.AdminUserExistsFn = func() bool { return true }
410	h.MetaClient.AuthenticateFn = func(u, p string) (meta.User, error) {
411
412		users := []meta.UserInfo{
413			{
414				Name:  "admin",
415				Hash:  "admin",
416				Admin: true,
417			},
418			{
419				Name: "user1",
420				Hash: "abcd",
421				Privileges: map[string]influxql.Privilege{
422					"db0": influxql.ReadPrivilege,
423				},
424			},
425		}
426
427		for _, user := range users {
428			if u == user.Name {
429				if p == user.Hash {
430					return &user, nil
431				}
432				return nil, meta.ErrAuthenticate
433			}
434		}
435		return nil, meta.ErrUserNotFound
436	}
437
438	for i, tt := range []struct {
439		user     string
440		password string
441		query    string
442		code     int
443	}{
444		{
445			query: "/query?q=SHOW+DATABASES",
446			code:  http.StatusUnauthorized,
447		},
448		{
449			user:     "user1",
450			password: "abcd",
451			query:    "/query?q=SHOW+DATABASES",
452			code:     http.StatusForbidden,
453		},
454		{
455			user:     "user2",
456			password: "abcd",
457			query:    "/query?q=SHOW+DATABASES",
458			code:     http.StatusUnauthorized,
459		},
460	} {
461		w := httptest.NewRecorder()
462		r := MustNewJSONRequest("GET", tt.query, nil)
463		params := r.URL.Query()
464		if tt.user != "" {
465			params.Set("u", tt.user)
466		}
467		if tt.password != "" {
468			params.Set("p", tt.password)
469		}
470		r.URL.RawQuery = params.Encode()
471
472		h.ServeHTTP(w, r)
473		if w.Code != tt.code {
474			t.Errorf("%d. unexpected status: got=%d exp=%d\noutput: %s", i, w.Code, tt.code, w.Body.String())
475		}
476	}
477}
478
479// Ensure the handler returns a status 200 if an error is returned in the result.
480func TestHandler_Query_ErrResult(t *testing.T) {
481	h := NewHandler(false)
482	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
483		return errors.New("measurement not found")
484	}
485
486	w := httptest.NewRecorder()
487	h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SHOW+SERIES+from+bin", nil))
488	if w.Code != http.StatusOK {
489		t.Fatalf("unexpected status: %d", w.Code)
490	} else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"statement_id":0,"error":"measurement not found"}]}` {
491		t.Fatalf("unexpected body: %s", body)
492	}
493}
494
495// Ensure that closing the HTTP connection causes the query to be interrupted.
496func TestHandler_Query_CloseNotify(t *testing.T) {
497	// Avoid leaking a goroutine when this fails.
498	done := make(chan struct{})
499	defer close(done)
500
501	interrupted := make(chan struct{})
502	h := NewHandler(false)
503	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
504		select {
505		case <-ctx.Done():
506		case <-done:
507		}
508		close(interrupted)
509		return nil
510	}
511
512	s := httptest.NewServer(h)
513	defer s.Close()
514
515	// Parse the URL and generate a query request.
516	u, err := url.Parse(s.URL)
517	if err != nil {
518		t.Fatal(err)
519	}
520	u.Path = "/query"
521
522	values := url.Values{}
523	values.Set("q", "SELECT * FROM cpu")
524	values.Set("db", "db0")
525	values.Set("rp", "rp0")
526	values.Set("chunked", "true")
527	u.RawQuery = values.Encode()
528
529	req, err := http.NewRequest("GET", u.String(), nil)
530	if err != nil {
531		t.Fatal(err)
532	}
533
534	// Perform the request and retrieve the response.
535	resp, err := http.DefaultClient.Do(req)
536	if err != nil {
537		t.Fatal(err)
538	}
539
540	// Validate that the interrupted channel has NOT been closed yet.
541	timer := time.NewTimer(100 * time.Millisecond)
542	select {
543	case <-interrupted:
544		timer.Stop()
545		t.Fatal("query interrupted unexpectedly")
546	case <-timer.C:
547	}
548
549	// Close the response body which should abort the query in the handler.
550	resp.Body.Close()
551
552	// The query should abort within 100 milliseconds.
553	timer.Reset(100 * time.Millisecond)
554	select {
555	case <-interrupted:
556		timer.Stop()
557	case <-timer.C:
558		t.Fatal("timeout while waiting for query to abort")
559	}
560}
561
562// Ensure the prometheus remote write works with valid values.
563func TestHandler_PromWrite(t *testing.T) {
564	req := &remote.WriteRequest{
565		Timeseries: []*remote.TimeSeries{
566			{
567				Labels: []*remote.LabelPair{
568					{Name: "host", Value: "a"},
569					{Name: "region", Value: "west"},
570				},
571				Samples: []*remote.Sample{
572					{TimestampMs: 1, Value: 1.2},
573					{TimestampMs: 3, Value: 14.5},
574					{TimestampMs: 6, Value: 222.99},
575				},
576			},
577		},
578	}
579
580	data, err := proto.Marshal(req)
581	if err != nil {
582		t.Fatal("couldn't marshal prometheus request")
583	}
584	compressed := snappy.Encode(nil, data)
585
586	b := bytes.NewReader(compressed)
587	h := NewHandler(false)
588	h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo {
589		return &meta.DatabaseInfo{}
590	}
591
592	var called bool
593	h.PointsWriter.WritePointsFn = func(db, rp string, _ models.ConsistencyLevel, _ meta.User, points []models.Point) error {
594		called = true
595
596		if got, exp := len(points), 3; got != exp {
597			t.Fatalf("got %d points, expected %d\n\npoints:\n%v", got, exp, points)
598		}
599
600		expFields := []models.Fields{
601			models.Fields{"value": req.Timeseries[0].Samples[0].Value},
602			models.Fields{"value": req.Timeseries[0].Samples[1].Value},
603			models.Fields{"value": req.Timeseries[0].Samples[2].Value},
604		}
605
606		expTS := []int64{
607			req.Timeseries[0].Samples[0].TimestampMs * int64(time.Millisecond),
608			req.Timeseries[0].Samples[1].TimestampMs * int64(time.Millisecond),
609			req.Timeseries[0].Samples[2].TimestampMs * int64(time.Millisecond),
610		}
611
612		for i, point := range points {
613			if got, exp := point.UnixNano(), expTS[i]; got != exp {
614				t.Fatalf("got time %d, expected %d\npoint:\n%v", got, exp, point)
615			}
616
617			exp := models.Tags{models.Tag{Key: []byte("host"), Value: []byte("a")}, models.Tag{Key: []byte("region"), Value: []byte("west")}}
618			if got := point.Tags(); !reflect.DeepEqual(got, exp) {
619				t.Fatalf("got tags: %v, expected: %v\npoint:\n%v", got, exp, point)
620			}
621
622			gotFields, err := point.Fields()
623			if err != nil {
624				t.Fatal(err.Error())
625			}
626
627			if got, exp := gotFields, expFields[i]; !reflect.DeepEqual(got, exp) {
628				t.Fatalf("got fields %v, expected %v\npoint:\n%v", got, exp, point)
629			}
630		}
631		return nil
632	}
633
634	w := httptest.NewRecorder()
635	h.ServeHTTP(w, MustNewRequest("POST", "/api/v1/prom/write?db=foo", b))
636	if !called {
637		t.Fatal("WritePoints: expected call")
638	}
639
640	if w.Code != http.StatusNoContent {
641		t.Fatalf("unexpected status: %d", w.Code)
642	}
643}
644
645// Ensure the prometheus remote write works with invalid values.
646func TestHandler_PromWrite_Dropped(t *testing.T) {
647	req := &remote.WriteRequest{
648		Timeseries: []*remote.TimeSeries{
649			{
650				Labels: []*remote.LabelPair{
651					{Name: "host", Value: "a"},
652					{Name: "region", Value: "west"},
653				},
654				Samples: []*remote.Sample{
655					{TimestampMs: 1, Value: 1.2},
656					{TimestampMs: 2, Value: math.NaN()},
657					{TimestampMs: 3, Value: 14.5},
658					{TimestampMs: 4, Value: math.Inf(-1)},
659					{TimestampMs: 5, Value: math.Inf(1)},
660					{TimestampMs: 6, Value: 222.99},
661					{TimestampMs: 7, Value: math.Inf(-1)},
662					{TimestampMs: 8, Value: math.Inf(1)},
663					{TimestampMs: 9, Value: math.Inf(1)},
664				},
665			},
666		},
667	}
668
669	data, err := proto.Marshal(req)
670	if err != nil {
671		t.Fatal("couldn't marshal prometheus request")
672	}
673	compressed := snappy.Encode(nil, data)
674
675	b := bytes.NewReader(compressed)
676	h := NewHandler(false)
677	h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo {
678		return &meta.DatabaseInfo{}
679	}
680
681	var called bool
682	h.PointsWriter.WritePointsFn = func(db, rp string, _ models.ConsistencyLevel, _ meta.User, points []models.Point) error {
683		called = true
684
685		if got, exp := len(points), 3; got != exp {
686			t.Fatalf("got %d points, expected %d\n\npoints:\n%v", got, exp, points)
687		}
688
689		expFields := []models.Fields{
690			models.Fields{"value": req.Timeseries[0].Samples[0].Value},
691			models.Fields{"value": req.Timeseries[0].Samples[2].Value},
692			models.Fields{"value": req.Timeseries[0].Samples[5].Value},
693		}
694
695		expTS := []int64{
696			req.Timeseries[0].Samples[0].TimestampMs * int64(time.Millisecond),
697			req.Timeseries[0].Samples[2].TimestampMs * int64(time.Millisecond),
698			req.Timeseries[0].Samples[5].TimestampMs * int64(time.Millisecond),
699		}
700
701		for i, point := range points {
702			if got, exp := point.UnixNano(), expTS[i]; got != exp {
703				t.Fatalf("got time %d, expected %d\npoint:\n%v", got, exp, point)
704			}
705
706			exp := models.Tags{models.Tag{Key: []byte("host"), Value: []byte("a")}, models.Tag{Key: []byte("region"), Value: []byte("west")}}
707			if got := point.Tags(); !reflect.DeepEqual(got, exp) {
708				t.Fatalf("got tags: %v, expected: %v\npoint:\n%v", got, exp, point)
709			}
710
711			gotFields, err := point.Fields()
712			if err != nil {
713				t.Fatal(err.Error())
714			}
715
716			if got, exp := gotFields, expFields[i]; !reflect.DeepEqual(got, exp) {
717				t.Fatalf("got fields %v, expected %v\npoint:\n%v", got, exp, point)
718			}
719		}
720		return nil
721	}
722
723	w := httptest.NewRecorder()
724	h.ServeHTTP(w, MustNewRequest("POST", "/api/v1/prom/write?db=foo", b))
725	if !called {
726		t.Fatal("WritePoints: expected call")
727	}
728
729	if w.Code != http.StatusNoContent {
730		t.Fatalf("unexpected status: %d", w.Code)
731	}
732}
733
734func mustMakeBigString(sz int) string {
735	a := make([]byte, 0, sz)
736	for i := 0; i < cap(a); i++ {
737		a = append(a, 'a')
738	}
739	return string(a)
740}
741
742func TestHandler_PromWrite_Error(t *testing.T) {
743	req := &remote.WriteRequest{
744		Timeseries: []*remote.TimeSeries{
745			{
746				// Invalid tag key
747				Labels:  []*remote.LabelPair{{Name: mustMakeBigString(models.MaxKeyLength), Value: "a"}},
748				Samples: []*remote.Sample{{TimestampMs: 1, Value: 1.2}},
749			},
750		},
751	}
752
753	data, err := proto.Marshal(req)
754	if err != nil {
755		t.Fatal("couldn't marshal prometheus request")
756	}
757	compressed := snappy.Encode(nil, data)
758
759	b := bytes.NewReader(compressed)
760	h := NewHandler(false)
761	h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo {
762		return &meta.DatabaseInfo{}
763	}
764
765	var called bool
766	h.PointsWriter.WritePointsFn = func(db, rp string, _ models.ConsistencyLevel, _ meta.User, points []models.Point) error {
767		called = true
768		return nil
769	}
770
771	w := httptest.NewRecorder()
772	h.ServeHTTP(w, MustNewRequest("POST", "/api/v1/prom/write?db=foo", b))
773	if w.Code != http.StatusBadRequest {
774		t.Fatalf("unexpected status: %d", w.Code)
775	}
776
777	if got, exp := strings.TrimSpace(w.Body.String()), `{"error":"max key length exceeded: 65572 \u003e 65535"}`; got != exp {
778		t.Fatalf("got error %q, expected %q", got, exp)
779	}
780
781	if called {
782		t.Fatal("WritePoints called but should not be")
783	}
784}
785
786// Ensure Prometheus remote read requests are converted to the correct InfluxQL query and
787// data is returned
788func TestHandler_PromRead(t *testing.T) {
789	req := &remote.ReadRequest{
790		Queries: []*remote.Query{{
791			Matchers: []*remote.LabelMatcher{
792				{
793					Type:  remote.MatchType_EQUAL,
794					Name:  "__name__",
795					Value: "value",
796				},
797			},
798			StartTimestampMs: 1,
799			EndTimestampMs:   2,
800		}},
801	}
802	data, err := proto.Marshal(req)
803	if err != nil {
804		t.Fatal("couldn't marshal prometheus request")
805	}
806	compressed := snappy.Encode(nil, data)
807	b := bytes.NewReader(compressed)
808	h := NewHandler(false)
809	w := httptest.NewRecorder()
810
811	// Number of results in the result set
812	var i int64
813	h.Store.ResultSet.NextFn = func() bool {
814		i++
815		return i <= 2
816	}
817
818	// data for each cursor.
819	h.Store.ResultSet.CursorFn = func() tsdb.Cursor {
820		cursor := internal.NewFloatArrayCursorMock()
821
822		var i int64
823		cursor.NextFn = func() *tsdb.FloatArray {
824			i++
825			ts := []int64{22000000 * i, 10000000000 * i}
826			vs := []float64{2.3, 2992.33}
827			if i > 2 {
828				ts, vs = nil, nil
829			}
830			return &tsdb.FloatArray{Timestamps: ts, Values: vs}
831		}
832
833		return cursor
834	}
835
836	// Tags for each cursor.
837	h.Store.ResultSet.TagsFn = func() models.Tags {
838		return models.NewTags(map[string]string{
839			"host":         fmt.Sprintf("server-%d", i),
840			"_measurement": "mem",
841		})
842	}
843
844	h.ServeHTTP(w, MustNewRequest("POST", "/api/v1/prom/read?db=foo&rp=bar", b))
845	if w.Code != http.StatusOK {
846		t.Fatalf("unexpected status: %d", w.Code)
847	}
848
849	reqBuf, err := snappy.Decode(nil, w.Body.Bytes())
850	if err != nil {
851		t.Fatal(err)
852	}
853
854	var resp remote.ReadResponse
855	if err := proto.Unmarshal(reqBuf, &resp); err != nil {
856		t.Fatal(err)
857	}
858
859	expResults := []*remote.QueryResult{
860		{
861			Timeseries: []*remote.TimeSeries{
862				{
863					Labels: []*remote.LabelPair{
864						{Name: "host", Value: "server-1"},
865					},
866					Samples: []*remote.Sample{
867						{TimestampMs: 22, Value: 2.3},
868						{TimestampMs: 10000, Value: 2992.33},
869						{TimestampMs: 44, Value: 2.3},
870						{TimestampMs: 20000, Value: 2992.33},
871					},
872				},
873				{
874					Labels: []*remote.LabelPair{
875						{Name: "host", Value: "server-2"},
876					},
877					Samples: []*remote.Sample{
878						{TimestampMs: 22, Value: 2.3},
879						{TimestampMs: 10000, Value: 2992.33},
880						{TimestampMs: 44, Value: 2.3},
881						{TimestampMs: 20000, Value: 2992.33},
882					},
883				},
884			},
885		},
886	}
887
888	if !reflect.DeepEqual(resp.Results, expResults) {
889		t.Fatalf("Results differ:\n%v", cmp.Diff(resp.Results, expResults))
890	}
891}
892
893func TestHandler_PromRead_NoResults(t *testing.T) {
894	req := &remote.ReadRequest{Queries: []*remote.Query{&remote.Query{
895		Matchers: []*remote.LabelMatcher{
896			{
897				Type:  remote.MatchType_EQUAL,
898				Name:  "__name__",
899				Value: "value",
900			},
901		},
902		StartTimestampMs: 0,
903		EndTimestampMs:   models.MaxNanoTime / int64(time.Millisecond),
904	}}}
905	data, err := proto.Marshal(req)
906	if err != nil {
907		t.Fatal("couldn't marshal prometheus request")
908	}
909	compressed := snappy.Encode(nil, data)
910	h := NewHandler(false)
911	w := httptest.NewRecorder()
912
913	b := bytes.NewReader(compressed)
914	h.ServeHTTP(w, MustNewJSONRequest("POST", "/api/v1/prom/read?db=foo", b))
915	if w.Code != http.StatusOK {
916		t.Fatalf("unexpected status: %d", w.Code)
917	}
918	reqBuf, err := snappy.Decode(nil, w.Body.Bytes())
919	if err != nil {
920		t.Fatal(err.Error())
921	}
922
923	var resp remote.ReadResponse
924	if err := proto.Unmarshal(reqBuf, &resp); err != nil {
925		t.Fatal(err.Error())
926	}
927}
928
929func TestHandler_PromRead_UnsupportedCursors(t *testing.T) {
930	req := &remote.ReadRequest{Queries: []*remote.Query{&remote.Query{
931		Matchers: []*remote.LabelMatcher{
932			{
933				Type:  remote.MatchType_EQUAL,
934				Name:  "__name__",
935				Value: "value",
936			},
937		},
938		StartTimestampMs: 0,
939		EndTimestampMs:   models.MaxNanoTime / int64(time.Millisecond),
940	}}}
941	data, err := proto.Marshal(req)
942	if err != nil {
943		t.Fatal("couldn't marshal prometheus request")
944	}
945	compressed := snappy.Encode(nil, data)
946
947	unsupported := []tsdb.Cursor{
948		internal.NewIntegerArrayCursorMock(),
949		internal.NewBooleanArrayCursorMock(),
950		internal.NewUnsignedArrayCursorMock(),
951		internal.NewStringArrayCursorMock(),
952	}
953
954	for _, cursor := range unsupported {
955		h := NewHandler(false)
956		w := httptest.NewRecorder()
957		var lb bytes.Buffer
958		h.Logger = logger.New(&lb)
959
960		more := true
961		h.Store.ResultSet.NextFn = func() bool { defer func() { more = false }(); return more }
962
963		// Set the cursor type that will be returned while iterating over
964		// the mock store.
965		h.Store.ResultSet.CursorFn = func() tsdb.Cursor {
966			return cursor
967		}
968
969		b := bytes.NewReader(compressed)
970		h.ServeHTTP(w, MustNewJSONRequest("POST", "/api/v1/prom/read?db=foo", b))
971		if w.Code != http.StatusOK {
972			t.Fatalf("unexpected status: %d", w.Code)
973		}
974		reqBuf, err := snappy.Decode(nil, w.Body.Bytes())
975		if err != nil {
976			t.Fatal(err.Error())
977		}
978
979		var resp remote.ReadResponse
980		if err := proto.Unmarshal(reqBuf, &resp); err != nil {
981			t.Fatal(err.Error())
982		}
983
984		if !strings.Contains(lb.String(), "cursor_type=") {
985			t.Fatalf("got log message %q, expected to contain \"cursor_type\"", lb.String())
986		}
987	}
988}
989
990func TestHandler_Flux_DisabledByDefault(t *testing.T) {
991	h := NewHandler(false)
992	w := httptest.NewRecorder()
993
994	body := bytes.NewBufferString(`from(bucket:"db/rp") |> range(start:-1h) |> last()`)
995	h.ServeHTTP(w, MustNewRequest("POST", "/api/v2/query", body))
996	if got := w.Code; !cmp.Equal(got, http.StatusForbidden) {
997		t.Fatalf("unexpected status: %d", got)
998	}
999
1000	exp := "Flux query service disabled. Verify flux-enabled=true in the [http] section of the InfluxDB config.\n"
1001	if got := string(w.Body.Bytes()); !cmp.Equal(got, exp) {
1002		t.Fatalf("unexpected body -got/+exp\n%s", cmp.Diff(got, exp))
1003	}
1004}
1005
1006func TestHandler_PromRead_NilResultSet(t *testing.T) {
1007	req := &remote.ReadRequest{
1008		Queries: []*remote.Query{{
1009			Matchers: []*remote.LabelMatcher{
1010				{
1011					Type:  remote.MatchType_EQUAL,
1012					Name:  "__name__",
1013					Value: "value",
1014				},
1015			},
1016			StartTimestampMs: 1,
1017			EndTimestampMs:   2,
1018		}},
1019	}
1020	data, err := proto.Marshal(req)
1021	if err != nil {
1022		log.Fatal("couldn't marshal prometheus request")
1023	}
1024	compressed := snappy.Encode(nil, data)
1025	b := bytes.NewReader(compressed)
1026
1027	h := NewHandler(false)
1028
1029	// Mocks the case when Store.Read() returns nil, nil
1030	h.Handler.Store.(*internal.StorageStoreMock).ReadFn = func(ctx context.Context, req *datatypes.ReadRequest) (reads.ResultSet, error) {
1031		return nil, nil
1032	}
1033
1034	w := httptest.NewRecorder()
1035
1036	h.ServeHTTP(w, MustNewRequest("POST", "/api/v1/prom/read?db=foo&rp=bar", b))
1037	if w.Code != http.StatusOK {
1038		t.Fatalf("unexpected status: %d", w.Code)
1039	}
1040
1041	if w.Header().Get("Content-Type") != "application/x-protobuf" {
1042		t.Fatalf("Got unexpected \"Content-Type\" header value:\n%v", cmp.Diff("application/x-protobuf", w.Header().Get("Content-Type")))
1043	}
1044	if w.Header().Get("Content-Encoding") != "snappy" {
1045		t.Fatalf("Got unexpected \"Content-Encoding\" header value:\n%v", cmp.Diff("snappy", w.Header().Get("Content-Encoding")))
1046	}
1047
1048	decompressed, err := snappy.Decode(nil, w.Body.Bytes())
1049	if err != nil {
1050		t.Fatal(err)
1051	}
1052
1053	resp := new(remote.ReadResponse)
1054	err = proto.Unmarshal(decompressed, resp)
1055	if err != nil {
1056		t.Fatal(err)
1057	}
1058
1059	expected := &remote.ReadResponse{
1060		Results: []*remote.QueryResult{{}},
1061	}
1062	if !reflect.DeepEqual(resp, expected) {
1063		t.Fatalf("Results differ:\n%v", cmp.Diff(expected, resp))
1064	}
1065}
1066
1067func TestHandler_Flux_QueryJSON(t *testing.T) {
1068	h := NewHandlerWithConfig(NewHandlerConfig(WithFlux(), WithNoLog()))
1069	called := false
1070	qry := "foo"
1071	h.Controller.QueryFn = func(ctx context.Context, compiler flux.Compiler) (i flux.Query, e error) {
1072		if exp := flux.CompilerType(lang.FluxCompilerType); compiler.CompilerType() != exp {
1073			t.Fatalf("unexpected compiler type -got/+exp\n%s", cmp.Diff(compiler.CompilerType(), exp))
1074		}
1075		if c, ok := compiler.(lang.FluxCompiler); !ok {
1076			t.Fatal("expected lang.FluxCompiler")
1077		} else if exp := qry; c.Query != exp {
1078			t.Fatalf("unexpected query -got/+exp\n%s", cmp.Diff(c.Query, exp))
1079		}
1080		called = true
1081		return internal.NewFluxQueryMock(), nil
1082	}
1083
1084	q := client.QueryRequest{Query: qry}
1085	var body bytes.Buffer
1086	if err := json.NewEncoder(&body).Encode(q); err != nil {
1087		t.Fatalf("unexpected JSON encoding error: %q", err.Error())
1088	}
1089
1090	req := MustNewRequest("POST", "/api/v2/query", &body)
1091	req.Header.Add("content-type", "application/json")
1092
1093	w := httptest.NewRecorder()
1094	h.ServeHTTP(w, req)
1095	if got := w.Code; !cmp.Equal(got, http.StatusOK) {
1096		t.Fatalf("unexpected status: %d", got)
1097	}
1098
1099	if !called {
1100		t.Fatalf("expected QueryFn to be called")
1101	}
1102}
1103
1104func TestHandler_Flux_SpecJSON(t *testing.T) {
1105	h := NewHandlerWithConfig(NewHandlerConfig(WithFlux(), WithNoLog()))
1106	called := false
1107	h.Controller.QueryFn = func(ctx context.Context, compiler flux.Compiler) (i flux.Query, e error) {
1108		if exp := flux.CompilerType(lang.SpecCompilerType); compiler.CompilerType() != exp {
1109			t.Fatalf("unexpected compiler type -got/+exp\n%s", cmp.Diff(compiler.CompilerType(), exp))
1110		}
1111		called = true
1112		return internal.NewFluxQueryMock(), nil
1113	}
1114
1115	q := client.QueryRequest{Spec: &flux.Spec{}}
1116	var body bytes.Buffer
1117	if err := json.NewEncoder(&body).Encode(q); err != nil {
1118		t.Fatalf("unexpected JSON encoding error: %q", err.Error())
1119	}
1120
1121	req := MustNewRequest("POST", "/api/v2/query", &body)
1122	req.Header.Add("content-type", "application/json")
1123
1124	w := httptest.NewRecorder()
1125	h.ServeHTTP(w, req)
1126	if got := w.Code; !cmp.Equal(got, http.StatusOK) {
1127		t.Fatalf("unexpected status: %d", got)
1128	}
1129
1130	if !called {
1131		t.Fatalf("expected QueryFn to be called")
1132	}
1133}
1134
1135func TestHandler_Flux_QueryText(t *testing.T) {
1136	h := NewHandlerWithConfig(NewHandlerConfig(WithFlux(), WithNoLog()))
1137	called := false
1138	qry := "bar"
1139	h.Controller.QueryFn = func(ctx context.Context, compiler flux.Compiler) (i flux.Query, e error) {
1140		if exp := flux.CompilerType(lang.FluxCompilerType); compiler.CompilerType() != exp {
1141			t.Fatalf("unexpected compiler type -got/+exp\n%s", cmp.Diff(compiler.CompilerType(), exp))
1142		}
1143		if c, ok := compiler.(lang.FluxCompiler); !ok {
1144			t.Fatal("expected lang.FluxCompiler")
1145		} else if exp := qry; c.Query != exp {
1146			t.Fatalf("unexpected query -got/+exp\n%s", cmp.Diff(c.Query, exp))
1147		}
1148		called = true
1149		return internal.NewFluxQueryMock(), nil
1150	}
1151
1152	req := MustNewRequest("POST", "/api/v2/query", bytes.NewBufferString(qry))
1153	req.Header.Add("content-type", "application/vnd.flux")
1154
1155	w := httptest.NewRecorder()
1156	h.ServeHTTP(w, req)
1157	if got := w.Code; !cmp.Equal(got, http.StatusOK) {
1158		t.Fatalf("unexpected status: %d", got)
1159	}
1160
1161	if !called {
1162		t.Fatalf("expected QueryFn to be called")
1163	}
1164}
1165
1166func TestHandler_Flux(t *testing.T) {
1167
1168	queryBytes := func(qs string) io.Reader {
1169		var b bytes.Buffer
1170		q := &client.QueryRequest{Query: qs}
1171		if err := json.NewEncoder(&b).Encode(q); err != nil {
1172			t.Fatalf("unexpected JSON encoding error: %q", err.Error())
1173		}
1174		return &b
1175	}
1176
1177	tests := []struct {
1178		name    string
1179		reqFn   func() *http.Request
1180		expCode int
1181		expBody string
1182	}{
1183		{
1184			name: "no media type",
1185			reqFn: func() *http.Request {
1186				return MustNewRequest("POST", "/api/v2/query", nil)
1187			},
1188			expCode: http.StatusBadRequest,
1189			expBody: "{\"error\":\"mime: no media type\"}\n",
1190		},
1191		{
1192			name: "200 OK",
1193			reqFn: func() *http.Request {
1194				req := MustNewRequest("POST", "/api/v2/query", queryBytes("foo"))
1195				req.Header.Add("content-type", "application/json")
1196				return req
1197			},
1198			expCode: http.StatusOK,
1199		},
1200	}
1201	for _, test := range tests {
1202		t.Run(test.name, func(t *testing.T) {
1203			h := NewHandlerWithConfig(NewHandlerConfig(WithFlux(), WithNoLog()))
1204			w := httptest.NewRecorder()
1205			h.ServeHTTP(w, test.reqFn())
1206			if got := w.Code; !cmp.Equal(got, test.expCode) {
1207				t.Fatalf("unexpected status: %d", got)
1208			}
1209
1210			if test.expBody != "" {
1211				if got := string(w.Body.Bytes()); !cmp.Equal(got, test.expBody) {
1212					t.Fatalf("unexpected body -got/+exp\n%s", cmp.Diff(got, test.expBody))
1213				}
1214			}
1215		})
1216	}
1217}
1218
1219func TestHandler_Flux_Auth(t *testing.T) {
1220	// Create the handler to be tested.
1221	h := NewHandlerWithConfig(NewHandlerConfig(WithFlux(), WithNoLog(), WithAuthentication()))
1222	h.MetaClient.AdminUserExistsFn = func() bool { return true }
1223	h.MetaClient.UserFn = func(username string) (meta.User, error) {
1224		if username != "user1" {
1225			return nil, meta.ErrUserNotFound
1226		}
1227		return &meta.UserInfo{
1228			Name:  "user1",
1229			Hash:  "abcd",
1230			Admin: true,
1231		}, nil
1232	}
1233	h.MetaClient.AuthenticateFn = func(u, p string) (meta.User, error) {
1234		if u != "user1" {
1235			return nil, fmt.Errorf("unexpected user: exp: user1, got: %s", u)
1236		} else if p != "abcd" {
1237			return nil, fmt.Errorf("unexpected password: exp: abcd, got: %s", p)
1238		}
1239		return h.MetaClient.User(u)
1240	}
1241
1242	h.Controller.QueryFn = func(ctx context.Context, compiler flux.Compiler) (i flux.Query, e error) {
1243		return internal.NewFluxQueryMock(), nil
1244	}
1245
1246	req := MustNewRequest("POST", "/api/v2/query", bytes.NewBufferString("bar"))
1247	req.Header.Set("content-type", "application/vnd.flux")
1248	req.Header.Set("Authorization", "Token user1:abcd")
1249	// Test the handler with valid user and password in the URL parameters.
1250	w := httptest.NewRecorder()
1251	h.ServeHTTP(w, req)
1252	if got := w.Code; !cmp.Equal(got, http.StatusOK) {
1253		t.Fatalf("unexpected status: %d", got)
1254	}
1255
1256	req.Header.Set("Authorization", "Token user1:efgh")
1257	w = httptest.NewRecorder()
1258	h.ServeHTTP(w, req)
1259	if got := w.Code; !cmp.Equal(got, http.StatusUnauthorized) {
1260		t.Fatalf("unexpected status: %d", got)
1261	}
1262}
1263
1264// Ensure the handler handles ping requests correctly.
1265// TODO: This should be expanded to verify the MetaClient check in servePing is working correctly
1266func TestHandler_Ping(t *testing.T) {
1267	h := NewHandler(false)
1268	w := httptest.NewRecorder()
1269	h.ServeHTTP(w, MustNewRequest("GET", "/ping", nil))
1270	if w.Code != http.StatusNoContent {
1271		t.Fatalf("unexpected status: %d", w.Code)
1272	}
1273	h.ServeHTTP(w, MustNewRequest("HEAD", "/ping", nil))
1274	if w.Code != http.StatusNoContent {
1275		t.Fatalf("unexpected status: %d", w.Code)
1276	}
1277}
1278
1279// Ensure the handler returns the version correctly from the different endpoints.
1280func TestHandler_Version(t *testing.T) {
1281	h := NewHandler(false)
1282	h.StatementExecutor.ExecuteStatementFn = func(stmt influxql.Statement, ctx *query.ExecutionContext) error {
1283		return nil
1284	}
1285	tests := []struct {
1286		method   string
1287		endpoint string
1288		body     io.Reader
1289	}{
1290		{
1291			method:   "GET",
1292			endpoint: "/ping",
1293			body:     nil,
1294		},
1295		{
1296			method:   "GET",
1297			endpoint: "/query?db=foo&q=SELECT+*+FROM+bar",
1298			body:     nil,
1299		},
1300		{
1301			method:   "POST",
1302			endpoint: "/write",
1303			body:     bytes.NewReader(make([]byte, 10)),
1304		},
1305		{
1306			method:   "GET",
1307			endpoint: "/notfound",
1308			body:     nil,
1309		},
1310	}
1311
1312	for _, test := range tests {
1313		w := httptest.NewRecorder()
1314		h.ServeHTTP(w, MustNewRequest(test.method, test.endpoint, test.body))
1315		if v := w.HeaderMap["X-Influxdb-Version"]; len(v) > 0 {
1316			if v[0] != "0.0.0" {
1317				t.Fatalf("unexpected version: %s", v)
1318			}
1319		} else {
1320			t.Fatalf("Header entry 'X-Influxdb-Version' not present")
1321		}
1322
1323		if v := w.HeaderMap["X-Influxdb-Build"]; len(v) > 0 {
1324			if v[0] != "OSS" {
1325				t.Fatalf("unexpected BuildType: %s", v)
1326			}
1327		} else {
1328			t.Fatalf("Header entry 'X-Influxdb-Build' not present")
1329		}
1330	}
1331}
1332
1333// Ensure the handler handles status requests correctly.
1334func TestHandler_Status(t *testing.T) {
1335	h := NewHandler(false)
1336	w := httptest.NewRecorder()
1337	h.ServeHTTP(w, MustNewRequest("GET", "/status", nil))
1338	if w.Code != http.StatusNoContent {
1339		t.Fatalf("unexpected status: %d", w.Code)
1340	}
1341	h.ServeHTTP(w, MustNewRequest("HEAD", "/status", nil))
1342	if w.Code != http.StatusNoContent {
1343		t.Fatalf("unexpected status: %d", w.Code)
1344	}
1345}
1346
1347// Ensure write endpoint can handle bad requests
1348func TestHandler_HandleBadRequestBody(t *testing.T) {
1349	b := bytes.NewReader(make([]byte, 10))
1350	h := NewHandler(false)
1351	w := httptest.NewRecorder()
1352	h.ServeHTTP(w, MustNewRequest("POST", "/write", b))
1353	if w.Code != http.StatusBadRequest {
1354		t.Fatalf("unexpected status: %d", w.Code)
1355	}
1356}
1357
1358func TestHandler_Write_EntityTooLarge_ContentLength(t *testing.T) {
1359	b := bytes.NewReader(make([]byte, 100))
1360	h := NewHandler(false)
1361	h.Config.MaxBodySize = 5
1362	h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo {
1363		return &meta.DatabaseInfo{}
1364	}
1365
1366	w := httptest.NewRecorder()
1367	h.ServeHTTP(w, MustNewRequest("POST", "/write?db=foo", b))
1368	if w.Code != http.StatusRequestEntityTooLarge {
1369		t.Fatalf("unexpected status: %d", w.Code)
1370	}
1371}
1372
1373func TestHandler_Write_SuppressLog(t *testing.T) {
1374	var buf bytes.Buffer
1375	c := httpd.NewConfig()
1376	c.SuppressWriteLog = true
1377	h := NewHandlerWithConfig(c)
1378	h.CLFLogger = log.New(&buf, "", log.LstdFlags)
1379	h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo {
1380		return &meta.DatabaseInfo{}
1381	}
1382	h.PointsWriter.WritePointsFn = func(database, retentionPolicy string, consistencyLevel models.ConsistencyLevel, user meta.User, points []models.Point) error {
1383		return nil
1384	}
1385
1386	b := strings.NewReader("cpu,host=server01 value=2\n")
1387	w := httptest.NewRecorder()
1388	h.ServeHTTP(w, MustNewRequest("POST", "/write?db=foo", b))
1389	if w.Code != http.StatusNoContent {
1390		t.Fatalf("unexpected status: %d", w.Code)
1391	}
1392
1393	// If the log has anything in it, this failed.
1394	if buf.Len() > 0 {
1395		t.Fatalf("expected no bytes to be written to the log, got %d", buf.Len())
1396	}
1397}
1398
1399// onlyReader implements io.Reader only to ensure Request.ContentLength is not set
1400type onlyReader struct {
1401	r io.Reader
1402}
1403
1404func (o onlyReader) Read(p []byte) (n int, err error) {
1405	return o.r.Read(p)
1406}
1407
1408func TestHandler_Write_EntityTooLarge_NoContentLength(t *testing.T) {
1409	b := onlyReader{bytes.NewReader(make([]byte, 100))}
1410	h := NewHandler(false)
1411	h.Config.MaxBodySize = 5
1412	h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo {
1413		return &meta.DatabaseInfo{}
1414	}
1415
1416	w := httptest.NewRecorder()
1417	h.ServeHTTP(w, MustNewRequest("POST", "/write?db=foo", b))
1418	if w.Code != http.StatusRequestEntityTooLarge {
1419		t.Fatalf("unexpected status: %d", w.Code)
1420	}
1421}
1422
1423// TestHandler_Write_NegativeMaxBodySize verifies no error occurs if MaxBodySize is < 0
1424func TestHandler_Write_NegativeMaxBodySize(t *testing.T) {
1425	b := bytes.NewReader([]byte(`foo n=1`))
1426	h := NewHandler(false)
1427	h.Config.MaxBodySize = -1
1428	h.MetaClient.DatabaseFn = func(name string) *meta.DatabaseInfo {
1429		return &meta.DatabaseInfo{}
1430	}
1431	called := false
1432	h.PointsWriter.WritePointsFn = func(_, _ string, _ models.ConsistencyLevel, _ meta.User, _ []models.Point) error {
1433		called = true
1434		return nil
1435	}
1436
1437	w := httptest.NewRecorder()
1438	h.ServeHTTP(w, MustNewRequest("POST", "/write?db=foo", b))
1439	if !called {
1440		t.Fatal("WritePoints: expected call")
1441	}
1442	if w.Code != http.StatusNoContent {
1443		t.Fatalf("unexpected status: %d", w.Code)
1444	}
1445}
1446
1447// Ensure X-Forwarded-For header writes the correct log message.
1448func TestHandler_XForwardedFor(t *testing.T) {
1449	var buf bytes.Buffer
1450	h := NewHandler(false)
1451	h.CLFLogger = log.New(&buf, "", 0)
1452
1453	req := MustNewRequest("GET", "/query", nil)
1454	req.Header.Set("X-Forwarded-For", "192.168.0.1")
1455	req.RemoteAddr = "127.0.0.1"
1456	h.ServeHTTP(httptest.NewRecorder(), req)
1457
1458	parts := strings.Split(buf.String(), " ")
1459	if parts[0] != "192.168.0.1,127.0.0.1" {
1460		t.Errorf("unexpected host ip address: %s", parts[0])
1461	}
1462}
1463
1464func TestHandler_XRequestId(t *testing.T) {
1465	var buf bytes.Buffer
1466	h := NewHandler(false)
1467	h.CLFLogger = log.New(&buf, "", 0)
1468
1469	cases := []map[string]string{
1470		{"X-Request-Id": "abc123", "Request-Id": ""},          // X-Request-Id is used.
1471		{"X-REQUEST-ID": "cde", "Request-Id": ""},             // X-REQUEST-ID is used.
1472		{"X-Request-Id": "", "Request-Id": "foobarzoo"},       // Request-Id is used.
1473		{"X-Request-Id": "abc123", "Request-Id": "foobarzoo"}, // X-Request-Id takes precedence.
1474		{"X-Request-Id": "", "Request-Id": ""},                // v1 UUID generated.
1475	}
1476
1477	for _, c := range cases {
1478		t.Run(fmt.Sprint(c), func(t *testing.T) {
1479			buf.Reset()
1480			req := MustNewRequest("GET", "/ping", nil)
1481			req.RemoteAddr = "127.0.0.1"
1482
1483			// Set the relevant request ID headers
1484			var allEmpty = true
1485			for k, v := range c {
1486				req.Header.Set(k, v)
1487				if v != "" {
1488					allEmpty = false
1489				}
1490			}
1491
1492			w := httptest.NewRecorder()
1493			h.ServeHTTP(w, req)
1494
1495			// Split up the HTTP log line. The request ID is currently located in
1496			// index 12. If the log line gets changed in the future, this test
1497			// will likely break and the index will need to be updated.
1498			parts := strings.Split(buf.String(), " ")
1499			i := 12
1500
1501			// If neither header is set then we expect a v1 UUID to be generated.
1502			if allEmpty {
1503				if got, exp := len(parts[i]), 36; got != exp {
1504					t.Fatalf("got ID of length %d, expected one of length %d", got, exp)
1505				}
1506			} else if c["X-Request-Id"] != "" {
1507				if got, exp := parts[i], c["X-Request-Id"]; got != exp {
1508					t.Fatalf("got ID of %q, expected %q", got, exp)
1509				}
1510			} else if c["X-REQUEST-ID"] != "" {
1511				if got, exp := parts[i], c["X-REQUEST-ID"]; got != exp {
1512					t.Fatalf("got ID of %q, expected %q", got, exp)
1513				}
1514			} else {
1515				if got, exp := parts[i], c["Request-Id"]; got != exp {
1516					t.Fatalf("got ID of %q, expected %q", got, exp)
1517				}
1518			}
1519
1520			// Check response headers
1521			if got, exp := w.Header().Get("Request-Id"), parts[i]; got != exp {
1522				t.Fatalf("Request-Id header was %s, expected %s", got, exp)
1523			} else if got, exp := w.Header().Get("X-Request-Id"), parts[i]; got != exp {
1524				t.Fatalf("X-Request-Id header was %s, expected %s", got, exp)
1525			}
1526		})
1527	}
1528}
1529
1530func TestThrottler_Handler(t *testing.T) {
1531	t.Run("OK", func(t *testing.T) {
1532		throttler := httpd.NewThrottler(2, 98)
1533
1534		// Send the total number of concurrent requests to the channel.
1535		var concurrentN int32
1536		concurrentCh := make(chan int)
1537
1538		h := throttler.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1539			atomic.AddInt32(&concurrentN, 1)
1540			concurrentCh <- int(atomic.LoadInt32(&concurrentN))
1541			time.Sleep(1 * time.Millisecond)
1542			atomic.AddInt32(&concurrentN, -1)
1543		}))
1544
1545		// Execute requests concurrently.
1546		const n = 100
1547		for i := 0; i < n; i++ {
1548			go func() { h.ServeHTTP(nil, nil) }()
1549		}
1550
1551		// Read the number of concurrent requests for every execution.
1552		for i := 0; i < n; i++ {
1553			if v := <-concurrentCh; v > 2 {
1554				t.Fatalf("concurrent requests exceed maximum: %d", v)
1555			}
1556		}
1557	})
1558
1559	t.Run("ErrTimeout", func(t *testing.T) {
1560		throttler := httpd.NewThrottler(2, 1)
1561		throttler.EnqueueTimeout = 1 * time.Millisecond
1562
1563		begin, end := make(chan struct{}), make(chan struct{})
1564		h := throttler.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1565			begin <- struct{}{}
1566			end <- struct{}{}
1567		}))
1568
1569		// First two requests should execute immediately.
1570		go func() { h.ServeHTTP(nil, nil) }()
1571		go func() { h.ServeHTTP(nil, nil) }()
1572
1573		<-begin
1574		<-begin
1575
1576		// Third request should be enqueued but timeout.
1577		w := httptest.NewRecorder()
1578		h.ServeHTTP(w, nil)
1579		if w.Code != http.StatusServiceUnavailable {
1580			t.Fatalf("unexpected status code: %d", w.Code)
1581		} else if body := w.Body.String(); body != "request throttled, exceeds timeout\n" {
1582			t.Fatalf("unexpected response body: %q", body)
1583		}
1584
1585		// Allow 2 existing requests to complete.
1586		<-end
1587		<-end
1588	})
1589
1590	t.Run("ErrFull", func(t *testing.T) {
1591		delay := 100 * time.Millisecond
1592		if os.Getenv("CI") != "" {
1593			delay = 2 * time.Second
1594		}
1595
1596		throttler := httpd.NewThrottler(2, 1)
1597
1598		resp := make(chan struct{})
1599		h := throttler.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1600			resp <- struct{}{}
1601		}))
1602
1603		// First two requests should execute immediately and third should be queued.
1604		go func() { h.ServeHTTP(nil, nil) }()
1605		go func() { h.ServeHTTP(nil, nil) }()
1606		go func() { h.ServeHTTP(nil, nil) }()
1607		time.Sleep(delay)
1608
1609		// Fourth request should fail when trying to enqueue.
1610		w := httptest.NewRecorder()
1611		h.ServeHTTP(w, nil)
1612		if w.Code != http.StatusServiceUnavailable {
1613			t.Fatalf("unexpected status code: %d", w.Code)
1614		} else if body := w.Body.String(); body != "request throttled, queue full\n" {
1615			t.Fatalf("unexpected response body: %q", body)
1616		}
1617
1618		// Allow 3 existing requests to complete.
1619		<-resp
1620		<-resp
1621		<-resp
1622	})
1623}
1624
1625// NewHandler represents a test wrapper for httpd.Handler.
1626type Handler struct {
1627	*httpd.Handler
1628	MetaClient        *internal.MetaClientMock
1629	StatementExecutor HandlerStatementExecutor
1630	QueryAuthorizer   HandlerQueryAuthorizer
1631	PointsWriter      HandlerPointsWriter
1632	Store             *internal.StorageStoreMock
1633	Controller        *internal.FluxControllerMock
1634}
1635
1636type configOption func(c *httpd.Config)
1637
1638func WithAuthentication() configOption {
1639	return func(c *httpd.Config) {
1640		c.AuthEnabled = true
1641		c.SharedSecret = "super secret key"
1642	}
1643}
1644
1645func WithFlux() configOption {
1646	return func(c *httpd.Config) {
1647		c.FluxEnabled = true
1648	}
1649}
1650
1651func WithNoLog() configOption {
1652	return func(c *httpd.Config) {
1653		c.LogEnabled = false
1654	}
1655}
1656
1657// NewHandlerConfig returns a new instance of httpd.Config with
1658// authentication configured.
1659func NewHandlerConfig(opts ...configOption) httpd.Config {
1660	config := httpd.NewConfig()
1661	for _, opt := range opts {
1662		opt(&config)
1663	}
1664	return config
1665}
1666
1667// NewHandler returns a new instance of Handler.
1668func NewHandler(requireAuthentication bool) *Handler {
1669	var opts []configOption
1670	if requireAuthentication {
1671		opts = append(opts, WithAuthentication())
1672	}
1673
1674	return NewHandlerWithConfig(NewHandlerConfig(opts...))
1675}
1676
1677func NewHandlerWithConfig(config httpd.Config) *Handler {
1678	h := &Handler{
1679		Handler: httpd.NewHandler(config),
1680	}
1681
1682	h.MetaClient = &internal.MetaClientMock{}
1683	h.Store = internal.NewStorageStoreMock()
1684	h.Controller = internal.NewFluxControllerMock()
1685
1686	h.Handler.MetaClient = h.MetaClient
1687	h.Handler.Store = h.Store
1688	h.Handler.QueryExecutor = query.NewExecutor()
1689	h.Handler.QueryExecutor.StatementExecutor = &h.StatementExecutor
1690	h.Handler.QueryAuthorizer = &h.QueryAuthorizer
1691	h.Handler.PointsWriter = &h.PointsWriter
1692	h.Handler.Version = "0.0.0"
1693	h.Handler.BuildType = "OSS"
1694	h.Handler.Controller = h.Controller
1695
1696	if testing.Verbose() {
1697		l := logger.New(os.Stdout)
1698		h.Handler.Logger = l
1699	}
1700
1701	return h
1702}
1703
1704// HandlerStatementExecutor is a mock implementation of Handler.StatementExecutor.
1705type HandlerStatementExecutor struct {
1706	ExecuteStatementFn func(stmt influxql.Statement, ctx *query.ExecutionContext) error
1707}
1708
1709func (e *HandlerStatementExecutor) ExecuteStatement(stmt influxql.Statement, ctx *query.ExecutionContext) error {
1710	return e.ExecuteStatementFn(stmt, ctx)
1711}
1712
1713// HandlerQueryAuthorizer is a mock implementation of Handler.QueryAuthorizer.
1714type HandlerQueryAuthorizer struct {
1715	AuthorizeQueryFn func(u meta.User, query *influxql.Query, database string) error
1716}
1717
1718func (a *HandlerQueryAuthorizer) AuthorizeQuery(u meta.User, query *influxql.Query, database string) error {
1719	return a.AuthorizeQueryFn(u, query, database)
1720}
1721
1722type HandlerPointsWriter struct {
1723	WritePointsFn func(database, retentionPolicy string, consistencyLevel models.ConsistencyLevel, user meta.User, points []models.Point) error
1724}
1725
1726func (h *HandlerPointsWriter) WritePoints(database, retentionPolicy string, consistencyLevel models.ConsistencyLevel, user meta.User, points []models.Point) error {
1727	return h.WritePointsFn(database, retentionPolicy, consistencyLevel, user, points)
1728}
1729
1730// MustNewRequest returns a new HTTP request. Panic on error.
1731func MustNewRequest(method, urlStr string, body io.Reader) *http.Request {
1732	r, err := http.NewRequest(method, urlStr, body)
1733	if err != nil {
1734		panic(err.Error())
1735	}
1736	return r
1737}
1738
1739// MustNewRequest returns a new HTTP request with the content type set. Panic on error.
1740func MustNewJSONRequest(method, urlStr string, body io.Reader) *http.Request {
1741	r := MustNewRequest(method, urlStr, body)
1742	r.Header.Set("Accept", "application/json")
1743	return r
1744}
1745
1746// MustJWTToken returns a new JWT token and signed string or panics trying.
1747func MustJWTToken(username, secret string, expired bool) (*jwt.Token, string) {
1748	token := jwt.New(jwt.GetSigningMethod("HS512"))
1749	token.Claims.(jwt.MapClaims)["username"] = username
1750	if expired {
1751		token.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(-time.Second).Unix()
1752	} else {
1753		token.Claims.(jwt.MapClaims)["exp"] = time.Now().Add(time.Minute * 10).Unix()
1754	}
1755	signed, err := token.SignedString([]byte(secret))
1756	if err != nil {
1757		panic(err)
1758	}
1759	return token, signed
1760}
1761