1package sqlbuilder
2
3import (
4	"database/sql/driver"
5	"fmt"
6	"strings"
7
8	db "upper.io/db.v3"
9	"upper.io/db.v3/internal/sqladapter/exql"
10)
11
12type templateWithUtils struct {
13	*exql.Template
14}
15
16func newTemplateWithUtils(template *exql.Template) *templateWithUtils {
17	return &templateWithUtils{template}
18}
19
20func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) {
21	switch t := in.(type) {
22	case db.RawValue:
23		return exql.RawValue(t.String()), t.Arguments()
24	case db.Function:
25		fnName := t.Name()
26		fnArgs := []interface{}{}
27		args, _ := toInterfaceArguments(t.Arguments())
28		fragments := []string{}
29		for i := range args {
30			frag, args := tu.PlaceholderValue(args[i])
31			fragment, err := frag.Compile(tu.Template)
32			if err == nil {
33				fragments = append(fragments, fragment)
34				fnArgs = append(fnArgs, args...)
35			}
36		}
37		return exql.RawValue(fnName + `(` + strings.Join(fragments, `, `) + `)`), fnArgs
38	default:
39		// Value must be escaped.
40		return sqlPlaceholder, []interface{}{in}
41	}
42}
43
44// toWhereWithArguments converts the given parameters into a exql.Where
45// value.
46func (tu *templateWithUtils) toWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) {
47	args = []interface{}{}
48
49	switch t := term.(type) {
50	case []interface{}:
51		if len(t) > 0 {
52			if s, ok := t[0].(string); ok {
53				if strings.ContainsAny(s, "?") || len(t) == 1 {
54					s, args = Preprocess(s, t[1:])
55					where.Conditions = []exql.Fragment{exql.RawValue(s)}
56				} else {
57					var val interface{}
58					key := s
59					if len(t) > 2 {
60						val = t[1:]
61					} else {
62						val = t[1]
63					}
64					cv, v := tu.toColumnValues(db.NewConstraint(key, val))
65					args = append(args, v...)
66					for i := range cv.ColumnValues {
67						where.Conditions = append(where.Conditions, cv.ColumnValues[i])
68					}
69				}
70				return
71			}
72		}
73		for i := range t {
74			w, v := tu.toWhereWithArguments(t[i])
75			if len(w.Conditions) == 0 {
76				continue
77			}
78			args = append(args, v...)
79			where.Conditions = append(where.Conditions, w.Conditions...)
80		}
81		return
82	case db.RawValue:
83		r, v := Preprocess(t.Raw(), t.Arguments())
84		where.Conditions = []exql.Fragment{exql.RawValue(r)}
85		args = append(args, v...)
86		return
87	case db.Constraints:
88		for _, c := range t.Constraints() {
89			w, v := tu.toWhereWithArguments(c)
90			if len(w.Conditions) == 0 {
91				continue
92			}
93			args = append(args, v...)
94			where.Conditions = append(where.Conditions, w.Conditions...)
95		}
96		return
97	case db.Compound:
98		var cond exql.Where
99
100		for _, c := range t.Sentences() {
101			w, v := tu.toWhereWithArguments(c)
102			if len(w.Conditions) == 0 {
103				continue
104			}
105			args = append(args, v...)
106			cond.Conditions = append(cond.Conditions, w.Conditions...)
107		}
108
109		if len(cond.Conditions) > 0 {
110			var frag exql.Fragment
111			switch t.Operator() {
112			case db.OperatorNone, db.OperatorAnd:
113				q := exql.And(cond)
114				frag = &q
115			case db.OperatorOr:
116				q := exql.Or(cond)
117				frag = &q
118			default:
119				panic(fmt.Sprintf("Unknown type %T", t))
120			}
121			where.Conditions = append(where.Conditions, frag)
122		}
123
124		return
125	case db.Constraint:
126		cv, v := tu.toColumnValues(t)
127		args = append(args, v...)
128		where.Conditions = append(where.Conditions, cv.ColumnValues...)
129		return where, args
130	}
131
132	panic(fmt.Sprintf("Unknown condition type %T", term))
133}
134
135func (tu *templateWithUtils) comparisonOperatorMapper(t db.ComparisonOperator) string {
136	if t == db.ComparisonOperatorNone {
137		return ""
138	}
139	if tu.ComparisonOperator != nil {
140		if op, ok := tu.ComparisonOperator[t]; ok {
141			return op
142		}
143	}
144	if op, ok := comparisonOperators[t]; ok {
145		return op
146	}
147	panic(fmt.Sprintf("unsupported comparison operator %v", t))
148}
149
150func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) {
151	args = []interface{}{}
152
153	switch t := term.(type) {
154	case db.Constraint:
155		columnValue := exql.ColumnValue{}
156
157		// Getting column and operator.
158		if column, ok := t.Key().(string); ok {
159			chunks := strings.SplitN(strings.TrimSpace(column), " ", 2)
160			columnValue.Column = exql.ColumnWithName(chunks[0])
161			if len(chunks) > 1 {
162				columnValue.Operator = chunks[1]
163			}
164		} else {
165			if rawValue, ok := t.Key().(db.RawValue); ok {
166				columnValue.Column = exql.RawValue(rawValue.Raw())
167				args = append(args, rawValue.Arguments()...)
168			} else {
169				columnValue.Column = exql.RawValue(fmt.Sprintf("%v", t.Key()))
170			}
171		}
172
173		switch value := t.Value().(type) {
174		case db.Function:
175			fnName, fnArgs := value.Name(), value.Arguments()
176			if len(fnArgs) == 0 {
177				// A function with no arguments.
178				fnName = fnName + "()"
179			} else {
180				// A function with one or more arguments.
181				fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
182			}
183			fnName, fnArgs = Preprocess(fnName, fnArgs)
184			columnValue.Value = exql.RawValue(fnName)
185			args = append(args, fnArgs...)
186		case db.RawValue:
187			q, a := Preprocess(value.Raw(), value.Arguments())
188			columnValue.Value = exql.RawValue(q)
189			args = append(args, a...)
190		case driver.Valuer:
191			columnValue.Value = exql.RawValue("?")
192			args = append(args, value)
193		case db.Comparison:
194			wrapper := &operatorWrapper{
195				tu: tu,
196				cv: &columnValue,
197				op: value,
198			}
199
200			q, a := wrapper.preprocess()
201			q, a = Preprocess(q, a)
202
203			columnValue = exql.ColumnValue{
204				Column: exql.RawValue(q),
205			}
206			if a != nil {
207				args = append(args, a...)
208			}
209
210			cv.ColumnValues = append(cv.ColumnValues, &columnValue)
211			return cv, args
212		default:
213			wrapper := &operatorWrapper{
214				tu: tu,
215				cv: &columnValue,
216				v:  value,
217			}
218
219			q, a := wrapper.preprocess()
220			q, a = Preprocess(q, a)
221
222			columnValue = exql.ColumnValue{
223				Column: exql.RawValue(q),
224			}
225			if a != nil {
226				args = append(args, a...)
227			}
228
229			cv.ColumnValues = append(cv.ColumnValues, &columnValue)
230			return cv, args
231		}
232
233		if columnValue.Operator == "" {
234			columnValue.Operator = tu.comparisonOperatorMapper(db.ComparisonOperatorEqual)
235		}
236		cv.ColumnValues = append(cv.ColumnValues, &columnValue)
237		return cv, args
238	case db.RawValue:
239		columnValue := exql.ColumnValue{}
240		p, q := Preprocess(t.Raw(), t.Arguments())
241		columnValue.Column = exql.RawValue(p)
242		cv.ColumnValues = append(cv.ColumnValues, &columnValue)
243		args = append(args, q...)
244		return cv, args
245	case db.Constraints:
246		for _, constraint := range t.Constraints() {
247			p, q := tu.toColumnValues(constraint)
248			cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...)
249			args = append(args, q...)
250		}
251		return cv, args
252	}
253
254	panic(fmt.Sprintf("Unknown term type %T.", term))
255}
256
257func (tu *templateWithUtils) setColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) {
258	args = []interface{}{}
259
260	switch t := term.(type) {
261	case []interface{}:
262		l := len(t)
263		for i := 0; i < l; i++ {
264			column, isString := t[i].(string)
265
266			if !isString {
267				p, q := tu.setColumnValues(t[i])
268				cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...)
269				args = append(args, q...)
270				continue
271			}
272
273			if !strings.ContainsAny(column, tu.AssignmentOperator) {
274				column = column + " " + tu.AssignmentOperator + " ?"
275			}
276
277			chunks := strings.SplitN(column, tu.AssignmentOperator, 2)
278
279			column = chunks[0]
280			format := strings.TrimSpace(chunks[1])
281
282			columnValue := exql.ColumnValue{
283				Column:   exql.ColumnWithName(column),
284				Operator: tu.AssignmentOperator,
285				Value:    exql.RawValue(format),
286			}
287
288			ps := strings.Count(format, "?")
289			if i+ps < l {
290				for j := 0; j < ps; j++ {
291					args = append(args, t[i+j+1])
292				}
293				i = i + ps
294			} else {
295				panic(fmt.Sprintf("Format string %q has more placeholders than given arguments.", format))
296			}
297
298			cv.ColumnValues = append(cv.ColumnValues, &columnValue)
299		}
300		return cv, args
301	case db.RawValue:
302		columnValue := exql.ColumnValue{}
303		p, q := Preprocess(t.Raw(), t.Arguments())
304		columnValue.Column = exql.RawValue(p)
305		cv.ColumnValues = append(cv.ColumnValues, &columnValue)
306		args = append(args, q...)
307		return cv, args
308	}
309
310	panic(fmt.Sprintf("Unknown term type %T.", term))
311}
312