1// Copyright 2014 beego Author. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package orm 16 17import ( 18 "context" 19 "database/sql" 20 "fmt" 21 "reflect" 22 "sync" 23 "time" 24 25 lru "github.com/hashicorp/golang-lru" 26) 27 28// DriverType database driver constant int. 29type DriverType int 30 31// Enum the Database driver 32const ( 33 _ DriverType = iota // int enum type 34 DRMySQL // mysql 35 DRSqlite // sqlite 36 DROracle // oracle 37 DRPostgres // pgsql 38 DRTiDB // TiDB 39) 40 41// database driver string. 42type driver string 43 44// get type constant int of current driver.. 45func (d driver) Type() DriverType { 46 a, _ := dataBaseCache.get(string(d)) 47 return a.Driver 48} 49 50// get name of current driver 51func (d driver) Name() string { 52 return string(d) 53} 54 55// check driver iis implemented Driver interface or not. 56var _ Driver = new(driver) 57 58var ( 59 dataBaseCache = &_dbCache{cache: make(map[string]*alias)} 60 drivers = map[string]DriverType{ 61 "mysql": DRMySQL, 62 "postgres": DRPostgres, 63 "sqlite3": DRSqlite, 64 "tidb": DRTiDB, 65 "oracle": DROracle, 66 "oci8": DROracle, // github.com/mattn/go-oci8 67 "ora": DROracle, //https://github.com/rana/ora 68 } 69 dbBasers = map[DriverType]dbBaser{ 70 DRMySQL: newdbBaseMysql(), 71 DRSqlite: newdbBaseSqlite(), 72 DROracle: newdbBaseOracle(), 73 DRPostgres: newdbBasePostgres(), 74 DRTiDB: newdbBaseTidb(), 75 } 76) 77 78// database alias cacher. 79type _dbCache struct { 80 mux sync.RWMutex 81 cache map[string]*alias 82} 83 84// add database alias with original name. 85func (ac *_dbCache) add(name string, al *alias) (added bool) { 86 ac.mux.Lock() 87 defer ac.mux.Unlock() 88 if _, ok := ac.cache[name]; !ok { 89 ac.cache[name] = al 90 added = true 91 } 92 return 93} 94 95// get database alias if cached. 96func (ac *_dbCache) get(name string) (al *alias, ok bool) { 97 ac.mux.RLock() 98 defer ac.mux.RUnlock() 99 al, ok = ac.cache[name] 100 return 101} 102 103// get default alias. 104func (ac *_dbCache) getDefault() (al *alias) { 105 al, _ = ac.get("default") 106 return 107} 108 109type DB struct { 110 *sync.RWMutex 111 DB *sql.DB 112 stmtDecorators *lru.Cache 113} 114 115func (d *DB) Begin() (*sql.Tx, error) { 116 return d.DB.Begin() 117} 118 119func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { 120 return d.DB.BeginTx(ctx, opts) 121} 122 123//su must call release to release *sql.Stmt after using 124func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { 125 d.RLock() 126 c, ok := d.stmtDecorators.Get(query) 127 if ok { 128 c.(*stmtDecorator).acquire() 129 d.RUnlock() 130 return c.(*stmtDecorator), nil 131 } 132 d.RUnlock() 133 134 d.Lock() 135 c, ok = d.stmtDecorators.Get(query) 136 if ok { 137 c.(*stmtDecorator).acquire() 138 d.Unlock() 139 return c.(*stmtDecorator), nil 140 } 141 142 stmt, err := d.Prepare(query) 143 if err != nil { 144 d.Unlock() 145 return nil, err 146 } 147 sd := newStmtDecorator(stmt) 148 sd.acquire() 149 d.stmtDecorators.Add(query, sd) 150 d.Unlock() 151 152 return sd, nil 153} 154 155func (d *DB) Prepare(query string) (*sql.Stmt, error) { 156 return d.DB.Prepare(query) 157} 158 159func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 160 return d.DB.PrepareContext(ctx, query) 161} 162 163func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { 164 sd, err := d.getStmtDecorator(query) 165 if err != nil { 166 return nil, err 167 } 168 stmt := sd.getStmt() 169 defer sd.release() 170 return stmt.Exec(args...) 171} 172 173func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 174 sd, err := d.getStmtDecorator(query) 175 if err != nil { 176 return nil, err 177 } 178 stmt := sd.getStmt() 179 defer sd.release() 180 return stmt.ExecContext(ctx, args...) 181} 182 183func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { 184 sd, err := d.getStmtDecorator(query) 185 if err != nil { 186 return nil, err 187 } 188 stmt := sd.getStmt() 189 defer sd.release() 190 return stmt.Query(args...) 191} 192 193func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 194 sd, err := d.getStmtDecorator(query) 195 if err != nil { 196 return nil, err 197 } 198 stmt := sd.getStmt() 199 defer sd.release() 200 return stmt.QueryContext(ctx, args...) 201} 202 203func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { 204 sd, err := d.getStmtDecorator(query) 205 if err != nil { 206 panic(err) 207 } 208 stmt := sd.getStmt() 209 defer sd.release() 210 return stmt.QueryRow(args...) 211 212} 213 214func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 215 sd, err := d.getStmtDecorator(query) 216 if err != nil { 217 panic(err) 218 } 219 stmt := sd.getStmt() 220 defer sd.release() 221 return stmt.QueryRowContext(ctx, args) 222} 223 224type alias struct { 225 Name string 226 Driver DriverType 227 DriverName string 228 DataSource string 229 MaxIdleConns int 230 MaxOpenConns int 231 DB *DB 232 DbBaser dbBaser 233 TZ *time.Location 234 Engine string 235} 236 237func detectTZ(al *alias) { 238 // orm timezone system match database 239 // default use Local 240 al.TZ = DefaultTimeLoc 241 242 if al.DriverName == "sphinx" { 243 return 244 } 245 246 switch al.Driver { 247 case DRMySQL: 248 row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") 249 var tz string 250 row.Scan(&tz) 251 if len(tz) >= 8 { 252 if tz[0] != '-' { 253 tz = "+" + tz 254 } 255 t, err := time.Parse("-07:00:00", tz) 256 if err == nil { 257 if t.Location().String() != "" { 258 al.TZ = t.Location() 259 } 260 } else { 261 DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) 262 } 263 } 264 265 // get default engine from current database 266 row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") 267 var engine string 268 var tx bool 269 row.Scan(&engine, &tx) 270 271 if engine != "" { 272 al.Engine = engine 273 } else { 274 al.Engine = "INNODB" 275 } 276 277 case DRSqlite, DROracle: 278 al.TZ = time.UTC 279 280 case DRPostgres: 281 row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") 282 var tz string 283 row.Scan(&tz) 284 loc, err := time.LoadLocation(tz) 285 if err == nil { 286 al.TZ = loc 287 } else { 288 DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) 289 } 290 } 291} 292 293func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { 294 al := new(alias) 295 al.Name = aliasName 296 al.DriverName = driverName 297 al.DB = &DB{ 298 RWMutex: new(sync.RWMutex), 299 DB: db, 300 stmtDecorators: newStmtDecoratorLruWithEvict(), 301 } 302 303 if dr, ok := drivers[driverName]; ok { 304 al.DbBaser = dbBasers[dr] 305 al.Driver = dr 306 } else { 307 return nil, fmt.Errorf("driver name `%s` have not registered", driverName) 308 } 309 310 err := db.Ping() 311 if err != nil { 312 return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) 313 } 314 315 if !dataBaseCache.add(aliasName, al) { 316 return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) 317 } 318 319 return al, nil 320} 321 322// AddAliasWthDB add a aliasName for the drivename 323func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { 324 _, err := addAliasWthDB(aliasName, driverName, db) 325 return err 326} 327 328// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. 329func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { 330 var ( 331 err error 332 db *sql.DB 333 al *alias 334 ) 335 336 db, err = sql.Open(driverName, dataSource) 337 if err != nil { 338 err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) 339 goto end 340 } 341 342 al, err = addAliasWthDB(aliasName, driverName, db) 343 if err != nil { 344 goto end 345 } 346 347 al.DataSource = dataSource 348 349 detectTZ(al) 350 351 for i, v := range params { 352 switch i { 353 case 0: 354 SetMaxIdleConns(al.Name, v) 355 case 1: 356 SetMaxOpenConns(al.Name, v) 357 } 358 } 359 360end: 361 if err != nil { 362 if db != nil { 363 db.Close() 364 } 365 DebugLog.Println(err.Error()) 366 } 367 368 return err 369} 370 371// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. 372func RegisterDriver(driverName string, typ DriverType) error { 373 if t, ok := drivers[driverName]; !ok { 374 drivers[driverName] = typ 375 } else { 376 if t != typ { 377 return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) 378 } 379 } 380 return nil 381} 382 383// SetDataBaseTZ Change the database default used timezone 384func SetDataBaseTZ(aliasName string, tz *time.Location) error { 385 if al, ok := dataBaseCache.get(aliasName); ok { 386 al.TZ = tz 387 } else { 388 return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) 389 } 390 return nil 391} 392 393// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name 394func SetMaxIdleConns(aliasName string, maxIdleConns int) { 395 al := getDbAlias(aliasName) 396 al.MaxIdleConns = maxIdleConns 397 al.DB.DB.SetMaxIdleConns(maxIdleConns) 398} 399 400// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name 401func SetMaxOpenConns(aliasName string, maxOpenConns int) { 402 al := getDbAlias(aliasName) 403 al.MaxOpenConns = maxOpenConns 404 al.DB.DB.SetMaxOpenConns(maxOpenConns) 405 // for tip go 1.2 406 if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { 407 fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) 408 } 409} 410 411// GetDB Get *sql.DB from registered database by db alias name. 412// Use "default" as alias name if you not set. 413func GetDB(aliasNames ...string) (*sql.DB, error) { 414 var name string 415 if len(aliasNames) > 0 { 416 name = aliasNames[0] 417 } else { 418 name = "default" 419 } 420 al, ok := dataBaseCache.get(name) 421 if ok { 422 return al.DB.DB, nil 423 } 424 return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) 425} 426 427type stmtDecorator struct { 428 wg sync.WaitGroup 429 stmt *sql.Stmt 430} 431 432func (s *stmtDecorator) getStmt() *sql.Stmt { 433 return s.stmt 434} 435 436// acquire will add one 437// since this method will be used inside read lock scope, 438// so we can not do more things here 439// we should think about refactor this 440func (s *stmtDecorator) acquire() { 441 s.wg.Add(1) 442} 443 444func (s *stmtDecorator) release() { 445 s.wg.Done() 446} 447 448//garbage recycle for stmt 449func (s *stmtDecorator) destroy() { 450 go func() { 451 s.wg.Wait() 452 _ = s.stmt.Close() 453 }() 454} 455 456func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { 457 return &stmtDecorator{ 458 stmt: sqlStmt, 459 } 460} 461 462func newStmtDecoratorLruWithEvict() *lru.Cache { 463 // temporarily solution 464 // we fixed this problem in v2.x 465 cache, _ := lru.NewWithEvict(50, func(key interface{}, value interface{}) { 466 value.(*stmtDecorator).destroy() 467 }) 468 return cache 469} 470