1// Copyright 2012 James Cooper. All rights reserved. 2// Use of this source code is governed by a MIT-style 3// license that can be found in the LICENSE file. 4 5// Package gorp provides a simple way to marshal Go structs to and from 6// SQL databases. It uses the database/sql package, and should work with any 7// compliant database/sql driver. 8// 9// Source code and project home: 10// https://github.com/go-gorp/gorp 11// 12package gorp 13 14import ( 15 "context" 16 "database/sql" 17 "database/sql/driver" 18 "fmt" 19 "reflect" 20 "regexp" 21 "strings" 22 "time" 23) 24 25// OracleString (empty string is null) 26// TODO: move to dialect/oracle?, rename to String? 27type OracleString struct { 28 sql.NullString 29} 30 31// Scan implements the Scanner interface. 32func (os *OracleString) Scan(value interface{}) error { 33 if value == nil { 34 os.String, os.Valid = "", false 35 return nil 36 } 37 os.Valid = true 38 return os.NullString.Scan(value) 39} 40 41// Value implements the driver Valuer interface. 42func (os OracleString) Value() (driver.Value, error) { 43 if !os.Valid || os.String == "" { 44 return nil, nil 45 } 46 return os.String, nil 47} 48 49// SqlTyper is a type that returns its database type. Most of the 50// time, the type can just use "database/sql/driver".Valuer; but when 51// it returns nil for its empty value, it needs to implement SqlTyper 52// to have its column type detected properly during table creation. 53type SqlTyper interface { 54 SqlType() driver.Valuer 55} 56 57// for fields that exists in DB table, but not exists in struct 58type dummyField struct{} 59 60// Scan implements the Scanner interface. 61func (nt *dummyField) Scan(value interface{}) error { 62 return nil 63} 64 65var zeroVal reflect.Value 66var versFieldConst = "[gorp_ver_field]" 67 68// The TypeConverter interface provides a way to map a value of one 69// type to another type when persisting to, or loading from, a database. 70// 71// Example use cases: Implement type converter to convert bool types to "y"/"n" strings, 72// or serialize a struct member as a JSON blob. 73type TypeConverter interface { 74 // ToDb converts val to another type. Called before INSERT/UPDATE operations 75 ToDb(val interface{}) (interface{}, error) 76 77 // FromDb returns a CustomScanner appropriate for this type. This will be used 78 // to hold values returned from SELECT queries. 79 // 80 // In particular the CustomScanner returned should implement a Binder 81 // function appropriate for the Go type you wish to convert the db value to 82 // 83 // If bool==false, then no custom scanner will be used for this field. 84 FromDb(target interface{}) (CustomScanner, bool) 85} 86 87// Executor exposes the sql.DB and sql.Tx Exec function so that it can be used 88// on internal functions that convert named parameters for the Exec function. 89type executor interface { 90 Exec(query string, args ...interface{}) (sql.Result, error) 91 ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 92} 93 94// SqlExecutor exposes gorp operations that can be run from Pre/Post 95// hooks. This hides whether the current operation that triggered the 96// hook is in a transaction. 97// 98// See the DbMap function docs for each of the functions below for more 99// information. 100type SqlExecutor interface { 101 Get(i interface{}, keys ...interface{}) (interface{}, error) 102 Insert(list ...interface{}) error 103 Update(list ...interface{}) (int64, error) 104 Delete(list ...interface{}) (int64, error) 105 Exec(query string, args ...interface{}) (sql.Result, error) 106 ExecNoTimeout(query string, args ...interface{}) (sql.Result, error) 107 Select(i interface{}, query string, 108 args ...interface{}) ([]interface{}, error) 109 SelectInt(query string, args ...interface{}) (int64, error) 110 SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) 111 SelectFloat(query string, args ...interface{}) (float64, error) 112 SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) 113 SelectStr(query string, args ...interface{}) (string, error) 114 SelectNullStr(query string, args ...interface{}) (sql.NullString, error) 115 SelectOne(holder interface{}, query string, args ...interface{}) error 116 Query(query string, args ...interface{}) (*sql.Rows, error) 117 QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 118 QueryRow(query string, args ...interface{}) *sql.Row 119 QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 120} 121 122// DynamicTable allows the users of gorp to dynamically 123// use different database table names during runtime 124// while sharing the same golang struct for in-memory data 125type DynamicTable interface { 126 TableName() string 127 SetTableName(string) 128} 129 130// Compile-time check that DbMap and Transaction implement the SqlExecutor 131// interface. 132var _, _ SqlExecutor = &DbMap{}, &Transaction{} 133 134func argsString(args ...interface{}) string { 135 var margs string 136 for i, a := range args { 137 var v interface{} = a 138 if x, ok := v.(driver.Valuer); ok { 139 y, err := x.Value() 140 if err == nil { 141 v = y 142 } 143 } 144 switch v.(type) { 145 case string: 146 v = fmt.Sprintf("%q", v) 147 default: 148 v = fmt.Sprintf("%v", v) 149 } 150 margs += fmt.Sprintf("%d:%s", i+1, v) 151 if i+1 < len(args) { 152 margs += " " 153 } 154 } 155 return margs 156} 157 158// Calls the Exec function on the executor, but attempts to expand any eligible named 159// query arguments first. 160func exec(e SqlExecutor, query string, doTimeout bool, args ...interface{}) (sql.Result, error) { 161 var dbMap *DbMap 162 var executor executor 163 switch m := e.(type) { 164 case *DbMap: 165 executor = m.Db 166 dbMap = m 167 case *Transaction: 168 executor = m.tx 169 dbMap = m.dbmap 170 } 171 172 if len(args) == 1 { 173 query, args = maybeExpandNamedQuery(dbMap, query, args) 174 } 175 176 if doTimeout { 177 ctx, cancel := context.WithTimeout(context.Background(), dbMap.QueryTimeout) 178 defer cancel() 179 return executor.ExecContext(ctx, query, args...) 180 } 181 182 return executor.Exec(query, args...) 183} 184 185// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used 186// as input to a named query. If so, it rewrites the query to use 187// dialect-dependent bindvars and instantiates the corresponding slice of 188// parameters by extracting data from the map / struct. 189// If not, returns the input values unchanged. 190func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) { 191 var ( 192 arg = args[0] 193 argval = reflect.ValueOf(arg) 194 ) 195 if argval.Kind() == reflect.Ptr { 196 argval = argval.Elem() 197 } 198 199 if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String { 200 return expandNamedQuery(m, query, func(key string) reflect.Value { 201 return argval.MapIndex(reflect.ValueOf(key)) 202 }) 203 } 204 if argval.Kind() != reflect.Struct { 205 return query, args 206 } 207 if _, ok := arg.(time.Time); ok { 208 // time.Time is driver.Value 209 return query, args 210 } 211 if _, ok := arg.(driver.Valuer); ok { 212 // driver.Valuer will be converted to driver.Value. 213 return query, args 214 } 215 216 return expandNamedQuery(m, query, argval.FieldByName) 217} 218 219var keyRegexp = regexp.MustCompile(`:[[:word:]]+`) 220 221// expandNamedQuery accepts a query with placeholders of the form ":key", and a 222// single arg of Kind Struct or Map[string]. It returns the query with the 223// dialect's placeholders, and a slice of args ready for positional insertion 224// into the query. 225func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) { 226 var ( 227 n int 228 args []interface{} 229 ) 230 return keyRegexp.ReplaceAllStringFunc(query, func(key string) string { 231 val := keyGetter(key[1:]) 232 if !val.IsValid() { 233 return key 234 } 235 args = append(args, val.Interface()) 236 newVar := m.Dialect.BindVar(n) 237 n++ 238 return newVar 239 }), args 240} 241 242func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) { 243 colToFieldIndex := make([][]int, len(cols)) 244 245 // check if type t is a mapped table - if so we'll 246 // check the table for column aliasing below 247 tableMapped := false 248 table := tableOrNil(m, t, name) 249 if table != nil { 250 tableMapped = true 251 } 252 253 // Loop over column names and find field in i to bind to 254 // based on column name. all returned columns must match 255 // a field in the i struct 256 missingColNames := []string{} 257 for x := range cols { 258 colName := strings.ToLower(cols[x]) 259 field, found := t.FieldByNameFunc(func(fieldName string) bool { 260 field, _ := t.FieldByName(fieldName) 261 cArguments := strings.Split(field.Tag.Get("db"), ",") 262 fieldName = cArguments[0] 263 264 if fieldName == "" || fieldName == "-" { 265 fieldName = field.Name 266 } 267 if tableMapped { 268 colMap := colMapOrNil(table, fieldName) 269 if colMap != nil && colMap.ColumnName != "-" { 270 fieldName = colMap.ColumnName 271 } 272 } 273 return colName == strings.ToLower(fieldName) 274 }) 275 if found { 276 colToFieldIndex[x] = field.Index 277 } 278 if colToFieldIndex[x] == nil { 279 missingColNames = append(missingColNames, colName) 280 } 281 } 282 if len(missingColNames) > 0 { 283 return colToFieldIndex, &NoFieldInTypeError{ 284 TypeName: t.Name(), 285 MissingColNames: missingColNames, 286 } 287 } 288 return colToFieldIndex, nil 289} 290 291func fieldByName(val reflect.Value, fieldName string) *reflect.Value { 292 // try to find field by exact match 293 f := val.FieldByName(fieldName) 294 295 if f != zeroVal { 296 return &f 297 } 298 299 // try to find by case insensitive match - only the Postgres driver 300 // seems to require this - in the case where columns are aliased in the sql 301 fieldNameL := strings.ToLower(fieldName) 302 fieldCount := val.NumField() 303 t := val.Type() 304 for i := 0; i < fieldCount; i++ { 305 sf := t.Field(i) 306 if strings.ToLower(sf.Name) == fieldNameL { 307 f := val.Field(i) 308 return &f 309 } 310 } 311 312 return nil 313} 314 315// toSliceType returns the element type of the given object, if the object is a 316// "*[]*Element" or "*[]Element". If not, returns nil. 317// err is returned if the user was trying to pass a pointer-to-slice but failed. 318func toSliceType(i interface{}) (reflect.Type, error) { 319 t := reflect.TypeOf(i) 320 if t.Kind() != reflect.Ptr { 321 // If it's a slice, return a more helpful error message 322 if t.Kind() == reflect.Slice { 323 return nil, fmt.Errorf("gorp: cannot SELECT into a non-pointer slice: %v", t) 324 } 325 return nil, nil 326 } 327 if t = t.Elem(); t.Kind() != reflect.Slice { 328 return nil, nil 329 } 330 return t.Elem(), nil 331} 332 333func toType(i interface{}) (reflect.Type, error) { 334 t := reflect.TypeOf(i) 335 336 // If a Pointer to a type, follow 337 for t.Kind() == reflect.Ptr { 338 t = t.Elem() 339 } 340 341 if t.Kind() != reflect.Struct { 342 return nil, fmt.Errorf("gorp: cannot SELECT into this type: %v", reflect.TypeOf(i)) 343 } 344 return t, nil 345} 346 347type foundTable struct { 348 table *TableMap 349 dynName *string 350} 351 352func tableFor(m *DbMap, t reflect.Type, i interface{}) (*foundTable, error) { 353 if dyn, isDynamic := i.(DynamicTable); isDynamic { 354 tableName := dyn.TableName() 355 table, err := m.DynamicTableFor(tableName, true) 356 if err != nil { 357 return nil, err 358 } 359 return &foundTable{ 360 table: table, 361 dynName: &tableName, 362 }, nil 363 } 364 table, err := m.TableFor(t, true) 365 if err != nil { 366 return nil, err 367 } 368 return &foundTable{table: table}, nil 369} 370 371func get(m *DbMap, exec SqlExecutor, i interface{}, 372 keys ...interface{}) (interface{}, error) { 373 374 t, err := toType(i) 375 if err != nil { 376 return nil, err 377 } 378 379 foundTable, err := tableFor(m, t, i) 380 if err != nil { 381 return nil, err 382 } 383 table := foundTable.table 384 385 plan := table.bindGet() 386 387 v := reflect.New(t) 388 if foundTable.dynName != nil { 389 retDyn := v.Interface().(DynamicTable) 390 retDyn.SetTableName(*foundTable.dynName) 391 } 392 393 dest := make([]interface{}, len(plan.argFields)) 394 395 conv := m.TypeConverter 396 custScan := make([]CustomScanner, 0) 397 398 for x, fieldName := range plan.argFields { 399 f := v.Elem().FieldByName(fieldName) 400 target := f.Addr().Interface() 401 if conv != nil { 402 scanner, ok := conv.FromDb(target) 403 if ok { 404 target = scanner.Holder 405 custScan = append(custScan, scanner) 406 } 407 } 408 dest[x] = target 409 } 410 411 ctx, cancel := context.WithTimeout(context.Background(), m.QueryTimeout) 412 defer cancel() 413 row := exec.QueryRowContext(ctx, plan.query, keys...) 414 415 err = row.Scan(dest...) 416 if err != nil { 417 if err == sql.ErrNoRows { 418 err = nil 419 } 420 return nil, err 421 } 422 423 for _, c := range custScan { 424 err = c.Bind() 425 if err != nil { 426 return nil, err 427 } 428 } 429 430 if v, ok := v.Interface().(HasPostGet); ok { 431 err := v.PostGet(exec) 432 if err != nil { 433 return nil, err 434 } 435 } 436 437 return v.Interface(), nil 438} 439 440func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { 441 count := int64(0) 442 for _, ptr := range list { 443 table, elem, err := m.tableForPointer(ptr, true) 444 if err != nil { 445 return -1, err 446 } 447 448 eval := elem.Addr().Interface() 449 if v, ok := eval.(HasPreDelete); ok { 450 err = v.PreDelete(exec) 451 if err != nil { 452 return -1, err 453 } 454 } 455 456 bi, err := table.bindDelete(elem) 457 if err != nil { 458 return -1, err 459 } 460 461 res, err := exec.Exec(bi.query, bi.args...) 462 if err != nil { 463 return -1, err 464 } 465 rows, err := res.RowsAffected() 466 if err != nil { 467 return -1, err 468 } 469 470 if rows == 0 && bi.existingVersion > 0 { 471 return lockError(m, exec, table.TableName, 472 bi.existingVersion, elem, bi.keys...) 473 } 474 475 count += rows 476 477 if v, ok := eval.(HasPostDelete); ok { 478 err := v.PostDelete(exec) 479 if err != nil { 480 return -1, err 481 } 482 } 483 } 484 485 return count, nil 486} 487 488func update(m *DbMap, exec SqlExecutor, colFilter ColumnFilter, list ...interface{}) (int64, error) { 489 count := int64(0) 490 for _, ptr := range list { 491 table, elem, err := m.tableForPointer(ptr, true) 492 if err != nil { 493 return -1, err 494 } 495 496 eval := elem.Addr().Interface() 497 if v, ok := eval.(HasPreUpdate); ok { 498 err = v.PreUpdate(exec) 499 if err != nil { 500 return -1, err 501 } 502 } 503 504 bi, err := table.bindUpdate(elem, colFilter) 505 if err != nil { 506 return -1, err 507 } 508 509 res, err := exec.Exec(bi.query, bi.args...) 510 if err != nil { 511 return -1, err 512 } 513 514 rows, err := res.RowsAffected() 515 if err != nil { 516 return -1, err 517 } 518 519 if rows == 0 && bi.existingVersion > 0 { 520 return lockError(m, exec, table.TableName, 521 bi.existingVersion, elem, bi.keys...) 522 } 523 524 if bi.versField != "" { 525 elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1) 526 } 527 528 count += rows 529 530 if v, ok := eval.(HasPostUpdate); ok { 531 err = v.PostUpdate(exec) 532 if err != nil { 533 return -1, err 534 } 535 } 536 } 537 return count, nil 538} 539 540func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { 541 for _, ptr := range list { 542 table, elem, err := m.tableForPointer(ptr, false) 543 if err != nil { 544 return err 545 } 546 547 eval := elem.Addr().Interface() 548 if v, ok := eval.(HasPreInsert); ok { 549 err := v.PreInsert(exec) 550 if err != nil { 551 return err 552 } 553 } 554 555 bi, err := table.bindInsert(elem) 556 if err != nil { 557 return err 558 } 559 560 if bi.autoIncrIdx > -1 { 561 f := elem.FieldByName(bi.autoIncrFieldName) 562 switch inserter := m.Dialect.(type) { 563 case IntegerAutoIncrInserter: 564 id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...) 565 if err != nil { 566 return err 567 } 568 k := f.Kind() 569 if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { 570 f.SetInt(id) 571 } else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) { 572 f.SetUint(uint64(id)) 573 } else { 574 return fmt.Errorf("gorp: cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName) 575 } 576 case TargetedAutoIncrInserter: 577 err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...) 578 if err != nil { 579 return err 580 } 581 case TargetQueryInserter: 582 var idQuery = table.ColMap(bi.autoIncrFieldName).GeneratedIdQuery 583 if idQuery == "" { 584 return fmt.Errorf("gorp: cannot set %s value if its ColumnMap.GeneratedIdQuery is empty", bi.autoIncrFieldName) 585 } 586 err := inserter.InsertQueryToTarget(exec, bi.query, idQuery, f.Addr().Interface(), bi.args...) 587 if err != nil { 588 return err 589 } 590 default: 591 return fmt.Errorf("gorp: cannot use autoincrement fields on dialects that do not implement an autoincrementing interface") 592 } 593 } else { 594 _, err := exec.Exec(bi.query, bi.args...) 595 if err != nil { 596 return err 597 } 598 } 599 600 if v, ok := eval.(HasPostInsert); ok { 601 err := v.PostInsert(exec) 602 if err != nil { 603 return err 604 } 605 } 606 } 607 return nil 608} 609