1/*
2Copyright 2014 SAP SE
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package driver
18
19import (
20	"context"
21	"database/sql"
22	"database/sql/driver"
23	"errors"
24	"fmt"
25	"io"
26	"reflect"
27	"sync"
28	"time"
29
30	"github.com/SAP/go-hdb/driver/sqltrace"
31
32	p "github.com/SAP/go-hdb/internal/protocol"
33)
34
35// DriverVersion is the version number of the hdb driver.
36const DriverVersion = "0.14.1"
37
38// DriverName is the driver name to use with sql.Open for hdb databases.
39const DriverName = "hdb"
40
41// Transaction isolation levels supported by hdb.
42const (
43	LevelReadCommitted  = "READ COMMITTED"
44	LevelRepeatableRead = "REPEATABLE READ"
45	LevelSerializable   = "SERIALIZABLE"
46)
47
48// Access modes supported by hdb.
49const (
50	modeReadOnly  = "READ ONLY"
51	modeReadWrite = "READ WRITE"
52)
53
54// map sql isolation level to hdb isolation level.
55var isolationLevel = map[driver.IsolationLevel]string{
56	driver.IsolationLevel(sql.LevelDefault):        LevelReadCommitted,
57	driver.IsolationLevel(sql.LevelReadCommitted):  LevelReadCommitted,
58	driver.IsolationLevel(sql.LevelRepeatableRead): LevelRepeatableRead,
59	driver.IsolationLevel(sql.LevelSerializable):   LevelSerializable,
60}
61
62// map sql read only flag to hdb access mode.
63var readOnly = map[bool]string{
64	true:  modeReadOnly,
65	false: modeReadWrite,
66}
67
68// ErrUnsupportedIsolationLevel is the error raised if a transaction is started with a not supported isolation level.
69var ErrUnsupportedIsolationLevel = errors.New("Unsupported isolation level")
70
71// ErrNestedTransaction is the error raised if a tranasction is created within a transaction as this is not supported by hdb.
72var ErrNestedTransaction = errors.New("Nested transactions are not supported")
73
74// needed for testing
75const driverDataFormatVersion = 1
76
77// queries
78const (
79	pingQuery          = "select 1 from dummy"
80	isolationLevelStmt = "set transaction isolation level %s"
81	accessModeStmt     = "set transaction %s"
82	sessionVariable    = "set %s=%s"
83)
84
85// bulk statement
86const (
87	bulk = "b$"
88)
89
90var (
91	flushTok   = new(struct{})
92	noFlushTok = new(struct{})
93)
94
95var (
96	// NoFlush is to be used as parameter in bulk statements to delay execution.
97	NoFlush = sql.Named(bulk, &noFlushTok)
98	// Flush can be used as optional parameter in bulk statements but is not required to trigger execution.
99	Flush = sql.Named(bulk, &flushTok)
100)
101
102var drv = &hdbDrv{}
103
104func init() {
105	sql.Register(DriverName, drv)
106}
107
108// driver
109
110//  check if driver implements all required interfaces
111var (
112	_ driver.Driver        = (*hdbDrv)(nil)
113	_ driver.DriverContext = (*hdbDrv)(nil)
114)
115
116type hdbDrv struct{}
117
118func (d *hdbDrv) Open(dsn string) (driver.Conn, error) {
119	connector, err := NewDSNConnector(dsn)
120	if err != nil {
121		return nil, err
122	}
123	return connector.Connect(context.Background())
124}
125
126func (d *hdbDrv) OpenConnector(dsn string) (driver.Connector, error) {
127	return NewDSNConnector(dsn)
128}
129
130// database connection
131
132//  check if conn implements all required interfaces
133var (
134	_ driver.Conn               = (*conn)(nil)
135	_ driver.ConnPrepareContext = (*conn)(nil)
136	_ driver.Pinger             = (*conn)(nil)
137	_ driver.ConnBeginTx        = (*conn)(nil)
138	_ driver.ExecerContext      = (*conn)(nil)
139	//go 1.9 issue (ExecerContext is only called if Execer is implemented)
140	_ driver.Execer         = (*conn)(nil)
141	_ driver.QueryerContext = (*conn)(nil)
142	//go 1.9 issue (QueryerContext is only called if Queryer is implemented)
143	// QueryContext is needed for stored procedures with table output parameters.
144	_ driver.Queryer           = (*conn)(nil)
145	_ driver.NamedValueChecker = (*conn)(nil)
146)
147
148type conn struct {
149	session *p.Session
150}
151
152func newConn(ctx context.Context, c *Connector) (driver.Conn, error) {
153	session, err := p.NewSession(ctx, c)
154	if err != nil {
155		return nil, err
156	}
157	conn := &conn{session: session}
158	if err := conn.init(ctx, c.sessionVariables); err != nil {
159		return nil, err
160	}
161	return conn, nil
162}
163
164func (c *conn) init(ctx context.Context, sv SessionVariables) error {
165	if sv == nil {
166		return nil
167	}
168	for k, v := range sv {
169		if _, err := c.ExecContext(ctx, fmt.Sprintf(sessionVariable, fmt.Sprintf("'%s'", k), fmt.Sprintf("'%s'", v)), nil); err != nil {
170			return err
171		}
172	}
173	return nil
174}
175
176func (c *conn) Prepare(query string) (driver.Stmt, error) {
177	panic("deprecated")
178}
179
180func (c *conn) Close() error {
181	c.session.Close()
182	return nil
183}
184
185func (c *conn) Begin() (driver.Tx, error) {
186	panic("deprecated")
187}
188
189func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
190
191	if c.session.IsBad() {
192		return nil, driver.ErrBadConn
193	}
194
195	if c.session.InTx() {
196		return nil, ErrNestedTransaction
197	}
198
199	level, ok := isolationLevel[opts.Isolation]
200	if !ok {
201		return nil, ErrUnsupportedIsolationLevel
202	}
203
204	done := make(chan struct{})
205	go func() {
206		// set isolation level
207		if _, err = c.ExecContext(ctx, fmt.Sprintf(isolationLevelStmt, level), nil); err != nil {
208			goto done
209		}
210		// set access mode
211		if _, err = c.ExecContext(ctx, fmt.Sprintf(accessModeStmt, readOnly[opts.ReadOnly]), nil); err != nil {
212			goto done
213		}
214		c.session.SetInTx(true)
215		tx = newTx(c.session)
216	done:
217		close(done)
218	}()
219
220	select {
221	case <-ctx.Done():
222		return nil, ctx.Err()
223	case <-done:
224		return tx, err
225	}
226}
227
228// Exec implements the database/sql/driver/Execer interface.
229// delete after go 1.9 compatibility is given up.
230func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
231	panic("deprecated")
232}
233
234// ExecContext implements the database/sql/driver/ExecerContext interface.
235func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
236	if c.session.IsBad() {
237		return nil, driver.ErrBadConn
238	}
239
240	if len(args) != 0 {
241		return nil, driver.ErrSkip //fast path not possible (prepare needed)
242	}
243
244	sqltrace.Traceln(query)
245
246	done := make(chan struct{})
247	go func() {
248		r, err = c.session.ExecDirect(query)
249		close(done)
250	}()
251
252	select {
253	case <-ctx.Done():
254		return nil, ctx.Err()
255	case <-done:
256		return r, err
257	}
258}
259
260// Queryer implements the database/sql/driver/Queryer interface.
261// delete after go 1.9 compatibility is given up.
262func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
263	panic("deprecated")
264}
265
266func (c *conn) Ping(ctx context.Context) (err error) {
267	if c.session.IsBad() {
268		return driver.ErrBadConn
269	}
270
271	done := make(chan struct{})
272	go func() {
273		_, err = c.QueryContext(ctx, pingQuery, nil)
274		close(done)
275	}()
276
277	select {
278	case <-ctx.Done():
279		return ctx.Err()
280	case <-done:
281		return err
282	}
283}
284
285// CheckNamedValue implements NamedValueChecker interface.
286// implemented for conn:
287// if querier or execer is called, sql checks parameters before
288// in case of parameters the method can be 'skipped' and force the prepare path
289// --> guarantee that a valid driver value is returned
290// --> if not implemented, Lob need to have a pseudo Value method to return a valid driver value
291func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
292	switch nv.Value.(type) {
293	case Lob, *Lob:
294		nv.Value = nil
295	}
296	return nil
297}
298
299//transaction
300
301//  check if tx implements all required interfaces
302var (
303	_ driver.Tx = (*tx)(nil)
304)
305
306type tx struct {
307	session *p.Session
308}
309
310func newTx(session *p.Session) *tx {
311	return &tx{
312		session: session,
313	}
314}
315
316func (t *tx) Commit() error {
317	if t.session.IsBad() {
318		return driver.ErrBadConn
319	}
320
321	return t.session.Commit()
322}
323
324func (t *tx) Rollback() error {
325	if t.session.IsBad() {
326		return driver.ErrBadConn
327	}
328
329	return t.session.Rollback()
330}
331
332//statement
333
334var argsPool = sync.Pool{}
335
336//  check if stmt implements all required interfaces
337var (
338	_ driver.Stmt              = (*stmt)(nil)
339	_ driver.StmtExecContext   = (*stmt)(nil)
340	_ driver.StmtQueryContext  = (*stmt)(nil)
341	_ driver.NamedValueChecker = (*stmt)(nil)
342)
343
344type stmt struct {
345	qt             p.QueryType
346	session        *p.Session
347	query          string
348	id             uint64
349	prmFieldSet    *p.ParameterFieldSet
350	resultFieldSet *p.ResultFieldSet
351	bulk, noFlush  bool
352	numArg         int
353	args           []driver.NamedValue
354}
355
356func newStmt(qt p.QueryType, session *p.Session, query string, id uint64, prmFieldSet *p.ParameterFieldSet, resultFieldSet *p.ResultFieldSet) (*stmt, error) {
357	return &stmt{qt: qt, session: session, query: query, id: id, prmFieldSet: prmFieldSet, resultFieldSet: resultFieldSet}, nil
358}
359
360func (s *stmt) Close() error {
361	if s.args != nil {
362		if len(s.args) != 0 {
363			sqltrace.Tracef("close: %s - not flushed records: %d)", s.query, int(len(s.args)/s.NumInput()))
364		}
365		argsPool.Put(s.args)
366		s.args = nil
367	}
368	return s.session.DropStatementID(s.id)
369}
370
371func (s *stmt) NumInput() int {
372	return s.prmFieldSet.NumInputField()
373}
374
375func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
376	panic("deprecated")
377}
378
379func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
380	if s.session.IsBad() {
381		return nil, driver.ErrBadConn
382	}
383
384	numField := s.prmFieldSet.NumInputField()
385	if len(args) != numField {
386		return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField)
387	}
388
389	sqltrace.Tracef("%s %v", s.query, args)
390
391	// init noFlush
392	noFlush := s.noFlush
393	s.noFlush = false
394
395	var _args []driver.NamedValue
396
397	done := make(chan struct{})
398
399	if !s.bulk {
400		go func() {
401			r, err = s.session.Exec(s.id, s.prmFieldSet, args)
402			close(done)
403		}()
404		goto done
405	}
406
407	if s.args == nil {
408		s.args, _ = argsPool.Get().([]driver.NamedValue)
409		if s.args == nil {
410			s.args = make([]driver.NamedValue, 0, len(args)*1000)
411		}
412		s.args = s.args[:0]
413	}
414
415	s.args = append(s.args, args...)
416	s.numArg++
417
418	if noFlush && s.numArg < maxSmallint { //TODO: check why bigArgument count does not work
419		return driver.ResultNoRows, nil
420	}
421
422	_args, _ = argsPool.Get().([]driver.NamedValue)
423	if _args == nil || cap(_args) < len(s.args) {
424		_args = make([]driver.NamedValue, len(s.args))
425	}
426	_args = _args[:len(s.args)]
427
428	copy(_args, s.args)
429	s.args = s.args[:0]
430	s.numArg = 0
431
432	go func() {
433		r, err = s.session.Exec(s.id, s.prmFieldSet, _args)
434		argsPool.Put(_args)
435		close(done)
436	}()
437
438done:
439	select {
440	case <-ctx.Done():
441		return nil, ctx.Err()
442	case <-done:
443		return r, err
444	}
445}
446
447func (s *stmt) Query(args []driver.Value) (rows driver.Rows, err error) {
448	panic("deprecated")
449}
450
451// Deprecated: see NamedValueChecker.
452//func (s *stmt) ColumnConverter(idx int) driver.ValueConverter {
453//}
454
455// CheckNamedValue implements NamedValueChecker interface.
456func (s *stmt) CheckNamedValue(nv *driver.NamedValue) error {
457	if nv.Name == bulk {
458		if ptr, ok := nv.Value.(**struct{}); ok {
459			switch ptr {
460			case &noFlushTok:
461				s.bulk, s.noFlush = true, true
462				return driver.ErrRemoveArgument
463			case &flushTok:
464				return driver.ErrRemoveArgument
465			}
466		}
467	}
468	return checkNamedValue(s.prmFieldSet, nv)
469}
470
471// driver.Rows drop-in replacement if driver Query or QueryRow is used for statements that doesn't return rows
472var noColumns = []string{}
473var noResult = new(noResultType)
474
475//  check if noResultType implements all required interfaces
476var (
477	_ driver.Rows = (*noResultType)(nil)
478)
479
480type noResultType struct{}
481
482func (r *noResultType) Columns() []string              { return noColumns }
483func (r *noResultType) Close() error                   { return nil }
484func (r *noResultType) Next(dest []driver.Value) error { return io.EOF }
485
486// rows
487type rows struct {
488}
489
490// query result
491
492//  check if queryResult implements all required interfaces
493var (
494	_ driver.Rows                           = (*queryResult)(nil)
495	_ driver.RowsColumnTypeDatabaseTypeName = (*queryResult)(nil) // go 1.8
496	_ driver.RowsColumnTypeLength           = (*queryResult)(nil) // go 1.8
497	_ driver.RowsColumnTypeNullable         = (*queryResult)(nil) // go 1.8
498	_ driver.RowsColumnTypePrecisionScale   = (*queryResult)(nil) // go 1.8
499	_ driver.RowsColumnTypeScanType         = (*queryResult)(nil) // go 1.8
500)
501
502type queryResult struct {
503	session        *p.Session
504	id             uint64
505	resultFieldSet *p.ResultFieldSet
506	fieldValues    *p.FieldValues
507	pos            int
508	attrs          p.PartAttributes
509	columns        []string
510	lastErr        error
511}
512
513func newQueryResult(session *p.Session, id uint64, resultFieldSet *p.ResultFieldSet, fieldValues *p.FieldValues, attrs p.PartAttributes) (driver.Rows, error) {
514	columns := make([]string, resultFieldSet.NumField())
515	for i := 0; i < len(columns); i++ {
516		columns[i] = resultFieldSet.Field(i).Name()
517	}
518
519	return &queryResult{
520		session:        session,
521		id:             id,
522		resultFieldSet: resultFieldSet,
523		fieldValues:    fieldValues,
524		attrs:          attrs,
525		columns:        columns,
526	}, nil
527}
528
529func (r *queryResult) Columns() []string {
530	return r.columns
531}
532
533func (r *queryResult) Close() error {
534	// if lastError is set, attrs are nil
535	if r.lastErr != nil {
536		return r.lastErr
537	}
538
539	if !r.attrs.ResultsetClosed() {
540		return r.session.CloseResultsetID(r.id)
541	}
542	return nil
543}
544
545func (r *queryResult) Next(dest []driver.Value) error {
546	if r.session.IsBad() {
547		return driver.ErrBadConn
548	}
549
550	if r.pos >= r.fieldValues.NumRow() {
551		if r.attrs.LastPacket() {
552			return io.EOF
553		}
554
555		var err error
556
557		if r.attrs, err = r.session.FetchNext(r.id, r.resultFieldSet, r.fieldValues); err != nil {
558			r.lastErr = err //fieldValues and attrs are nil
559			return err
560		}
561
562		if r.attrs.NoRows() {
563			return io.EOF
564		}
565
566		r.pos = 0
567
568	}
569
570	r.fieldValues.Row(r.pos, dest)
571	r.pos++
572
573	return nil
574}
575
576func (r *queryResult) ColumnTypeDatabaseTypeName(idx int) string {
577	return r.resultFieldSet.Field(idx).TypeCode().TypeName()
578}
579
580func (r *queryResult) ColumnTypeLength(idx int) (int64, bool) {
581	return r.resultFieldSet.Field(idx).TypeLength()
582}
583
584func (r *queryResult) ColumnTypePrecisionScale(idx int) (int64, int64, bool) {
585	return r.resultFieldSet.Field(idx).TypePrecisionScale()
586}
587
588func (r *queryResult) ColumnTypeNullable(idx int) (bool, bool) {
589	return r.resultFieldSet.Field(idx).Nullable(), true
590}
591
592var (
593	scanTypeUnknown  = reflect.TypeOf(new(interface{})).Elem()
594	scanTypeTinyint  = reflect.TypeOf(uint8(0))
595	scanTypeSmallint = reflect.TypeOf(int16(0))
596	scanTypeInteger  = reflect.TypeOf(int32(0))
597	scanTypeBigint   = reflect.TypeOf(int64(0))
598	scanTypeReal     = reflect.TypeOf(float32(0.0))
599	scanTypeDouble   = reflect.TypeOf(float64(0.0))
600	scanTypeTime     = reflect.TypeOf(time.Time{})
601	scanTypeString   = reflect.TypeOf(string(""))
602	scanTypeBytes    = reflect.TypeOf([]byte{})
603	scanTypeDecimal  = reflect.TypeOf(Decimal{})
604	scanTypeLob      = reflect.TypeOf(Lob{})
605)
606
607func (r *queryResult) ColumnTypeScanType(idx int) reflect.Type {
608	switch r.resultFieldSet.Field(idx).TypeCode().DataType() {
609	default:
610		return scanTypeUnknown
611	case p.DtTinyint:
612		return scanTypeTinyint
613	case p.DtSmallint:
614		return scanTypeSmallint
615	case p.DtInteger:
616		return scanTypeInteger
617	case p.DtBigint:
618		return scanTypeBigint
619	case p.DtReal:
620		return scanTypeReal
621	case p.DtDouble:
622		return scanTypeDouble
623	case p.DtTime:
624		return scanTypeTime
625	case p.DtDecimal:
626		return scanTypeDecimal
627	case p.DtString:
628		return scanTypeString
629	case p.DtBytes:
630		return scanTypeBytes
631	case p.DtLob:
632		return scanTypeLob
633	}
634}
635