1package sqlx 2 3// Named Query Support 4// 5// * BindMap - bind query bindvars to map/struct args 6// * NamedExec, NamedQuery - named query w/ struct or map 7// * NamedStmt - a pre-compiled named query which is a prepared statement 8// 9// Internal Interfaces: 10// 11// * compileNamedQuery - rebind a named query, returning a query and list of names 12// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist 13// 14import ( 15 "bytes" 16 "database/sql" 17 "errors" 18 "fmt" 19 "reflect" 20 "regexp" 21 "strconv" 22 "unicode" 23 24 "github.com/jmoiron/sqlx/reflectx" 25) 26 27// NamedStmt is a prepared statement that executes named queries. Prepare it 28// how you would execute a NamedQuery, but pass in a struct or map when executing. 29type NamedStmt struct { 30 Params []string 31 QueryString string 32 Stmt *Stmt 33} 34 35// Close closes the named statement. 36func (n *NamedStmt) Close() error { 37 return n.Stmt.Close() 38} 39 40// Exec executes a named statement using the struct passed. 41// Any named placeholder parameters are replaced with fields from arg. 42func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { 43 args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 44 if err != nil { 45 return *new(sql.Result), err 46 } 47 return n.Stmt.Exec(args...) 48} 49 50// Query executes a named statement using the struct argument, returning rows. 51// Any named placeholder parameters are replaced with fields from arg. 52func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { 53 args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 54 if err != nil { 55 return nil, err 56 } 57 return n.Stmt.Query(args...) 58} 59 60// QueryRow executes a named statement against the database. Because sqlx cannot 61// create a *sql.Row with an error condition pre-set for binding errors, sqlx 62// returns a *sqlx.Row instead. 63// Any named placeholder parameters are replaced with fields from arg. 64func (n *NamedStmt) QueryRow(arg interface{}) *Row { 65 args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 66 if err != nil { 67 return &Row{err: err} 68 } 69 return n.Stmt.QueryRowx(args...) 70} 71 72// MustExec execs a NamedStmt, panicing on error 73// Any named placeholder parameters are replaced with fields from arg. 74func (n *NamedStmt) MustExec(arg interface{}) sql.Result { 75 res, err := n.Exec(arg) 76 if err != nil { 77 panic(err) 78 } 79 return res 80} 81 82// Queryx using this NamedStmt 83// Any named placeholder parameters are replaced with fields from arg. 84func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { 85 r, err := n.Query(arg) 86 if err != nil { 87 return nil, err 88 } 89 return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err 90} 91 92// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is 93// an alias for QueryRow. 94// Any named placeholder parameters are replaced with fields from arg. 95func (n *NamedStmt) QueryRowx(arg interface{}) *Row { 96 return n.QueryRow(arg) 97} 98 99// Select using this NamedStmt 100// Any named placeholder parameters are replaced with fields from arg. 101func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { 102 rows, err := n.Queryx(arg) 103 if err != nil { 104 return err 105 } 106 // if something happens here, we want to make sure the rows are Closed 107 defer rows.Close() 108 return scanAll(rows, dest, false) 109} 110 111// Get using this NamedStmt 112// Any named placeholder parameters are replaced with fields from arg. 113func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { 114 r := n.QueryRowx(arg) 115 return r.scanAny(dest, false) 116} 117 118// Unsafe creates an unsafe version of the NamedStmt 119func (n *NamedStmt) Unsafe() *NamedStmt { 120 r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString} 121 r.Stmt.unsafe = true 122 return r 123} 124 125// A union interface of preparer and binder, required to be able to prepare 126// named statements (as the bindtype must be determined). 127type namedPreparer interface { 128 Preparer 129 binder 130} 131 132func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { 133 bindType := BindType(p.DriverName()) 134 q, args, err := compileNamedQuery([]byte(query), bindType) 135 if err != nil { 136 return nil, err 137 } 138 stmt, err := Preparex(p, q) 139 if err != nil { 140 return nil, err 141 } 142 return &NamedStmt{ 143 QueryString: q, 144 Params: args, 145 Stmt: stmt, 146 }, nil 147} 148 149// convertMapStringInterface attempts to convert v to map[string]interface{}. 150// Unlike v.(map[string]interface{}), this function works on named types that 151// are convertible to map[string]interface{} as well. 152func convertMapStringInterface(v interface{}) (map[string]interface{}, bool) { 153 var m map[string]interface{} 154 mtype := reflect.TypeOf(m) 155 t := reflect.TypeOf(v) 156 if !t.ConvertibleTo(mtype) { 157 return nil, false 158 } 159 return reflect.ValueOf(v).Convert(mtype).Interface().(map[string]interface{}), true 160 161} 162 163func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { 164 if maparg, ok := convertMapStringInterface(arg); ok { 165 return bindMapArgs(names, maparg) 166 } 167 return bindArgs(names, arg, m) 168} 169 170// private interface to generate a list of interfaces from a given struct 171// type, given a list of names to pull out of the struct. Used by public 172// BindStruct interface. 173func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { 174 arglist := make([]interface{}, 0, len(names)) 175 176 // grab the indirected value of arg 177 v := reflect.ValueOf(arg) 178 for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { 179 v = v.Elem() 180 } 181 182 err := m.TraversalsByNameFunc(v.Type(), names, func(i int, t []int) error { 183 if len(t) == 0 { 184 return fmt.Errorf("could not find name %s in %#v", names[i], arg) 185 } 186 187 val := reflectx.FieldByIndexesReadOnly(v, t) 188 arglist = append(arglist, val.Interface()) 189 190 return nil 191 }) 192 193 return arglist, err 194} 195 196// like bindArgs, but for maps. 197func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) { 198 arglist := make([]interface{}, 0, len(names)) 199 200 for _, name := range names { 201 val, ok := arg[name] 202 if !ok { 203 return arglist, fmt.Errorf("could not find name %s in %#v", name, arg) 204 } 205 arglist = append(arglist, val) 206 } 207 return arglist, nil 208} 209 210// bindStruct binds a named parameter query with fields from a struct argument. 211// The rules for binding field names to parameter names follow the same 212// conventions as for StructScan, including obeying the `db` struct tags. 213func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { 214 bound, names, err := compileNamedQuery([]byte(query), bindType) 215 if err != nil { 216 return "", []interface{}{}, err 217 } 218 219 arglist, err := bindAnyArgs(names, arg, m) 220 if err != nil { 221 return "", []interface{}{}, err 222 } 223 224 return bound, arglist, nil 225} 226 227var valuesReg = regexp.MustCompile(`\)\s*(?i)VALUES\s*\(`) 228 229func findMatchingClosingBracketIndex(s string) int { 230 count := 0 231 for i, ch := range s { 232 if ch == '(' { 233 count++ 234 } 235 if ch == ')' { 236 count-- 237 if count == 0 { 238 return i 239 } 240 } 241 } 242 return 0 243} 244 245func fixBound(bound string, loop int) string { 246 loc := valuesReg.FindStringIndex(bound) 247 // defensive guard when "VALUES (...)" not found 248 if len(loc) < 2 { 249 return bound 250 } 251 252 openingBracketIndex := loc[1] - 1 253 index := findMatchingClosingBracketIndex(bound[openingBracketIndex:]) 254 // defensive guard. must have closing bracket 255 if index == 0 { 256 return bound 257 } 258 closingBracketIndex := openingBracketIndex + index + 1 259 260 var buffer bytes.Buffer 261 262 buffer.WriteString(bound[0:closingBracketIndex]) 263 for i := 0; i < loop-1; i++ { 264 buffer.WriteString(",") 265 buffer.WriteString(bound[openingBracketIndex:closingBracketIndex]) 266 } 267 buffer.WriteString(bound[closingBracketIndex:]) 268 return buffer.String() 269} 270 271// bindArray binds a named parameter query with fields from an array or slice of 272// structs argument. 273func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { 274 // do the initial binding with QUESTION; if bindType is not question, 275 // we can rebind it at the end. 276 bound, names, err := compileNamedQuery([]byte(query), QUESTION) 277 if err != nil { 278 return "", []interface{}{}, err 279 } 280 arrayValue := reflect.ValueOf(arg) 281 arrayLen := arrayValue.Len() 282 if arrayLen == 0 { 283 return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg) 284 } 285 var arglist = make([]interface{}, 0, len(names)*arrayLen) 286 for i := 0; i < arrayLen; i++ { 287 elemArglist, err := bindAnyArgs(names, arrayValue.Index(i).Interface(), m) 288 if err != nil { 289 return "", []interface{}{}, err 290 } 291 arglist = append(arglist, elemArglist...) 292 } 293 if arrayLen > 1 { 294 bound = fixBound(bound, arrayLen) 295 } 296 // adjust binding type if we weren't on question 297 if bindType != QUESTION { 298 bound = Rebind(bindType, bound) 299 } 300 return bound, arglist, nil 301} 302 303// bindMap binds a named parameter query with a map of arguments. 304func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { 305 bound, names, err := compileNamedQuery([]byte(query), bindType) 306 if err != nil { 307 return "", []interface{}{}, err 308 } 309 310 arglist, err := bindMapArgs(names, args) 311 return bound, arglist, err 312} 313 314// -- Compilation of Named Queries 315 316// Allow digits and letters in bind params; additionally runes are 317// checked against underscores, meaning that bind params can have be 318// alphanumeric with underscores. Mind the difference between unicode 319// digits and numbers, where '5' is a digit but '五' is not. 320var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} 321 322// FIXME: this function isn't safe for unicode named params, as a failing test 323// can testify. This is not a regression but a failure of the original code 324// as well. It should be modified to range over runes in a string rather than 325// bytes, even though this is less convenient and slower. Hopefully the 326// addition of the prepared NamedStmt (which will only do this once) will make 327// up for the slightly slower ad-hoc NamedExec/NamedQuery. 328 329// compile a NamedQuery into an unbound query (using the '?' bindvar) and 330// a list of names. 331func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { 332 names = make([]string, 0, 10) 333 rebound := make([]byte, 0, len(qs)) 334 335 inName := false 336 last := len(qs) - 1 337 currentVar := 1 338 name := make([]byte, 0, 10) 339 340 for i, b := range qs { 341 // a ':' while we're in a name is an error 342 if b == ':' { 343 // if this is the second ':' in a '::' escape sequence, append a ':' 344 if inName && i > 0 && qs[i-1] == ':' { 345 rebound = append(rebound, ':') 346 inName = false 347 continue 348 } else if inName { 349 err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i)) 350 return query, names, err 351 } 352 inName = true 353 name = []byte{} 354 } else if inName && i > 0 && b == '=' && len(name) == 0 { 355 rebound = append(rebound, ':', '=') 356 inName = false 357 continue 358 // if we're in a name, and this is an allowed character, continue 359 } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { 360 // append the byte to the name if we are in a name and not on the last byte 361 name = append(name, b) 362 // if we're in a name and it's not an allowed character, the name is done 363 } else if inName { 364 inName = false 365 // if this is the final byte of the string and it is part of the name, then 366 // make sure to add it to the name 367 if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { 368 name = append(name, b) 369 } 370 // add the string representation to the names list 371 names = append(names, string(name)) 372 // add a proper bindvar for the bindType 373 switch bindType { 374 // oracle only supports named type bind vars even for positional 375 case NAMED: 376 rebound = append(rebound, ':') 377 rebound = append(rebound, name...) 378 case QUESTION, UNKNOWN: 379 rebound = append(rebound, '?') 380 case DOLLAR: 381 rebound = append(rebound, '$') 382 for _, b := range strconv.Itoa(currentVar) { 383 rebound = append(rebound, byte(b)) 384 } 385 currentVar++ 386 case AT: 387 rebound = append(rebound, '@', 'p') 388 for _, b := range strconv.Itoa(currentVar) { 389 rebound = append(rebound, byte(b)) 390 } 391 currentVar++ 392 } 393 // add this byte to string unless it was not part of the name 394 if i != last { 395 rebound = append(rebound, b) 396 } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { 397 rebound = append(rebound, b) 398 } 399 } else { 400 // this is a normal byte and should just go onto the rebound query 401 rebound = append(rebound, b) 402 } 403 } 404 405 return string(rebound), names, err 406} 407 408// BindNamed binds a struct or a map to a query with named parameters. 409// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future. 410func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) { 411 return bindNamedMapper(bindType, query, arg, mapper()) 412} 413 414// Named takes a query using named parameters and an argument and 415// returns a new query with a list of args that can be executed by 416// a database. The return value uses the `?` bindvar. 417func Named(query string, arg interface{}) (string, []interface{}, error) { 418 return bindNamedMapper(QUESTION, query, arg, mapper()) 419} 420 421func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { 422 t := reflect.TypeOf(arg) 423 k := t.Kind() 424 switch { 425 case k == reflect.Map && t.Key().Kind() == reflect.String: 426 m, ok := convertMapStringInterface(arg) 427 if !ok { 428 return "", nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg) 429 } 430 return bindMap(bindType, query, m) 431 case k == reflect.Array || k == reflect.Slice: 432 return bindArray(bindType, query, arg, m) 433 default: 434 return bindStruct(bindType, query, arg, m) 435 } 436} 437 438// NamedQuery binds a named query and then runs Query on the result using the 439// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with 440// map[string]interface{} types. 441func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { 442 q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) 443 if err != nil { 444 return nil, err 445 } 446 return e.Queryx(q, args...) 447} 448 449// NamedExec uses BindStruct to get a query executable by the driver and 450// then runs Exec on the result. Returns an error from the binding 451// or the query execution itself. 452func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) { 453 q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) 454 if err != nil { 455 return nil, err 456 } 457 return e.Exec(q, args...) 458} 459