1// Copyright 2012-present Oliver Eilhard. All rights reserved.
2// Use of this source code is governed by a MIT-license.
3// See http://olivere.mit-license.org/license.txt for details.
4
5package elastic
6
7import (
8	"context"
9	"fmt"
10	"io"
11	"net/http"
12	"net/url"
13	"strings"
14	"sync"
15
16	"github.com/olivere/elastic/v7/uritemplates"
17)
18
19const (
20	// DefaultScrollKeepAlive is the default time a scroll cursor will be kept alive.
21	DefaultScrollKeepAlive = "5m"
22)
23
24// ScrollService iterates over pages of search results from Elasticsearch.
25type ScrollService struct {
26	client  *Client
27	retrier Retrier
28
29	pretty     *bool       // pretty format the returned JSON response
30	human      *bool       // return human readable values for statistics
31	errorTrace *bool       // include the stack trace of returned errors
32	filterPath []string    // list of filters used to reduce the response
33	headers    http.Header // custom request-level HTTP headers
34
35	indices           []string
36	types             []string
37	keepAlive         string
38	body              interface{}
39	ss                *SearchSource
40	size              *int
41	routing           string
42	preference        string
43	ignoreUnavailable *bool
44	allowNoIndices    *bool
45	expandWildcards   string
46	maxResponseSize   int64
47
48	mu       sync.RWMutex
49	scrollId string
50}
51
52// NewScrollService initializes and returns a new ScrollService.
53func NewScrollService(client *Client) *ScrollService {
54	builder := &ScrollService{
55		client:    client,
56		ss:        NewSearchSource(),
57		keepAlive: DefaultScrollKeepAlive,
58	}
59	return builder
60}
61
62// Pretty tells Elasticsearch whether to return a formatted JSON response.
63func (s *ScrollService) Pretty(pretty bool) *ScrollService {
64	s.pretty = &pretty
65	return s
66}
67
68// Human specifies whether human readable values should be returned in
69// the JSON response, e.g. "7.5mb".
70func (s *ScrollService) Human(human bool) *ScrollService {
71	s.human = &human
72	return s
73}
74
75// ErrorTrace specifies whether to include the stack trace of returned errors.
76func (s *ScrollService) ErrorTrace(errorTrace bool) *ScrollService {
77	s.errorTrace = &errorTrace
78	return s
79}
80
81// FilterPath specifies a list of filters used to reduce the response.
82func (s *ScrollService) FilterPath(filterPath ...string) *ScrollService {
83	s.filterPath = filterPath
84	return s
85}
86
87// Header adds a header to the request.
88func (s *ScrollService) Header(name string, value string) *ScrollService {
89	if s.headers == nil {
90		s.headers = http.Header{}
91	}
92	s.headers.Add(name, value)
93	return s
94}
95
96// Headers specifies the headers of the request.
97func (s *ScrollService) Headers(headers http.Header) *ScrollService {
98	s.headers = headers
99	return s
100}
101
102// Retrier allows to set specific retry logic for this ScrollService.
103// If not specified, it will use the client's default retrier.
104func (s *ScrollService) Retrier(retrier Retrier) *ScrollService {
105	s.retrier = retrier
106	return s
107}
108
109// Index sets the name of one or more indices to iterate over.
110func (s *ScrollService) Index(indices ...string) *ScrollService {
111	if s.indices == nil {
112		s.indices = make([]string, 0)
113	}
114	s.indices = append(s.indices, indices...)
115	return s
116}
117
118// Type sets the name of one or more types to iterate over.
119//
120// Deprecated: Types are in the process of being removed. Instead of using a type, prefer to
121// filter on a field on the document.
122func (s *ScrollService) Type(types ...string) *ScrollService {
123	if s.types == nil {
124		s.types = make([]string, 0)
125	}
126	s.types = append(s.types, types...)
127	return s
128}
129
130// Scroll is an alias for KeepAlive, the time to keep
131// the cursor alive (e.g. "5m" for 5 minutes).
132func (s *ScrollService) Scroll(keepAlive string) *ScrollService {
133	s.keepAlive = keepAlive
134	return s
135}
136
137// KeepAlive sets the maximum time after which the cursor will expire.
138// It is "5m" by default.
139func (s *ScrollService) KeepAlive(keepAlive string) *ScrollService {
140	s.keepAlive = keepAlive
141	return s
142}
143
144// Size specifies the number of documents Elasticsearch should return
145// from each shard, per page.
146func (s *ScrollService) Size(size int) *ScrollService {
147	s.size = &size
148	return s
149}
150
151// Highlight allows to highlight search results on one or more fields
152func (s *ScrollService) Highlight(highlight *Highlight) *ScrollService {
153	s.ss = s.ss.Highlight(highlight)
154	return s
155}
156
157// Body sets the raw body to send to Elasticsearch. This can be e.g. a string,
158// a map[string]interface{} or anything that can be serialized into JSON.
159// Notice that setting the body disables the use of SearchSource and many
160// other properties of the ScanService.
161func (s *ScrollService) Body(body interface{}) *ScrollService {
162	s.body = body
163	return s
164}
165
166// SearchSource sets the search source builder to use with this iterator.
167// Notice that only a certain number of properties can be used when scrolling,
168// e.g. query and sorting.
169func (s *ScrollService) SearchSource(searchSource *SearchSource) *ScrollService {
170	s.ss = searchSource
171	if s.ss == nil {
172		s.ss = NewSearchSource()
173	}
174	return s
175}
176
177// Query sets the query to perform, e.g. a MatchAllQuery.
178func (s *ScrollService) Query(query Query) *ScrollService {
179	s.ss = s.ss.Query(query)
180	return s
181}
182
183// PostFilter is executed as the last filter. It only affects the
184// search hits but not facets. See
185// https://www.elastic.co/guide/en/elasticsearch/reference/7.0/search-request-post-filter.html
186// for details.
187func (s *ScrollService) PostFilter(postFilter Query) *ScrollService {
188	s.ss = s.ss.PostFilter(postFilter)
189	return s
190}
191
192// Slice allows slicing the scroll request into several batches.
193// This is supported in Elasticsearch 5.0 or later.
194// See https://www.elastic.co/guide/en/elasticsearch/reference/7.0/search-request-scroll.html#sliced-scroll
195// for details.
196func (s *ScrollService) Slice(sliceQuery Query) *ScrollService {
197	s.ss = s.ss.Slice(sliceQuery)
198	return s
199}
200
201// FetchSource indicates whether the response should contain the stored
202// _source for every hit.
203func (s *ScrollService) FetchSource(fetchSource bool) *ScrollService {
204	s.ss = s.ss.FetchSource(fetchSource)
205	return s
206}
207
208// FetchSourceContext indicates how the _source should be fetched.
209func (s *ScrollService) FetchSourceContext(fetchSourceContext *FetchSourceContext) *ScrollService {
210	s.ss = s.ss.FetchSourceContext(fetchSourceContext)
211	return s
212}
213
214// Version can be set to true to return a version for each search hit.
215// See https://www.elastic.co/guide/en/elasticsearch/reference/7.0/search-request-version.html.
216func (s *ScrollService) Version(version bool) *ScrollService {
217	s.ss = s.ss.Version(version)
218	return s
219}
220
221// Sort adds a sort order. This can have negative effects on the performance
222// of the scroll operation as Elasticsearch needs to sort first.
223func (s *ScrollService) Sort(field string, ascending bool) *ScrollService {
224	s.ss = s.ss.Sort(field, ascending)
225	return s
226}
227
228// SortWithInfo specifies a sort order. Notice that sorting can have a
229// negative impact on scroll performance.
230func (s *ScrollService) SortWithInfo(info SortInfo) *ScrollService {
231	s.ss = s.ss.SortWithInfo(info)
232	return s
233}
234
235// SortBy specifies a sort order. Notice that sorting can have a
236// negative impact on scroll performance.
237func (s *ScrollService) SortBy(sorter ...Sorter) *ScrollService {
238	s.ss = s.ss.SortBy(sorter...)
239	return s
240}
241
242// TrackTotalHits controls if the total hit count for the query should be tracked.
243//
244// See https://www.elastic.co/guide/en/elasticsearch/reference/7.1/search-request-track-total-hits.html
245// for details.
246func (s *ScrollService) TrackTotalHits(trackTotalHits interface{}) *ScrollService {
247	s.ss = s.ss.TrackTotalHits(trackTotalHits)
248	return s
249}
250
251// Routing is a list of specific routing values to control the shards
252// the search will be executed on.
253func (s *ScrollService) Routing(routings ...string) *ScrollService {
254	s.routing = strings.Join(routings, ",")
255	return s
256}
257
258// Preference sets the preference to execute the search. Defaults to
259// randomize across shards ("random"). Can be set to "_local" to prefer
260// local shards, "_primary" to execute on primary shards only,
261// or a custom value which guarantees that the same order will be used
262// across different requests.
263func (s *ScrollService) Preference(preference string) *ScrollService {
264	s.preference = preference
265	return s
266}
267
268// IgnoreUnavailable indicates whether the specified concrete indices
269// should be ignored when unavailable (missing or closed).
270func (s *ScrollService) IgnoreUnavailable(ignoreUnavailable bool) *ScrollService {
271	s.ignoreUnavailable = &ignoreUnavailable
272	return s
273}
274
275// AllowNoIndices indicates whether to ignore if a wildcard indices
276// expression resolves into no concrete indices. (This includes `_all` string
277// or when no indices have been specified).
278func (s *ScrollService) AllowNoIndices(allowNoIndices bool) *ScrollService {
279	s.allowNoIndices = &allowNoIndices
280	return s
281}
282
283// ExpandWildcards indicates whether to expand wildcard expression to
284// concrete indices that are open, closed or both.
285func (s *ScrollService) ExpandWildcards(expandWildcards string) *ScrollService {
286	s.expandWildcards = expandWildcards
287	return s
288}
289
290// MaxResponseSize sets an upper limit on the response body size that we accept,
291// to guard against OOM situations.
292func (s *ScrollService) MaxResponseSize(maxResponseSize int64) *ScrollService {
293	s.maxResponseSize = maxResponseSize
294	return s
295}
296
297// ScrollId specifies the identifier of a scroll in action.
298func (s *ScrollService) ScrollId(scrollId string) *ScrollService {
299	s.mu.Lock()
300	s.scrollId = scrollId
301	s.mu.Unlock()
302	return s
303}
304
305// Do returns the next search result. It will return io.EOF as error if there
306// are no more search results.
307func (s *ScrollService) Do(ctx context.Context) (*SearchResult, error) {
308	s.mu.RLock()
309	nextScrollId := s.scrollId
310	s.mu.RUnlock()
311	if len(nextScrollId) == 0 {
312		return s.first(ctx)
313	}
314	return s.next(ctx)
315}
316
317// Clear cancels the current scroll operation. If you don't do this manually,
318// the scroll will be expired automatically by Elasticsearch. You can control
319// how long a scroll cursor is kept alive with the KeepAlive func.
320func (s *ScrollService) Clear(ctx context.Context) error {
321	s.mu.RLock()
322	scrollId := s.scrollId
323	s.mu.RUnlock()
324	if len(scrollId) == 0 {
325		return nil
326	}
327
328	path := "/_search/scroll"
329	params := url.Values{}
330	if v := s.pretty; v != nil {
331		params.Set("pretty", fmt.Sprint(*v))
332	}
333	if v := s.human; v != nil {
334		params.Set("human", fmt.Sprint(*v))
335	}
336	if v := s.errorTrace; v != nil {
337		params.Set("error_trace", fmt.Sprint(*v))
338	}
339	if len(s.filterPath) > 0 {
340		params.Set("filter_path", strings.Join(s.filterPath, ","))
341	}
342	body := struct {
343		ScrollId []string `json:"scroll_id,omitempty"`
344	}{
345		ScrollId: []string{scrollId},
346	}
347
348	_, err := s.client.PerformRequest(ctx, PerformRequestOptions{
349		Method:  "DELETE",
350		Path:    path,
351		Params:  params,
352		Body:    body,
353		Retrier: s.retrier,
354	})
355	if err != nil {
356		return err
357	}
358
359	return nil
360}
361
362// -- First --
363
364// first takes the first page of search results.
365func (s *ScrollService) first(ctx context.Context) (*SearchResult, error) {
366	// Get URL and parameters for request
367	path, params, err := s.buildFirstURL()
368	if err != nil {
369		return nil, err
370	}
371
372	// Get HTTP request body
373	body, err := s.bodyFirst()
374	if err != nil {
375		return nil, err
376	}
377
378	// Get HTTP response
379	res, err := s.client.PerformRequest(ctx, PerformRequestOptions{
380		Method:          "POST",
381		Path:            path,
382		Params:          params,
383		Body:            body,
384		Retrier:         s.retrier,
385		Headers:         s.headers,
386		MaxResponseSize: s.maxResponseSize,
387	})
388	if err != nil {
389		return nil, err
390	}
391
392	// Return operation response
393	ret := new(SearchResult)
394	if err := s.client.decoder.Decode(res.Body, ret); err != nil {
395		return nil, err
396	}
397	s.mu.Lock()
398	s.scrollId = ret.ScrollId
399	s.mu.Unlock()
400	if ret.Hits == nil || len(ret.Hits.Hits) == 0 {
401		return ret, io.EOF
402	}
403	return ret, nil
404}
405
406// buildFirstURL builds the URL for retrieving the first page.
407func (s *ScrollService) buildFirstURL() (string, url.Values, error) {
408	// Build URL
409	var err error
410	var path string
411	if len(s.indices) == 0 && len(s.types) == 0 {
412		path = "/_search"
413	} else if len(s.indices) > 0 && len(s.types) == 0 {
414		path, err = uritemplates.Expand("/{index}/_search", map[string]string{
415			"index": strings.Join(s.indices, ","),
416		})
417	} else if len(s.indices) == 0 && len(s.types) > 0 {
418		path, err = uritemplates.Expand("/_all/{typ}/_search", map[string]string{
419			"typ": strings.Join(s.types, ","),
420		})
421	} else {
422		path, err = uritemplates.Expand("/{index}/{typ}/_search", map[string]string{
423			"index": strings.Join(s.indices, ","),
424			"typ":   strings.Join(s.types, ","),
425		})
426	}
427	if err != nil {
428		return "", url.Values{}, err
429	}
430
431	// Add query string parameters
432	params := url.Values{}
433	if v := s.pretty; v != nil {
434		params.Set("pretty", fmt.Sprint(*v))
435	}
436	if v := s.human; v != nil {
437		params.Set("human", fmt.Sprint(*v))
438	}
439	if v := s.errorTrace; v != nil {
440		params.Set("error_trace", fmt.Sprint(*v))
441	}
442	if len(s.filterPath) > 0 {
443		// Always add "hits._scroll_id", otherwise we cannot scroll
444		var found bool
445		for _, path := range s.filterPath {
446			if path == "_scroll_id" {
447				found = true
448				break
449			}
450		}
451		if !found {
452			s.filterPath = append(s.filterPath, "_scroll_id")
453		}
454		params.Set("filter_path", strings.Join(s.filterPath, ","))
455	}
456	if s.size != nil && *s.size > 0 {
457		params.Set("size", fmt.Sprintf("%d", *s.size))
458	}
459	if len(s.keepAlive) > 0 {
460		params.Set("scroll", s.keepAlive)
461	}
462	if len(s.routing) > 0 {
463		params.Set("routing", s.routing)
464	}
465	if len(s.preference) > 0 {
466		params.Set("preference", s.preference)
467	}
468	if s.allowNoIndices != nil {
469		params.Set("allow_no_indices", fmt.Sprintf("%v", *s.allowNoIndices))
470	}
471	if len(s.expandWildcards) > 0 {
472		params.Set("expand_wildcards", s.expandWildcards)
473	}
474	if s.ignoreUnavailable != nil {
475		params.Set("ignore_unavailable", fmt.Sprintf("%v", *s.ignoreUnavailable))
476	}
477
478	return path, params, nil
479}
480
481// bodyFirst returns the request to fetch the first batch of results.
482func (s *ScrollService) bodyFirst() (interface{}, error) {
483	var err error
484	var body interface{}
485
486	if s.body != nil {
487		body = s.body
488	} else {
489		// Use _doc sort by default if none is specified
490		if !s.ss.hasSort() {
491			// Use efficient sorting when no user-defined query/body is specified
492			s.ss = s.ss.SortBy(SortByDoc{})
493		}
494
495		// Body from search source
496		body, err = s.ss.Source()
497		if err != nil {
498			return nil, err
499		}
500	}
501
502	return body, nil
503}
504
505// -- Next --
506
507func (s *ScrollService) next(ctx context.Context) (*SearchResult, error) {
508	// Get URL for request
509	path, params, err := s.buildNextURL()
510	if err != nil {
511		return nil, err
512	}
513
514	// Setup HTTP request body
515	body, err := s.bodyNext()
516	if err != nil {
517		return nil, err
518	}
519
520	// Get HTTP response
521	res, err := s.client.PerformRequest(ctx, PerformRequestOptions{
522		Method:          "POST",
523		Path:            path,
524		Params:          params,
525		Body:            body,
526		Retrier:         s.retrier,
527		Headers:         s.headers,
528		MaxResponseSize: s.maxResponseSize,
529	})
530	if err != nil {
531		return nil, err
532	}
533
534	// Return operation response
535	ret := new(SearchResult)
536	if err := s.client.decoder.Decode(res.Body, ret); err != nil {
537		return nil, err
538	}
539	s.mu.Lock()
540	s.scrollId = ret.ScrollId
541	s.mu.Unlock()
542	if ret.Hits == nil || len(ret.Hits.Hits) == 0 {
543		return ret, io.EOF
544	}
545	return ret, nil
546}
547
548// buildNextURL builds the URL for the operation.
549func (s *ScrollService) buildNextURL() (string, url.Values, error) {
550	path := "/_search/scroll"
551
552	// Add query string parameters
553	params := url.Values{}
554	if v := s.pretty; v != nil {
555		params.Set("pretty", fmt.Sprint(*v))
556	}
557	if v := s.human; v != nil {
558		params.Set("human", fmt.Sprint(*v))
559	}
560	if v := s.errorTrace; v != nil {
561		params.Set("error_trace", fmt.Sprint(*v))
562	}
563	if len(s.filterPath) > 0 {
564		// Always add "hits._scroll_id", otherwise we cannot scroll
565		var found bool
566		for _, path := range s.filterPath {
567			if path == "_scroll_id" {
568				found = true
569				break
570			}
571		}
572		if !found {
573			s.filterPath = append(s.filterPath, "_scroll_id")
574		}
575		params.Set("filter_path", strings.Join(s.filterPath, ","))
576	}
577
578	return path, params, nil
579}
580
581// body returns the request to fetch the next batch of results.
582func (s *ScrollService) bodyNext() (interface{}, error) {
583	s.mu.RLock()
584	body := struct {
585		Scroll   string `json:"scroll"`
586		ScrollId string `json:"scroll_id,omitempty"`
587	}{
588		Scroll:   s.keepAlive,
589		ScrollId: s.scrollId,
590	}
591	s.mu.RUnlock()
592	return body, nil
593}
594
595// DocvalueField adds a single field to load from the field data cache
596// and return as part of the search.
597func (s *ScrollService) DocvalueField(docvalueField string) *ScrollService {
598	s.ss = s.ss.DocvalueField(docvalueField)
599	return s
600}
601
602// DocvalueFieldWithFormat adds a single field to load from the field data cache
603// and return as part of the search.
604func (s *ScrollService) DocvalueFieldWithFormat(docvalueField DocvalueField) *ScrollService {
605	s.ss = s.ss.DocvalueFieldWithFormat(docvalueField)
606	return s
607}
608
609// DocvalueFields adds one or more fields to load from the field data cache
610// and return as part of the search.
611func (s *ScrollService) DocvalueFields(docvalueFields ...string) *ScrollService {
612	s.ss = s.ss.DocvalueFields(docvalueFields...)
613	return s
614}
615
616// DocvalueFieldsWithFormat adds one or more fields to load from the field data cache
617// and return as part of the search.
618func (s *ScrollService) DocvalueFieldsWithFormat(docvalueFields ...DocvalueField) *ScrollService {
619	s.ss = s.ss.DocvalueFieldsWithFormat(docvalueFields...)
620	return s
621}
622