1// Copyright 2012-2015 Oliver Eilhard. All rights reserved.
2// Use of this source code is governed by a MIT-license.
3// See http://olivere.mit-license.org/license.txt for details.
4
5package elastic
6
7// For more details, see
8// http://www.elasticsearch.org/guide/reference/api/search/phrase-suggest/
9type PhraseSuggester struct {
10	Suggester
11	name           string
12	text           string
13	field          string
14	analyzer       string
15	size           *int
16	shardSize      *int
17	contextQueries []SuggesterContextQuery
18
19	// fields specific to a phrase suggester
20	maxErrors               *float32
21	separator               *string
22	realWordErrorLikelihood *float32
23	confidence              *float32
24	generators              map[string][]CandidateGenerator
25	gramSize                *int
26	smoothingModel          SmoothingModel
27	forceUnigrams           *bool
28	tokenLimit              *int
29	preTag, postTag         *string
30	collateQuery            *string
31	collateFilter           *string
32	collatePreference       *string
33	collateParams           map[string]interface{}
34	collatePrune            *bool
35}
36
37// Creates a new phrase suggester.
38func NewPhraseSuggester(name string) PhraseSuggester {
39	return PhraseSuggester{
40		name:           name,
41		contextQueries: make([]SuggesterContextQuery, 0),
42		collateParams:  make(map[string]interface{}),
43	}
44}
45
46func (q PhraseSuggester) Name() string {
47	return q.name
48}
49
50func (q PhraseSuggester) Text(text string) PhraseSuggester {
51	q.text = text
52	return q
53}
54
55func (q PhraseSuggester) Field(field string) PhraseSuggester {
56	q.field = field
57	return q
58}
59
60func (q PhraseSuggester) Analyzer(analyzer string) PhraseSuggester {
61	q.analyzer = analyzer
62	return q
63}
64
65func (q PhraseSuggester) Size(size int) PhraseSuggester {
66	q.size = &size
67	return q
68}
69
70func (q PhraseSuggester) ShardSize(shardSize int) PhraseSuggester {
71	q.shardSize = &shardSize
72	return q
73}
74
75func (q PhraseSuggester) ContextQuery(query SuggesterContextQuery) PhraseSuggester {
76	q.contextQueries = append(q.contextQueries, query)
77	return q
78}
79
80func (q PhraseSuggester) ContextQueries(queries ...SuggesterContextQuery) PhraseSuggester {
81	q.contextQueries = append(q.contextQueries, queries...)
82	return q
83}
84
85func (q PhraseSuggester) GramSize(gramSize int) PhraseSuggester {
86	if gramSize >= 1 {
87		q.gramSize = &gramSize
88	}
89	return q
90}
91
92func (q PhraseSuggester) MaxErrors(maxErrors float32) PhraseSuggester {
93	q.maxErrors = &maxErrors
94	return q
95}
96
97func (q PhraseSuggester) Separator(separator string) PhraseSuggester {
98	q.separator = &separator
99	return q
100}
101
102func (q PhraseSuggester) RealWordErrorLikelihood(realWordErrorLikelihood float32) PhraseSuggester {
103	q.realWordErrorLikelihood = &realWordErrorLikelihood
104	return q
105}
106
107func (q PhraseSuggester) Confidence(confidence float32) PhraseSuggester {
108	q.confidence = &confidence
109	return q
110}
111
112func (q PhraseSuggester) CandidateGenerator(generator CandidateGenerator) PhraseSuggester {
113	if q.generators == nil {
114		q.generators = make(map[string][]CandidateGenerator)
115	}
116	typ := generator.Type()
117	if _, found := q.generators[typ]; !found {
118		q.generators[typ] = make([]CandidateGenerator, 0)
119	}
120	q.generators[typ] = append(q.generators[typ], generator)
121	return q
122}
123
124func (q PhraseSuggester) CandidateGenerators(generators ...CandidateGenerator) PhraseSuggester {
125	for _, g := range generators {
126		q = q.CandidateGenerator(g)
127	}
128	return q
129}
130
131func (q PhraseSuggester) ClearCandidateGenerator() PhraseSuggester {
132	q.generators = nil
133	return q
134}
135
136func (q PhraseSuggester) ForceUnigrams(forceUnigrams bool) PhraseSuggester {
137	q.forceUnigrams = &forceUnigrams
138	return q
139}
140
141func (q PhraseSuggester) SmoothingModel(smoothingModel SmoothingModel) PhraseSuggester {
142	q.smoothingModel = smoothingModel
143	return q
144}
145
146func (q PhraseSuggester) TokenLimit(tokenLimit int) PhraseSuggester {
147	q.tokenLimit = &tokenLimit
148	return q
149}
150
151func (q PhraseSuggester) Highlight(preTag, postTag string) PhraseSuggester {
152	q.preTag = &preTag
153	q.postTag = &postTag
154	return q
155}
156
157func (q PhraseSuggester) CollateQuery(collateQuery string) PhraseSuggester {
158	q.collateQuery = &collateQuery
159	return q
160}
161
162func (q PhraseSuggester) CollateFilter(collateFilter string) PhraseSuggester {
163	q.collateFilter = &collateFilter
164	return q
165}
166
167func (q PhraseSuggester) CollatePreference(collatePreference string) PhraseSuggester {
168	q.collatePreference = &collatePreference
169	return q
170}
171
172func (q PhraseSuggester) CollateParams(collateParams map[string]interface{}) PhraseSuggester {
173	q.collateParams = collateParams
174	return q
175}
176
177func (q PhraseSuggester) CollatePrune(collatePrune bool) PhraseSuggester {
178	q.collatePrune = &collatePrune
179	return q
180}
181
182// simplePhraseSuggesterRequest is necessary because the order in which
183// the JSON elements are routed to Elasticsearch is relevant.
184// We got into trouble when using plain maps because the text element
185// needs to go before the simple_phrase element.
186type phraseSuggesterRequest struct {
187	Text   string      `json:"text"`
188	Phrase interface{} `json:"phrase"`
189}
190
191// Creates the source for the phrase suggester.
192func (q PhraseSuggester) Source(includeName bool) interface{} {
193	ps := &phraseSuggesterRequest{}
194
195	if q.text != "" {
196		ps.Text = q.text
197	}
198
199	suggester := make(map[string]interface{})
200	ps.Phrase = suggester
201
202	if q.analyzer != "" {
203		suggester["analyzer"] = q.analyzer
204	}
205	if q.field != "" {
206		suggester["field"] = q.field
207	}
208	if q.size != nil {
209		suggester["size"] = *q.size
210	}
211	if q.shardSize != nil {
212		suggester["shard_size"] = *q.shardSize
213	}
214	switch len(q.contextQueries) {
215	case 0:
216	case 1:
217		suggester["context"] = q.contextQueries[0].Source()
218	default:
219		ctxq := make([]interface{}, 0)
220		for _, query := range q.contextQueries {
221			ctxq = append(ctxq, query.Source())
222		}
223		suggester["context"] = ctxq
224	}
225
226	// Phase-specified parameters
227	if q.realWordErrorLikelihood != nil {
228		suggester["real_word_error_likelihood"] = *q.realWordErrorLikelihood
229	}
230	if q.confidence != nil {
231		suggester["confidence"] = *q.confidence
232	}
233	if q.separator != nil {
234		suggester["separator"] = *q.separator
235	}
236	if q.maxErrors != nil {
237		suggester["max_errors"] = *q.maxErrors
238	}
239	if q.gramSize != nil {
240		suggester["gram_size"] = *q.gramSize
241	}
242	if q.forceUnigrams != nil {
243		suggester["force_unigrams"] = *q.forceUnigrams
244	}
245	if q.tokenLimit != nil {
246		suggester["token_limit"] = *q.tokenLimit
247	}
248	if q.generators != nil && len(q.generators) > 0 {
249		for typ, generators := range q.generators {
250			arr := make([]interface{}, 0)
251			for _, g := range generators {
252				arr = append(arr, g.Source())
253			}
254			suggester[typ] = arr
255		}
256	}
257	if q.smoothingModel != nil {
258		x := make(map[string]interface{})
259		x[q.smoothingModel.Type()] = q.smoothingModel.Source()
260		suggester["smoothing"] = x
261	}
262	if q.preTag != nil {
263		hl := make(map[string]string)
264		hl["pre_tag"] = *q.preTag
265		if q.postTag != nil {
266			hl["post_tag"] = *q.postTag
267		}
268		suggester["highlight"] = hl
269	}
270	if q.collateQuery != nil || q.collateFilter != nil {
271		collate := make(map[string]interface{})
272		suggester["collate"] = collate
273		if q.collateQuery != nil {
274			collate["query"] = *q.collateQuery
275		}
276		if q.collateFilter != nil {
277			collate["filter"] = *q.collateFilter
278		}
279		if q.collatePreference != nil {
280			collate["preference"] = *q.collatePreference
281		}
282		if len(q.collateParams) > 0 {
283			collate["params"] = q.collateParams
284		}
285		if q.collatePrune != nil {
286			collate["prune"] = *q.collatePrune
287		}
288	}
289
290	if !includeName {
291		return ps
292	}
293
294	source := make(map[string]interface{})
295	source[q.name] = ps
296	return source
297}
298
299// -- Smoothing models --
300
301type SmoothingModel interface {
302	Type() string
303	Source() interface{}
304}
305
306// StupidBackoffSmoothingModel implements a stupid backoff smoothing model.
307// See http://www.elasticsearch.org/guide/en/elasticsearch/reference/current/search-suggesters-phrase.html#_smoothing_models
308// for details about smoothing models.
309type StupidBackoffSmoothingModel struct {
310	discount float64
311}
312
313func NewStupidBackoffSmoothingModel(discount float64) *StupidBackoffSmoothingModel {
314	return &StupidBackoffSmoothingModel{
315		discount: discount,
316	}
317}
318
319func (sm *StupidBackoffSmoothingModel) Type() string {
320	return "stupid_backoff"
321}
322
323func (sm *StupidBackoffSmoothingModel) Source() interface{} {
324	source := make(map[string]interface{})
325	source["discount"] = sm.discount
326	return source
327}
328
329// --
330
331// LaplaceSmoothingModel implements a laplace smoothing model.
332// See http://www.elasticsearch.org/guide/en/elasticsearch/reference/current/search-suggesters-phrase.html#_smoothing_models
333// for details about smoothing models.
334type LaplaceSmoothingModel struct {
335	alpha float64
336}
337
338func NewLaplaceSmoothingModel(alpha float64) *LaplaceSmoothingModel {
339	return &LaplaceSmoothingModel{
340		alpha: alpha,
341	}
342}
343
344func (sm *LaplaceSmoothingModel) Type() string {
345	return "laplace"
346}
347
348func (sm *LaplaceSmoothingModel) Source() interface{} {
349	source := make(map[string]interface{})
350	source["alpha"] = sm.alpha
351	return source
352}
353
354// --
355
356// LinearInterpolationSmoothingModel implements a linear interpolation
357// smoothing model.
358// See http://www.elasticsearch.org/guide/en/elasticsearch/reference/current/search-suggesters-phrase.html#_smoothing_models
359// for details about smoothing models.
360type LinearInterpolationSmoothingModel struct {
361	trigramLamda  float64
362	bigramLambda  float64
363	unigramLambda float64
364}
365
366func NewLinearInterpolationSmoothingModel(trigramLamda, bigramLambda, unigramLambda float64) *LinearInterpolationSmoothingModel {
367	return &LinearInterpolationSmoothingModel{
368		trigramLamda:  trigramLamda,
369		bigramLambda:  bigramLambda,
370		unigramLambda: unigramLambda,
371	}
372}
373
374func (sm *LinearInterpolationSmoothingModel) Type() string {
375	return "linear_interpolation"
376}
377
378func (sm *LinearInterpolationSmoothingModel) Source() interface{} {
379	source := make(map[string]interface{})
380	source["trigram_lambda"] = sm.trigramLamda
381	source["bigram_lambda"] = sm.bigramLambda
382	source["unigram_lambda"] = sm.unigramLambda
383	return source
384}
385
386// -- CandidateGenerator --
387
388type CandidateGenerator interface {
389	Type() string
390	Source() interface{}
391}
392
393// DirectCandidateGenerator implements a direct candidate generator.
394// See http://www.elasticsearch.org/guide/en/elasticsearch/reference/current/search-suggesters-phrase.html#_smoothing_models
395// for details about smoothing models.
396type DirectCandidateGenerator struct {
397	field          string
398	preFilter      *string
399	postFilter     *string
400	suggestMode    *string
401	accuracy       *float64
402	size           *int
403	sort           *string
404	stringDistance *string
405	maxEdits       *int
406	maxInspections *int
407	maxTermFreq    *float64
408	prefixLength   *int
409	minWordLength  *int
410	minDocFreq     *float64
411}
412
413func NewDirectCandidateGenerator(field string) *DirectCandidateGenerator {
414	return &DirectCandidateGenerator{
415		field: field,
416	}
417}
418
419func (g *DirectCandidateGenerator) Type() string {
420	return "direct_generator"
421}
422
423func (g *DirectCandidateGenerator) Field(field string) *DirectCandidateGenerator {
424	g.field = field
425	return g
426}
427
428func (g *DirectCandidateGenerator) PreFilter(preFilter string) *DirectCandidateGenerator {
429	g.preFilter = &preFilter
430	return g
431}
432
433func (g *DirectCandidateGenerator) PostFilter(postFilter string) *DirectCandidateGenerator {
434	g.postFilter = &postFilter
435	return g
436}
437
438func (g *DirectCandidateGenerator) SuggestMode(suggestMode string) *DirectCandidateGenerator {
439	g.suggestMode = &suggestMode
440	return g
441}
442
443func (g *DirectCandidateGenerator) Accuracy(accuracy float64) *DirectCandidateGenerator {
444	g.accuracy = &accuracy
445	return g
446}
447
448func (g *DirectCandidateGenerator) Size(size int) *DirectCandidateGenerator {
449	g.size = &size
450	return g
451}
452
453func (g *DirectCandidateGenerator) Sort(sort string) *DirectCandidateGenerator {
454	g.sort = &sort
455	return g
456}
457
458func (g *DirectCandidateGenerator) StringDistance(stringDistance string) *DirectCandidateGenerator {
459	g.stringDistance = &stringDistance
460	return g
461}
462
463func (g *DirectCandidateGenerator) MaxEdits(maxEdits int) *DirectCandidateGenerator {
464	g.maxEdits = &maxEdits
465	return g
466}
467
468func (g *DirectCandidateGenerator) MaxInspections(maxInspections int) *DirectCandidateGenerator {
469	g.maxInspections = &maxInspections
470	return g
471}
472
473func (g *DirectCandidateGenerator) MaxTermFreq(maxTermFreq float64) *DirectCandidateGenerator {
474	g.maxTermFreq = &maxTermFreq
475	return g
476}
477
478func (g *DirectCandidateGenerator) PrefixLength(prefixLength int) *DirectCandidateGenerator {
479	g.prefixLength = &prefixLength
480	return g
481}
482
483func (g *DirectCandidateGenerator) MinWordLength(minWordLength int) *DirectCandidateGenerator {
484	g.minWordLength = &minWordLength
485	return g
486}
487
488func (g *DirectCandidateGenerator) MinDocFreq(minDocFreq float64) *DirectCandidateGenerator {
489	g.minDocFreq = &minDocFreq
490	return g
491}
492
493func (g *DirectCandidateGenerator) Source() interface{} {
494	source := make(map[string]interface{})
495	if g.field != "" {
496		source["field"] = g.field
497	}
498	if g.suggestMode != nil {
499		source["suggest_mode"] = *g.suggestMode
500	}
501	if g.accuracy != nil {
502		source["accuracy"] = *g.accuracy
503	}
504	if g.size != nil {
505		source["size"] = *g.size
506	}
507	if g.sort != nil {
508		source["sort"] = *g.sort
509	}
510	if g.stringDistance != nil {
511		source["string_distance"] = *g.stringDistance
512	}
513	if g.maxEdits != nil {
514		source["max_edits"] = *g.maxEdits
515	}
516	if g.maxInspections != nil {
517		source["max_inspections"] = *g.maxInspections
518	}
519	if g.maxTermFreq != nil {
520		source["max_term_freq"] = *g.maxTermFreq
521	}
522	if g.prefixLength != nil {
523		source["prefix_length"] = *g.prefixLength
524	}
525	if g.minWordLength != nil {
526		source["min_word_length"] = *g.minWordLength
527	}
528	if g.minDocFreq != nil {
529		source["min_doc_freq"] = *g.minDocFreq
530	}
531	if g.preFilter != nil {
532		source["pre_filter"] = *g.preFilter
533	}
534	if g.postFilter != nil {
535		source["post_filter"] = *g.postFilter
536	}
537	return source
538}
539