1// Copyright 2012-present Oliver Eilhard. All rights reserved.
2// Use of this source code is governed by a MIT-license.
3// See http://olivere.mit-license.org/license.txt for details.
4
5package elastic
6
7import "fmt"
8
9// A bool query matches documents matching boolean
10// combinations of other queries.
11// For more details, see:
12// https://www.elastic.co/guide/en/elasticsearch/reference/6.7/query-dsl-bool-query.html
13type BoolQuery struct {
14	Query
15	mustClauses        []Query
16	mustNotClauses     []Query
17	filterClauses      []Query
18	shouldClauses      []Query
19	boost              *float64
20	minimumShouldMatch string
21	adjustPureNegative *bool
22	queryName          string
23}
24
25// Creates a new bool query.
26func NewBoolQuery() *BoolQuery {
27	return &BoolQuery{
28		mustClauses:    make([]Query, 0),
29		mustNotClauses: make([]Query, 0),
30		filterClauses:  make([]Query, 0),
31		shouldClauses:  make([]Query, 0),
32	}
33}
34
35func (q *BoolQuery) Must(queries ...Query) *BoolQuery {
36	q.mustClauses = append(q.mustClauses, queries...)
37	return q
38}
39
40func (q *BoolQuery) MustNot(queries ...Query) *BoolQuery {
41	q.mustNotClauses = append(q.mustNotClauses, queries...)
42	return q
43}
44
45func (q *BoolQuery) Filter(filters ...Query) *BoolQuery {
46	q.filterClauses = append(q.filterClauses, filters...)
47	return q
48}
49
50func (q *BoolQuery) Should(queries ...Query) *BoolQuery {
51	q.shouldClauses = append(q.shouldClauses, queries...)
52	return q
53}
54
55func (q *BoolQuery) Boost(boost float64) *BoolQuery {
56	q.boost = &boost
57	return q
58}
59
60func (q *BoolQuery) MinimumShouldMatch(minimumShouldMatch string) *BoolQuery {
61	q.minimumShouldMatch = minimumShouldMatch
62	return q
63}
64
65func (q *BoolQuery) MinimumNumberShouldMatch(minimumNumberShouldMatch int) *BoolQuery {
66	q.minimumShouldMatch = fmt.Sprintf("%d", minimumNumberShouldMatch)
67	return q
68}
69
70func (q *BoolQuery) AdjustPureNegative(adjustPureNegative bool) *BoolQuery {
71	q.adjustPureNegative = &adjustPureNegative
72	return q
73}
74
75func (q *BoolQuery) QueryName(queryName string) *BoolQuery {
76	q.queryName = queryName
77	return q
78}
79
80// Creates the query source for the bool query.
81func (q *BoolQuery) Source() (interface{}, error) {
82	// {
83	//	"bool" : {
84	//		"must" : {
85	//			"term" : { "user" : "kimchy" }
86	//		},
87	//		"must_not" : {
88	//			"range" : {
89	//				"age" : { "from" : 10, "to" : 20 }
90	//			}
91	//		},
92	//    "filter" : [
93	//      ...
94	//    ]
95	//		"should" : [
96	//			{
97	//				"term" : { "tag" : "wow" }
98	//			},
99	//			{
100	//				"term" : { "tag" : "elasticsearch" }
101	//			}
102	//		],
103	//		"minimum_should_match" : 1,
104	//		"boost" : 1.0
105	//	}
106	// }
107
108	query := make(map[string]interface{})
109
110	boolClause := make(map[string]interface{})
111	query["bool"] = boolClause
112
113	// must
114	if len(q.mustClauses) == 1 {
115		src, err := q.mustClauses[0].Source()
116		if err != nil {
117			return nil, err
118		}
119		boolClause["must"] = src
120	} else if len(q.mustClauses) > 1 {
121		var clauses []interface{}
122		for _, subQuery := range q.mustClauses {
123			src, err := subQuery.Source()
124			if err != nil {
125				return nil, err
126			}
127			clauses = append(clauses, src)
128		}
129		boolClause["must"] = clauses
130	}
131
132	// must_not
133	if len(q.mustNotClauses) == 1 {
134		src, err := q.mustNotClauses[0].Source()
135		if err != nil {
136			return nil, err
137		}
138		boolClause["must_not"] = src
139	} else if len(q.mustNotClauses) > 1 {
140		var clauses []interface{}
141		for _, subQuery := range q.mustNotClauses {
142			src, err := subQuery.Source()
143			if err != nil {
144				return nil, err
145			}
146			clauses = append(clauses, src)
147		}
148		boolClause["must_not"] = clauses
149	}
150
151	// filter
152	if len(q.filterClauses) == 1 {
153		src, err := q.filterClauses[0].Source()
154		if err != nil {
155			return nil, err
156		}
157		boolClause["filter"] = src
158	} else if len(q.filterClauses) > 1 {
159		var clauses []interface{}
160		for _, subQuery := range q.filterClauses {
161			src, err := subQuery.Source()
162			if err != nil {
163				return nil, err
164			}
165			clauses = append(clauses, src)
166		}
167		boolClause["filter"] = clauses
168	}
169
170	// should
171	if len(q.shouldClauses) == 1 {
172		src, err := q.shouldClauses[0].Source()
173		if err != nil {
174			return nil, err
175		}
176		boolClause["should"] = src
177	} else if len(q.shouldClauses) > 1 {
178		var clauses []interface{}
179		for _, subQuery := range q.shouldClauses {
180			src, err := subQuery.Source()
181			if err != nil {
182				return nil, err
183			}
184			clauses = append(clauses, src)
185		}
186		boolClause["should"] = clauses
187	}
188
189	if q.boost != nil {
190		boolClause["boost"] = *q.boost
191	}
192	if q.minimumShouldMatch != "" {
193		boolClause["minimum_should_match"] = q.minimumShouldMatch
194	}
195	if q.adjustPureNegative != nil {
196		boolClause["adjust_pure_negative"] = *q.adjustPureNegative
197	}
198	if q.queryName != "" {
199		boolClause["_name"] = q.queryName
200	}
201
202	return query, nil
203}
204