1/* 2Copyright 2014 SAP SE 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package driver 18 19import ( 20 "context" 21 "database/sql" 22 "database/sql/driver" 23 "errors" 24 "fmt" 25 "io" 26 "reflect" 27 "sync" 28 "time" 29 30 "github.com/SAP/go-hdb/driver/sqltrace" 31 32 p "github.com/SAP/go-hdb/internal/protocol" 33) 34 35// DriverVersion is the version number of the hdb driver. 36const DriverVersion = "0.14.1" 37 38// DriverName is the driver name to use with sql.Open for hdb databases. 39const DriverName = "hdb" 40 41// Transaction isolation levels supported by hdb. 42const ( 43 LevelReadCommitted = "READ COMMITTED" 44 LevelRepeatableRead = "REPEATABLE READ" 45 LevelSerializable = "SERIALIZABLE" 46) 47 48// Access modes supported by hdb. 49const ( 50 modeReadOnly = "READ ONLY" 51 modeReadWrite = "READ WRITE" 52) 53 54// map sql isolation level to hdb isolation level. 55var isolationLevel = map[driver.IsolationLevel]string{ 56 driver.IsolationLevel(sql.LevelDefault): LevelReadCommitted, 57 driver.IsolationLevel(sql.LevelReadCommitted): LevelReadCommitted, 58 driver.IsolationLevel(sql.LevelRepeatableRead): LevelRepeatableRead, 59 driver.IsolationLevel(sql.LevelSerializable): LevelSerializable, 60} 61 62// map sql read only flag to hdb access mode. 63var readOnly = map[bool]string{ 64 true: modeReadOnly, 65 false: modeReadWrite, 66} 67 68// ErrUnsupportedIsolationLevel is the error raised if a transaction is started with a not supported isolation level. 69var ErrUnsupportedIsolationLevel = errors.New("Unsupported isolation level") 70 71// ErrNestedTransaction is the error raised if a tranasction is created within a transaction as this is not supported by hdb. 72var ErrNestedTransaction = errors.New("Nested transactions are not supported") 73 74// needed for testing 75const driverDataFormatVersion = 1 76 77// queries 78const ( 79 pingQuery = "select 1 from dummy" 80 isolationLevelStmt = "set transaction isolation level %s" 81 accessModeStmt = "set transaction %s" 82 sessionVariable = "set %s=%s" 83) 84 85// bulk statement 86const ( 87 bulk = "b$" 88) 89 90var ( 91 flushTok = new(struct{}) 92 noFlushTok = new(struct{}) 93) 94 95var ( 96 // NoFlush is to be used as parameter in bulk statements to delay execution. 97 NoFlush = sql.Named(bulk, &noFlushTok) 98 // Flush can be used as optional parameter in bulk statements but is not required to trigger execution. 99 Flush = sql.Named(bulk, &flushTok) 100) 101 102var drv = &hdbDrv{} 103 104func init() { 105 sql.Register(DriverName, drv) 106} 107 108// driver 109 110// check if driver implements all required interfaces 111var ( 112 _ driver.Driver = (*hdbDrv)(nil) 113 _ driver.DriverContext = (*hdbDrv)(nil) 114) 115 116type hdbDrv struct{} 117 118func (d *hdbDrv) Open(dsn string) (driver.Conn, error) { 119 connector, err := NewDSNConnector(dsn) 120 if err != nil { 121 return nil, err 122 } 123 return connector.Connect(context.Background()) 124} 125 126func (d *hdbDrv) OpenConnector(dsn string) (driver.Connector, error) { 127 return NewDSNConnector(dsn) 128} 129 130// database connection 131 132// check if conn implements all required interfaces 133var ( 134 _ driver.Conn = (*conn)(nil) 135 _ driver.ConnPrepareContext = (*conn)(nil) 136 _ driver.Pinger = (*conn)(nil) 137 _ driver.ConnBeginTx = (*conn)(nil) 138 _ driver.ExecerContext = (*conn)(nil) 139 //go 1.9 issue (ExecerContext is only called if Execer is implemented) 140 _ driver.Execer = (*conn)(nil) 141 _ driver.QueryerContext = (*conn)(nil) 142 //go 1.9 issue (QueryerContext is only called if Queryer is implemented) 143 // QueryContext is needed for stored procedures with table output parameters. 144 _ driver.Queryer = (*conn)(nil) 145 _ driver.NamedValueChecker = (*conn)(nil) 146) 147 148type conn struct { 149 session *p.Session 150} 151 152func newConn(ctx context.Context, c *Connector) (driver.Conn, error) { 153 session, err := p.NewSession(ctx, c) 154 if err != nil { 155 return nil, err 156 } 157 conn := &conn{session: session} 158 if err := conn.init(ctx, c.sessionVariables); err != nil { 159 return nil, err 160 } 161 return conn, nil 162} 163 164func (c *conn) init(ctx context.Context, sv SessionVariables) error { 165 if sv == nil { 166 return nil 167 } 168 for k, v := range sv { 169 if _, err := c.ExecContext(ctx, fmt.Sprintf(sessionVariable, fmt.Sprintf("'%s'", k), fmt.Sprintf("'%s'", v)), nil); err != nil { 170 return err 171 } 172 } 173 return nil 174} 175 176func (c *conn) Prepare(query string) (driver.Stmt, error) { 177 panic("deprecated") 178} 179 180func (c *conn) Close() error { 181 c.session.Close() 182 return nil 183} 184 185func (c *conn) Begin() (driver.Tx, error) { 186 panic("deprecated") 187} 188 189func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { 190 191 if c.session.IsBad() { 192 return nil, driver.ErrBadConn 193 } 194 195 if c.session.InTx() { 196 return nil, ErrNestedTransaction 197 } 198 199 level, ok := isolationLevel[opts.Isolation] 200 if !ok { 201 return nil, ErrUnsupportedIsolationLevel 202 } 203 204 done := make(chan struct{}) 205 go func() { 206 // set isolation level 207 if _, err = c.ExecContext(ctx, fmt.Sprintf(isolationLevelStmt, level), nil); err != nil { 208 goto done 209 } 210 // set access mode 211 if _, err = c.ExecContext(ctx, fmt.Sprintf(accessModeStmt, readOnly[opts.ReadOnly]), nil); err != nil { 212 goto done 213 } 214 c.session.SetInTx(true) 215 tx = newTx(c.session) 216 done: 217 close(done) 218 }() 219 220 select { 221 case <-ctx.Done(): 222 return nil, ctx.Err() 223 case <-done: 224 return tx, err 225 } 226} 227 228// Exec implements the database/sql/driver/Execer interface. 229// delete after go 1.9 compatibility is given up. 230func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { 231 panic("deprecated") 232} 233 234// ExecContext implements the database/sql/driver/ExecerContext interface. 235func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { 236 if c.session.IsBad() { 237 return nil, driver.ErrBadConn 238 } 239 240 if len(args) != 0 { 241 return nil, driver.ErrSkip //fast path not possible (prepare needed) 242 } 243 244 sqltrace.Traceln(query) 245 246 done := make(chan struct{}) 247 go func() { 248 r, err = c.session.ExecDirect(query) 249 close(done) 250 }() 251 252 select { 253 case <-ctx.Done(): 254 return nil, ctx.Err() 255 case <-done: 256 return r, err 257 } 258} 259 260// Queryer implements the database/sql/driver/Queryer interface. 261// delete after go 1.9 compatibility is given up. 262func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 263 panic("deprecated") 264} 265 266func (c *conn) Ping(ctx context.Context) (err error) { 267 if c.session.IsBad() { 268 return driver.ErrBadConn 269 } 270 271 done := make(chan struct{}) 272 go func() { 273 _, err = c.QueryContext(ctx, pingQuery, nil) 274 close(done) 275 }() 276 277 select { 278 case <-ctx.Done(): 279 return ctx.Err() 280 case <-done: 281 return err 282 } 283} 284 285// CheckNamedValue implements NamedValueChecker interface. 286// implemented for conn: 287// if querier or execer is called, sql checks parameters before 288// in case of parameters the method can be 'skipped' and force the prepare path 289// --> guarantee that a valid driver value is returned 290// --> if not implemented, Lob need to have a pseudo Value method to return a valid driver value 291func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { 292 switch nv.Value.(type) { 293 case Lob, *Lob: 294 nv.Value = nil 295 } 296 return nil 297} 298 299//transaction 300 301// check if tx implements all required interfaces 302var ( 303 _ driver.Tx = (*tx)(nil) 304) 305 306type tx struct { 307 session *p.Session 308} 309 310func newTx(session *p.Session) *tx { 311 return &tx{ 312 session: session, 313 } 314} 315 316func (t *tx) Commit() error { 317 if t.session.IsBad() { 318 return driver.ErrBadConn 319 } 320 321 return t.session.Commit() 322} 323 324func (t *tx) Rollback() error { 325 if t.session.IsBad() { 326 return driver.ErrBadConn 327 } 328 329 return t.session.Rollback() 330} 331 332//statement 333 334var argsPool = sync.Pool{} 335 336// check if stmt implements all required interfaces 337var ( 338 _ driver.Stmt = (*stmt)(nil) 339 _ driver.StmtExecContext = (*stmt)(nil) 340 _ driver.StmtQueryContext = (*stmt)(nil) 341 _ driver.NamedValueChecker = (*stmt)(nil) 342) 343 344type stmt struct { 345 qt p.QueryType 346 session *p.Session 347 query string 348 id uint64 349 prmFieldSet *p.ParameterFieldSet 350 resultFieldSet *p.ResultFieldSet 351 bulk, noFlush bool 352 numArg int 353 args []driver.NamedValue 354} 355 356func newStmt(qt p.QueryType, session *p.Session, query string, id uint64, prmFieldSet *p.ParameterFieldSet, resultFieldSet *p.ResultFieldSet) (*stmt, error) { 357 return &stmt{qt: qt, session: session, query: query, id: id, prmFieldSet: prmFieldSet, resultFieldSet: resultFieldSet}, nil 358} 359 360func (s *stmt) Close() error { 361 if s.args != nil { 362 if len(s.args) != 0 { 363 sqltrace.Tracef("close: %s - not flushed records: %d)", s.query, int(len(s.args)/s.NumInput())) 364 } 365 argsPool.Put(s.args) 366 s.args = nil 367 } 368 return s.session.DropStatementID(s.id) 369} 370 371func (s *stmt) NumInput() int { 372 return s.prmFieldSet.NumInputField() 373} 374 375func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { 376 panic("deprecated") 377} 378 379func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { 380 if s.session.IsBad() { 381 return nil, driver.ErrBadConn 382 } 383 384 numField := s.prmFieldSet.NumInputField() 385 if len(args) != numField { 386 return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField) 387 } 388 389 sqltrace.Tracef("%s %v", s.query, args) 390 391 // init noFlush 392 noFlush := s.noFlush 393 s.noFlush = false 394 395 var _args []driver.NamedValue 396 397 done := make(chan struct{}) 398 399 if !s.bulk { 400 go func() { 401 r, err = s.session.Exec(s.id, s.prmFieldSet, args) 402 close(done) 403 }() 404 goto done 405 } 406 407 if s.args == nil { 408 s.args, _ = argsPool.Get().([]driver.NamedValue) 409 if s.args == nil { 410 s.args = make([]driver.NamedValue, 0, len(args)*1000) 411 } 412 s.args = s.args[:0] 413 } 414 415 s.args = append(s.args, args...) 416 s.numArg++ 417 418 if noFlush && s.numArg < maxSmallint { //TODO: check why bigArgument count does not work 419 return driver.ResultNoRows, nil 420 } 421 422 _args, _ = argsPool.Get().([]driver.NamedValue) 423 if _args == nil || cap(_args) < len(s.args) { 424 _args = make([]driver.NamedValue, len(s.args)) 425 } 426 _args = _args[:len(s.args)] 427 428 copy(_args, s.args) 429 s.args = s.args[:0] 430 s.numArg = 0 431 432 go func() { 433 r, err = s.session.Exec(s.id, s.prmFieldSet, _args) 434 argsPool.Put(_args) 435 close(done) 436 }() 437 438done: 439 select { 440 case <-ctx.Done(): 441 return nil, ctx.Err() 442 case <-done: 443 return r, err 444 } 445} 446 447func (s *stmt) Query(args []driver.Value) (rows driver.Rows, err error) { 448 panic("deprecated") 449} 450 451// Deprecated: see NamedValueChecker. 452//func (s *stmt) ColumnConverter(idx int) driver.ValueConverter { 453//} 454 455// CheckNamedValue implements NamedValueChecker interface. 456func (s *stmt) CheckNamedValue(nv *driver.NamedValue) error { 457 if nv.Name == bulk { 458 if ptr, ok := nv.Value.(**struct{}); ok { 459 switch ptr { 460 case &noFlushTok: 461 s.bulk, s.noFlush = true, true 462 return driver.ErrRemoveArgument 463 case &flushTok: 464 return driver.ErrRemoveArgument 465 } 466 } 467 } 468 return checkNamedValue(s.prmFieldSet, nv) 469} 470 471// driver.Rows drop-in replacement if driver Query or QueryRow is used for statements that doesn't return rows 472var noColumns = []string{} 473var noResult = new(noResultType) 474 475// check if noResultType implements all required interfaces 476var ( 477 _ driver.Rows = (*noResultType)(nil) 478) 479 480type noResultType struct{} 481 482func (r *noResultType) Columns() []string { return noColumns } 483func (r *noResultType) Close() error { return nil } 484func (r *noResultType) Next(dest []driver.Value) error { return io.EOF } 485 486// rows 487type rows struct { 488} 489 490// query result 491 492// check if queryResult implements all required interfaces 493var ( 494 _ driver.Rows = (*queryResult)(nil) 495 _ driver.RowsColumnTypeDatabaseTypeName = (*queryResult)(nil) // go 1.8 496 _ driver.RowsColumnTypeLength = (*queryResult)(nil) // go 1.8 497 _ driver.RowsColumnTypeNullable = (*queryResult)(nil) // go 1.8 498 _ driver.RowsColumnTypePrecisionScale = (*queryResult)(nil) // go 1.8 499 _ driver.RowsColumnTypeScanType = (*queryResult)(nil) // go 1.8 500) 501 502type queryResult struct { 503 session *p.Session 504 id uint64 505 resultFieldSet *p.ResultFieldSet 506 fieldValues *p.FieldValues 507 pos int 508 attrs p.PartAttributes 509 columns []string 510 lastErr error 511} 512 513func newQueryResult(session *p.Session, id uint64, resultFieldSet *p.ResultFieldSet, fieldValues *p.FieldValues, attrs p.PartAttributes) (driver.Rows, error) { 514 columns := make([]string, resultFieldSet.NumField()) 515 for i := 0; i < len(columns); i++ { 516 columns[i] = resultFieldSet.Field(i).Name() 517 } 518 519 return &queryResult{ 520 session: session, 521 id: id, 522 resultFieldSet: resultFieldSet, 523 fieldValues: fieldValues, 524 attrs: attrs, 525 columns: columns, 526 }, nil 527} 528 529func (r *queryResult) Columns() []string { 530 return r.columns 531} 532 533func (r *queryResult) Close() error { 534 // if lastError is set, attrs are nil 535 if r.lastErr != nil { 536 return r.lastErr 537 } 538 539 if !r.attrs.ResultsetClosed() { 540 return r.session.CloseResultsetID(r.id) 541 } 542 return nil 543} 544 545func (r *queryResult) Next(dest []driver.Value) error { 546 if r.session.IsBad() { 547 return driver.ErrBadConn 548 } 549 550 if r.pos >= r.fieldValues.NumRow() { 551 if r.attrs.LastPacket() { 552 return io.EOF 553 } 554 555 var err error 556 557 if r.attrs, err = r.session.FetchNext(r.id, r.resultFieldSet, r.fieldValues); err != nil { 558 r.lastErr = err //fieldValues and attrs are nil 559 return err 560 } 561 562 if r.attrs.NoRows() { 563 return io.EOF 564 } 565 566 r.pos = 0 567 568 } 569 570 r.fieldValues.Row(r.pos, dest) 571 r.pos++ 572 573 return nil 574} 575 576func (r *queryResult) ColumnTypeDatabaseTypeName(idx int) string { 577 return r.resultFieldSet.Field(idx).TypeCode().TypeName() 578} 579 580func (r *queryResult) ColumnTypeLength(idx int) (int64, bool) { 581 return r.resultFieldSet.Field(idx).TypeLength() 582} 583 584func (r *queryResult) ColumnTypePrecisionScale(idx int) (int64, int64, bool) { 585 return r.resultFieldSet.Field(idx).TypePrecisionScale() 586} 587 588func (r *queryResult) ColumnTypeNullable(idx int) (bool, bool) { 589 return r.resultFieldSet.Field(idx).Nullable(), true 590} 591 592var ( 593 scanTypeUnknown = reflect.TypeOf(new(interface{})).Elem() 594 scanTypeTinyint = reflect.TypeOf(uint8(0)) 595 scanTypeSmallint = reflect.TypeOf(int16(0)) 596 scanTypeInteger = reflect.TypeOf(int32(0)) 597 scanTypeBigint = reflect.TypeOf(int64(0)) 598 scanTypeReal = reflect.TypeOf(float32(0.0)) 599 scanTypeDouble = reflect.TypeOf(float64(0.0)) 600 scanTypeTime = reflect.TypeOf(time.Time{}) 601 scanTypeString = reflect.TypeOf(string("")) 602 scanTypeBytes = reflect.TypeOf([]byte{}) 603 scanTypeDecimal = reflect.TypeOf(Decimal{}) 604 scanTypeLob = reflect.TypeOf(Lob{}) 605) 606 607func (r *queryResult) ColumnTypeScanType(idx int) reflect.Type { 608 switch r.resultFieldSet.Field(idx).TypeCode().DataType() { 609 default: 610 return scanTypeUnknown 611 case p.DtTinyint: 612 return scanTypeTinyint 613 case p.DtSmallint: 614 return scanTypeSmallint 615 case p.DtInteger: 616 return scanTypeInteger 617 case p.DtBigint: 618 return scanTypeBigint 619 case p.DtReal: 620 return scanTypeReal 621 case p.DtDouble: 622 return scanTypeDouble 623 case p.DtTime: 624 return scanTypeTime 625 case p.DtDecimal: 626 return scanTypeDecimal 627 case p.DtString: 628 return scanTypeString 629 case p.DtBytes: 630 return scanTypeBytes 631 case p.DtLob: 632 return scanTypeLob 633 } 634} 635