1// Copyright 2012-present Oliver Eilhard. All rights reserved.
2// Use of this source code is governed by a MIT-license.
3// See http://olivere.mit-license.org/license.txt for details.
4
5package elastic
6
7import (
8	"bytes"
9	"context"
10	"encoding/json"
11	"errors"
12	"fmt"
13	"log"
14	"net"
15	"net/http"
16	"net/http/httptest"
17	"net/url"
18	"reflect"
19	"regexp"
20	"strings"
21	"sync"
22	"testing"
23	"time"
24
25	"github.com/fortytw2/leaktest"
26
27	"github.com/olivere/elastic/v7/config"
28)
29
30func findConn(s string, slice ...*conn) (int, bool) {
31	for i, t := range slice {
32		if s == t.URL() {
33			return i, true
34		}
35	}
36	return -1, false
37}
38
39// -- NewClient --
40
41func TestClientDefaults(t *testing.T) {
42	client, err := NewClient()
43	if err != nil {
44		t.Fatal(err)
45	}
46	if client.healthcheckEnabled != true {
47		t.Errorf("expected health checks to be enabled, got: %v", client.healthcheckEnabled)
48	}
49	if client.healthcheckTimeoutStartup != DefaultHealthcheckTimeoutStartup {
50		t.Errorf("expected health checks timeout on startup = %v, got: %v", DefaultHealthcheckTimeoutStartup, client.healthcheckTimeoutStartup)
51	}
52	if client.healthcheckTimeout != DefaultHealthcheckTimeout {
53		t.Errorf("expected health checks timeout = %v, got: %v", DefaultHealthcheckTimeout, client.healthcheckTimeout)
54	}
55	if client.healthcheckInterval != DefaultHealthcheckInterval {
56		t.Errorf("expected health checks interval = %v, got: %v", DefaultHealthcheckInterval, client.healthcheckInterval)
57	}
58	if client.snifferEnabled != true {
59		t.Errorf("expected sniffing to be enabled, got: %v", client.snifferEnabled)
60	}
61	if client.snifferTimeoutStartup != DefaultSnifferTimeoutStartup {
62		t.Errorf("expected sniffer timeout on startup = %v, got: %v", DefaultSnifferTimeoutStartup, client.snifferTimeoutStartup)
63	}
64	if client.snifferTimeout != DefaultSnifferTimeout {
65		t.Errorf("expected sniffer timeout = %v, got: %v", DefaultSnifferTimeout, client.snifferTimeout)
66	}
67	if client.snifferInterval != DefaultSnifferInterval {
68		t.Errorf("expected sniffer interval = %v, got: %v", DefaultSnifferInterval, client.snifferInterval)
69	}
70	if client.basicAuth != false {
71		t.Errorf("expected no basic auth; got: %v", client.basicAuth)
72	}
73	if client.basicAuthUsername != "" {
74		t.Errorf("expected no basic auth username; got: %q", client.basicAuthUsername)
75	}
76	if client.basicAuthPassword != "" {
77		t.Errorf("expected no basic auth password; got: %q", client.basicAuthUsername)
78	}
79	if client.sendGetBodyAs != "GET" {
80		t.Errorf("expected sendGetBodyAs to be GET; got: %q", client.sendGetBodyAs)
81	}
82}
83
84func TestClientWithoutURL(t *testing.T) {
85	client, err := NewClient()
86	if err != nil {
87		t.Fatal(err)
88	}
89	// Two things should happen here:
90	// 1. The client starts sniffing the cluster on DefaultURL
91	// 2. The sniffing process should find (at least) one node in the cluster, i.e. the DefaultURL
92	if len(client.conns) == 0 {
93		t.Fatalf("expected at least 1 node in the cluster, got: %d (%v)", len(client.conns), client.conns)
94	}
95	if !isTravis() {
96		if _, found := findConn(DefaultURL, client.conns...); !found {
97			t.Errorf("expected to find node with default URL of %s in %v", DefaultURL, client.conns)
98		}
99	}
100}
101
102func TestClientWithSingleURL(t *testing.T) {
103	client, err := NewClient(SetURL("http://127.0.0.1:9200"))
104	if err != nil {
105		t.Fatal(err)
106	}
107	// Two things should happen here:
108	// 1. The client starts sniffing the cluster on DefaultURL
109	// 2. The sniffing process should find (at least) one node in the cluster, i.e. the DefaultURL
110	if len(client.conns) == 0 {
111		t.Fatalf("expected at least 1 node in the cluster, got: %d (%v)", len(client.conns), client.conns)
112	}
113	if !isTravis() {
114		if _, found := findConn(DefaultURL, client.conns...); !found {
115			t.Errorf("expected to find node with default URL of %s in %v", DefaultURL, client.conns)
116		}
117	}
118}
119
120func TestClientWithMultipleURLs(t *testing.T) {
121	client, err := NewClient(SetURL("http://127.0.0.1:9200", "http://127.0.0.1:9201"))
122	if err != nil {
123		t.Fatal(err)
124	}
125	// The client should sniff both URLs, but only 127.0.0.1:9200 should return nodes.
126	if len(client.conns) != 1 {
127		t.Fatalf("expected exactly 1 node in the local cluster, got: %d (%v)", len(client.conns), client.conns)
128	}
129	if !isTravis() {
130		if client.conns[0].URL() != DefaultURL {
131			t.Errorf("expected to find node with default URL of %s in %v", DefaultURL, client.conns)
132		}
133	}
134}
135
136func TestClientWithInvalidURLs(t *testing.T) {
137	client, err := NewClient(SetURL(" http://foo.com", "http://[fe80::%31%25en0]:8080/"))
138	if err == nil {
139		t.Fatal("expected error, got nil")
140	}
141	if want, have := `first path segment in URL cannot contain colon`, err.Error(); !strings.Contains(have, want) {
142		t.Fatalf("expected error to contain %q, have %q", want, have)
143	}
144	if client != nil {
145		t.Fatal("expected client == nil")
146	}
147}
148
149func TestClientWithBasicAuth(t *testing.T) {
150	client, err := NewClient(SetBasicAuth("user", "secret"))
151	if err != nil {
152		t.Fatal(err)
153	}
154	if client.basicAuth != true {
155		t.Errorf("expected basic auth; got: %v", client.basicAuth)
156	}
157	if got, want := client.basicAuthUsername, "user"; got != want {
158		t.Errorf("expected basic auth username %q; got: %q", want, got)
159	}
160	if got, want := client.basicAuthPassword, "secret"; got != want {
161		t.Errorf("expected basic auth password %q; got: %q", want, got)
162	}
163}
164
165func TestClientWithBasicAuthInUserInfo(t *testing.T) {
166	client, err := NewClient(SetURL("http://user1:secret1@localhost:9200", "http://user2:secret2@localhost:9200"))
167	if err != nil {
168		t.Fatal(err)
169	}
170	if client.basicAuth != true {
171		t.Errorf("expected basic auth; got: %v", client.basicAuth)
172	}
173	if got, want := client.basicAuthUsername, "user1"; got != want {
174		t.Errorf("expected basic auth username %q; got: %q", want, got)
175	}
176	if got, want := client.basicAuthPassword, "secret1"; got != want {
177		t.Errorf("expected basic auth password %q; got: %q", want, got)
178	}
179}
180
181func TestClientWithXpackSecurity(t *testing.T) {
182	// Connect to ES Platinum with X-Pack Security enabled and L: elastic, P: elastic
183	client, err := NewClient(SetURL("http://elastic:elastic@127.0.0.1:9210"))
184	if err != nil {
185		t.Fatal(err)
186	}
187	if client.basicAuth != true {
188		t.Errorf("expected basic auth; got: %v", client.basicAuth)
189	}
190	if got, want := client.basicAuthUsername, "elastic"; got != want {
191		t.Errorf("expected basic auth username %q; got: %q", want, got)
192	}
193	if got, want := client.basicAuthPassword, "elastic"; got != want {
194		t.Errorf("expected basic auth password %q; got: %q", want, got)
195	}
196}
197
198func TestClientFromConfig(t *testing.T) {
199	cfg, err := config.Parse("http://127.0.0.1:9200")
200	if err != nil {
201		t.Fatal(err)
202	}
203	client, err := NewClientFromConfig(cfg)
204	if err != nil {
205		t.Fatal(err)
206	}
207	// Two things should happen here:
208	// 1. The client starts sniffing the cluster on DefaultURL
209	// 2. The sniffing process should find (at least) one node in the cluster, i.e. the DefaultURL
210	if len(client.conns) == 0 {
211		t.Fatalf("expected at least 1 node in the cluster, got: %d (%v)", len(client.conns), client.conns)
212	}
213	if !isTravis() {
214		if _, found := findConn(DefaultURL, client.conns...); !found {
215			t.Errorf("expected to find node with default URL of %s in %v", DefaultURL, client.conns)
216		}
217	}
218}
219
220func TestClientDialFromConfig(t *testing.T) {
221	cfg, err := config.Parse("http://127.0.0.1:9200")
222	if err != nil {
223		t.Fatal(err)
224	}
225	client, err := DialWithConfig(context.Background(), cfg)
226	if err != nil {
227		t.Fatal(err)
228	}
229	// Two things should happen here:
230	// 1. The client starts sniffing the cluster on DefaultURL
231	// 2. The sniffing process should find (at least) one node in the cluster, i.e. the DefaultURL
232	if len(client.conns) == 0 {
233		t.Fatalf("expected at least 1 node in the cluster, got: %d (%v)", len(client.conns), client.conns)
234	}
235	if !isTravis() {
236		if _, found := findConn(DefaultURL, client.conns...); !found {
237			t.Errorf("expected to find node with default URL of %s in %v", DefaultURL, client.conns)
238		}
239	}
240}
241
242func TestClientDialContext(t *testing.T) {
243	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
244	defer cancel()
245	client, err := DialContext(ctx, SetURL("http://localhost:9200"))
246	if err != nil {
247		t.Fatalf("expected successful connection, got %v", err)
248	}
249	client.Stop()
250}
251
252func TestClientDialContextTimeoutFromHealthcheck(t *testing.T) {
253	start := time.Now().UTC()
254	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
255	defer cancel()
256	_, err := DialContext(ctx, SetURL("http://localhost:9299"), SetHealthcheckTimeoutStartup(5*time.Second))
257	if !IsContextErr(err) {
258		t.Fatal(err)
259	}
260	if time.Since(start) < 3*time.Second {
261		t.Fatalf("early timeout")
262	}
263	if time.Since(start) >= 5*time.Second {
264		t.Fatalf("timeout probably due to healthcheck, not context cancellation")
265	}
266}
267
268func TestClientDialContextTimeoutFromSniffer(t *testing.T) {
269	start := time.Now().UTC()
270	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
271	defer cancel()
272	_, err := DialContext(ctx, SetURL("http://localhost:9299"), SetHealthcheck(false))
273	if !IsContextErr(err) {
274		t.Fatal(err)
275	}
276	if time.Since(start) < 3*time.Second {
277		t.Fatalf("early timeout")
278	}
279	if time.Since(start) >= 5*time.Second {
280		t.Fatalf("timeout probably not caused by context cancellation")
281	}
282}
283
284func TestClientSniffSuccess(t *testing.T) {
285	client, err := NewClient(SetURL("http://127.0.0.1:19200", "http://127.0.0.1:9200"))
286	if err != nil {
287		t.Fatal(err)
288	}
289	// The client should sniff both URLs, but only 127.0.0.1:9200 should return nodes.
290	if len(client.conns) != 1 {
291		t.Fatalf("expected exactly 1 node in the local cluster, got: %d (%v)", len(client.conns), client.conns)
292	}
293}
294
295func TestClientSniffFailure(t *testing.T) {
296	_, err := NewClient(SetURL("http://127.0.0.1:19200", "http://127.0.0.1:19201"))
297	if err == nil {
298		t.Fatalf("expected cluster to fail with no nodes found")
299	}
300}
301
302func TestClientSnifferCallback(t *testing.T) {
303	var calls int
304	cb := func(node *NodesInfoNode) bool {
305		calls++
306		return false
307	}
308	_, err := NewClient(
309		SetURL("http://127.0.0.1:19200", "http://127.0.0.1:9200"),
310		SetSnifferCallback(cb))
311	if err == nil {
312		t.Fatalf("expected cluster to fail with no nodes found")
313	}
314	if calls != 1 {
315		t.Fatalf("expected 1 call to the sniffer callback, got %d", calls)
316	}
317}
318
319func TestClientSniffDisabled(t *testing.T) {
320	client, err := NewClient(SetSniff(false), SetURL("http://127.0.0.1:9200", "http://127.0.0.1:9201"))
321	if err != nil {
322		t.Fatal(err)
323	}
324	// The client should not sniff, so it should have two connections.
325	if len(client.conns) != 2 {
326		t.Fatalf("expected 2 nodes, got: %d (%v)", len(client.conns), client.conns)
327	}
328	// Make two requests, so that both connections are being used
329	for i := 0; i < len(client.conns); i++ {
330		client.Flush().Do(context.TODO())
331	}
332	// The first connection (127.0.0.1:9200) should now be okay.
333	if i, found := findConn("http://127.0.0.1:9200", client.conns...); !found {
334		t.Fatalf("expected connection to %q to be found", "http://127.0.0.1:9200")
335	} else {
336		if conn := client.conns[i]; conn.IsDead() {
337			t.Fatal("expected connection to be alive, but it is dead")
338		}
339	}
340	// The second connection (127.0.0.1:9201) should now be marked as dead.
341	if i, found := findConn("http://127.0.0.1:9201", client.conns...); !found {
342		t.Fatalf("expected connection to %q to be found", "http://127.0.0.1:9201")
343	} else {
344		if conn := client.conns[i]; !conn.IsDead() {
345			t.Fatal("expected connection to be dead, but it is alive")
346		}
347	}
348}
349
350func TestClientWillMarkConnectionsAsAliveWhenAllAreDead(t *testing.T) {
351	client, err := NewClient(SetURL("http://127.0.0.1:9201"),
352		SetSniff(false), SetHealthcheck(false), SetMaxRetries(0))
353	if err != nil {
354		t.Fatal(err)
355	}
356	// We should have a connection.
357	if len(client.conns) != 1 {
358		t.Fatalf("expected 1 node, got: %d (%v)", len(client.conns), client.conns)
359	}
360
361	// Make a request, so that the connections is marked as dead.
362	client.Flush().Do(context.TODO())
363
364	// The connection should now be marked as dead.
365	if i, found := findConn("http://127.0.0.1:9201", client.conns...); !found {
366		t.Fatalf("expected connection to %q to be found", "http://127.0.0.1:9201")
367	} else {
368		if conn := client.conns[i]; !conn.IsDead() {
369			t.Fatalf("expected connection to be dead, got: %v", conn)
370		}
371	}
372
373	// Now send another request and the connection should be marked as alive again.
374	client.Flush().Do(context.TODO())
375
376	if i, found := findConn("http://127.0.0.1:9201", client.conns...); !found {
377		t.Fatalf("expected connection to %q to be found", "http://127.0.0.1:9201")
378	} else {
379		if conn := client.conns[i]; conn.IsDead() {
380			t.Fatalf("expected connection to be alive, got: %v", conn)
381		}
382	}
383}
384
385func TestClientWithRequiredPlugins(t *testing.T) {
386	_, err := NewClient(SetRequiredPlugins("no-such-plugin"))
387	if err == nil {
388		t.Fatal("expected error when creating client")
389	}
390	if got, want := err.Error(), "elastic: plugin no-such-plugin not found"; got != want {
391		t.Fatalf("expected error %q; got: %q", want, got)
392	}
393}
394
395func TestClientHealthcheckStartupTimeout(t *testing.T) {
396	start := time.Now()
397	_, err := NewClient(SetURL("http://localhost:9299"), SetHealthcheckTimeoutStartup(5*time.Second))
398	duration := time.Since(start)
399	if !IsConnErr(err) {
400		t.Fatal(err)
401	}
402	if !strings.Contains(err.Error(), "connection refused") {
403		t.Fatalf("expected error to contain %q, have %q", "connection refused", err.Error())
404	}
405	if duration < 5*time.Second {
406		t.Fatalf("expected a timeout in more than 5 seconds; got: %v", duration)
407	}
408}
409
410func TestClientHealthcheckTimeoutLeak(t *testing.T) {
411	// This test test checks if healthcheck requests are canceled
412	// after timeout.
413	// It contains couple of hacks which won't be needed once we
414	// stop supporting Go1.7.
415	// On Go1.7 it uses server side effects to monitor if connection
416	// was closed,
417	// and on Go 1.8+ we're additionally honestly monitoring routine
418	// leaks via leaktest.
419	mux := http.NewServeMux()
420
421	var reqDoneMu sync.Mutex
422	var reqDone bool
423	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
424		cn, ok := w.(http.CloseNotifier)
425		if !ok {
426			t.Fatalf("Writer is not CloseNotifier, but %v", reflect.TypeOf(w).Name())
427		}
428		<-cn.CloseNotify()
429		reqDoneMu.Lock()
430		reqDone = true
431		reqDoneMu.Unlock()
432	})
433
434	lis, err := net.Listen("tcp", "127.0.0.1:0")
435	if err != nil {
436		t.Fatalf("Couldn't setup listener: %v", err)
437	}
438	addr := lis.Addr().String()
439
440	srv := &http.Server{
441		Handler: mux,
442	}
443	go srv.Serve(lis)
444
445	cli := &Client{
446		c: &http.Client{},
447		conns: []*conn{
448			&conn{
449				url: "http://" + addr + "/",
450			},
451		},
452	}
453
454	type closer interface {
455		Shutdown(context.Context) error
456	}
457
458	// pre-Go1.8 Server can't Shutdown
459	cl, isServerCloseable := (interface{}(srv)).(closer)
460
461	// Since Go1.7 can't Shutdown() - there will be leak from server
462	// Monitor leaks on Go 1.8+
463	if isServerCloseable {
464		defer leaktest.CheckTimeout(t, time.Second*10)()
465	}
466
467	cli.healthcheck(context.Background(), time.Millisecond*500, true)
468
469	if isServerCloseable {
470		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
471		defer cancel()
472		cl.Shutdown(ctx)
473	}
474
475	<-time.After(time.Second)
476	reqDoneMu.Lock()
477	if !reqDone {
478		reqDoneMu.Unlock()
479		t.Fatal("Request wasn't canceled or stopped")
480	}
481	reqDoneMu.Unlock()
482}
483
484func TestClientSniffUpdatingNodeURL(t *testing.T) {
485	var (
486		nodeID  = "3DWDurZJQvWyWIOFnEB7VA"
487		nodeURL string
488		n       int
489	)
490	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
491		n++
492		w.Header().Set("Content-Type", "application/json")
493		if r.URL.Path != "/_nodes/http" {
494			w.WriteHeader(http.StatusInternalServerError)
495			return
496		}
497		u, err := url.Parse(nodeURL)
498		if err != nil {
499			w.WriteHeader(http.StatusInternalServerError)
500			return
501		}
502		fmt.Fprintf(w, `{
503			"cluster_name": "elasticsearch",
504			"nodes": {
505				%q: {
506					"name": "elasticsearch",
507					"http": {
508						"publish_address": %q
509					}
510				}
511			}
512		}`, nodeID, u.Host)
513		fmt.Fprintln(w)
514	})
515	ts := httptest.NewServer(h)
516	defer ts.Close()
517
518	nodeURL = ts.URL
519
520	client, err := NewSimpleClient(SetURL(ts.URL), SetSniff(true))
521	if err != nil {
522		t.Fatal(err)
523	}
524
525	if want, have := 0, n; want != have {
526		t.Fatalf("expected %d calls to handler; got %d", want, have)
527	}
528	if want, have := 1, len(client.conns); want != have {
529		t.Fatalf("expected %d connections; got %d", want, have)
530	}
531	if want, have := nodeURL, client.conns[0].URL(); want != have {
532		t.Fatalf("expected URL=%q; got %q", want, have)
533	}
534
535	err = client.sniff(context.Background(), 2*time.Second)
536	if err != nil {
537		t.Fatal(err)
538	}
539	if want, have := 1, n; want != have {
540		t.Fatalf("expected %d calls to handler; got %d", want, have)
541	}
542	if want, have := 1, len(client.conns); want != have {
543		t.Fatalf("expected %d connections; got %d", want, have)
544	}
545	if want, have := nodeID, client.conns[0].NodeID(); want != have {
546		t.Fatalf("expected NodeID=%q; got %q", want, have)
547	}
548	if want, have := nodeURL, client.conns[0].URL(); want != have {
549		t.Fatalf("expected URL=%q; got %q", want, have)
550	}
551	oldNodeID := client.conns[0].NodeID()
552	oldURL := client.conns[0].URL()
553
554	nodeURL = "http://127.0.0.1:9999" // some other nodeURL to report
555
556	err = client.sniff(context.Background(), 2*time.Second)
557	if err != nil {
558		t.Fatal(err)
559	}
560	if want, have := 2, n; want != have {
561		t.Fatalf("expected %d calls to handler; got %d", want, have)
562	}
563	if want, have := 1, len(client.conns); want != have {
564		t.Fatalf("expected %d connections; got %d", want, have)
565	}
566	newNodeID := client.conns[0].NodeID()
567	newURL := client.conns[0].URL()
568
569	// NodeID mustn't change
570	if newNodeID != oldNodeID {
571		t.Fatalf("expected NodeID=%q; got %q", oldNodeID, newNodeID)
572	}
573	// URL must have change
574	if newURL == oldURL {
575		t.Fatalf("expected to update URL=%q to %q", oldURL, newURL)
576	}
577}
578
579// -- NewSimpleClient --
580
581func TestSimpleClientDefaults(t *testing.T) {
582	client, err := NewSimpleClient()
583	if err != nil {
584		t.Fatal(err)
585	}
586	if client.healthcheckEnabled != false {
587		t.Errorf("expected health checks to be disabled, got: %v", client.healthcheckEnabled)
588	}
589	if client.healthcheckTimeoutStartup != off {
590		t.Errorf("expected health checks timeout on startup = %v, got: %v", off, client.healthcheckTimeoutStartup)
591	}
592	if client.healthcheckTimeout != off {
593		t.Errorf("expected health checks timeout = %v, got: %v", off, client.healthcheckTimeout)
594	}
595	if client.healthcheckInterval != off {
596		t.Errorf("expected health checks interval = %v, got: %v", off, client.healthcheckInterval)
597	}
598	if client.snifferEnabled != false {
599		t.Errorf("expected sniffing to be disabled, got: %v", client.snifferEnabled)
600	}
601	if client.snifferTimeoutStartup != off {
602		t.Errorf("expected sniffer timeout on startup = %v, got: %v", off, client.snifferTimeoutStartup)
603	}
604	if client.snifferTimeout != off {
605		t.Errorf("expected sniffer timeout = %v, got: %v", off, client.snifferTimeout)
606	}
607	if client.snifferInterval != off {
608		t.Errorf("expected sniffer interval = %v, got: %v", off, client.snifferInterval)
609	}
610	if client.basicAuth != false {
611		t.Errorf("expected no basic auth; got: %v", client.basicAuth)
612	}
613	if client.basicAuthUsername != "" {
614		t.Errorf("expected no basic auth username; got: %q", client.basicAuthUsername)
615	}
616	if client.basicAuthPassword != "" {
617		t.Errorf("expected no basic auth password; got: %q", client.basicAuthUsername)
618	}
619	if client.sendGetBodyAs != "GET" {
620		t.Errorf("expected sendGetBodyAs to be GET; got: %q", client.sendGetBodyAs)
621	}
622}
623
624// -- Start and stop --
625
626func TestClientStartAndStop(t *testing.T) {
627	client, err := NewClient()
628	if err != nil {
629		t.Fatal(err)
630	}
631
632	running := client.IsRunning()
633	if !running {
634		t.Fatalf("expected background processes to run; got: %v", running)
635	}
636
637	// Stop
638	client.Stop()
639	running = client.IsRunning()
640	if running {
641		t.Fatalf("expected background processes to be stopped; got: %v", running)
642	}
643
644	// Stop again => no-op
645	client.Stop()
646	running = client.IsRunning()
647	if running {
648		t.Fatalf("expected background processes to be stopped; got: %v", running)
649	}
650
651	// Start
652	client.Start()
653	running = client.IsRunning()
654	if !running {
655		t.Fatalf("expected background processes to run; got: %v", running)
656	}
657
658	// Start again => no-op
659	client.Start()
660	running = client.IsRunning()
661	if !running {
662		t.Fatalf("expected background processes to run; got: %v", running)
663	}
664}
665
666func TestClientStartAndStopWithSnifferAndHealthchecksDisabled(t *testing.T) {
667	client, err := NewClient(SetSniff(false), SetHealthcheck(false))
668	if err != nil {
669		t.Fatal(err)
670	}
671
672	running := client.IsRunning()
673	if !running {
674		t.Fatalf("expected background processes to run; got: %v", running)
675	}
676
677	// Stop
678	client.Stop()
679	running = client.IsRunning()
680	if running {
681		t.Fatalf("expected background processes to be stopped; got: %v", running)
682	}
683
684	// Stop again => no-op
685	client.Stop()
686	running = client.IsRunning()
687	if running {
688		t.Fatalf("expected background processes to be stopped; got: %v", running)
689	}
690
691	// Start
692	client.Start()
693	running = client.IsRunning()
694	if !running {
695		t.Fatalf("expected background processes to run; got: %v", running)
696	}
697
698	// Start again => no-op
699	client.Start()
700	running = client.IsRunning()
701	if !running {
702		t.Fatalf("expected background processes to run; got: %v", running)
703	}
704}
705
706// -- Sniffing --
707
708func TestClientSniffNode(t *testing.T) {
709	client, err := NewClient()
710	if err != nil {
711		t.Fatal(err)
712	}
713
714	ch := make(chan []*conn)
715	go func() { ch <- client.sniffNode(context.Background(), DefaultURL) }()
716
717	select {
718	case nodes := <-ch:
719		if len(nodes) != 1 {
720			t.Fatalf("expected %d nodes; got: %d", 1, len(nodes))
721		}
722		pattern := `http:\/\/[\d\.]+:9200`
723		matched, err := regexp.MatchString(pattern, nodes[0].URL())
724		if err != nil {
725			t.Fatal(err)
726		}
727		if !matched {
728			t.Fatalf("expected node URL pattern %q; got: %q", pattern, nodes[0].URL())
729		}
730	case <-time.After(2 * time.Second):
731		t.Fatal("expected no timeout in sniff node")
732		break
733	}
734}
735
736func TestClientSniffOnDefaultURL(t *testing.T) {
737	client, _ := NewClient()
738	if client == nil {
739		t.Fatal("no client returned")
740	}
741
742	ch := make(chan error, 1)
743	go func() {
744		ch <- client.sniff(context.Background(), DefaultSnifferTimeoutStartup)
745	}()
746
747	select {
748	case err := <-ch:
749		if err != nil {
750			t.Fatalf("expected sniff to succeed; got: %v", err)
751		}
752		if len(client.conns) != 1 {
753			t.Fatalf("expected %d nodes; got: %d", 1, len(client.conns))
754		}
755		pattern := `http:\/\/[\d\.]+:9200`
756		matched, err := regexp.MatchString(pattern, client.conns[0].URL())
757		if err != nil {
758			t.Fatal(err)
759		}
760		if !matched {
761			t.Fatalf("expected node URL pattern %q; got: %q", pattern, client.conns[0].URL())
762		}
763	case <-time.After(2 * time.Second):
764		t.Fatal("expected no timeout in sniff")
765		break
766	}
767}
768
769func TestClientSniffTimeoutLeak(t *testing.T) {
770	// This test test checks if sniff requests are canceled
771	// after timeout.
772	// It contains couple of hacks which won't be needed once we
773	// stop supporting Go1.7.
774	// On Go1.7 it uses server side effects to monitor if connection
775	// was closed,
776	// and on Go 1.8+ we're additionally honestly monitoring routine
777	// leaks via leaktest.
778	mux := http.NewServeMux()
779
780	var reqDoneMu sync.Mutex
781	var reqDone bool
782	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
783		cn, ok := w.(http.CloseNotifier)
784		if !ok {
785			t.Fatalf("Writer is not CloseNotifier, but %v", reflect.TypeOf(w).Name())
786		}
787		<-cn.CloseNotify()
788		reqDoneMu.Lock()
789		reqDone = true
790		reqDoneMu.Unlock()
791	})
792
793	lis, err := net.Listen("tcp", "127.0.0.1:0")
794	if err != nil {
795		t.Fatalf("Couldn't setup listener: %v", err)
796	}
797	addr := lis.Addr().String()
798
799	srv := &http.Server{
800		Handler: mux,
801	}
802	go srv.Serve(lis)
803
804	cli := &Client{
805		c: &http.Client{},
806		conns: []*conn{
807			&conn{
808				url: "http://" + addr + "/",
809			},
810		},
811		snifferEnabled: true,
812	}
813
814	type closer interface {
815		Shutdown(context.Context) error
816	}
817
818	// pre-Go1.8 Server can't Shutdown
819	cl, isServerCloseable := (interface{}(srv)).(closer)
820
821	// Since Go1.7 can't Shutdown() - there will be leak from server
822	// Monitor leaks on Go 1.8+
823	if isServerCloseable {
824		defer leaktest.CheckTimeout(t, time.Second*10)()
825	}
826
827	cli.sniff(context.Background(), time.Millisecond*500)
828
829	if isServerCloseable {
830		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
831		defer cancel()
832		cl.Shutdown(ctx)
833	}
834
835	<-time.After(time.Second)
836	reqDoneMu.Lock()
837	if !reqDone {
838		reqDoneMu.Unlock()
839		t.Fatal("Request wasn't canceled or stopped")
840	}
841	reqDoneMu.Unlock()
842}
843
844func TestClientExtractHostname(t *testing.T) {
845	tests := []struct {
846		Scheme  string
847		Address string
848		Output  string
849	}{
850		{
851			Scheme:  "http",
852			Address: "127.0.0.1:9200",
853			Output:  "http://127.0.0.1:9200",
854		},
855		{
856			Scheme:  "https",
857			Address: "127.0.0.1:9200",
858			Output:  "https://127.0.0.1:9200",
859		},
860		{
861			Scheme:  "http",
862			Address: "127.0.0.1:19200",
863			Output:  "http://127.0.0.1:19200",
864		},
865		{
866			Scheme:  "http",
867			Address: "myelk.local/10.1.0.24:9200",
868			Output:  "http://myelk.local:9200",
869		},
870	}
871
872	client, err := NewClient(SetSniff(false), SetHealthcheck(false))
873	if err != nil {
874		t.Fatal(err)
875	}
876	for _, test := range tests {
877		got := client.extractHostname(test.Scheme, test.Address)
878		if want := test.Output; want != got {
879			t.Errorf("expected %q; got: %q", want, got)
880		}
881	}
882}
883
884// -- Selector --
885
886func TestClientSelectConnHealthy(t *testing.T) {
887	client, err := NewClient(
888		SetSniff(false),
889		SetHealthcheck(false),
890		SetURL("http://127.0.0.1:9200", "http://127.0.0.1:9201"))
891	if err != nil {
892		t.Fatal(err)
893	}
894
895	// Both are healthy, so we should get both URLs in round-robin
896	client.conns[0].MarkAsHealthy()
897	client.conns[1].MarkAsHealthy()
898
899	// #1: Return 1st
900	c, err := client.next()
901	if err != nil {
902		t.Fatal(err)
903	}
904	if c.URL() != client.conns[0].URL() {
905		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[0].URL())
906	}
907	// #2: Return 2nd
908	c, err = client.next()
909	if err != nil {
910		t.Fatal(err)
911	}
912	if c.URL() != client.conns[1].URL() {
913		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[1].URL())
914	}
915	// #3: Return 1st
916	c, err = client.next()
917	if err != nil {
918		t.Fatal(err)
919	}
920	if c.URL() != client.conns[0].URL() {
921		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[0].URL())
922	}
923}
924
925func TestClientSelectConnHealthyAndDead(t *testing.T) {
926	client, err := NewClient(
927		SetSniff(false),
928		SetHealthcheck(false),
929		SetURL("http://127.0.0.1:9200", "http://127.0.0.1:9201"))
930	if err != nil {
931		t.Fatal(err)
932	}
933
934	// 1st is healthy, second is dead
935	client.conns[0].MarkAsHealthy()
936	client.conns[1].MarkAsDead()
937
938	// #1: Return 1st
939	c, err := client.next()
940	if err != nil {
941		t.Fatal(err)
942	}
943	if c.URL() != client.conns[0].URL() {
944		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[0].URL())
945	}
946	// #2: Return 1st again
947	c, err = client.next()
948	if err != nil {
949		t.Fatal(err)
950	}
951	if c.URL() != client.conns[0].URL() {
952		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[0].URL())
953	}
954	// #3: Return 1st again and again
955	c, err = client.next()
956	if err != nil {
957		t.Fatal(err)
958	}
959	if c.URL() != client.conns[0].URL() {
960		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[0].URL())
961	}
962}
963
964func TestClientSelectConnDeadAndHealthy(t *testing.T) {
965	client, err := NewClient(
966		SetSniff(false),
967		SetHealthcheck(false),
968		SetURL("http://127.0.0.1:9200", "http://127.0.0.1:9201"))
969	if err != nil {
970		t.Fatal(err)
971	}
972
973	// 1st is dead, 2nd is healthy
974	client.conns[0].MarkAsDead()
975	client.conns[1].MarkAsHealthy()
976
977	// #1: Return 2nd
978	c, err := client.next()
979	if err != nil {
980		t.Fatal(err)
981	}
982	if c.URL() != client.conns[1].URL() {
983		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[1].URL())
984	}
985	// #2: Return 2nd again
986	c, err = client.next()
987	if err != nil {
988		t.Fatal(err)
989	}
990	if c.URL() != client.conns[1].URL() {
991		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[1].URL())
992	}
993	// #3: Return 2nd again and again
994	c, err = client.next()
995	if err != nil {
996		t.Fatal(err)
997	}
998	if c.URL() != client.conns[1].URL() {
999		t.Fatalf("expected %s; got: %s", c.URL(), client.conns[1].URL())
1000	}
1001}
1002
1003func TestClientSelectConnAllDead(t *testing.T) {
1004	client, err := NewClient(
1005		SetSniff(false),
1006		SetHealthcheck(false),
1007		SetURL("http://127.0.0.1:9200", "http://127.0.0.1:9201"))
1008	if err != nil {
1009		t.Fatal(err)
1010	}
1011
1012	// Both are dead
1013	client.conns[0].MarkAsDead()
1014	client.conns[1].MarkAsDead()
1015
1016	// If all connections are dead, next should make them alive again, but
1017	// still return an error when it first finds out.
1018	c, err := client.next()
1019	if !IsConnErr(err) {
1020		t.Fatal(err)
1021	}
1022	if c != nil {
1023		t.Fatalf("expected no connection; got: %v", c)
1024	}
1025	// Return a connection
1026	c, err = client.next()
1027	if err != nil {
1028		t.Fatalf("expected no error; got: %v", err)
1029	}
1030	if c == nil {
1031		t.Fatalf("expected connection; got: %v", c)
1032	}
1033	// Return a connection
1034	c, err = client.next()
1035	if err != nil {
1036		t.Fatalf("expected no error; got: %v", err)
1037	}
1038	if c == nil {
1039		t.Fatalf("expected connection; got: %v", c)
1040	}
1041}
1042
1043// -- ElasticsearchVersion --
1044
1045func TestElasticsearchVersion(t *testing.T) {
1046	client, err := NewClient()
1047	if err != nil {
1048		t.Fatal(err)
1049	}
1050	version, err := client.ElasticsearchVersion(DefaultURL)
1051	if err != nil {
1052		t.Fatal(err)
1053	}
1054	if version == "" {
1055		t.Errorf("expected a version number, got: %q", version)
1056	}
1057}
1058
1059// -- IndexNames --
1060
1061func TestIndexNames(t *testing.T) {
1062	client := setupTestClientAndCreateIndex(t)
1063	names, err := client.IndexNames()
1064	if err != nil {
1065		t.Fatal(err)
1066	}
1067	if len(names) == 0 {
1068		t.Fatalf("expected some index names, got: %d", len(names))
1069	}
1070	var found bool
1071	for _, name := range names {
1072		if name == testIndexName {
1073			found = true
1074			break
1075		}
1076	}
1077	if !found {
1078		t.Fatalf("expected to find index %q; got: %v", testIndexName, found)
1079	}
1080}
1081
1082// -- PerformRequest --
1083
1084func TestPerformRequest(t *testing.T) {
1085	client, err := NewClient()
1086	if err != nil {
1087		t.Fatal(err)
1088	}
1089	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1090		Method: "GET",
1091		Path:   "/",
1092	})
1093	if err != nil {
1094		t.Fatal(err)
1095	}
1096	if res == nil {
1097		t.Fatal("expected response to be != nil")
1098	}
1099
1100	ret := new(PingResult)
1101	if err := json.Unmarshal(res.Body, ret); err != nil {
1102		t.Fatalf("expected no error on decode; got: %v", err)
1103	}
1104	if ret.ClusterName == "" {
1105		t.Errorf("expected cluster name; got: %q", ret.ClusterName)
1106	}
1107}
1108
1109func TestPerformRequestWithSimpleClient(t *testing.T) {
1110	client, err := NewSimpleClient()
1111	if err != nil {
1112		t.Fatal(err)
1113	}
1114	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1115		Method: "GET",
1116		Path:   "/",
1117	})
1118	if err != nil {
1119		t.Fatal(err)
1120	}
1121	if res == nil {
1122		t.Fatal("expected response to be != nil")
1123	}
1124
1125	ret := new(PingResult)
1126	if err := json.Unmarshal(res.Body, ret); err != nil {
1127		t.Fatalf("expected no error on decode; got: %v", err)
1128	}
1129	if ret.ClusterName == "" {
1130		t.Errorf("expected cluster name; got: %q", ret.ClusterName)
1131	}
1132}
1133
1134func TestPerformRequestWithLogger(t *testing.T) {
1135	var w bytes.Buffer
1136	out := log.New(&w, "LOGGER ", log.LstdFlags)
1137
1138	client, err := NewClient(SetInfoLog(out), SetSniff(false))
1139	if err != nil {
1140		t.Fatal(err)
1141	}
1142
1143	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1144		Method: "GET",
1145		Path:   "/",
1146	})
1147	if err != nil {
1148		t.Fatal(err)
1149	}
1150	if res == nil {
1151		t.Fatal("expected response to be != nil")
1152	}
1153
1154	ret := new(PingResult)
1155	if err := json.Unmarshal(res.Body, ret); err != nil {
1156		t.Fatalf("expected no error on decode; got: %v", err)
1157	}
1158	if ret.ClusterName == "" {
1159		t.Errorf("expected cluster name; got: %q", ret.ClusterName)
1160	}
1161
1162	got := w.String()
1163	pattern := `^LOGGER \d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2} GET http://.*/ \[status:200, request:\d+\.\d{3}s\]\n`
1164	matched, err := regexp.MatchString(pattern, got)
1165	if err != nil {
1166		t.Fatalf("expected log line to match %q; got: %v", pattern, err)
1167	}
1168	if !matched {
1169		t.Errorf("expected log line to match %q; got: %v", pattern, got)
1170	}
1171}
1172
1173func TestPerformRequestWithLoggerAndTracer(t *testing.T) {
1174	var lw bytes.Buffer
1175	lout := log.New(&lw, "LOGGER ", log.LstdFlags)
1176
1177	var tw bytes.Buffer
1178	tout := log.New(&tw, "TRACER ", log.LstdFlags)
1179
1180	client, err := NewClient(SetInfoLog(lout), SetTraceLog(tout), SetSniff(false))
1181	if err != nil {
1182		t.Fatal(err)
1183	}
1184
1185	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1186		Method: "GET",
1187		Path:   "/",
1188	})
1189	if err != nil {
1190		t.Fatal(err)
1191	}
1192	if res == nil {
1193		t.Fatal("expected response to be != nil")
1194	}
1195
1196	ret := new(PingResult)
1197	if err := json.Unmarshal(res.Body, ret); err != nil {
1198		t.Fatalf("expected no error on decode; got: %v", err)
1199	}
1200	if ret.ClusterName == "" {
1201		t.Errorf("expected cluster name; got: %q", ret.ClusterName)
1202	}
1203
1204	lgot := lw.String()
1205	if lgot == "" {
1206		t.Errorf("expected logger output; got: %q", lgot)
1207	}
1208
1209	tgot := tw.String()
1210	if tgot == "" {
1211		t.Errorf("expected tracer output; got: %q", tgot)
1212	}
1213}
1214func TestPerformRequestWithTracerOnError(t *testing.T) {
1215	var tw bytes.Buffer
1216	tout := log.New(&tw, "TRACER ", log.LstdFlags)
1217
1218	client, err := NewClient(SetTraceLog(tout), SetSniff(false))
1219	if err != nil {
1220		t.Fatal(err)
1221	}
1222
1223	client.PerformRequest(context.TODO(), PerformRequestOptions{
1224		Method: "GET",
1225		Path:   "/no-such-index",
1226	})
1227
1228	tgot := tw.String()
1229	if tgot == "" {
1230		t.Errorf("expected tracer output; got: %q", tgot)
1231	}
1232}
1233
1234type customLogger struct {
1235	out bytes.Buffer
1236}
1237
1238func (l *customLogger) Printf(format string, v ...interface{}) {
1239	l.out.WriteString(fmt.Sprintf(format, v...) + "\n")
1240}
1241
1242func TestPerformRequestWithCustomLogger(t *testing.T) {
1243	logger := &customLogger{}
1244
1245	client, err := NewClient(SetInfoLog(logger), SetSniff(false))
1246	if err != nil {
1247		t.Fatal(err)
1248	}
1249
1250	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1251		Method: "GET",
1252		Path:   "/",
1253	})
1254	if err != nil {
1255		t.Fatal(err)
1256	}
1257	if res == nil {
1258		t.Fatal("expected response to be != nil")
1259	}
1260
1261	ret := new(PingResult)
1262	if err := json.Unmarshal(res.Body, ret); err != nil {
1263		t.Fatalf("expected no error on decode; got: %v", err)
1264	}
1265	if ret.ClusterName == "" {
1266		t.Errorf("expected cluster name; got: %q", ret.ClusterName)
1267	}
1268
1269	got := logger.out.String()
1270	pattern := `^GET http://.*/ \[status:200, request:\d+\.\d{3}s\]\n`
1271	matched, err := regexp.MatchString(pattern, got)
1272	if err != nil {
1273		t.Fatalf("expected log line to match %q; got: %v", pattern, err)
1274	}
1275	if !matched {
1276		t.Errorf("expected log line to match %q; got: %v", pattern, got)
1277	}
1278}
1279
1280func TestPerformRequestWithMaxResponseSize(t *testing.T) {
1281	client, err := NewClient()
1282	if err != nil {
1283		t.Fatal(err)
1284	}
1285	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1286		Method:          "GET",
1287		Path:            "/",
1288		MaxResponseSize: 1000,
1289	})
1290	if err != nil {
1291		t.Fatal(err)
1292	}
1293	if res == nil {
1294		t.Fatal("expected response to be != nil")
1295	}
1296
1297	_, err = client.PerformRequest(context.TODO(), PerformRequestOptions{
1298		Method:          "GET",
1299		Path:            "/",
1300		MaxResponseSize: 100,
1301	})
1302	if err != ErrResponseSize {
1303		t.Fatal("expected response size error")
1304	}
1305}
1306
1307func TestPerformRequestOnNoConnectionsWithHealthcheckRevival(t *testing.T) {
1308	fail := func(r *http.Request) (*http.Response, error) {
1309		return nil, errors.New("request failed")
1310	}
1311	tr := &failingTransport{path: "/fail", fail: fail}
1312	httpClient := &http.Client{Transport: tr}
1313	client, err := NewClient(SetHttpClient(httpClient), SetMaxRetries(0), SetHealthcheck(true))
1314	if err != nil {
1315		t.Fatal(err)
1316	}
1317
1318	// Run against a failing endpoint to mark connection as dead
1319	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1320		Method: "GET",
1321		Path:   "/fail",
1322	})
1323	if err == nil {
1324		t.Fatal(err)
1325	}
1326	if res != nil {
1327		t.Fatal("expected no response")
1328	}
1329
1330	// Forced healthcheck should bring connection back to life and complete request
1331	res, err = client.PerformRequest(context.TODO(), PerformRequestOptions{
1332		Method: "GET",
1333		Path:   "/",
1334	})
1335	if err != nil {
1336		t.Fatal(err)
1337	}
1338	if res == nil {
1339		t.Fatal("expected response to be != nil")
1340	}
1341}
1342
1343// failingTransport will run a fail callback if it sees a given URL path prefix.
1344type failingTransport struct {
1345	path string                                      // path prefix to look for
1346	fail func(*http.Request) (*http.Response, error) // call when path prefix is found
1347	next http.RoundTripper                           // next round-tripper (use http.DefaultTransport if nil)
1348}
1349
1350// RoundTrip implements a failing transport.
1351func (tr *failingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
1352	if strings.HasPrefix(r.URL.Path, tr.path) && tr.fail != nil {
1353		return tr.fail(r)
1354	}
1355	if tr.next != nil {
1356		return tr.next.RoundTrip(r)
1357	}
1358	return http.DefaultTransport.RoundTrip(r)
1359}
1360
1361func TestPerformRequestRetryOnHttpError(t *testing.T) {
1362	var numFailedReqs int
1363	fail := func(r *http.Request) (*http.Response, error) {
1364		numFailedReqs += 1
1365		return nil, errors.New("request failed")
1366	}
1367
1368	// Run against a failing endpoint and see if PerformRequest
1369	// retries correctly.
1370	tr := &failingTransport{path: "/fail", fail: fail}
1371	httpClient := &http.Client{Transport: tr}
1372
1373	client, err := NewClient(SetHttpClient(httpClient), SetMaxRetries(5), SetHealthcheck(false))
1374	if err != nil {
1375		t.Fatal(err)
1376	}
1377
1378	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1379		Method: "GET",
1380		Path:   "/fail",
1381	})
1382	if err == nil {
1383		t.Fatal("expected error")
1384	}
1385	if res != nil {
1386		t.Fatal("expected no response")
1387	}
1388	// Connection should be marked as dead after it failed
1389	if numFailedReqs != 5 {
1390		t.Errorf("expected %d failed requests; got: %d", 5, numFailedReqs)
1391	}
1392}
1393
1394func TestPerformRequestNoRetryOnValidButUnsuccessfulHttpStatus(t *testing.T) {
1395	var numFailedReqs int
1396	fail := func(r *http.Request) (*http.Response, error) {
1397		numFailedReqs += 1
1398		return &http.Response{Request: r, StatusCode: 500, Body: http.NoBody}, nil
1399	}
1400
1401	// Run against a failing endpoint and see if PerformRequest
1402	// retries correctly.
1403	tr := &failingTransport{path: "/fail", fail: fail}
1404	httpClient := &http.Client{Transport: tr}
1405
1406	client, err := NewClient(SetHttpClient(httpClient), SetMaxRetries(5), SetHealthcheck(false))
1407	if err != nil {
1408		t.Fatal(err)
1409	}
1410
1411	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1412		Method: "GET",
1413		Path:   "/fail",
1414	})
1415	if err == nil {
1416		t.Fatal("expected error")
1417	}
1418	if res == nil {
1419		t.Fatal("expected response, got nil")
1420	}
1421	if want, got := 500, res.StatusCode; want != got {
1422		t.Fatalf("expected status code = %d, got %d", want, got)
1423	}
1424	// Retry should not have triggered additional requests because
1425	if numFailedReqs != 1 {
1426		t.Errorf("expected %d failed requests; got: %d", 1, numFailedReqs)
1427	}
1428}
1429
1430// failingBody will return an error when json.Marshal is called on it.
1431type failingBody struct{}
1432
1433// MarshalJSON implements the json.Marshaler interface and always returns an error.
1434func (fb failingBody) MarshalJSON() ([]byte, error) {
1435	return nil, errors.New("failing to marshal")
1436}
1437
1438func TestPerformRequestWithSetBodyError(t *testing.T) {
1439	client, err := NewClient()
1440	if err != nil {
1441		t.Fatal(err)
1442	}
1443	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1444		Method: "GET",
1445		Path:   "/",
1446		Body:   failingBody{},
1447	})
1448	if err == nil {
1449		t.Fatal("expected error")
1450	}
1451	if res != nil {
1452		t.Fatal("expected no response")
1453	}
1454}
1455
1456// sleepingTransport will sleep before doing a request.
1457type sleepingTransport struct {
1458	timeout time.Duration
1459}
1460
1461// RoundTrip implements a "sleepy" transport.
1462func (tr *sleepingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
1463	time.Sleep(tr.timeout)
1464	return http.DefaultTransport.RoundTrip(r)
1465}
1466
1467func TestPerformRequestWithCancel(t *testing.T) {
1468	tr := &sleepingTransport{timeout: 3 * time.Second}
1469	httpClient := &http.Client{Transport: tr}
1470
1471	client, err := NewSimpleClient(SetHttpClient(httpClient), SetMaxRetries(0))
1472	if err != nil {
1473		t.Fatal(err)
1474	}
1475
1476	type result struct {
1477		res *Response
1478		err error
1479	}
1480	ctx, cancel := context.WithCancel(context.Background())
1481	defer cancel()
1482
1483	resc := make(chan result, 1)
1484	go func() {
1485		res, err := client.PerformRequest(ctx, PerformRequestOptions{
1486			Method: "GET",
1487			Path:   "/",
1488		})
1489		resc <- result{res: res, err: err}
1490	}()
1491	select {
1492	case <-time.After(1 * time.Second):
1493		cancel()
1494	case res := <-resc:
1495		t.Fatalf("expected response before cancel, got %v", res)
1496	case <-ctx.Done():
1497		t.Fatalf("expected no early termination, got ctx.Done(): %v", ctx.Err())
1498	}
1499	err = ctx.Err()
1500	if err != context.Canceled {
1501		t.Fatalf("expected error context.Canceled, got: %v", err)
1502	}
1503}
1504
1505func TestPerformRequestWithTimeout(t *testing.T) {
1506	tr := &sleepingTransport{timeout: 3 * time.Second}
1507	httpClient := &http.Client{Transport: tr}
1508
1509	client, err := NewSimpleClient(SetHttpClient(httpClient), SetMaxRetries(0))
1510	if err != nil {
1511		t.Fatal(err)
1512	}
1513
1514	type result struct {
1515		res *Response
1516		err error
1517	}
1518	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
1519	defer cancel()
1520
1521	resc := make(chan result, 1)
1522	go func() {
1523		res, err := client.PerformRequest(ctx, PerformRequestOptions{
1524			Method: "GET",
1525			Path:   "/",
1526		})
1527		resc <- result{res: res, err: err}
1528	}()
1529	select {
1530	case res := <-resc:
1531		t.Fatalf("expected timeout before response, got %v", res)
1532	case <-ctx.Done():
1533		err := ctx.Err()
1534		if err != context.DeadlineExceeded {
1535			t.Fatalf("expected error context.DeadlineExceeded, got: %v", err)
1536		}
1537	}
1538}
1539
1540func TestPerformRequestWithCustomHTTPHeadersOnRequest(t *testing.T) {
1541	client, err := NewClient()
1542	if err != nil {
1543		t.Fatal(err)
1544	}
1545	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1546		Method: "GET",
1547		Path:   "/_tasks",
1548		Params: url.Values{
1549			"pretty": []string{"true"},
1550		},
1551		Headers: http.Header{
1552			"X-Opaque-Id": []string{"123456"},
1553		},
1554	})
1555	if err != nil {
1556		t.Fatal(err)
1557	}
1558	if res == nil {
1559		t.Fatal("expected response to be != nil")
1560	}
1561	if want, have := "123456", res.Header.Get("X-Opaque-Id"); want != have {
1562		t.Fatalf("want response header X-Opaque-Id=%q, have %q", want, have)
1563	}
1564}
1565
1566func TestPerformRequestWithCustomHTTPHeadersOnClient(t *testing.T) {
1567	client, err := NewClient(SetHeaders(http.Header{
1568		"Custom-Id": []string{"olivere"},
1569	}))
1570	if err != nil {
1571		t.Fatal(err)
1572	}
1573	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1574		Method: "GET",
1575		Path:   "/_tasks",
1576		Params: url.Values{
1577			"pretty": []string{"true"},
1578		},
1579		Headers: http.Header{
1580			"X-Opaque-Id": []string{"123456"},
1581		},
1582	})
1583	if err != nil {
1584		t.Fatal(err)
1585	}
1586	if res == nil {
1587		t.Fatal("expected response to be != nil")
1588	}
1589	// Request-level headers have preference
1590	if want, have := "123456", res.Header.Get("X-Opaque-Id"); want != have {
1591		t.Fatalf("want response header X-Opaque-Id=%q, have %q", want, have)
1592	}
1593}
1594
1595func TestPerformRequestWithCustomHTTPHeadersPriority(t *testing.T) {
1596	var req *http.Request
1597	h := func(r *http.Request) (*http.Response, error) {
1598		req = new(http.Request)
1599		*req = *r
1600		return &http.Response{Request: r, StatusCode: http.StatusOK, Body: http.NoBody}, nil
1601	}
1602	tr := &failingTransport{path: "/", fail: h}
1603	httpClient := &http.Client{Transport: tr}
1604
1605	client, err := NewClient(SetHttpClient(httpClient), SetHeaders(http.Header{
1606		"Custom-Id":   []string{"olivere"},
1607		"X-Opaque-Id": []string{"sandra"}, // <- will be overridden by request-level header
1608	}), SetSniff(false), SetHealthcheck(false))
1609	if err != nil {
1610		t.Fatal(err)
1611	}
1612	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1613		Method: "GET",
1614		Path:   "/",
1615		Params: url.Values{
1616			"pretty": []string{"true"},
1617		},
1618		Headers: http.Header{
1619			"X-Opaque-Id": []string{"123456"}, // <- request-level has preference
1620			"X-Somewhat":  []string{"somewhat"},
1621		},
1622	})
1623	if err != nil {
1624		t.Fatal(err)
1625	}
1626	if res == nil {
1627		t.Fatal("expected response to be != nil")
1628	}
1629	if req == nil {
1630		t.Fatal("expected to record HTTP request")
1631	}
1632	if want, have := "123456", req.Header.Get("X-Opaque-Id"); want != have {
1633		t.Fatalf("want HTTP header X-Opaque-Id=%q, have %q", want, have)
1634	}
1635	if want, have := "olivere", req.Header.Get("Custom-Id"); want != have {
1636		t.Fatalf("want HTTP header Custom-Id=%q, have %q", want, have)
1637	}
1638	if want, have := "somewhat", req.Header.Get("X-Somewhat"); want != have {
1639		t.Fatalf("want HTTP header X-Somewhat=%q, have %q", want, have)
1640	}
1641}
1642
1643// -- Compression --
1644
1645// Notice that the trace log does always print "Accept-Encoding: gzip"
1646// regardless of whether compression is enabled or not. This is because
1647// of the underlying "httputil.DumpRequestOut".
1648//
1649// Use a real HTTP proxy/recorder to convince yourself that
1650// "Accept-Encoding: gzip" is NOT sent when DisableCompression
1651// is set to true.
1652//
1653// See also:
1654// https://groups.google.com/forum/#!topic/golang-nuts/ms8QNCzew8Q
1655
1656func TestPerformRequestWithCompressionEnabled(t *testing.T) {
1657	testPerformRequestWithCompression(t, &http.Client{
1658		Transport: &http.Transport{
1659			DisableCompression: true,
1660		},
1661	})
1662}
1663
1664func TestPerformRequestWithCompressionDisabled(t *testing.T) {
1665	testPerformRequestWithCompression(t, &http.Client{
1666		Transport: &http.Transport{
1667			DisableCompression: false,
1668		},
1669	})
1670}
1671
1672func testPerformRequestWithCompression(t *testing.T, hc *http.Client) {
1673	client, err := NewClient(SetHttpClient(hc), SetSniff(false))
1674	if err != nil {
1675		t.Fatal(err)
1676	}
1677	res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
1678		Method: "GET",
1679		Path:   "/",
1680	})
1681	if err != nil {
1682		t.Fatal(err)
1683	}
1684	if res == nil {
1685		t.Fatal("expected response to be != nil")
1686	}
1687
1688	ret := new(PingResult)
1689	if err := json.Unmarshal(res.Body, ret); err != nil {
1690		t.Fatalf("expected no error on decode; got: %v", err)
1691	}
1692	if ret.ClusterName == "" {
1693		t.Errorf("expected cluster name; got: %q", ret.ClusterName)
1694	}
1695}
1696