1// Copyright 2012-present Oliver Eilhard. All rights reserved.
2// Use of this source code is governed by a MIT-license.
3// See http://olivere.mit-license.org/license.txt for details.
4
5package elastic
6
7import (
8	"context"
9	"encoding/json"
10	"testing"
11)
12
13func TestUpdateByQueryBuildURL(t *testing.T) {
14	client := setupTestClient(t)
15
16	tests := []struct {
17		Indices   []string
18		Types     []string
19		Expected  string
20		ExpectErr bool
21	}{
22		{
23			[]string{},
24			[]string{},
25			"",
26			true,
27		},
28		{
29			[]string{"index1"},
30			[]string{},
31			"/index1/_update_by_query",
32			false,
33		},
34		{
35			[]string{"index1", "index2"},
36			[]string{},
37			"/index1%2Cindex2/_update_by_query",
38			false,
39		},
40		{
41			[]string{},
42			[]string{"type1"},
43			"",
44			true,
45		},
46		{
47			[]string{"index1"},
48			[]string{"type1"},
49			"/index1/type1/_update_by_query",
50			false,
51		},
52		{
53			[]string{"index1", "index2"},
54			[]string{"type1", "type2"},
55			"/index1%2Cindex2/type1%2Ctype2/_update_by_query",
56			false,
57		},
58	}
59
60	for i, test := range tests {
61		builder := client.UpdateByQuery().Index(test.Indices...).Type(test.Types...)
62		err := builder.Validate()
63		if err != nil {
64			if !test.ExpectErr {
65				t.Errorf("case #%d: %v", i+1, err)
66				continue
67			}
68		} else {
69			// err == nil
70			if test.ExpectErr {
71				t.Errorf("case #%d: expected error", i+1)
72				continue
73			}
74			path, _, _ := builder.buildURL()
75			if path != test.Expected {
76				t.Errorf("case #%d: expected %q; got: %q", i+1, test.Expected, path)
77			}
78		}
79	}
80}
81
82func TestUpdateByQueryBodyWithQuery(t *testing.T) {
83	client := setupTestClient(t)
84	out, err := client.UpdateByQuery().Query(NewTermQuery("user", "olivere")).getBody()
85	if err != nil {
86		t.Fatal(err)
87	}
88	b, err := json.Marshal(out)
89	if err != nil {
90		t.Fatal(err)
91	}
92	got := string(b)
93	want := `{"query":{"term":{"user":"olivere"}}}`
94	if got != want {
95		t.Fatalf("\ngot  %s\nwant %s", got, want)
96	}
97}
98
99func TestUpdateByQueryBodyWithQueryAndScript(t *testing.T) {
100	client := setupTestClient(t)
101	out, err := client.UpdateByQuery().
102		Query(NewTermQuery("user", "olivere")).
103		Script(NewScriptInline("ctx._source.likes++")).
104		getBody()
105	if err != nil {
106		t.Fatal(err)
107	}
108	b, err := json.Marshal(out)
109	if err != nil {
110		t.Fatal(err)
111	}
112	got := string(b)
113	want := `{"query":{"term":{"user":"olivere"}},"script":{"source":"ctx._source.likes++"}}`
114	if got != want {
115		t.Fatalf("\ngot  %s\nwant %s", got, want)
116	}
117}
118
119func TestUpdateByQuery(t *testing.T) {
120	client := setupTestClientAndCreateIndexAndAddDocs(t) //, SetTraceLog(log.New(os.Stdout, "", 0)))
121	esversion, err := client.ElasticsearchVersion(DefaultURL)
122	if err != nil {
123		t.Fatal(err)
124	}
125	if esversion < "2.3.0" {
126		t.Skipf("Elasticsearch %v does not support update-by-query yet", esversion)
127	}
128
129	sourceCount, err := client.Count(testIndexName).Do(context.TODO())
130	if err != nil {
131		t.Fatal(err)
132	}
133	if sourceCount <= 0 {
134		t.Fatalf("expected more than %d documents; got: %d", 0, sourceCount)
135	}
136
137	res, err := client.UpdateByQuery(testIndexName).ProceedOnVersionConflict().Do(context.TODO())
138	if err != nil {
139		t.Fatal(err)
140	}
141	if res == nil {
142		t.Fatal("response is nil")
143	}
144	if res.Updated != sourceCount {
145		t.Fatalf("expected %d; got: %d", sourceCount, res.Updated)
146	}
147}
148
149func TestUpdateByQueryAsync(t *testing.T) {
150	client := setupTestClientAndCreateIndexAndAddDocs(t) //, SetTraceLog(log.New(os.Stdout, "", 0)))
151	esversion, err := client.ElasticsearchVersion(DefaultURL)
152	if err != nil {
153		t.Fatal(err)
154	}
155	if esversion < "2.3.0" {
156		t.Skipf("Elasticsearch %v does not support update-by-query yet", esversion)
157	}
158
159	sourceCount, err := client.Count(testIndexName).Do(context.TODO())
160	if err != nil {
161		t.Fatal(err)
162	}
163	if sourceCount <= 0 {
164		t.Fatalf("expected more than %d documents; got: %d", 0, sourceCount)
165	}
166
167	res, err := client.UpdateByQuery(testIndexName).
168		ProceedOnVersionConflict().
169		Slices("auto").
170		DoAsync(context.TODO())
171	if err != nil {
172		t.Fatal(err)
173	}
174	if res == nil {
175		t.Fatal("expected result != nil")
176	}
177	if res.TaskId == "" {
178		t.Errorf("expected a task id, got %+v", res)
179	}
180
181	tasksGetTask := client.TasksGetTask()
182	taskStatus, err := tasksGetTask.TaskId(res.TaskId).Do(context.TODO())
183	if err != nil {
184		t.Fatal(err)
185	}
186	if taskStatus == nil {
187		t.Fatal("expected task status result != nil")
188	}
189}
190