1// Copyright 2012-present Oliver Eilhard. All rights reserved.
2// Use of this source code is governed by a MIT-license.
3// See http://olivere.mit-license.org/license.txt for details.
4
5package elastic
6
7import (
8	"context"
9	"encoding/json"
10	"io/ioutil"
11	"net/http"
12	"strings"
13	"testing"
14)
15
16func TestUpdateByQueryBuildURL(t *testing.T) {
17	client := setupTestClient(t)
18
19	tests := []struct {
20		Indices   []string
21		Types     []string
22		Expected  string
23		ExpectErr bool
24	}{
25		{
26			[]string{},
27			[]string{},
28			"",
29			true,
30		},
31		{
32			[]string{"index1"},
33			[]string{},
34			"/index1/_update_by_query",
35			false,
36		},
37		{
38			[]string{"index1", "index2"},
39			[]string{},
40			"/index1%2Cindex2/_update_by_query",
41			false,
42		},
43		{
44			[]string{},
45			[]string{"type1"},
46			"",
47			true,
48		},
49		{
50			[]string{"index1"},
51			[]string{"type1"},
52			"/index1/type1/_update_by_query",
53			false,
54		},
55		{
56			[]string{"index1", "index2"},
57			[]string{"type1", "type2"},
58			"/index1%2Cindex2/type1%2Ctype2/_update_by_query",
59			false,
60		},
61	}
62
63	for i, test := range tests {
64		builder := client.UpdateByQuery().Index(test.Indices...).Type(test.Types...)
65		err := builder.Validate()
66		if err != nil {
67			if !test.ExpectErr {
68				t.Errorf("case #%d: %v", i+1, err)
69				continue
70			}
71		} else {
72			// err == nil
73			if test.ExpectErr {
74				t.Errorf("case #%d: expected error", i+1)
75				continue
76			}
77			path, _, _ := builder.buildURL()
78			if path != test.Expected {
79				t.Errorf("case #%d: expected %q; got: %q", i+1, test.Expected, path)
80			}
81		}
82	}
83}
84
85func TestUpdateByQueryBodyWithQuery(t *testing.T) {
86	client := setupTestClient(t)
87	out, err := client.UpdateByQuery().Query(NewTermQuery("user", "olivere")).getBody()
88	if err != nil {
89		t.Fatal(err)
90	}
91	b, err := json.Marshal(out)
92	if err != nil {
93		t.Fatal(err)
94	}
95	got := string(b)
96	want := `{"query":{"term":{"user":"olivere"}}}`
97	if got != want {
98		t.Fatalf("\ngot  %s\nwant %s", got, want)
99	}
100}
101
102func TestUpdateByQueryBodyWithQueryAndScript(t *testing.T) {
103	client := setupTestClient(t)
104	out, err := client.UpdateByQuery().
105		Query(NewTermQuery("user", "olivere")).
106		Script(NewScriptInline("ctx._source.likes++")).
107		getBody()
108	if err != nil {
109		t.Fatal(err)
110	}
111	b, err := json.Marshal(out)
112	if err != nil {
113		t.Fatal(err)
114	}
115	got := string(b)
116	want := `{"query":{"term":{"user":"olivere"}},"script":{"source":"ctx._source.likes++"}}`
117	if got != want {
118		t.Fatalf("\ngot  %s\nwant %s", got, want)
119	}
120}
121
122func TestUpdateByQuery(t *testing.T) {
123	client := setupTestClientAndCreateIndexAndAddDocs(t) //, SetTraceLog(log.New(os.Stdout, "", 0)))
124	esversion, err := client.ElasticsearchVersion(DefaultURL)
125	if err != nil {
126		t.Fatal(err)
127	}
128	if esversion < "2.3.0" {
129		t.Skipf("Elasticsearch %v does not support update-by-query yet", esversion)
130	}
131
132	sourceCount, err := client.Count(testIndexName).Do(context.TODO())
133	if err != nil {
134		t.Fatal(err)
135	}
136	if sourceCount <= 0 {
137		t.Fatalf("expected more than %d documents; got: %d", 0, sourceCount)
138	}
139
140	res, err := client.UpdateByQuery(testIndexName).ProceedOnVersionConflict().Do(context.TODO())
141	if err != nil {
142		t.Fatal(err)
143	}
144	if res == nil {
145		t.Fatal("response is nil")
146	}
147	if res.Updated != sourceCount {
148		t.Fatalf("expected %d; got: %d", sourceCount, res.Updated)
149	}
150}
151
152func TestUpdateByQueryAsync(t *testing.T) {
153	client := setupTestClientAndCreateIndexAndAddDocs(t) //, SetTraceLog(log.New(os.Stdout, "", 0)))
154	esversion, err := client.ElasticsearchVersion(DefaultURL)
155	if err != nil {
156		t.Fatal(err)
157	}
158	if esversion < "2.3.0" {
159		t.Skipf("Elasticsearch %v does not support update-by-query yet", esversion)
160	}
161
162	sourceCount, err := client.Count(testIndexName).Do(context.TODO())
163	if err != nil {
164		t.Fatal(err)
165	}
166	if sourceCount <= 0 {
167		t.Fatalf("expected more than %d documents; got: %d", 0, sourceCount)
168	}
169
170	res, err := client.UpdateByQuery(testIndexName).
171		ProceedOnVersionConflict().
172		Slices("auto").
173		DoAsync(context.TODO())
174	if err != nil {
175		t.Fatal(err)
176	}
177	if res == nil {
178		t.Fatal("expected result != nil")
179	}
180	if res.TaskId == "" {
181		t.Errorf("expected a task id, got %+v", res)
182	}
183
184	tasksGetTask := client.TasksGetTask()
185	taskStatus, err := tasksGetTask.TaskId(res.TaskId).Do(context.TODO())
186	if err != nil {
187		t.Fatal(err)
188	}
189	if taskStatus == nil {
190		t.Fatal("expected task status result != nil")
191	}
192}
193
194func TestUpdateByQueryConflict(t *testing.T) {
195	fail := func(r *http.Request) (*http.Response, error) {
196		body := `{
197			"took": 3,
198			"timed_out": false,
199			"total": 1,
200			"updated": 0,
201			"deleted": 0,
202			"batches": 1,
203			"version_conflicts": 1,
204			"noops": 0,
205			"retries": {
206			  "bulk": 0,
207			  "search": 0
208			},
209			"throttled_millis": 0,
210			"requests_per_second": -1,
211			"throttled_until_millis": 0,
212			"failures": [
213			  {
214				"index": "a",
215				"type": "_doc",
216				"id": "yjsmdGsBm363wfQmSbhj",
217				"cause": {
218				  "type": "version_conflict_engine_exception",
219				  "reason": "[_doc][yjsmdGsBm363wfQmSbhj]: version conflict, current version [4] is different than the one provided [3]",
220				  "index_uuid": "1rmL3mt8TimwshF-M1DxdQ",
221				  "shard": "0",
222				  "index": "a"
223				},
224				"status": 409
225			  }
226			]
227		   }`
228		return &http.Response{
229			StatusCode:    http.StatusConflict,
230			Body:          ioutil.NopCloser(strings.NewReader(body)),
231			ContentLength: int64(len(body)),
232		}, nil
233	}
234
235	// Run against a failing endpoint and see if PerformRequest
236	// retries correctly.
237	tr := &failingTransport{path: "/example/_update_by_query", fail: fail}
238	httpClient := &http.Client{Transport: tr}
239	client, err := NewClient(SetHttpClient(httpClient), SetHealthcheck(false))
240	if err != nil {
241		t.Fatal(err)
242	}
243	res, err := client.UpdateByQuery("example").ProceedOnVersionConflict().Do(context.TODO())
244	if err != nil {
245		t.Fatalf("mock should not be failed %+v", err)
246	}
247	if res.Took != 3 {
248		t.Errorf("took should be 3, got %d", res.Took)
249	}
250	if res.Total != 1 {
251		t.Errorf("total should be 1, got %d", res.Total)
252	}
253	if res.VersionConflicts != 1 {
254		t.Errorf("total should be 1, got %d", res.VersionConflicts)
255	}
256	if len(res.Failures) != 1 {
257		t.Errorf("failures length should be 1, got %d", len(res.Failures))
258	}
259	expected := bulkIndexByScrollResponseFailure{Index: "a", Type: "_doc", Id: "yjsmdGsBm363wfQmSbhj", Status: 409}
260	if res.Failures[0] != expected {
261		t.Errorf("failures should be %+v, got %+v", expected, res.Failures[0])
262	}
263}
264