1package sqlx
2
3// Named Query Support
4//
5//  * BindMap - bind query bindvars to map/struct args
6//	* NamedExec, NamedQuery - named query w/ struct or map
7//  * NamedStmt - a pre-compiled named query which is a prepared statement
8//
9// Internal Interfaces:
10//
11//  * compileNamedQuery - rebind a named query, returning a query and list of names
12//  * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist
13//
14import (
15	"bytes"
16	"database/sql"
17	"errors"
18	"fmt"
19	"reflect"
20	"regexp"
21	"strconv"
22	"unicode"
23
24	"github.com/jmoiron/sqlx/reflectx"
25)
26
27// NamedStmt is a prepared statement that executes named queries.  Prepare it
28// how you would execute a NamedQuery, but pass in a struct or map when executing.
29type NamedStmt struct {
30	Params      []string
31	QueryString string
32	Stmt        *Stmt
33}
34
35// Close closes the named statement.
36func (n *NamedStmt) Close() error {
37	return n.Stmt.Close()
38}
39
40// Exec executes a named statement using the struct passed.
41// Any named placeholder parameters are replaced with fields from arg.
42func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {
43	args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
44	if err != nil {
45		return *new(sql.Result), err
46	}
47	return n.Stmt.Exec(args...)
48}
49
50// Query executes a named statement using the struct argument, returning rows.
51// Any named placeholder parameters are replaced with fields from arg.
52func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
53	args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
54	if err != nil {
55		return nil, err
56	}
57	return n.Stmt.Query(args...)
58}
59
60// QueryRow executes a named statement against the database.  Because sqlx cannot
61// create a *sql.Row with an error condition pre-set for binding errors, sqlx
62// returns a *sqlx.Row instead.
63// Any named placeholder parameters are replaced with fields from arg.
64func (n *NamedStmt) QueryRow(arg interface{}) *Row {
65	args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
66	if err != nil {
67		return &Row{err: err}
68	}
69	return n.Stmt.QueryRowx(args...)
70}
71
72// MustExec execs a NamedStmt, panicing on error
73// Any named placeholder parameters are replaced with fields from arg.
74func (n *NamedStmt) MustExec(arg interface{}) sql.Result {
75	res, err := n.Exec(arg)
76	if err != nil {
77		panic(err)
78	}
79	return res
80}
81
82// Queryx using this NamedStmt
83// Any named placeholder parameters are replaced with fields from arg.
84func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
85	r, err := n.Query(arg)
86	if err != nil {
87		return nil, err
88	}
89	return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
90}
91
92// QueryRowx this NamedStmt.  Because of limitations with QueryRow, this is
93// an alias for QueryRow.
94// Any named placeholder parameters are replaced with fields from arg.
95func (n *NamedStmt) QueryRowx(arg interface{}) *Row {
96	return n.QueryRow(arg)
97}
98
99// Select using this NamedStmt
100// Any named placeholder parameters are replaced with fields from arg.
101func (n *NamedStmt) Select(dest interface{}, arg interface{}) error {
102	rows, err := n.Queryx(arg)
103	if err != nil {
104		return err
105	}
106	// if something happens here, we want to make sure the rows are Closed
107	defer rows.Close()
108	return scanAll(rows, dest, false)
109}
110
111// Get using this NamedStmt
112// Any named placeholder parameters are replaced with fields from arg.
113func (n *NamedStmt) Get(dest interface{}, arg interface{}) error {
114	r := n.QueryRowx(arg)
115	return r.scanAny(dest, false)
116}
117
118// Unsafe creates an unsafe version of the NamedStmt
119func (n *NamedStmt) Unsafe() *NamedStmt {
120	r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString}
121	r.Stmt.unsafe = true
122	return r
123}
124
125// A union interface of preparer and binder, required to be able to prepare
126// named statements (as the bindtype must be determined).
127type namedPreparer interface {
128	Preparer
129	binder
130}
131
132func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) {
133	bindType := BindType(p.DriverName())
134	q, args, err := compileNamedQuery([]byte(query), bindType)
135	if err != nil {
136		return nil, err
137	}
138	stmt, err := Preparex(p, q)
139	if err != nil {
140		return nil, err
141	}
142	return &NamedStmt{
143		QueryString: q,
144		Params:      args,
145		Stmt:        stmt,
146	}, nil
147}
148
149// convertMapStringInterface attempts to convert v to map[string]interface{}.
150// Unlike v.(map[string]interface{}), this function works on named types that
151// are convertible to map[string]interface{} as well.
152func convertMapStringInterface(v interface{}) (map[string]interface{}, bool) {
153	var m map[string]interface{}
154	mtype := reflect.TypeOf(m)
155	t := reflect.TypeOf(v)
156	if !t.ConvertibleTo(mtype) {
157		return nil, false
158	}
159	return reflect.ValueOf(v).Convert(mtype).Interface().(map[string]interface{}), true
160
161}
162
163func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
164	if maparg, ok := convertMapStringInterface(arg); ok {
165		return bindMapArgs(names, maparg)
166	}
167	return bindArgs(names, arg, m)
168}
169
170// private interface to generate a list of interfaces from a given struct
171// type, given a list of names to pull out of the struct.  Used by public
172// BindStruct interface.
173func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
174	arglist := make([]interface{}, 0, len(names))
175
176	// grab the indirected value of arg
177	v := reflect.ValueOf(arg)
178	for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; {
179		v = v.Elem()
180	}
181
182	err := m.TraversalsByNameFunc(v.Type(), names, func(i int, t []int) error {
183		if len(t) == 0 {
184			return fmt.Errorf("could not find name %s in %#v", names[i], arg)
185		}
186
187		val := reflectx.FieldByIndexesReadOnly(v, t)
188		arglist = append(arglist, val.Interface())
189
190		return nil
191	})
192
193	return arglist, err
194}
195
196// like bindArgs, but for maps.
197func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) {
198	arglist := make([]interface{}, 0, len(names))
199
200	for _, name := range names {
201		val, ok := arg[name]
202		if !ok {
203			return arglist, fmt.Errorf("could not find name %s in %#v", name, arg)
204		}
205		arglist = append(arglist, val)
206	}
207	return arglist, nil
208}
209
210// bindStruct binds a named parameter query with fields from a struct argument.
211// The rules for binding field names to parameter names follow the same
212// conventions as for StructScan, including obeying the `db` struct tags.
213func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
214	bound, names, err := compileNamedQuery([]byte(query), bindType)
215	if err != nil {
216		return "", []interface{}{}, err
217	}
218
219	arglist, err := bindAnyArgs(names, arg, m)
220	if err != nil {
221		return "", []interface{}{}, err
222	}
223
224	return bound, arglist, nil
225}
226
227var valuesReg = regexp.MustCompile(`\)\s*(?i)VALUES\s*\(`)
228
229func findMatchingClosingBracketIndex(s string) int {
230	count := 0
231	for i, ch := range s {
232		if ch == '(' {
233			count++
234		}
235		if ch == ')' {
236			count--
237			if count == 0 {
238				return i
239			}
240		}
241	}
242	return 0
243}
244
245func fixBound(bound string, loop int) string {
246	loc := valuesReg.FindStringIndex(bound)
247	// defensive guard when "VALUES (...)" not found
248	if len(loc) < 2 {
249		return bound
250	}
251
252	openingBracketIndex := loc[1] - 1
253	index := findMatchingClosingBracketIndex(bound[openingBracketIndex:])
254	// defensive guard. must have closing bracket
255	if index == 0 {
256		return bound
257	}
258	closingBracketIndex := openingBracketIndex + index + 1
259
260	var buffer bytes.Buffer
261
262	buffer.WriteString(bound[0:closingBracketIndex])
263	for i := 0; i < loop-1; i++ {
264		buffer.WriteString(",")
265		buffer.WriteString(bound[openingBracketIndex:closingBracketIndex])
266	}
267	buffer.WriteString(bound[closingBracketIndex:])
268	return buffer.String()
269}
270
271// bindArray binds a named parameter query with fields from an array or slice of
272// structs argument.
273func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
274	// do the initial binding with QUESTION;  if bindType is not question,
275	// we can rebind it at the end.
276	bound, names, err := compileNamedQuery([]byte(query), QUESTION)
277	if err != nil {
278		return "", []interface{}{}, err
279	}
280	arrayValue := reflect.ValueOf(arg)
281	arrayLen := arrayValue.Len()
282	if arrayLen == 0 {
283		return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg)
284	}
285	var arglist = make([]interface{}, 0, len(names)*arrayLen)
286	for i := 0; i < arrayLen; i++ {
287		elemArglist, err := bindAnyArgs(names, arrayValue.Index(i).Interface(), m)
288		if err != nil {
289			return "", []interface{}{}, err
290		}
291		arglist = append(arglist, elemArglist...)
292	}
293	if arrayLen > 1 {
294		bound = fixBound(bound, arrayLen)
295	}
296	// adjust binding type if we weren't on question
297	if bindType != QUESTION {
298		bound = Rebind(bindType, bound)
299	}
300	return bound, arglist, nil
301}
302
303// bindMap binds a named parameter query with a map of arguments.
304func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) {
305	bound, names, err := compileNamedQuery([]byte(query), bindType)
306	if err != nil {
307		return "", []interface{}{}, err
308	}
309
310	arglist, err := bindMapArgs(names, args)
311	return bound, arglist, err
312}
313
314// -- Compilation of Named Queries
315
316// Allow digits and letters in bind params;  additionally runes are
317// checked against underscores, meaning that bind params can have be
318// alphanumeric with underscores.  Mind the difference between unicode
319// digits and numbers, where '5' is a digit but '五' is not.
320var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit}
321
322// FIXME: this function isn't safe for unicode named params, as a failing test
323// can testify.  This is not a regression but a failure of the original code
324// as well.  It should be modified to range over runes in a string rather than
325// bytes, even though this is less convenient and slower.  Hopefully the
326// addition of the prepared NamedStmt (which will only do this once) will make
327// up for the slightly slower ad-hoc NamedExec/NamedQuery.
328
329// compile a NamedQuery into an unbound query (using the '?' bindvar) and
330// a list of names.
331func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
332	names = make([]string, 0, 10)
333	rebound := make([]byte, 0, len(qs))
334
335	inName := false
336	last := len(qs) - 1
337	currentVar := 1
338	name := make([]byte, 0, 10)
339
340	for i, b := range qs {
341		// a ':' while we're in a name is an error
342		if b == ':' {
343			// if this is the second ':' in a '::' escape sequence, append a ':'
344			if inName && i > 0 && qs[i-1] == ':' {
345				rebound = append(rebound, ':')
346				inName = false
347				continue
348			} else if inName {
349				err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
350				return query, names, err
351			}
352			inName = true
353			name = []byte{}
354		} else if inName && i > 0 && b == '=' && len(name) == 0 {
355			rebound = append(rebound, ':', '=')
356			inName = false
357			continue
358			// if we're in a name, and this is an allowed character, continue
359		} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last {
360			// append the byte to the name if we are in a name and not on the last byte
361			name = append(name, b)
362			// if we're in a name and it's not an allowed character, the name is done
363		} else if inName {
364			inName = false
365			// if this is the final byte of the string and it is part of the name, then
366			// make sure to add it to the name
367			if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) {
368				name = append(name, b)
369			}
370			// add the string representation to the names list
371			names = append(names, string(name))
372			// add a proper bindvar for the bindType
373			switch bindType {
374			// oracle only supports named type bind vars even for positional
375			case NAMED:
376				rebound = append(rebound, ':')
377				rebound = append(rebound, name...)
378			case QUESTION, UNKNOWN:
379				rebound = append(rebound, '?')
380			case DOLLAR:
381				rebound = append(rebound, '$')
382				for _, b := range strconv.Itoa(currentVar) {
383					rebound = append(rebound, byte(b))
384				}
385				currentVar++
386			case AT:
387				rebound = append(rebound, '@', 'p')
388				for _, b := range strconv.Itoa(currentVar) {
389					rebound = append(rebound, byte(b))
390				}
391				currentVar++
392			}
393			// add this byte to string unless it was not part of the name
394			if i != last {
395				rebound = append(rebound, b)
396			} else if !unicode.IsOneOf(allowedBindRunes, rune(b)) {
397				rebound = append(rebound, b)
398			}
399		} else {
400			// this is a normal byte and should just go onto the rebound query
401			rebound = append(rebound, b)
402		}
403	}
404
405	return string(rebound), names, err
406}
407
408// BindNamed binds a struct or a map to a query with named parameters.
409// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future.
410func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) {
411	return bindNamedMapper(bindType, query, arg, mapper())
412}
413
414// Named takes a query using named parameters and an argument and
415// returns a new query with a list of args that can be executed by
416// a database.  The return value uses the `?` bindvar.
417func Named(query string, arg interface{}) (string, []interface{}, error) {
418	return bindNamedMapper(QUESTION, query, arg, mapper())
419}
420
421func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
422	t := reflect.TypeOf(arg)
423	k := t.Kind()
424	switch {
425	case k == reflect.Map && t.Key().Kind() == reflect.String:
426		m, ok := convertMapStringInterface(arg)
427		if !ok {
428			return "", nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg)
429		}
430		return bindMap(bindType, query, m)
431	case k == reflect.Array || k == reflect.Slice:
432		return bindArray(bindType, query, arg, m)
433	default:
434		return bindStruct(bindType, query, arg, m)
435	}
436}
437
438// NamedQuery binds a named query and then runs Query on the result using the
439// provided Ext (sqlx.Tx, sqlx.Db).  It works with both structs and with
440// map[string]interface{} types.
441func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) {
442	q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
443	if err != nil {
444		return nil, err
445	}
446	return e.Queryx(q, args...)
447}
448
449// NamedExec uses BindStruct to get a query executable by the driver and
450// then runs Exec on the result.  Returns an error from the binding
451// or the query execution itself.
452func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) {
453	q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
454	if err != nil {
455		return nil, err
456	}
457	return e.Exec(q, args...)
458}
459