1// Copyright 2014 beego Author. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package orm
16
17import (
18	"context"
19	"database/sql"
20	"fmt"
21	"reflect"
22	"sync"
23	"time"
24
25	lru "github.com/hashicorp/golang-lru"
26)
27
28// DriverType database driver constant int.
29type DriverType int
30
31// Enum the Database driver
32const (
33	_          DriverType = iota // int enum type
34	DRMySQL                      // mysql
35	DRSqlite                     // sqlite
36	DROracle                     // oracle
37	DRPostgres                   // pgsql
38	DRTiDB                       // TiDB
39)
40
41// database driver string.
42type driver string
43
44// get type constant int of current driver..
45func (d driver) Type() DriverType {
46	a, _ := dataBaseCache.get(string(d))
47	return a.Driver
48}
49
50// get name of current driver
51func (d driver) Name() string {
52	return string(d)
53}
54
55// check driver iis implemented Driver interface or not.
56var _ Driver = new(driver)
57
58var (
59	dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
60	drivers       = map[string]DriverType{
61		"mysql":    DRMySQL,
62		"postgres": DRPostgres,
63		"sqlite3":  DRSqlite,
64		"tidb":     DRTiDB,
65		"oracle":   DROracle,
66		"oci8":     DROracle, // github.com/mattn/go-oci8
67		"ora":      DROracle, //https://github.com/rana/ora
68	}
69	dbBasers = map[DriverType]dbBaser{
70		DRMySQL:    newdbBaseMysql(),
71		DRSqlite:   newdbBaseSqlite(),
72		DROracle:   newdbBaseOracle(),
73		DRPostgres: newdbBasePostgres(),
74		DRTiDB:     newdbBaseTidb(),
75	}
76)
77
78// database alias cacher.
79type _dbCache struct {
80	mux   sync.RWMutex
81	cache map[string]*alias
82}
83
84// add database alias with original name.
85func (ac *_dbCache) add(name string, al *alias) (added bool) {
86	ac.mux.Lock()
87	defer ac.mux.Unlock()
88	if _, ok := ac.cache[name]; !ok {
89		ac.cache[name] = al
90		added = true
91	}
92	return
93}
94
95// get database alias if cached.
96func (ac *_dbCache) get(name string) (al *alias, ok bool) {
97	ac.mux.RLock()
98	defer ac.mux.RUnlock()
99	al, ok = ac.cache[name]
100	return
101}
102
103// get default alias.
104func (ac *_dbCache) getDefault() (al *alias) {
105	al, _ = ac.get("default")
106	return
107}
108
109type DB struct {
110	*sync.RWMutex
111	DB             *sql.DB
112	stmtDecorators *lru.Cache
113}
114
115func (d *DB) Begin() (*sql.Tx, error) {
116	return d.DB.Begin()
117}
118
119func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
120	return d.DB.BeginTx(ctx, opts)
121}
122
123//su must call release to release *sql.Stmt after using
124func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
125	d.RLock()
126	c, ok := d.stmtDecorators.Get(query)
127	if ok {
128		c.(*stmtDecorator).acquire()
129		d.RUnlock()
130		return c.(*stmtDecorator), nil
131	}
132	d.RUnlock()
133
134	d.Lock()
135	c, ok = d.stmtDecorators.Get(query)
136	if ok {
137		c.(*stmtDecorator).acquire()
138		d.Unlock()
139		return c.(*stmtDecorator), nil
140	}
141
142	stmt, err := d.Prepare(query)
143	if err != nil {
144		d.Unlock()
145		return nil, err
146	}
147	sd := newStmtDecorator(stmt)
148	sd.acquire()
149	d.stmtDecorators.Add(query, sd)
150	d.Unlock()
151
152	return sd, nil
153}
154
155func (d *DB) Prepare(query string) (*sql.Stmt, error) {
156	return d.DB.Prepare(query)
157}
158
159func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
160	return d.DB.PrepareContext(ctx, query)
161}
162
163func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
164	sd, err := d.getStmtDecorator(query)
165	if err != nil {
166		return nil, err
167	}
168	stmt := sd.getStmt()
169	defer sd.release()
170	return stmt.Exec(args...)
171}
172
173func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
174	sd, err := d.getStmtDecorator(query)
175	if err != nil {
176		return nil, err
177	}
178	stmt := sd.getStmt()
179	defer sd.release()
180	return stmt.ExecContext(ctx, args...)
181}
182
183func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
184	sd, err := d.getStmtDecorator(query)
185	if err != nil {
186		return nil, err
187	}
188	stmt := sd.getStmt()
189	defer sd.release()
190	return stmt.Query(args...)
191}
192
193func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
194	sd, err := d.getStmtDecorator(query)
195	if err != nil {
196		return nil, err
197	}
198	stmt := sd.getStmt()
199	defer sd.release()
200	return stmt.QueryContext(ctx, args...)
201}
202
203func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
204	sd, err := d.getStmtDecorator(query)
205	if err != nil {
206		panic(err)
207	}
208	stmt := sd.getStmt()
209	defer sd.release()
210	return stmt.QueryRow(args...)
211
212}
213
214func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
215	sd, err := d.getStmtDecorator(query)
216	if err != nil {
217		panic(err)
218	}
219	stmt := sd.getStmt()
220	defer sd.release()
221	return stmt.QueryRowContext(ctx, args)
222}
223
224type alias struct {
225	Name         string
226	Driver       DriverType
227	DriverName   string
228	DataSource   string
229	MaxIdleConns int
230	MaxOpenConns int
231	DB           *DB
232	DbBaser      dbBaser
233	TZ           *time.Location
234	Engine       string
235}
236
237func detectTZ(al *alias) {
238	// orm timezone system match database
239	// default use Local
240	al.TZ = DefaultTimeLoc
241
242	if al.DriverName == "sphinx" {
243		return
244	}
245
246	switch al.Driver {
247	case DRMySQL:
248		row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
249		var tz string
250		row.Scan(&tz)
251		if len(tz) >= 8 {
252			if tz[0] != '-' {
253				tz = "+" + tz
254			}
255			t, err := time.Parse("-07:00:00", tz)
256			if err == nil {
257				if t.Location().String() != "" {
258					al.TZ = t.Location()
259				}
260			} else {
261				DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
262			}
263		}
264
265		// get default engine from current database
266		row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
267		var engine string
268		var tx bool
269		row.Scan(&engine, &tx)
270
271		if engine != "" {
272			al.Engine = engine
273		} else {
274			al.Engine = "INNODB"
275		}
276
277	case DRSqlite, DROracle:
278		al.TZ = time.UTC
279
280	case DRPostgres:
281		row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
282		var tz string
283		row.Scan(&tz)
284		loc, err := time.LoadLocation(tz)
285		if err == nil {
286			al.TZ = loc
287		} else {
288			DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
289		}
290	}
291}
292
293func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
294	al := new(alias)
295	al.Name = aliasName
296	al.DriverName = driverName
297	al.DB = &DB{
298		RWMutex:        new(sync.RWMutex),
299		DB:             db,
300		stmtDecorators: newStmtDecoratorLruWithEvict(),
301	}
302
303	if dr, ok := drivers[driverName]; ok {
304		al.DbBaser = dbBasers[dr]
305		al.Driver = dr
306	} else {
307		return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
308	}
309
310	err := db.Ping()
311	if err != nil {
312		return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
313	}
314
315	if !dataBaseCache.add(aliasName, al) {
316		return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
317	}
318
319	return al, nil
320}
321
322// AddAliasWthDB add a aliasName for the drivename
323func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
324	_, err := addAliasWthDB(aliasName, driverName, db)
325	return err
326}
327
328// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
329func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
330	var (
331		err error
332		db  *sql.DB
333		al  *alias
334	)
335
336	db, err = sql.Open(driverName, dataSource)
337	if err != nil {
338		err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
339		goto end
340	}
341
342	al, err = addAliasWthDB(aliasName, driverName, db)
343	if err != nil {
344		goto end
345	}
346
347	al.DataSource = dataSource
348
349	detectTZ(al)
350
351	for i, v := range params {
352		switch i {
353		case 0:
354			SetMaxIdleConns(al.Name, v)
355		case 1:
356			SetMaxOpenConns(al.Name, v)
357		}
358	}
359
360end:
361	if err != nil {
362		if db != nil {
363			db.Close()
364		}
365		DebugLog.Println(err.Error())
366	}
367
368	return err
369}
370
371// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
372func RegisterDriver(driverName string, typ DriverType) error {
373	if t, ok := drivers[driverName]; !ok {
374		drivers[driverName] = typ
375	} else {
376		if t != typ {
377			return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
378		}
379	}
380	return nil
381}
382
383// SetDataBaseTZ Change the database default used timezone
384func SetDataBaseTZ(aliasName string, tz *time.Location) error {
385	if al, ok := dataBaseCache.get(aliasName); ok {
386		al.TZ = tz
387	} else {
388		return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
389	}
390	return nil
391}
392
393// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
394func SetMaxIdleConns(aliasName string, maxIdleConns int) {
395	al := getDbAlias(aliasName)
396	al.MaxIdleConns = maxIdleConns
397	al.DB.DB.SetMaxIdleConns(maxIdleConns)
398}
399
400// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
401func SetMaxOpenConns(aliasName string, maxOpenConns int) {
402	al := getDbAlias(aliasName)
403	al.MaxOpenConns = maxOpenConns
404	al.DB.DB.SetMaxOpenConns(maxOpenConns)
405	// for tip go 1.2
406	if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
407		fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
408	}
409}
410
411// GetDB Get *sql.DB from registered database by db alias name.
412// Use "default" as alias name if you not set.
413func GetDB(aliasNames ...string) (*sql.DB, error) {
414	var name string
415	if len(aliasNames) > 0 {
416		name = aliasNames[0]
417	} else {
418		name = "default"
419	}
420	al, ok := dataBaseCache.get(name)
421	if ok {
422		return al.DB.DB, nil
423	}
424	return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
425}
426
427type stmtDecorator struct {
428	wg   sync.WaitGroup
429	stmt *sql.Stmt
430}
431
432func (s *stmtDecorator) getStmt() *sql.Stmt {
433	return s.stmt
434}
435
436// acquire will add one
437// since this method will be used inside read lock scope,
438// so we can not do more things here
439// we should think about refactor this
440func (s *stmtDecorator) acquire() {
441	s.wg.Add(1)
442}
443
444func (s *stmtDecorator) release() {
445	s.wg.Done()
446}
447
448//garbage recycle for stmt
449func (s *stmtDecorator) destroy() {
450	go func() {
451		s.wg.Wait()
452		_ = s.stmt.Close()
453	}()
454}
455
456func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
457	return &stmtDecorator{
458		stmt: sqlStmt,
459	}
460}
461
462func newStmtDecoratorLruWithEvict() *lru.Cache {
463	// temporarily solution
464	// we fixed this problem in v2.x
465	cache, _ := lru.NewWithEvict(50, func(key interface{}, value interface{}) {
466		value.(*stmtDecorator).destroy()
467	})
468	return cache
469}
470