1// Copyright 2012 James Cooper. All rights reserved.
2// Use of this source code is governed by a MIT-style
3// license that can be found in the LICENSE file.
4
5// Package gorp provides a simple way to marshal Go structs to and from
6// SQL databases.  It uses the database/sql package, and should work with any
7// compliant database/sql driver.
8//
9// Source code and project home:
10// https://github.com/go-gorp/gorp
11//
12package gorp
13
14import (
15	"context"
16	"database/sql"
17	"database/sql/driver"
18	"fmt"
19	"reflect"
20	"regexp"
21	"strings"
22	"time"
23)
24
25// OracleString (empty string is null)
26// TODO: move to dialect/oracle?, rename to String?
27type OracleString struct {
28	sql.NullString
29}
30
31// Scan implements the Scanner interface.
32func (os *OracleString) Scan(value interface{}) error {
33	if value == nil {
34		os.String, os.Valid = "", false
35		return nil
36	}
37	os.Valid = true
38	return os.NullString.Scan(value)
39}
40
41// Value implements the driver Valuer interface.
42func (os OracleString) Value() (driver.Value, error) {
43	if !os.Valid || os.String == "" {
44		return nil, nil
45	}
46	return os.String, nil
47}
48
49// SqlTyper is a type that returns its database type.  Most of the
50// time, the type can just use "database/sql/driver".Valuer; but when
51// it returns nil for its empty value, it needs to implement SqlTyper
52// to have its column type detected properly during table creation.
53type SqlTyper interface {
54	SqlType() driver.Valuer
55}
56
57// for fields that exists in DB table, but not exists in struct
58type dummyField struct{}
59
60// Scan implements the Scanner interface.
61func (nt *dummyField) Scan(value interface{}) error {
62	return nil
63}
64
65var zeroVal reflect.Value
66var versFieldConst = "[gorp_ver_field]"
67
68// The TypeConverter interface provides a way to map a value of one
69// type to another type when persisting to, or loading from, a database.
70//
71// Example use cases: Implement type converter to convert bool types to "y"/"n" strings,
72// or serialize a struct member as a JSON blob.
73type TypeConverter interface {
74	// ToDb converts val to another type. Called before INSERT/UPDATE operations
75	ToDb(val interface{}) (interface{}, error)
76
77	// FromDb returns a CustomScanner appropriate for this type. This will be used
78	// to hold values returned from SELECT queries.
79	//
80	// In particular the CustomScanner returned should implement a Binder
81	// function appropriate for the Go type you wish to convert the db value to
82	//
83	// If bool==false, then no custom scanner will be used for this field.
84	FromDb(target interface{}) (CustomScanner, bool)
85}
86
87// Executor exposes the sql.DB and sql.Tx Exec function so that it can be used
88// on internal functions that convert named parameters for the Exec function.
89type executor interface {
90	Exec(query string, args ...interface{}) (sql.Result, error)
91	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
92}
93
94// SqlExecutor exposes gorp operations that can be run from Pre/Post
95// hooks.  This hides whether the current operation that triggered the
96// hook is in a transaction.
97//
98// See the DbMap function docs for each of the functions below for more
99// information.
100type SqlExecutor interface {
101	Get(i interface{}, keys ...interface{}) (interface{}, error)
102	Insert(list ...interface{}) error
103	Update(list ...interface{}) (int64, error)
104	Delete(list ...interface{}) (int64, error)
105	Exec(query string, args ...interface{}) (sql.Result, error)
106	ExecNoTimeout(query string, args ...interface{}) (sql.Result, error)
107	Select(i interface{}, query string,
108		args ...interface{}) ([]interface{}, error)
109	SelectInt(query string, args ...interface{}) (int64, error)
110	SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error)
111	SelectFloat(query string, args ...interface{}) (float64, error)
112	SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error)
113	SelectStr(query string, args ...interface{}) (string, error)
114	SelectNullStr(query string, args ...interface{}) (sql.NullString, error)
115	SelectOne(holder interface{}, query string, args ...interface{}) error
116	Query(query string, args ...interface{}) (*sql.Rows, error)
117	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
118	QueryRow(query string, args ...interface{}) *sql.Row
119	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
120}
121
122// DynamicTable allows the users of gorp to dynamically
123// use different database table names during runtime
124// while sharing the same golang struct for in-memory data
125type DynamicTable interface {
126	TableName() string
127	SetTableName(string)
128}
129
130// Compile-time check that DbMap and Transaction implement the SqlExecutor
131// interface.
132var _, _ SqlExecutor = &DbMap{}, &Transaction{}
133
134func argsString(args ...interface{}) string {
135	var margs string
136	for i, a := range args {
137		var v interface{} = a
138		if x, ok := v.(driver.Valuer); ok {
139			y, err := x.Value()
140			if err == nil {
141				v = y
142			}
143		}
144		switch v.(type) {
145		case string:
146			v = fmt.Sprintf("%q", v)
147		default:
148			v = fmt.Sprintf("%v", v)
149		}
150		margs += fmt.Sprintf("%d:%s", i+1, v)
151		if i+1 < len(args) {
152			margs += " "
153		}
154	}
155	return margs
156}
157
158// Calls the Exec function on the executor, but attempts to expand any eligible named
159// query arguments first.
160func exec(e SqlExecutor, query string, doTimeout bool, args ...interface{}) (sql.Result, error) {
161	var dbMap *DbMap
162	var executor executor
163	switch m := e.(type) {
164	case *DbMap:
165		executor = m.Db
166		dbMap = m
167	case *Transaction:
168		executor = m.tx
169		dbMap = m.dbmap
170	}
171
172	if len(args) == 1 {
173		query, args = maybeExpandNamedQuery(dbMap, query, args)
174	}
175
176	if doTimeout {
177		ctx, cancel := context.WithTimeout(context.Background(), dbMap.QueryTimeout)
178		defer cancel()
179		return executor.ExecContext(ctx, query, args...)
180	}
181
182	return executor.Exec(query, args...)
183}
184
185// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used
186// as input to a named query.  If so, it rewrites the query to use
187// dialect-dependent bindvars and instantiates the corresponding slice of
188// parameters by extracting data from the map / struct.
189// If not, returns the input values unchanged.
190func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) {
191	var (
192		arg    = args[0]
193		argval = reflect.ValueOf(arg)
194	)
195	if argval.Kind() == reflect.Ptr {
196		argval = argval.Elem()
197	}
198
199	if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String {
200		return expandNamedQuery(m, query, func(key string) reflect.Value {
201			return argval.MapIndex(reflect.ValueOf(key))
202		})
203	}
204	if argval.Kind() != reflect.Struct {
205		return query, args
206	}
207	if _, ok := arg.(time.Time); ok {
208		// time.Time is driver.Value
209		return query, args
210	}
211	if _, ok := arg.(driver.Valuer); ok {
212		// driver.Valuer will be converted to driver.Value.
213		return query, args
214	}
215
216	return expandNamedQuery(m, query, argval.FieldByName)
217}
218
219var keyRegexp = regexp.MustCompile(`:[[:word:]]+`)
220
221// expandNamedQuery accepts a query with placeholders of the form ":key", and a
222// single arg of Kind Struct or Map[string].  It returns the query with the
223// dialect's placeholders, and a slice of args ready for positional insertion
224// into the query.
225func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) {
226	var (
227		n    int
228		args []interface{}
229	)
230	return keyRegexp.ReplaceAllStringFunc(query, func(key string) string {
231		val := keyGetter(key[1:])
232		if !val.IsValid() {
233			return key
234		}
235		args = append(args, val.Interface())
236		newVar := m.Dialect.BindVar(n)
237		n++
238		return newVar
239	}), args
240}
241
242func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) {
243	colToFieldIndex := make([][]int, len(cols))
244
245	// check if type t is a mapped table - if so we'll
246	// check the table for column aliasing below
247	tableMapped := false
248	table := tableOrNil(m, t, name)
249	if table != nil {
250		tableMapped = true
251	}
252
253	// Loop over column names and find field in i to bind to
254	// based on column name. all returned columns must match
255	// a field in the i struct
256	missingColNames := []string{}
257	for x := range cols {
258		colName := strings.ToLower(cols[x])
259		field, found := t.FieldByNameFunc(func(fieldName string) bool {
260			field, _ := t.FieldByName(fieldName)
261			cArguments := strings.Split(field.Tag.Get("db"), ",")
262			fieldName = cArguments[0]
263
264			if fieldName == "" || fieldName == "-" {
265				fieldName = field.Name
266			}
267			if tableMapped {
268				colMap := colMapOrNil(table, fieldName)
269				if colMap != nil && colMap.ColumnName != "-" {
270					fieldName = colMap.ColumnName
271				}
272			}
273			return colName == strings.ToLower(fieldName)
274		})
275		if found {
276			colToFieldIndex[x] = field.Index
277		}
278		if colToFieldIndex[x] == nil {
279			missingColNames = append(missingColNames, colName)
280		}
281	}
282	if len(missingColNames) > 0 {
283		return colToFieldIndex, &NoFieldInTypeError{
284			TypeName:        t.Name(),
285			MissingColNames: missingColNames,
286		}
287	}
288	return colToFieldIndex, nil
289}
290
291func fieldByName(val reflect.Value, fieldName string) *reflect.Value {
292	// try to find field by exact match
293	f := val.FieldByName(fieldName)
294
295	if f != zeroVal {
296		return &f
297	}
298
299	// try to find by case insensitive match - only the Postgres driver
300	// seems to require this - in the case where columns are aliased in the sql
301	fieldNameL := strings.ToLower(fieldName)
302	fieldCount := val.NumField()
303	t := val.Type()
304	for i := 0; i < fieldCount; i++ {
305		sf := t.Field(i)
306		if strings.ToLower(sf.Name) == fieldNameL {
307			f := val.Field(i)
308			return &f
309		}
310	}
311
312	return nil
313}
314
315// toSliceType returns the element type of the given object, if the object is a
316// "*[]*Element" or "*[]Element". If not, returns nil.
317// err is returned if the user was trying to pass a pointer-to-slice but failed.
318func toSliceType(i interface{}) (reflect.Type, error) {
319	t := reflect.TypeOf(i)
320	if t.Kind() != reflect.Ptr {
321		// If it's a slice, return a more helpful error message
322		if t.Kind() == reflect.Slice {
323			return nil, fmt.Errorf("gorp: cannot SELECT into a non-pointer slice: %v", t)
324		}
325		return nil, nil
326	}
327	if t = t.Elem(); t.Kind() != reflect.Slice {
328		return nil, nil
329	}
330	return t.Elem(), nil
331}
332
333func toType(i interface{}) (reflect.Type, error) {
334	t := reflect.TypeOf(i)
335
336	// If a Pointer to a type, follow
337	for t.Kind() == reflect.Ptr {
338		t = t.Elem()
339	}
340
341	if t.Kind() != reflect.Struct {
342		return nil, fmt.Errorf("gorp: cannot SELECT into this type: %v", reflect.TypeOf(i))
343	}
344	return t, nil
345}
346
347type foundTable struct {
348	table   *TableMap
349	dynName *string
350}
351
352func tableFor(m *DbMap, t reflect.Type, i interface{}) (*foundTable, error) {
353	if dyn, isDynamic := i.(DynamicTable); isDynamic {
354		tableName := dyn.TableName()
355		table, err := m.DynamicTableFor(tableName, true)
356		if err != nil {
357			return nil, err
358		}
359		return &foundTable{
360			table:   table,
361			dynName: &tableName,
362		}, nil
363	}
364	table, err := m.TableFor(t, true)
365	if err != nil {
366		return nil, err
367	}
368	return &foundTable{table: table}, nil
369}
370
371func get(m *DbMap, exec SqlExecutor, i interface{},
372	keys ...interface{}) (interface{}, error) {
373
374	t, err := toType(i)
375	if err != nil {
376		return nil, err
377	}
378
379	foundTable, err := tableFor(m, t, i)
380	if err != nil {
381		return nil, err
382	}
383	table := foundTable.table
384
385	plan := table.bindGet()
386
387	v := reflect.New(t)
388	if foundTable.dynName != nil {
389		retDyn := v.Interface().(DynamicTable)
390		retDyn.SetTableName(*foundTable.dynName)
391	}
392
393	dest := make([]interface{}, len(plan.argFields))
394
395	conv := m.TypeConverter
396	custScan := make([]CustomScanner, 0)
397
398	for x, fieldName := range plan.argFields {
399		f := v.Elem().FieldByName(fieldName)
400		target := f.Addr().Interface()
401		if conv != nil {
402			scanner, ok := conv.FromDb(target)
403			if ok {
404				target = scanner.Holder
405				custScan = append(custScan, scanner)
406			}
407		}
408		dest[x] = target
409	}
410
411	ctx, cancel := context.WithTimeout(context.Background(), m.QueryTimeout)
412	defer cancel()
413	row := exec.QueryRowContext(ctx, plan.query, keys...)
414
415	err = row.Scan(dest...)
416	if err != nil {
417		if err == sql.ErrNoRows {
418			err = nil
419		}
420		return nil, err
421	}
422
423	for _, c := range custScan {
424		err = c.Bind()
425		if err != nil {
426			return nil, err
427		}
428	}
429
430	if v, ok := v.Interface().(HasPostGet); ok {
431		err := v.PostGet(exec)
432		if err != nil {
433			return nil, err
434		}
435	}
436
437	return v.Interface(), nil
438}
439
440func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
441	count := int64(0)
442	for _, ptr := range list {
443		table, elem, err := m.tableForPointer(ptr, true)
444		if err != nil {
445			return -1, err
446		}
447
448		eval := elem.Addr().Interface()
449		if v, ok := eval.(HasPreDelete); ok {
450			err = v.PreDelete(exec)
451			if err != nil {
452				return -1, err
453			}
454		}
455
456		bi, err := table.bindDelete(elem)
457		if err != nil {
458			return -1, err
459		}
460
461		res, err := exec.Exec(bi.query, bi.args...)
462		if err != nil {
463			return -1, err
464		}
465		rows, err := res.RowsAffected()
466		if err != nil {
467			return -1, err
468		}
469
470		if rows == 0 && bi.existingVersion > 0 {
471			return lockError(m, exec, table.TableName,
472				bi.existingVersion, elem, bi.keys...)
473		}
474
475		count += rows
476
477		if v, ok := eval.(HasPostDelete); ok {
478			err := v.PostDelete(exec)
479			if err != nil {
480				return -1, err
481			}
482		}
483	}
484
485	return count, nil
486}
487
488func update(m *DbMap, exec SqlExecutor, colFilter ColumnFilter, list ...interface{}) (int64, error) {
489	count := int64(0)
490	for _, ptr := range list {
491		table, elem, err := m.tableForPointer(ptr, true)
492		if err != nil {
493			return -1, err
494		}
495
496		eval := elem.Addr().Interface()
497		if v, ok := eval.(HasPreUpdate); ok {
498			err = v.PreUpdate(exec)
499			if err != nil {
500				return -1, err
501			}
502		}
503
504		bi, err := table.bindUpdate(elem, colFilter)
505		if err != nil {
506			return -1, err
507		}
508
509		res, err := exec.Exec(bi.query, bi.args...)
510		if err != nil {
511			return -1, err
512		}
513
514		rows, err := res.RowsAffected()
515		if err != nil {
516			return -1, err
517		}
518
519		if rows == 0 && bi.existingVersion > 0 {
520			return lockError(m, exec, table.TableName,
521				bi.existingVersion, elem, bi.keys...)
522		}
523
524		if bi.versField != "" {
525			elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1)
526		}
527
528		count += rows
529
530		if v, ok := eval.(HasPostUpdate); ok {
531			err = v.PostUpdate(exec)
532			if err != nil {
533				return -1, err
534			}
535		}
536	}
537	return count, nil
538}
539
540func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
541	for _, ptr := range list {
542		table, elem, err := m.tableForPointer(ptr, false)
543		if err != nil {
544			return err
545		}
546
547		eval := elem.Addr().Interface()
548		if v, ok := eval.(HasPreInsert); ok {
549			err := v.PreInsert(exec)
550			if err != nil {
551				return err
552			}
553		}
554
555		bi, err := table.bindInsert(elem)
556		if err != nil {
557			return err
558		}
559
560		if bi.autoIncrIdx > -1 {
561			f := elem.FieldByName(bi.autoIncrFieldName)
562			switch inserter := m.Dialect.(type) {
563			case IntegerAutoIncrInserter:
564				id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...)
565				if err != nil {
566					return err
567				}
568				k := f.Kind()
569				if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
570					f.SetInt(id)
571				} else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) {
572					f.SetUint(uint64(id))
573				} else {
574					return fmt.Errorf("gorp: cannot set autoincrement value on non-Int field. SQL=%s  autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName)
575				}
576			case TargetedAutoIncrInserter:
577				err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...)
578				if err != nil {
579					return err
580				}
581			case TargetQueryInserter:
582				var idQuery = table.ColMap(bi.autoIncrFieldName).GeneratedIdQuery
583				if idQuery == "" {
584					return fmt.Errorf("gorp: cannot set %s value if its ColumnMap.GeneratedIdQuery is empty", bi.autoIncrFieldName)
585				}
586				err := inserter.InsertQueryToTarget(exec, bi.query, idQuery, f.Addr().Interface(), bi.args...)
587				if err != nil {
588					return err
589				}
590			default:
591				return fmt.Errorf("gorp: cannot use autoincrement fields on dialects that do not implement an autoincrementing interface")
592			}
593		} else {
594			_, err := exec.Exec(bi.query, bi.args...)
595			if err != nil {
596				return err
597			}
598		}
599
600		if v, ok := eval.(HasPostInsert); ok {
601			err := v.PostInsert(exec)
602			if err != nil {
603				return err
604			}
605		}
606	}
607	return nil
608}
609