1// Copyright 2014 beego Author. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// +build go1.8
16
17// Package orm provide ORM for MySQL/PostgreSQL/sqlite
18// Simple Usage
19//
20//	package main
21//
22//	import (
23//		"fmt"
24//		"github.com/astaxie/beego/orm"
25//		_ "github.com/go-sql-driver/mysql" // import your used driver
26//	)
27//
28//	// Model Struct
29//	type User struct {
30//		Id   int    `orm:"auto"`
31//		Name string `orm:"size(100)"`
32//	}
33//
34//	func init() {
35//		orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
36//	}
37//
38//	func main() {
39//		o := orm.NewOrm()
40//		user := User{Name: "slene"}
41//		// insert
42//		id, err := o.Insert(&user)
43//		// update
44//		user.Name = "astaxie"
45//		num, err := o.Update(&user)
46//		// read one
47//		u := User{Id: user.Id}
48//		err = o.Read(&u)
49//		// delete
50//		num, err = o.Delete(&u)
51//	}
52//
53// more docs: http://beego.me/docs/mvc/model/overview.md
54package orm
55
56import (
57	"context"
58	"database/sql"
59	"errors"
60	"fmt"
61	"os"
62	"reflect"
63	"sync"
64	"time"
65)
66
67// DebugQueries define the debug
68const (
69	DebugQueries = iota
70)
71
72// Define common vars
73var (
74	Debug            = false
75	DebugLog         = NewLog(os.Stdout)
76	DefaultRowsLimit = -1
77	DefaultRelsDepth = 2
78	DefaultTimeLoc   = time.Local
79	ErrTxHasBegan    = errors.New("<Ormer.Begin> transaction already begin")
80	ErrTxDone        = errors.New("<Ormer.Commit/Rollback> transaction not begin")
81	ErrMultiRows     = errors.New("<QuerySeter> return multi rows")
82	ErrNoRows        = errors.New("<QuerySeter> no row found")
83	ErrStmtClosed    = errors.New("<QuerySeter> stmt already closed")
84	ErrArgs          = errors.New("<Ormer> args error may be empty")
85	ErrNotImplement  = errors.New("have not implement")
86)
87
88// Params stores the Params
89type Params map[string]interface{}
90
91// ParamsList stores paramslist
92type ParamsList []interface{}
93
94type orm struct {
95	alias *alias
96	db    dbQuerier
97	isTx  bool
98}
99
100var _ Ormer = new(orm)
101
102// get model info and model reflect value
103func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
104	val := reflect.ValueOf(md)
105	ind = reflect.Indirect(val)
106	typ := ind.Type()
107	if needPtr && val.Kind() != reflect.Ptr {
108		panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
109	}
110	name := getFullName(typ)
111	if mi, ok := modelCache.getByFullName(name); ok {
112		return mi, ind
113	}
114	panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
115}
116
117// get field info from model info by given field name
118func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
119	fi, ok := mi.fields.GetByAny(name)
120	if !ok {
121		panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName))
122	}
123	return fi
124}
125
126// read data to model
127func (o *orm) Read(md interface{}, cols ...string) error {
128	mi, ind := o.getMiInd(md, true)
129	return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
130}
131
132// read data to model, like Read(), but use "SELECT FOR UPDATE" form
133func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
134	mi, ind := o.getMiInd(md, true)
135	return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
136}
137
138// Try to read a row from the database, or insert one if it doesn't exist
139func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
140	cols = append([]string{col1}, cols...)
141	mi, ind := o.getMiInd(md, true)
142	err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
143	if err == ErrNoRows {
144		// Create
145		id, err := o.Insert(md)
146		return (err == nil), id, err
147	}
148
149	id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
150	if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
151		id = int64(vid.Uint())
152	} else if mi.fields.pk.rel {
153		return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
154	} else {
155		id = vid.Int()
156	}
157
158	return false, id, err
159}
160
161// insert model data to database
162func (o *orm) Insert(md interface{}) (int64, error) {
163	mi, ind := o.getMiInd(md, true)
164	id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
165	if err != nil {
166		return id, err
167	}
168
169	o.setPk(mi, ind, id)
170
171	return id, nil
172}
173
174// set auto pk field
175func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
176	if mi.fields.pk.auto {
177		if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
178			ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
179		} else {
180			ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
181		}
182	}
183}
184
185// insert some models to database
186func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
187	var cnt int64
188
189	sind := reflect.Indirect(reflect.ValueOf(mds))
190
191	switch sind.Kind() {
192	case reflect.Array, reflect.Slice:
193		if sind.Len() == 0 {
194			return cnt, ErrArgs
195		}
196	default:
197		return cnt, ErrArgs
198	}
199
200	if bulk <= 1 {
201		for i := 0; i < sind.Len(); i++ {
202			ind := reflect.Indirect(sind.Index(i))
203			mi, _ := o.getMiInd(ind.Interface(), false)
204			id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
205			if err != nil {
206				return cnt, err
207			}
208
209			o.setPk(mi, ind, id)
210
211			cnt++
212		}
213	} else {
214		mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
215		return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
216	}
217	return cnt, nil
218}
219
220// InsertOrUpdate data to database
221func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
222	mi, ind := o.getMiInd(md, true)
223	id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
224	if err != nil {
225		return id, err
226	}
227
228	o.setPk(mi, ind, id)
229
230	return id, nil
231}
232
233// update model to database.
234// cols set the columns those want to update.
235func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
236	mi, ind := o.getMiInd(md, true)
237	return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
238}
239
240// delete model in database
241// cols shows the delete conditions values read from. default is pk
242func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
243	mi, ind := o.getMiInd(md, true)
244	num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
245	if err != nil {
246		return num, err
247	}
248	if num > 0 {
249		o.setPk(mi, ind, 0)
250	}
251	return num, nil
252}
253
254// create a models to models queryer
255func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
256	mi, ind := o.getMiInd(md, true)
257	fi := o.getFieldInfo(mi, name)
258
259	switch {
260	case fi.fieldType == RelManyToMany:
261	case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough:
262	default:
263		panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName))
264	}
265
266	return newQueryM2M(md, o, mi, fi, ind)
267}
268
269// load related models to md model.
270// args are limit, offset int and order string.
271//
272// example:
273// 	orm.LoadRelated(post,"Tags")
274// 	for _,tag := range post.Tags{...}
275//
276// make sure the relation is defined in model struct tags.
277func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
278	_, fi, ind, qseter := o.queryRelated(md, name)
279
280	qs := qseter.(*querySet)
281
282	var relDepth int
283	var limit, offset int64
284	var order string
285	for i, arg := range args {
286		switch i {
287		case 0:
288			if v, ok := arg.(bool); ok {
289				if v {
290					relDepth = DefaultRelsDepth
291				}
292			} else if v, ok := arg.(int); ok {
293				relDepth = v
294			}
295		case 1:
296			limit = ToInt64(arg)
297		case 2:
298			offset = ToInt64(arg)
299		case 3:
300			order, _ = arg.(string)
301		}
302	}
303
304	switch fi.fieldType {
305	case RelOneToOne, RelForeignKey, RelReverseOne:
306		limit = 1
307		offset = 0
308	}
309
310	qs.limit = limit
311	qs.offset = offset
312	qs.relDepth = relDepth
313
314	if len(order) > 0 {
315		qs.orders = []string{order}
316	}
317
318	find := ind.FieldByIndex(fi.fieldIndex)
319
320	var nums int64
321	var err error
322	switch fi.fieldType {
323	case RelOneToOne, RelForeignKey, RelReverseOne:
324		val := reflect.New(find.Type().Elem())
325		container := val.Interface()
326		err = qs.One(container)
327		if err == nil {
328			find.Set(val)
329			nums = 1
330		}
331	default:
332		nums, err = qs.All(find.Addr().Interface())
333	}
334
335	return nums, err
336}
337
338// return a QuerySeter for related models to md model.
339// it can do all, update, delete in QuerySeter.
340// example:
341// 	qs := orm.QueryRelated(post,"Tag")
342//  qs.All(&[]*Tag{})
343//
344func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
345	// is this api needed ?
346	_, _, _, qs := o.queryRelated(md, name)
347	return qs
348}
349
350// get QuerySeter for related models to md model
351func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
352	mi, ind := o.getMiInd(md, true)
353	fi := o.getFieldInfo(mi, name)
354
355	_, _, exist := getExistPk(mi, ind)
356	if !exist {
357		panic(ErrMissPK)
358	}
359
360	var qs *querySet
361
362	switch fi.fieldType {
363	case RelOneToOne, RelForeignKey, RelManyToMany:
364		if !fi.inModel {
365			break
366		}
367		qs = o.getRelQs(md, mi, fi)
368	case RelReverseOne, RelReverseMany:
369		if !fi.inModel {
370			break
371		}
372		qs = o.getReverseQs(md, mi, fi)
373	}
374
375	if qs == nil {
376		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name))
377	}
378
379	return mi, fi, ind, qs
380}
381
382// get reverse relation QuerySeter
383func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
384	switch fi.fieldType {
385	case RelReverseOne, RelReverseMany:
386	default:
387		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName))
388	}
389
390	var q *querySet
391
392	if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough {
393		q = newQuerySet(o, fi.relModelInfo).(*querySet)
394		q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
395	} else {
396		q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet)
397		q.cond = NewCondition().And(fi.reverseFieldInfo.column, md)
398	}
399
400	return q
401}
402
403// get relation QuerySeter
404func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
405	switch fi.fieldType {
406	case RelOneToOne, RelForeignKey, RelManyToMany:
407	default:
408		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName))
409	}
410
411	q := newQuerySet(o, fi.relModelInfo).(*querySet)
412	q.cond = NewCondition()
413
414	if fi.fieldType == RelManyToMany {
415		q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
416	} else {
417		q.cond = q.cond.And(fi.reverseFieldInfo.column, md)
418	}
419
420	return q
421}
422
423// return a QuerySeter for table operations.
424// table name can be string or struct.
425// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
426func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
427	var name string
428	if table, ok := ptrStructOrTableName.(string); ok {
429		name = nameStrategyMap[defaultNameStrategy](table)
430		if mi, ok := modelCache.get(name); ok {
431			qs = newQuerySet(o, mi)
432		}
433	} else {
434		name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
435		if mi, ok := modelCache.getByFullName(name); ok {
436			qs = newQuerySet(o, mi)
437		}
438	}
439	if qs == nil {
440		panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
441	}
442	return
443}
444
445// switch to another registered database driver by given name.
446func (o *orm) Using(name string) error {
447	if o.isTx {
448		panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
449	}
450	if al, ok := dataBaseCache.get(name); ok {
451		o.alias = al
452		if Debug {
453			o.db = newDbQueryLog(al, al.DB)
454		} else {
455			o.db = al.DB
456		}
457	} else {
458		return fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", name)
459	}
460	return nil
461}
462
463// begin transaction
464func (o *orm) Begin() error {
465	return o.BeginTx(context.Background(), nil)
466}
467
468func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error {
469	if o.isTx {
470		return ErrTxHasBegan
471	}
472	var tx *sql.Tx
473	tx, err := o.db.(txer).BeginTx(ctx, opts)
474	if err != nil {
475		return err
476	}
477	o.isTx = true
478	if Debug {
479		o.db.(*dbQueryLog).SetDB(tx)
480	} else {
481		o.db = tx
482	}
483	return nil
484}
485
486// commit transaction
487func (o *orm) Commit() error {
488	if !o.isTx {
489		return ErrTxDone
490	}
491	err := o.db.(txEnder).Commit()
492	if err == nil {
493		o.isTx = false
494		o.Using(o.alias.Name)
495	} else if err == sql.ErrTxDone {
496		return ErrTxDone
497	}
498	return err
499}
500
501// rollback transaction
502func (o *orm) Rollback() error {
503	if !o.isTx {
504		return ErrTxDone
505	}
506	err := o.db.(txEnder).Rollback()
507	if err == nil {
508		o.isTx = false
509		o.Using(o.alias.Name)
510	} else if err == sql.ErrTxDone {
511		return ErrTxDone
512	}
513	return err
514}
515
516// return a raw query seter for raw sql string.
517func (o *orm) Raw(query string, args ...interface{}) RawSeter {
518	return newRawSet(o, query, args)
519}
520
521// return current using database Driver
522func (o *orm) Driver() Driver {
523	return driver(o.alias.Name)
524}
525
526// return sql.DBStats for current database
527func (o *orm) DBStats() *sql.DBStats {
528	if o.alias != nil && o.alias.DB != nil {
529		stats := o.alias.DB.DB.Stats()
530		return &stats
531	}
532	return nil
533}
534
535// NewOrm create new orm
536func NewOrm() Ormer {
537	BootStrap() // execute only once
538
539	o := new(orm)
540	err := o.Using("default")
541	if err != nil {
542		panic(err)
543	}
544	return o
545}
546
547// NewOrmWithDB create a new ormer object with specify *sql.DB for query
548func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
549	var al *alias
550
551	if dr, ok := drivers[driverName]; ok {
552		al = new(alias)
553		al.DbBaser = dbBasers[dr]
554		al.Driver = dr
555	} else {
556		return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
557	}
558
559	al.Name = aliasName
560	al.DriverName = driverName
561	al.DB = &DB{
562		RWMutex:        new(sync.RWMutex),
563		DB:             db,
564		stmtDecorators: newStmtDecoratorLruWithEvict(),
565	}
566
567	detectTZ(al)
568
569	o := new(orm)
570	o.alias = al
571
572	if Debug {
573		o.db = newDbQueryLog(o.alias, db)
574	} else {
575		o.db = db
576	}
577
578	return o, nil
579}
580