1package agent
2
3import (
4	"bytes"
5	"encoding/json"
6	"fmt"
7	"net/http"
8	"net/http/httptest"
9	"reflect"
10	"sync/atomic"
11	"testing"
12
13	"github.com/hashicorp/consul/testrpc"
14
15	"github.com/hashicorp/consul/agent/structs"
16	"github.com/hashicorp/consul/types"
17	"github.com/stretchr/testify/require"
18)
19
20// MockPreparedQuery is a fake endpoint that we inject into the Consul server
21// in order to observe the RPC calls made by these HTTP endpoints. This lets
22// us make sure that the request is being formed properly without having to
23// set up a realistic environment for prepared queries, which is a huge task and
24// already done in detail inside the prepared query endpoint's unit tests. If we
25// can prove this formats proper requests into that then we should be good to
26// go. We will do a single set of end-to-end tests in here to make sure that the
27// server is wired up to the right endpoint when not "injected".
28type MockPreparedQuery struct {
29	applyFn   func(*structs.PreparedQueryRequest, *string) error
30	getFn     func(*structs.PreparedQuerySpecificRequest, *structs.IndexedPreparedQueries) error
31	listFn    func(*structs.DCSpecificRequest, *structs.IndexedPreparedQueries) error
32	executeFn func(*structs.PreparedQueryExecuteRequest, *structs.PreparedQueryExecuteResponse) error
33	explainFn func(*structs.PreparedQueryExecuteRequest, *structs.PreparedQueryExplainResponse) error
34}
35
36func (m *MockPreparedQuery) Apply(args *structs.PreparedQueryRequest,
37	reply *string) (err error) {
38	if m.applyFn != nil {
39		return m.applyFn(args, reply)
40	}
41	return fmt.Errorf("should not have called Apply")
42}
43
44func (m *MockPreparedQuery) Get(args *structs.PreparedQuerySpecificRequest,
45	reply *structs.IndexedPreparedQueries) error {
46	if m.getFn != nil {
47		return m.getFn(args, reply)
48	}
49	return fmt.Errorf("should not have called Get")
50}
51
52func (m *MockPreparedQuery) List(args *structs.DCSpecificRequest,
53	reply *structs.IndexedPreparedQueries) error {
54	if m.listFn != nil {
55		return m.listFn(args, reply)
56	}
57	return fmt.Errorf("should not have called List")
58}
59
60func (m *MockPreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
61	reply *structs.PreparedQueryExecuteResponse) error {
62	if m.executeFn != nil {
63		return m.executeFn(args, reply)
64	}
65	return fmt.Errorf("should not have called Execute")
66}
67
68func (m *MockPreparedQuery) Explain(args *structs.PreparedQueryExecuteRequest,
69	reply *structs.PreparedQueryExplainResponse) error {
70	if m.explainFn != nil {
71		return m.explainFn(args, reply)
72	}
73	return fmt.Errorf("should not have called Explain")
74}
75
76func TestPreparedQuery_Create(t *testing.T) {
77	if testing.Short() {
78		t.Skip("too slow for testing.Short")
79	}
80
81	t.Parallel()
82	a := NewTestAgent(t, "")
83	defer a.Shutdown()
84
85	m := MockPreparedQuery{
86		applyFn: func(args *structs.PreparedQueryRequest, reply *string) error {
87			expected := &structs.PreparedQueryRequest{
88				Datacenter: "dc1",
89				Op:         structs.PreparedQueryCreate,
90				Query: &structs.PreparedQuery{
91					Name:    "my-query",
92					Session: "my-session",
93					Service: structs.ServiceQuery{
94						Service: "my-service",
95						Failover: structs.QueryDatacenterOptions{
96							NearestN:    4,
97							Datacenters: []string{"dc1", "dc2"},
98						},
99						IgnoreCheckIDs: []types.CheckID{"broken_check"},
100						OnlyPassing:    true,
101						Tags:           []string{"foo", "bar"},
102						NodeMeta:       map[string]string{"somekey": "somevalue"},
103						ServiceMeta:    map[string]string{"env": "prod"},
104					},
105					DNS: structs.QueryDNSOptions{
106						TTL: "10s",
107					},
108				},
109				WriteRequest: structs.WriteRequest{
110					Token: "my-token",
111				},
112			}
113			if !reflect.DeepEqual(args, expected) {
114				t.Fatalf("bad: %v", args)
115			}
116
117			*reply = "my-id"
118			return nil
119		},
120	}
121	if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
122		t.Fatalf("err: %v", err)
123	}
124
125	body := bytes.NewBuffer(nil)
126	enc := json.NewEncoder(body)
127	raw := map[string]interface{}{
128		"Name":    "my-query",
129		"Session": "my-session",
130		"Service": map[string]interface{}{
131			"Service": "my-service",
132			"Failover": map[string]interface{}{
133				"NearestN":    4,
134				"Datacenters": []string{"dc1", "dc2"},
135			},
136			"IgnoreCheckIDs": []string{"broken_check"},
137			"OnlyPassing":    true,
138			"Tags":           []string{"foo", "bar"},
139			"NodeMeta":       map[string]string{"somekey": "somevalue"},
140			"ServiceMeta":    map[string]string{"env": "prod"},
141		},
142		"DNS": map[string]interface{}{
143			"TTL": "10s",
144		},
145	}
146	if err := enc.Encode(raw); err != nil {
147		t.Fatalf("err: %v", err)
148	}
149
150	req, _ := http.NewRequest("POST", "/v1/query?token=my-token", body)
151	resp := httptest.NewRecorder()
152	obj, err := a.srv.PreparedQueryGeneral(resp, req)
153	if err != nil {
154		t.Fatalf("err: %v", err)
155	}
156	if resp.Code != 200 {
157		t.Fatalf("bad code: %d", resp.Code)
158	}
159	r, ok := obj.(preparedQueryCreateResponse)
160	if !ok {
161		t.Fatalf("unexpected: %T", obj)
162	}
163	if r.ID != "my-id" {
164		t.Fatalf("bad ID: %s", r.ID)
165	}
166}
167
168func TestPreparedQuery_List(t *testing.T) {
169	if testing.Short() {
170		t.Skip("too slow for testing.Short")
171	}
172
173	t.Parallel()
174	t.Run("", func(t *testing.T) {
175		a := NewTestAgent(t, "")
176		defer a.Shutdown()
177
178		m := MockPreparedQuery{
179			listFn: func(args *structs.DCSpecificRequest, reply *structs.IndexedPreparedQueries) error {
180				// Return an empty response.
181				return nil
182			},
183		}
184		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
185			t.Fatalf("err: %v", err)
186		}
187
188		body := bytes.NewBuffer(nil)
189		req, _ := http.NewRequest("GET", "/v1/query", body)
190		resp := httptest.NewRecorder()
191		obj, err := a.srv.PreparedQueryGeneral(resp, req)
192		if err != nil {
193			t.Fatalf("err: %v", err)
194		}
195		if resp.Code != 200 {
196			t.Fatalf("bad code: %d", resp.Code)
197		}
198		r, ok := obj.(structs.PreparedQueries)
199		if !ok {
200			t.Fatalf("unexpected: %T", obj)
201		}
202		if r == nil || len(r) != 0 {
203			t.Fatalf("bad: %v", r)
204		}
205	})
206
207	t.Run("", func(t *testing.T) {
208		a := NewTestAgent(t, "")
209		defer a.Shutdown()
210
211		m := MockPreparedQuery{
212			listFn: func(args *structs.DCSpecificRequest, reply *structs.IndexedPreparedQueries) error {
213				expected := &structs.DCSpecificRequest{
214					Datacenter: "dc1",
215					QueryOptions: structs.QueryOptions{
216						Token:             "my-token",
217						RequireConsistent: true,
218					},
219				}
220				if !reflect.DeepEqual(args, expected) {
221					t.Fatalf("bad: %v", args)
222				}
223
224				query := &structs.PreparedQuery{
225					ID: "my-id",
226				}
227				reply.Queries = append(reply.Queries, query)
228				return nil
229			},
230		}
231		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
232			t.Fatalf("err: %v", err)
233		}
234
235		body := bytes.NewBuffer(nil)
236		req, _ := http.NewRequest("GET", "/v1/query?token=my-token&consistent=true", body)
237		resp := httptest.NewRecorder()
238		obj, err := a.srv.PreparedQueryGeneral(resp, req)
239		if err != nil {
240			t.Fatalf("err: %v", err)
241		}
242		if resp.Code != 200 {
243			t.Fatalf("bad code: %d", resp.Code)
244		}
245		r, ok := obj.(structs.PreparedQueries)
246		if !ok {
247			t.Fatalf("unexpected: %T", obj)
248		}
249		if len(r) != 1 || r[0].ID != "my-id" {
250			t.Fatalf("bad: %v", r)
251		}
252	})
253}
254
255func TestPreparedQuery_Execute(t *testing.T) {
256	if testing.Short() {
257		t.Skip("too slow for testing.Short")
258	}
259
260	t.Parallel()
261	t.Run("", func(t *testing.T) {
262		a := NewTestAgent(t, "")
263		defer a.Shutdown()
264
265		m := MockPreparedQuery{
266			executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
267				// Just return an empty response.
268				return nil
269			},
270		}
271		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
272			t.Fatalf("err: %v", err)
273		}
274
275		body := bytes.NewBuffer(nil)
276		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute", body)
277		resp := httptest.NewRecorder()
278		obj, err := a.srv.PreparedQuerySpecific(resp, req)
279		if err != nil {
280			t.Fatalf("err: %v", err)
281		}
282		if resp.Code != 200 {
283			t.Fatalf("bad code: %d", resp.Code)
284		}
285		r, ok := obj.(structs.PreparedQueryExecuteResponse)
286		if !ok {
287			t.Fatalf("unexpected: %T", obj)
288		}
289		if r.Nodes == nil || len(r.Nodes) != 0 {
290			t.Fatalf("bad: %v", r)
291		}
292	})
293
294	t.Run("", func(t *testing.T) {
295		a := NewTestAgent(t, "")
296		defer a.Shutdown()
297
298		m := MockPreparedQuery{
299			executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
300				expected := &structs.PreparedQueryExecuteRequest{
301					Datacenter:    "dc1",
302					QueryIDOrName: "my-id",
303					Limit:         5,
304					Source: structs.QuerySource{
305						Datacenter: "dc1",
306						Node:       "my-node",
307					},
308					Agent: structs.QuerySource{
309						Datacenter: a.Config.Datacenter,
310						Node:       a.Config.NodeName,
311					},
312					QueryOptions: structs.QueryOptions{
313						Token:             "my-token",
314						RequireConsistent: true,
315					},
316				}
317				if !reflect.DeepEqual(args, expected) {
318					t.Fatalf("bad: %v", args)
319				}
320
321				// Just set something so we can tell this is returned.
322				reply.Failovers = 99
323				return nil
324			},
325		}
326		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
327			t.Fatalf("err: %v", err)
328		}
329
330		body := bytes.NewBuffer(nil)
331		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?token=my-token&consistent=true&near=my-node&limit=5", body)
332		resp := httptest.NewRecorder()
333		obj, err := a.srv.PreparedQuerySpecific(resp, req)
334		if err != nil {
335			t.Fatalf("err: %v", err)
336		}
337		if resp.Code != 200 {
338			t.Fatalf("bad code: %d", resp.Code)
339		}
340		r, ok := obj.(structs.PreparedQueryExecuteResponse)
341		if !ok {
342			t.Fatalf("unexpected: %T", obj)
343		}
344		if r.Failovers != 99 {
345			t.Fatalf("bad: %v", r)
346		}
347	})
348
349	t.Run("", func(t *testing.T) {
350		a := NewTestAgent(t, "")
351		defer a.Shutdown()
352
353		m := MockPreparedQuery{
354			executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
355				expected := &structs.PreparedQueryExecuteRequest{
356					Datacenter:    "dc1",
357					QueryIDOrName: "my-id",
358					Limit:         5,
359					Source: structs.QuerySource{
360						Datacenter: "dc1",
361						Node:       "_ip",
362						Ip:         "127.0.0.1",
363					},
364					Agent: structs.QuerySource{
365						Datacenter: a.Config.Datacenter,
366						Node:       a.Config.NodeName,
367					},
368					QueryOptions: structs.QueryOptions{
369						Token:             "my-token",
370						RequireConsistent: true,
371					},
372				}
373				if !reflect.DeepEqual(args, expected) {
374					t.Fatalf("bad: %v", args)
375				}
376
377				// Just set something so we can tell this is returned.
378				reply.Failovers = 99
379				return nil
380			},
381		}
382		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
383			t.Fatalf("err: %v", err)
384		}
385
386		body := bytes.NewBuffer(nil)
387		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?token=my-token&consistent=true&near=_ip&limit=5", body)
388		req.Header.Add("X-Forwarded-For", "127.0.0.1")
389		resp := httptest.NewRecorder()
390		obj, err := a.srv.PreparedQuerySpecific(resp, req)
391		if err != nil {
392			t.Fatalf("err: %v", err)
393		}
394		if resp.Code != 200 {
395			t.Fatalf("bad code: %d", resp.Code)
396		}
397		r, ok := obj.(structs.PreparedQueryExecuteResponse)
398		if !ok {
399			t.Fatalf("unexpected: %T", obj)
400		}
401		if r.Failovers != 99 {
402			t.Fatalf("bad: %v", r)
403		}
404	})
405
406	t.Run("", func(t *testing.T) {
407		a := NewTestAgent(t, "")
408		defer a.Shutdown()
409
410		m := MockPreparedQuery{
411			executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
412				expected := &structs.PreparedQueryExecuteRequest{
413					Datacenter:    "dc1",
414					QueryIDOrName: "my-id",
415					Limit:         5,
416					Source: structs.QuerySource{
417						Datacenter: "dc1",
418						Node:       "_ip",
419						Ip:         "198.18.0.1",
420					},
421					Agent: structs.QuerySource{
422						Datacenter: a.Config.Datacenter,
423						Node:       a.Config.NodeName,
424					},
425					QueryOptions: structs.QueryOptions{
426						Token:             "my-token",
427						RequireConsistent: true,
428					},
429				}
430				if !reflect.DeepEqual(args, expected) {
431					t.Fatalf("bad: %v", args)
432				}
433
434				// Just set something so we can tell this is returned.
435				reply.Failovers = 99
436				return nil
437			},
438		}
439		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
440			t.Fatalf("err: %v", err)
441		}
442
443		body := bytes.NewBuffer(nil)
444		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?token=my-token&consistent=true&near=_ip&limit=5", body)
445		req.Header.Add("X-Forwarded-For", "198.18.0.1")
446		resp := httptest.NewRecorder()
447		obj, err := a.srv.PreparedQuerySpecific(resp, req)
448		if err != nil {
449			t.Fatalf("err: %v", err)
450		}
451		if resp.Code != 200 {
452			t.Fatalf("bad code: %d", resp.Code)
453		}
454		r, ok := obj.(structs.PreparedQueryExecuteResponse)
455		if !ok {
456			t.Fatalf("unexpected: %T", obj)
457		}
458		if r.Failovers != 99 {
459			t.Fatalf("bad: %v", r)
460		}
461
462		req, _ = http.NewRequest("GET", "/v1/query/my-id/execute?token=my-token&consistent=true&near=_ip&limit=5", body)
463		req.Header.Add("X-Forwarded-For", "198.18.0.1, 198.19.0.1")
464		resp = httptest.NewRecorder()
465		obj, err = a.srv.PreparedQuerySpecific(resp, req)
466		if err != nil {
467			t.Fatalf("err: %v", err)
468		}
469		if resp.Code != 200 {
470			t.Fatalf("bad code: %d", resp.Code)
471		}
472		r, ok = obj.(structs.PreparedQueryExecuteResponse)
473		if !ok {
474			t.Fatalf("unexpected: %T", obj)
475		}
476		if r.Failovers != 99 {
477			t.Fatalf("bad: %v", r)
478		}
479	})
480
481	// Ensure the proper params are set when no special args are passed
482	t.Run("", func(t *testing.T) {
483		a := NewTestAgent(t, "")
484		defer a.Shutdown()
485
486		m := MockPreparedQuery{
487			executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
488				if args.Source.Node != "" {
489					t.Fatalf("expect node to be empty, got %q", args.Source.Node)
490				}
491				expect := structs.QuerySource{
492					Datacenter: a.Config.Datacenter,
493					Node:       a.Config.NodeName,
494				}
495				if !reflect.DeepEqual(args.Agent, expect) {
496					t.Fatalf("expect: %#v\nactual: %#v", expect, args.Agent)
497				}
498				return nil
499			},
500		}
501		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
502			t.Fatalf("err: %v", err)
503		}
504
505		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute", nil)
506		resp := httptest.NewRecorder()
507		if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
508			t.Fatalf("err: %v", err)
509		}
510	})
511
512	// Ensure WAN translation occurs for a response outside of the local DC.
513	t.Run("", func(t *testing.T) {
514		a := NewTestAgent(t, `
515			datacenter = "dc1"
516			translate_wan_addrs = true
517		`)
518		defer a.Shutdown()
519
520		m := MockPreparedQuery{
521			executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
522				nodesResponse := make(structs.CheckServiceNodes, 1)
523				nodesResponse[0].Node = &structs.Node{
524					Node: "foo", Address: "127.0.0.1",
525					TaggedAddresses: map[string]string{
526						"wan": "127.0.0.2",
527					},
528				}
529				nodesResponse[0].Service = &structs.NodeService{
530					Service: "foo",
531					Address: "10.0.1.1",
532					Port:    8080,
533					TaggedAddresses: map[string]structs.ServiceAddress{
534						"wan": {
535							Address: "198.18.0.1",
536							Port:    80,
537						},
538					},
539				}
540				reply.Nodes = nodesResponse
541				reply.Datacenter = "dc2"
542				return nil
543			},
544		}
545		require.NoError(t, a.registerEndpoint("PreparedQuery", &m))
546
547		body := bytes.NewBuffer(nil)
548		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?dc=dc2", body)
549		resp := httptest.NewRecorder()
550		obj, err := a.srv.PreparedQuerySpecific(resp, req)
551		require.NoError(t, err)
552		require.Equal(t, 200, resp.Code)
553		r, ok := obj.(structs.PreparedQueryExecuteResponse)
554		require.True(t, ok, "unexpected: %T", obj)
555		require.NotNil(t, r.Nodes)
556		require.Len(t, r.Nodes, 1)
557
558		node := r.Nodes[0]
559		require.NotNil(t, node.Node)
560		require.Equal(t, "127.0.0.2", node.Node.Address)
561		require.NotNil(t, node.Service)
562		require.Equal(t, "198.18.0.1", node.Service.Address)
563		require.Equal(t, 80, node.Service.Port)
564	})
565
566	// Ensure WAN translation doesn't occur for the local DC.
567	t.Run("", func(t *testing.T) {
568		a := NewTestAgent(t, `
569			datacenter = "dc1"
570			translate_wan_addrs = true
571		`)
572		defer a.Shutdown()
573
574		m := MockPreparedQuery{
575			executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
576				nodesResponse := make(structs.CheckServiceNodes, 1)
577				nodesResponse[0].Node = &structs.Node{
578					Node: "foo", Address: "127.0.0.1",
579					TaggedAddresses: map[string]string{
580						"wan": "127.0.0.2",
581					},
582				}
583				reply.Nodes = nodesResponse
584				reply.Datacenter = "dc1"
585				return nil
586			},
587		}
588		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
589			t.Fatalf("err: %v", err)
590		}
591
592		body := bytes.NewBuffer(nil)
593		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?dc=dc2", body)
594		resp := httptest.NewRecorder()
595		obj, err := a.srv.PreparedQuerySpecific(resp, req)
596		if err != nil {
597			t.Fatalf("err: %v", err)
598		}
599		if resp.Code != 200 {
600			t.Fatalf("bad code: %d", resp.Code)
601		}
602		r, ok := obj.(structs.PreparedQueryExecuteResponse)
603		if !ok {
604			t.Fatalf("unexpected: %T", obj)
605		}
606		if r.Nodes == nil || len(r.Nodes) != 1 {
607			t.Fatalf("bad: %v", r)
608		}
609
610		node := r.Nodes[0]
611		if node.Node.Address != "127.0.0.1" {
612			t.Fatalf("bad: %v", node.Node)
613		}
614	})
615
616	t.Run("", func(t *testing.T) {
617		a := NewTestAgent(t, "")
618		defer a.Shutdown()
619
620		body := bytes.NewBuffer(nil)
621		req, _ := http.NewRequest("GET", "/v1/query/not-there/execute", body)
622		resp := httptest.NewRecorder()
623		if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
624			t.Fatalf("err: %v", err)
625		}
626		if resp.Code != 404 {
627			t.Fatalf("bad code: %d", resp.Code)
628		}
629	})
630}
631
632func TestPreparedQuery_ExecuteCached(t *testing.T) {
633	if testing.Short() {
634		t.Skip("too slow for testing.Short")
635	}
636
637	t.Parallel()
638
639	a := NewTestAgent(t, "")
640	defer a.Shutdown()
641
642	failovers := int32(99)
643
644	m := MockPreparedQuery{
645		executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
646			// Just set something so we can tell this is returned.
647			reply.Failovers = int(atomic.LoadInt32(&failovers))
648			return nil
649		},
650	}
651	if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
652		t.Fatalf("err: %v", err)
653	}
654
655	doRequest := func(expectFailovers int, expectCache string, revalidate bool) {
656		body := bytes.NewBuffer(nil)
657		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?cached", body)
658
659		if revalidate {
660			req.Header.Set("Cache-Control", "must-revalidate")
661		}
662
663		resp := httptest.NewRecorder()
664		obj, err := a.srv.PreparedQuerySpecific(resp, req)
665
666		require := require.New(t)
667		require.NoError(err)
668		require.Equal(200, resp.Code)
669
670		r, ok := obj.(structs.PreparedQueryExecuteResponse)
671		require.True(ok)
672		require.Equal(expectFailovers, r.Failovers)
673
674		require.Equal(expectCache, resp.Header().Get("X-Cache"))
675	}
676
677	// Should be a miss at first
678	doRequest(99, "MISS", false)
679
680	// Change the actual response
681	atomic.StoreInt32(&failovers, 66)
682
683	// Request again, should be a cache hit and have the cached (not current)
684	// value.
685	doRequest(99, "HIT", false)
686
687	// Request with max age that should invalidate cache. note that this will be
688	// sent as max-age=0 as that uses seconds but that should cause immediate
689	// invalidation rather than being ignored as an unset value.
690	doRequest(66, "MISS", true)
691}
692
693func TestPreparedQuery_Explain(t *testing.T) {
694	if testing.Short() {
695		t.Skip("too slow for testing.Short")
696	}
697
698	t.Parallel()
699	t.Run("", func(t *testing.T) {
700		a := NewTestAgent(t, "")
701		defer a.Shutdown()
702
703		m := MockPreparedQuery{
704			explainFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExplainResponse) error {
705				expected := &structs.PreparedQueryExecuteRequest{
706					Datacenter:    "dc1",
707					QueryIDOrName: "my-id",
708					Limit:         5,
709					Source: structs.QuerySource{
710						Datacenter: "dc1",
711						Node:       "my-node",
712					},
713					Agent: structs.QuerySource{
714						Datacenter: a.Config.Datacenter,
715						Node:       a.Config.NodeName,
716					},
717					QueryOptions: structs.QueryOptions{
718						Token:             "my-token",
719						RequireConsistent: true,
720					},
721				}
722				if !reflect.DeepEqual(args, expected) {
723					t.Fatalf("bad: %v", args)
724				}
725
726				// Just set something so we can tell this is returned.
727				reply.Query.Name = "hello"
728				return nil
729			},
730		}
731		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
732			t.Fatalf("err: %v", err)
733		}
734
735		body := bytes.NewBuffer(nil)
736		req, _ := http.NewRequest("GET", "/v1/query/my-id/explain?token=my-token&consistent=true&near=my-node&limit=5", body)
737		resp := httptest.NewRecorder()
738		obj, err := a.srv.PreparedQuerySpecific(resp, req)
739		if err != nil {
740			t.Fatalf("err: %v", err)
741		}
742		if resp.Code != 200 {
743			t.Fatalf("bad code: %d", resp.Code)
744		}
745		r, ok := obj.(structs.PreparedQueryExplainResponse)
746		if !ok {
747			t.Fatalf("unexpected: %T", obj)
748		}
749		if r.Query.Name != "hello" {
750			t.Fatalf("bad: %v", r)
751		}
752	})
753
754	t.Run("", func(t *testing.T) {
755		a := NewTestAgent(t, "")
756		defer a.Shutdown()
757
758		body := bytes.NewBuffer(nil)
759		req, _ := http.NewRequest("GET", "/v1/query/not-there/explain", body)
760		resp := httptest.NewRecorder()
761		if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
762			t.Fatalf("err: %v", err)
763		}
764		if resp.Code != 404 {
765			t.Fatalf("bad code: %d", resp.Code)
766		}
767	})
768
769	// Ensure that Connect is passed through
770	t.Run("", func(t *testing.T) {
771		a := NewTestAgent(t, "")
772		defer a.Shutdown()
773		require := require.New(t)
774
775		m := MockPreparedQuery{
776			executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
777				require.True(args.Connect)
778				return nil
779			},
780		}
781		require.NoError(a.registerEndpoint("PreparedQuery", &m))
782
783		body := bytes.NewBuffer(nil)
784		req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?connect=true", body)
785		resp := httptest.NewRecorder()
786		_, err := a.srv.PreparedQuerySpecific(resp, req)
787		require.NoError(err)
788		require.Equal(200, resp.Code)
789	})
790}
791
792func TestPreparedQuery_Get(t *testing.T) {
793	if testing.Short() {
794		t.Skip("too slow for testing.Short")
795	}
796
797	t.Parallel()
798	t.Run("", func(t *testing.T) {
799		a := NewTestAgent(t, "")
800		defer a.Shutdown()
801
802		m := MockPreparedQuery{
803			getFn: func(args *structs.PreparedQuerySpecificRequest, reply *structs.IndexedPreparedQueries) error {
804				expected := &structs.PreparedQuerySpecificRequest{
805					Datacenter: "dc1",
806					QueryID:    "my-id",
807					QueryOptions: structs.QueryOptions{
808						Token:             "my-token",
809						RequireConsistent: true,
810					},
811				}
812				if !reflect.DeepEqual(args, expected) {
813					t.Fatalf("bad: %v", args)
814				}
815
816				query := &structs.PreparedQuery{
817					ID: "my-id",
818				}
819				reply.Queries = append(reply.Queries, query)
820				return nil
821			},
822		}
823		if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
824			t.Fatalf("err: %v", err)
825		}
826
827		body := bytes.NewBuffer(nil)
828		req, _ := http.NewRequest("GET", "/v1/query/my-id?token=my-token&consistent=true", body)
829		resp := httptest.NewRecorder()
830		obj, err := a.srv.PreparedQuerySpecific(resp, req)
831		if err != nil {
832			t.Fatalf("err: %v", err)
833		}
834		if resp.Code != 200 {
835			t.Fatalf("bad code: %d", resp.Code)
836		}
837		r, ok := obj.(structs.PreparedQueries)
838		if !ok {
839			t.Fatalf("unexpected: %T", obj)
840		}
841		if len(r) != 1 || r[0].ID != "my-id" {
842			t.Fatalf("bad: %v", r)
843		}
844	})
845
846	t.Run("", func(t *testing.T) {
847		a := NewTestAgent(t, "")
848		defer a.Shutdown()
849
850		body := bytes.NewBuffer(nil)
851		req, _ := http.NewRequest("GET", "/v1/query/f004177f-2c28-83b7-4229-eacc25fe55d1", body)
852		resp := httptest.NewRecorder()
853		if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
854			t.Fatalf("err: %v", err)
855		}
856		if resp.Code != 404 {
857			t.Fatalf("bad code: %d", resp.Code)
858		}
859	})
860}
861
862func TestPreparedQuery_Update(t *testing.T) {
863	if testing.Short() {
864		t.Skip("too slow for testing.Short")
865	}
866
867	t.Parallel()
868	a := NewTestAgent(t, "")
869	defer a.Shutdown()
870
871	m := MockPreparedQuery{
872		applyFn: func(args *structs.PreparedQueryRequest, reply *string) error {
873			expected := &structs.PreparedQueryRequest{
874				Datacenter: "dc1",
875				Op:         structs.PreparedQueryUpdate,
876				Query: &structs.PreparedQuery{
877					ID:      "my-id",
878					Name:    "my-query",
879					Session: "my-session",
880					Service: structs.ServiceQuery{
881						Service: "my-service",
882						Failover: structs.QueryDatacenterOptions{
883							NearestN:    4,
884							Datacenters: []string{"dc1", "dc2"},
885						},
886						OnlyPassing: true,
887						Tags:        []string{"foo", "bar"},
888						NodeMeta:    map[string]string{"somekey": "somevalue"},
889					},
890					DNS: structs.QueryDNSOptions{
891						TTL: "10s",
892					},
893				},
894				WriteRequest: structs.WriteRequest{
895					Token: "my-token",
896				},
897			}
898			if !reflect.DeepEqual(args, expected) {
899				t.Fatalf("bad: %v", args)
900			}
901
902			*reply = "don't care"
903			return nil
904		},
905	}
906	if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
907		t.Fatalf("err: %v", err)
908	}
909
910	body := bytes.NewBuffer(nil)
911	enc := json.NewEncoder(body)
912	raw := map[string]interface{}{
913		"ID":      "this should get ignored",
914		"Name":    "my-query",
915		"Session": "my-session",
916		"Service": map[string]interface{}{
917			"Service": "my-service",
918			"Failover": map[string]interface{}{
919				"NearestN":    4,
920				"Datacenters": []string{"dc1", "dc2"},
921			},
922			"OnlyPassing": true,
923			"Tags":        []string{"foo", "bar"},
924			"NodeMeta":    map[string]string{"somekey": "somevalue"},
925		},
926		"DNS": map[string]interface{}{
927			"TTL": "10s",
928		},
929	}
930	if err := enc.Encode(raw); err != nil {
931		t.Fatalf("err: %v", err)
932	}
933
934	req, _ := http.NewRequest("PUT", "/v1/query/my-id?token=my-token", body)
935	resp := httptest.NewRecorder()
936	if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
937		t.Fatalf("err: %v", err)
938	}
939	if resp.Code != 200 {
940		t.Fatalf("bad code: %d", resp.Code)
941	}
942}
943
944func TestPreparedQuery_Delete(t *testing.T) {
945	if testing.Short() {
946		t.Skip("too slow for testing.Short")
947	}
948
949	t.Parallel()
950	a := NewTestAgent(t, "")
951	defer a.Shutdown()
952
953	m := MockPreparedQuery{
954		applyFn: func(args *structs.PreparedQueryRequest, reply *string) error {
955			expected := &structs.PreparedQueryRequest{
956				Datacenter: "dc1",
957				Op:         structs.PreparedQueryDelete,
958				Query: &structs.PreparedQuery{
959					ID: "my-id",
960				},
961				WriteRequest: structs.WriteRequest{
962					Token: "my-token",
963				},
964			}
965			if !reflect.DeepEqual(args, expected) {
966				t.Fatalf("bad: %v", args)
967			}
968
969			*reply = "don't care"
970			return nil
971		},
972	}
973	if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
974		t.Fatalf("err: %v", err)
975	}
976
977	body := bytes.NewBuffer(nil)
978	enc := json.NewEncoder(body)
979	raw := map[string]interface{}{
980		"ID": "this should get ignored",
981	}
982	if err := enc.Encode(raw); err != nil {
983		t.Fatalf("err: %v", err)
984	}
985
986	req, _ := http.NewRequest("DELETE", "/v1/query/my-id?token=my-token", body)
987	resp := httptest.NewRecorder()
988	if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
989		t.Fatalf("err: %v", err)
990	}
991	if resp.Code != 200 {
992		t.Fatalf("bad code: %d", resp.Code)
993	}
994}
995
996func TestPreparedQuery_parseLimit(t *testing.T) {
997	t.Parallel()
998	body := bytes.NewBuffer(nil)
999	req, _ := http.NewRequest("GET", "/v1/query", body)
1000	limit := 99
1001	if err := parseLimit(req, &limit); err != nil {
1002		t.Fatalf("err: %v", err)
1003	}
1004	if limit != 0 {
1005		t.Fatalf("bad limit: %d", limit)
1006	}
1007
1008	req, _ = http.NewRequest("GET", "/v1/query?limit=11", body)
1009	if err := parseLimit(req, &limit); err != nil {
1010		t.Fatalf("err: %v", err)
1011	}
1012	if limit != 11 {
1013		t.Fatalf("bad limit: %d", limit)
1014	}
1015
1016	req, _ = http.NewRequest("GET", "/v1/query?limit=bob", body)
1017	if err := parseLimit(req, &limit); err == nil {
1018		t.Fatalf("bad: %v", err)
1019	}
1020}
1021
1022// Since we've done exhaustive testing of the calls into the endpoints above
1023// this is just a basic end-to-end sanity check to make sure things are wired
1024// correctly when calling through to the real endpoints.
1025func TestPreparedQuery_Integration(t *testing.T) {
1026	if testing.Short() {
1027		t.Skip("too slow for testing.Short")
1028	}
1029
1030	t.Parallel()
1031	a := NewTestAgent(t, "")
1032	defer a.Shutdown()
1033	testrpc.WaitForTestAgent(t, a.RPC, "dc1")
1034
1035	// Register a node and a service.
1036	{
1037		args := &structs.RegisterRequest{
1038			Datacenter: "dc1",
1039			Node:       a.Config.NodeName,
1040			Address:    "127.0.0.1",
1041			Service: &structs.NodeService{
1042				Service: "my-service",
1043			},
1044		}
1045		var out struct{}
1046		if err := a.RPC("Catalog.Register", args, &out); err != nil {
1047			t.Fatalf("err: %v", err)
1048		}
1049	}
1050
1051	// Create a query.
1052	var id string
1053	{
1054		body := bytes.NewBuffer(nil)
1055		enc := json.NewEncoder(body)
1056		raw := map[string]interface{}{
1057			"Name": "my-query",
1058			"Service": map[string]interface{}{
1059				"Service": "my-service",
1060			},
1061		}
1062		if err := enc.Encode(raw); err != nil {
1063			t.Fatalf("err: %v", err)
1064		}
1065
1066		req, _ := http.NewRequest("POST", "/v1/query", body)
1067		resp := httptest.NewRecorder()
1068		obj, err := a.srv.PreparedQueryGeneral(resp, req)
1069		if err != nil {
1070			t.Fatalf("err: %v", err)
1071		}
1072		if resp.Code != 200 {
1073			t.Fatalf("bad code: %d", resp.Code)
1074		}
1075		r, ok := obj.(preparedQueryCreateResponse)
1076		if !ok {
1077			t.Fatalf("unexpected: %T", obj)
1078		}
1079		id = r.ID
1080	}
1081
1082	// List them all.
1083	{
1084		body := bytes.NewBuffer(nil)
1085		req, _ := http.NewRequest("GET", "/v1/query?token=root", body)
1086		resp := httptest.NewRecorder()
1087		obj, err := a.srv.PreparedQueryGeneral(resp, req)
1088		if err != nil {
1089			t.Fatalf("err: %v", err)
1090		}
1091		if resp.Code != 200 {
1092			t.Fatalf("bad code: %d", resp.Code)
1093		}
1094		r, ok := obj.(structs.PreparedQueries)
1095		if !ok {
1096			t.Fatalf("unexpected: %T", obj)
1097		}
1098		if len(r) != 1 {
1099			t.Fatalf("bad: %v", r)
1100		}
1101	}
1102
1103	// Execute it.
1104	{
1105		body := bytes.NewBuffer(nil)
1106		req, _ := http.NewRequest("GET", "/v1/query/"+id+"/execute", body)
1107		resp := httptest.NewRecorder()
1108		obj, err := a.srv.PreparedQuerySpecific(resp, req)
1109		if err != nil {
1110			t.Fatalf("err: %v", err)
1111		}
1112		if resp.Code != 200 {
1113			t.Fatalf("bad code: %d", resp.Code)
1114		}
1115		r, ok := obj.(structs.PreparedQueryExecuteResponse)
1116		if !ok {
1117			t.Fatalf("unexpected: %T", obj)
1118		}
1119		if len(r.Nodes) != 1 {
1120			t.Fatalf("bad: %v", r)
1121		}
1122	}
1123
1124	// Read it back.
1125	{
1126		body := bytes.NewBuffer(nil)
1127		req, _ := http.NewRequest("GET", "/v1/query/"+id, body)
1128		resp := httptest.NewRecorder()
1129		obj, err := a.srv.PreparedQuerySpecific(resp, req)
1130		if err != nil {
1131			t.Fatalf("err: %v", err)
1132		}
1133		if resp.Code != 200 {
1134			t.Fatalf("bad code: %d", resp.Code)
1135		}
1136		r, ok := obj.(structs.PreparedQueries)
1137		if !ok {
1138			t.Fatalf("unexpected: %T", obj)
1139		}
1140		if len(r) != 1 {
1141			t.Fatalf("bad: %v", r)
1142		}
1143	}
1144
1145	// Make an update to it.
1146	{
1147		body := bytes.NewBuffer(nil)
1148		enc := json.NewEncoder(body)
1149		raw := map[string]interface{}{
1150			"Name": "my-query",
1151			"Service": map[string]interface{}{
1152				"Service":     "my-service",
1153				"OnlyPassing": true,
1154			},
1155		}
1156		if err := enc.Encode(raw); err != nil {
1157			t.Fatalf("err: %v", err)
1158		}
1159
1160		req, _ := http.NewRequest("PUT", "/v1/query/"+id, body)
1161		resp := httptest.NewRecorder()
1162		if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
1163			t.Fatalf("err: %v", err)
1164		}
1165		if resp.Code != 200 {
1166			t.Fatalf("bad code: %d", resp.Code)
1167		}
1168	}
1169
1170	// Delete it.
1171	{
1172		body := bytes.NewBuffer(nil)
1173		req, _ := http.NewRequest("DELETE", "/v1/query/"+id, body)
1174		resp := httptest.NewRecorder()
1175		if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
1176			t.Fatalf("err: %v", err)
1177		}
1178		if resp.Code != 200 {
1179			t.Fatalf("bad code: %d", resp.Code)
1180		}
1181	}
1182}
1183