1// Copyright 2012-present Oliver Eilhard. All rights reserved. 2// Use of this source code is governed by a MIT-license. 3// See http://olivere.mit-license.org/license.txt for details. 4 5package elastic 6 7import ( 8 "context" 9 "encoding/json" 10 "io/ioutil" 11 "net/http" 12 "strings" 13 "testing" 14) 15 16func TestUpdateByQueryBuildURL(t *testing.T) { 17 client := setupTestClient(t) 18 19 tests := []struct { 20 Indices []string 21 Types []string 22 Expected string 23 ExpectErr bool 24 }{ 25 { 26 []string{}, 27 []string{}, 28 "", 29 true, 30 }, 31 { 32 []string{"index1"}, 33 []string{}, 34 "/index1/_update_by_query", 35 false, 36 }, 37 { 38 []string{"index1", "index2"}, 39 []string{}, 40 "/index1%2Cindex2/_update_by_query", 41 false, 42 }, 43 { 44 []string{}, 45 []string{"type1"}, 46 "", 47 true, 48 }, 49 { 50 []string{"index1"}, 51 []string{"type1"}, 52 "/index1/type1/_update_by_query", 53 false, 54 }, 55 { 56 []string{"index1", "index2"}, 57 []string{"type1", "type2"}, 58 "/index1%2Cindex2/type1%2Ctype2/_update_by_query", 59 false, 60 }, 61 } 62 63 for i, test := range tests { 64 builder := client.UpdateByQuery().Index(test.Indices...).Type(test.Types...) 65 err := builder.Validate() 66 if err != nil { 67 if !test.ExpectErr { 68 t.Errorf("case #%d: %v", i+1, err) 69 continue 70 } 71 } else { 72 // err == nil 73 if test.ExpectErr { 74 t.Errorf("case #%d: expected error", i+1) 75 continue 76 } 77 path, _, _ := builder.buildURL() 78 if path != test.Expected { 79 t.Errorf("case #%d: expected %q; got: %q", i+1, test.Expected, path) 80 } 81 } 82 } 83} 84 85func TestUpdateByQueryBodyWithQuery(t *testing.T) { 86 client := setupTestClient(t) 87 out, err := client.UpdateByQuery().Query(NewTermQuery("user", "olivere")).getBody() 88 if err != nil { 89 t.Fatal(err) 90 } 91 b, err := json.Marshal(out) 92 if err != nil { 93 t.Fatal(err) 94 } 95 got := string(b) 96 want := `{"query":{"term":{"user":"olivere"}}}` 97 if got != want { 98 t.Fatalf("\ngot %s\nwant %s", got, want) 99 } 100} 101 102func TestUpdateByQueryBodyWithQueryAndScript(t *testing.T) { 103 client := setupTestClient(t) 104 out, err := client.UpdateByQuery(). 105 Query(NewTermQuery("user", "olivere")). 106 Script(NewScriptInline("ctx._source.likes++")). 107 getBody() 108 if err != nil { 109 t.Fatal(err) 110 } 111 b, err := json.Marshal(out) 112 if err != nil { 113 t.Fatal(err) 114 } 115 got := string(b) 116 want := `{"query":{"term":{"user":"olivere"}},"script":{"source":"ctx._source.likes++"}}` 117 if got != want { 118 t.Fatalf("\ngot %s\nwant %s", got, want) 119 } 120} 121 122func TestUpdateByQuery(t *testing.T) { 123 client := setupTestClientAndCreateIndexAndAddDocs(t) //, SetTraceLog(log.New(os.Stdout, "", 0))) 124 esversion, err := client.ElasticsearchVersion(DefaultURL) 125 if err != nil { 126 t.Fatal(err) 127 } 128 if esversion < "2.3.0" { 129 t.Skipf("Elasticsearch %v does not support update-by-query yet", esversion) 130 } 131 132 sourceCount, err := client.Count(testIndexName).Do(context.TODO()) 133 if err != nil { 134 t.Fatal(err) 135 } 136 if sourceCount <= 0 { 137 t.Fatalf("expected more than %d documents; got: %d", 0, sourceCount) 138 } 139 140 res, err := client.UpdateByQuery(testIndexName).ProceedOnVersionConflict().Do(context.TODO()) 141 if err != nil { 142 t.Fatal(err) 143 } 144 if res == nil { 145 t.Fatal("response is nil") 146 } 147 if res.Updated != sourceCount { 148 t.Fatalf("expected %d; got: %d", sourceCount, res.Updated) 149 } 150} 151 152func TestUpdateByQueryAsync(t *testing.T) { 153 client := setupTestClientAndCreateIndexAndAddDocs(t) //, SetTraceLog(log.New(os.Stdout, "", 0))) 154 esversion, err := client.ElasticsearchVersion(DefaultURL) 155 if err != nil { 156 t.Fatal(err) 157 } 158 if esversion < "2.3.0" { 159 t.Skipf("Elasticsearch %v does not support update-by-query yet", esversion) 160 } 161 162 sourceCount, err := client.Count(testIndexName).Do(context.TODO()) 163 if err != nil { 164 t.Fatal(err) 165 } 166 if sourceCount <= 0 { 167 t.Fatalf("expected more than %d documents; got: %d", 0, sourceCount) 168 } 169 170 res, err := client.UpdateByQuery(testIndexName). 171 ProceedOnVersionConflict(). 172 Slices("auto"). 173 DoAsync(context.TODO()) 174 if err != nil { 175 t.Fatal(err) 176 } 177 if res == nil { 178 t.Fatal("expected result != nil") 179 } 180 if res.TaskId == "" { 181 t.Errorf("expected a task id, got %+v", res) 182 } 183 184 tasksGetTask := client.TasksGetTask() 185 taskStatus, err := tasksGetTask.TaskId(res.TaskId).Do(context.TODO()) 186 if err != nil { 187 t.Fatal(err) 188 } 189 if taskStatus == nil { 190 t.Fatal("expected task status result != nil") 191 } 192} 193 194func TestUpdateByQueryConflict(t *testing.T) { 195 fail := func(r *http.Request) (*http.Response, error) { 196 body := `{ 197 "took": 3, 198 "timed_out": false, 199 "total": 1, 200 "updated": 0, 201 "deleted": 0, 202 "batches": 1, 203 "version_conflicts": 1, 204 "noops": 0, 205 "retries": { 206 "bulk": 0, 207 "search": 0 208 }, 209 "throttled_millis": 0, 210 "requests_per_second": -1, 211 "throttled_until_millis": 0, 212 "failures": [ 213 { 214 "index": "a", 215 "type": "_doc", 216 "id": "yjsmdGsBm363wfQmSbhj", 217 "cause": { 218 "type": "version_conflict_engine_exception", 219 "reason": "[_doc][yjsmdGsBm363wfQmSbhj]: version conflict, current version [4] is different than the one provided [3]", 220 "index_uuid": "1rmL3mt8TimwshF-M1DxdQ", 221 "shard": "0", 222 "index": "a" 223 }, 224 "status": 409 225 } 226 ] 227 }` 228 return &http.Response{ 229 StatusCode: http.StatusConflict, 230 Body: ioutil.NopCloser(strings.NewReader(body)), 231 ContentLength: int64(len(body)), 232 }, nil 233 } 234 235 // Run against a failing endpoint and see if PerformRequest 236 // retries correctly. 237 tr := &failingTransport{path: "/example/_update_by_query", fail: fail} 238 httpClient := &http.Client{Transport: tr} 239 client, err := NewClient(SetHttpClient(httpClient), SetHealthcheck(false)) 240 if err != nil { 241 t.Fatal(err) 242 } 243 res, err := client.UpdateByQuery("example").ProceedOnVersionConflict().Do(context.TODO()) 244 if err != nil { 245 t.Fatalf("mock should not be failed %+v", err) 246 } 247 if res.Took != 3 { 248 t.Errorf("took should be 3, got %d", res.Took) 249 } 250 if res.Total != 1 { 251 t.Errorf("total should be 1, got %d", res.Total) 252 } 253 if res.VersionConflicts != 1 { 254 t.Errorf("total should be 1, got %d", res.VersionConflicts) 255 } 256 if len(res.Failures) != 1 { 257 t.Errorf("failures length should be 1, got %d", len(res.Failures)) 258 } 259 expected := bulkIndexByScrollResponseFailure{Index: "a", Type: "_doc", Id: "yjsmdGsBm363wfQmSbhj", Status: 409} 260 if res.Failures[0] != expected { 261 t.Errorf("failures should be %+v, got %+v", expected, res.Failures[0]) 262 } 263} 264