1package mssql
2
3import (
4	"context"
5	"database/sql"
6	"database/sql/driver"
7	"encoding/binary"
8	"errors"
9	"fmt"
10	"io"
11	"math"
12	"net"
13	"reflect"
14	"strings"
15	"time"
16	"unicode"
17
18	"github.com/denisenkom/go-mssqldb/internal/querytext"
19)
20
21// ReturnStatus may be used to return the return value from a proc.
22//
23//   var rs mssql.ReturnStatus
24//   _, err := db.Exec("theproc", &rs)
25//   log.Printf("return status = %d", rs)
26type ReturnStatus int32
27
28var driverInstance = &Driver{processQueryText: true}
29var driverInstanceNoProcess = &Driver{processQueryText: false}
30
31func init() {
32	sql.Register("mssql", driverInstance)
33	sql.Register("sqlserver", driverInstanceNoProcess)
34	createDialer = func(p *connectParams) Dialer {
35		return netDialer{&net.Dialer{KeepAlive: p.keepAlive}}
36	}
37}
38
39var createDialer func(p *connectParams) Dialer
40
41type netDialer struct {
42	nd *net.Dialer
43}
44
45func (d netDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
46	return d.nd.DialContext(ctx, network, addr)
47}
48
49type Driver struct {
50	log optionalLogger
51
52	processQueryText bool
53}
54
55// OpenConnector opens a new connector. Useful to dial with a context.
56func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
57	params, err := parseConnectParams(dsn)
58	if err != nil {
59		return nil, err
60	}
61	return &Connector{
62		params: params,
63		driver: d,
64	}, nil
65}
66
67func (d *Driver) Open(dsn string) (driver.Conn, error) {
68	return d.open(context.Background(), dsn)
69}
70
71func SetLogger(logger Logger) {
72	driverInstance.SetLogger(logger)
73	driverInstanceNoProcess.SetLogger(logger)
74}
75
76func (d *Driver) SetLogger(logger Logger) {
77	d.log = optionalLogger{logger}
78}
79
80// NewConnector creates a new connector from a DSN.
81// The returned connector may be used with sql.OpenDB.
82func NewConnector(dsn string) (*Connector, error) {
83	params, err := parseConnectParams(dsn)
84	if err != nil {
85		return nil, err
86	}
87	c := &Connector{
88		params: params,
89		driver: driverInstanceNoProcess,
90	}
91	return c, nil
92}
93
94// Connector holds the parsed DSN and is ready to make a new connection
95// at any time.
96//
97// In the future, settings that cannot be passed through a string DSN
98// may be set directly on the connector.
99type Connector struct {
100	params connectParams
101	driver *Driver
102
103	// SessionInitSQL is executed after marking a given session to be reset.
104	// When not present, the next query will still reset the session to the
105	// database defaults.
106	//
107	// When present the connection will immediately mark the session to
108	// be reset, then execute the SessionInitSQL text to setup the session
109	// that may be different from the base database defaults.
110	//
111	// For Example, the application relies on the following defaults
112	// but is not allowed to set them at the database system level.
113	//
114	//    SET XACT_ABORT ON;
115	//    SET TEXTSIZE -1;
116	//    SET ANSI_NULLS ON;
117	//    SET LOCK_TIMEOUT 10000;
118	//
119	// SessionInitSQL should not attempt to manually call sp_reset_connection.
120	// This will happen at the TDS layer.
121	//
122	// SessionInitSQL is optional. The session will be reset even if
123	// SessionInitSQL is empty.
124	SessionInitSQL string
125
126	// Dialer sets a custom dialer for all network operations.
127	// If Dialer is not set, normal net dialers are used.
128	Dialer Dialer
129}
130
131type Dialer interface {
132	DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
133}
134
135func (c *Connector) getDialer(p *connectParams) Dialer {
136	if c != nil && c.Dialer != nil {
137		return c.Dialer
138	}
139	return createDialer(p)
140}
141
142type Conn struct {
143	connector      *Connector
144	sess           *tdsSession
145	transactionCtx context.Context
146	resetSession   bool
147
148	processQueryText bool
149	connectionGood   bool
150
151	outs         map[string]interface{}
152	returnStatus *ReturnStatus
153}
154
155func (c *Conn) setReturnStatus(s ReturnStatus) {
156	if c.returnStatus == nil {
157		return
158	}
159	*c.returnStatus = s
160}
161
162func (c *Conn) checkBadConn(err error) error {
163	// this is a hack to address Issue #275
164	// we set connectionGood flag to false if
165	// error indicates that connection is not usable
166	// but we return actual error instead of ErrBadConn
167	// this will cause connection to stay in a pool
168	// but next request to this connection will return ErrBadConn
169
170	// it might be possible to revise this hack after
171	// https://github.com/golang/go/issues/20807
172	// is implemented
173	switch err {
174	case nil:
175		return nil
176	case io.EOF:
177		c.connectionGood = false
178		return driver.ErrBadConn
179	case driver.ErrBadConn:
180		// It is an internal programming error if driver.ErrBadConn
181		// is ever passed to this function. driver.ErrBadConn should
182		// only ever be returned in response to a *mssql.Conn.connectionGood == false
183		// check in the external facing API.
184		panic("driver.ErrBadConn in checkBadConn. This should not happen.")
185	}
186
187	switch err.(type) {
188	case net.Error:
189		c.connectionGood = false
190		return err
191	case StreamError:
192		c.connectionGood = false
193		return err
194	default:
195		return err
196	}
197}
198
199func (c *Conn) clearOuts() {
200	c.outs = nil
201}
202
203func (c *Conn) simpleProcessResp(ctx context.Context) error {
204	tokchan := make(chan tokenStruct, 5)
205	go processResponse(ctx, c.sess, tokchan, c.outs)
206	c.clearOuts()
207	for tok := range tokchan {
208		switch token := tok.(type) {
209		case doneStruct:
210			if token.isError() {
211				return c.checkBadConn(token.getError())
212			}
213		case error:
214			return c.checkBadConn(token)
215		}
216	}
217	return nil
218}
219
220func (c *Conn) Commit() error {
221	if !c.connectionGood {
222		return driver.ErrBadConn
223	}
224	if err := c.sendCommitRequest(); err != nil {
225		return c.checkBadConn(err)
226	}
227	return c.simpleProcessResp(c.transactionCtx)
228}
229
230func (c *Conn) sendCommitRequest() error {
231	headers := []headerStruct{
232		{hdrtype: dataStmHdrTransDescr,
233			data: transDescrHdr{c.sess.tranid, 1}.pack()},
234	}
235	reset := c.resetSession
236	c.resetSession = false
237	if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
238		if c.sess.logFlags&logErrors != 0 {
239			c.sess.log.Printf("Failed to send CommitXact with %v", err)
240		}
241		c.connectionGood = false
242		return fmt.Errorf("Faild to send CommitXact: %v", err)
243	}
244	return nil
245}
246
247func (c *Conn) Rollback() error {
248	if !c.connectionGood {
249		return driver.ErrBadConn
250	}
251	if err := c.sendRollbackRequest(); err != nil {
252		return c.checkBadConn(err)
253	}
254	return c.simpleProcessResp(c.transactionCtx)
255}
256
257func (c *Conn) sendRollbackRequest() error {
258	headers := []headerStruct{
259		{hdrtype: dataStmHdrTransDescr,
260			data: transDescrHdr{c.sess.tranid, 1}.pack()},
261	}
262	reset := c.resetSession
263	c.resetSession = false
264	if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
265		if c.sess.logFlags&logErrors != 0 {
266			c.sess.log.Printf("Failed to send RollbackXact with %v", err)
267		}
268		c.connectionGood = false
269		return fmt.Errorf("Failed to send RollbackXact: %v", err)
270	}
271	return nil
272}
273
274func (c *Conn) Begin() (driver.Tx, error) {
275	return c.begin(context.Background(), isolationUseCurrent)
276}
277
278func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
279	if !c.connectionGood {
280		return nil, driver.ErrBadConn
281	}
282	err = c.sendBeginRequest(ctx, tdsIsolation)
283	if err != nil {
284		return nil, c.checkBadConn(err)
285	}
286	tx, err = c.processBeginResponse(ctx)
287	if err != nil {
288		return nil, c.checkBadConn(err)
289	}
290	return
291}
292
293func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
294	c.transactionCtx = ctx
295	headers := []headerStruct{
296		{hdrtype: dataStmHdrTransDescr,
297			data: transDescrHdr{0, 1}.pack()},
298	}
299	reset := c.resetSession
300	c.resetSession = false
301	if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
302		if c.sess.logFlags&logErrors != 0 {
303			c.sess.log.Printf("Failed to send BeginXact with %v", err)
304		}
305		c.connectionGood = false
306		return fmt.Errorf("Failed to send BeginXact: %v", err)
307	}
308	return nil
309}
310
311func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
312	if err := c.simpleProcessResp(ctx); err != nil {
313		return nil, err
314	}
315	// successful BEGINXACT request will return sess.tranid
316	// for started transaction
317	return c, nil
318}
319
320func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
321	params, err := parseConnectParams(dsn)
322	if err != nil {
323		return nil, err
324	}
325	return d.connect(ctx, nil, params)
326}
327
328// connect to the server, using the provided context for dialing only.
329func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) {
330	sess, err := connect(ctx, c, d.log, params)
331	if err != nil {
332		// main server failed, try fail-over partner
333		if params.failOverPartner == "" {
334			return nil, err
335		}
336
337		params.host = params.failOverPartner
338		if params.failOverPort != 0 {
339			params.port = params.failOverPort
340		}
341
342		sess, err = connect(ctx, c, d.log, params)
343		if err != nil {
344			// fail-over partner also failed, now fail
345			return nil, err
346		}
347	}
348
349	conn := &Conn{
350		connector:        c,
351		sess:             sess,
352		transactionCtx:   context.Background(),
353		processQueryText: d.processQueryText,
354		connectionGood:   true,
355	}
356
357	return conn, nil
358}
359
360func (c *Conn) Close() error {
361	return c.sess.buf.transport.Close()
362}
363
364type Stmt struct {
365	c          *Conn
366	query      string
367	paramCount int
368	notifSub   *queryNotifSub
369}
370
371type queryNotifSub struct {
372	msgText string
373	options string
374	timeout uint32
375}
376
377func (c *Conn) Prepare(query string) (driver.Stmt, error) {
378	if !c.connectionGood {
379		return nil, driver.ErrBadConn
380	}
381	if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
382		return c.prepareCopyIn(context.Background(), query)
383	}
384	return c.prepareContext(context.Background(), query)
385}
386
387func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
388	paramCount := -1
389	if c.processQueryText {
390		query, paramCount = querytext.ParseParams(query)
391	}
392	return &Stmt{c, query, paramCount, nil}, nil
393}
394
395func (s *Stmt) Close() error {
396	return nil
397}
398
399func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) {
400	// 2.2.5.3.1 Query Notifications Header
401	// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/e168d373-a7b7-41aa-b6ca-25985466a7e0
402	// Timeout in milliseconds in TDS protocol.
403	to := uint32(timeout / time.Millisecond)
404	if to < 1 {
405		to = 1
406	}
407	s.notifSub = &queryNotifSub{id, options, to}
408}
409
410func (s *Stmt) NumInput() int {
411	return s.paramCount
412}
413
414func (s *Stmt) sendQuery(args []namedValue) (err error) {
415	headers := []headerStruct{
416		{hdrtype: dataStmHdrTransDescr,
417			data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
418	}
419
420	if s.notifSub != nil {
421		headers = append(headers,
422			headerStruct{
423				hdrtype: dataStmHdrQueryNotif,
424				data: queryNotifHdr{
425					s.notifSub.msgText,
426					s.notifSub.options,
427					s.notifSub.timeout,
428				}.pack(),
429			})
430	}
431
432	conn := s.c
433
434	// no need to check number of parameters here, it is checked by database/sql
435	if conn.sess.logFlags&logSQL != 0 {
436		conn.sess.log.Println(s.query)
437	}
438	if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
439		for i := 0; i < len(args); i++ {
440			if len(args[i].Name) > 0 {
441				s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
442			} else {
443				s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
444			}
445		}
446	}
447
448	reset := conn.resetSession
449	conn.resetSession = false
450	if len(args) == 0 {
451		if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
452			if conn.sess.logFlags&logErrors != 0 {
453				conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
454			}
455			conn.connectionGood = false
456			return fmt.Errorf("failed to send SQL Batch: %v", err)
457		}
458	} else {
459		proc := sp_ExecuteSql
460		var params []param
461		if isProc(s.query) {
462			proc.name = s.query
463			params, _, err = s.makeRPCParams(args, true)
464			if err != nil {
465				return
466			}
467		} else {
468			var decls []string
469			params, decls, err = s.makeRPCParams(args, false)
470			if err != nil {
471				return
472			}
473			params[0] = makeStrParam(s.query)
474			params[1] = makeStrParam(strings.Join(decls, ","))
475		}
476		if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
477			if conn.sess.logFlags&logErrors != 0 {
478				conn.sess.log.Printf("Failed to send Rpc with %v", err)
479			}
480			conn.connectionGood = false
481			return fmt.Errorf("Failed to send RPC: %v", err)
482		}
483	}
484	return
485}
486
487// isProc takes the query text in s and determines if it is a stored proc name
488// or SQL text.
489func isProc(s string) bool {
490	if len(s) == 0 {
491		return false
492	}
493	const (
494		outside = iota
495		text
496		escaped
497	)
498	st := outside
499	var rn1, rPrev rune
500	for _, r := range s {
501		rPrev = rn1
502		rn1 = r
503		switch r {
504		// No newlines or string sequences.
505		case '\n', '\r', '\'', ';':
506			return false
507		}
508		switch st {
509		case outside:
510			switch {
511			case unicode.IsSpace(r):
512				return false
513			case r == '[':
514				st = escaped
515				continue
516			case r == ']' && rPrev == ']':
517				st = escaped
518				continue
519			case unicode.IsLetter(r):
520				st = text
521			}
522		case text:
523			switch {
524			case r == '.':
525				st = outside
526				continue
527			case unicode.IsSpace(r):
528				return false
529			}
530		case escaped:
531			switch {
532			case r == ']':
533				st = outside
534				continue
535			}
536		}
537	}
538	return true
539}
540
541func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, error) {
542	var err error
543	var offset int
544	if !isProc {
545		offset = 2
546	}
547	params := make([]param, len(args)+offset)
548	decls := make([]string, len(args))
549	for i, val := range args {
550		params[i+offset], err = s.makeParam(val.Value)
551		if err != nil {
552			return nil, nil, err
553		}
554		var name string
555		if len(val.Name) > 0 {
556			name = "@" + val.Name
557		} else if !isProc {
558			name = fmt.Sprintf("@p%d", val.Ordinal)
559		}
560		params[i+offset].Name = name
561		decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti))
562	}
563	return params, decls, nil
564}
565
566type namedValue struct {
567	Name    string
568	Ordinal int
569	Value   driver.Value
570}
571
572func convertOldArgs(args []driver.Value) []namedValue {
573	list := make([]namedValue, len(args))
574	for i, v := range args {
575		list[i] = namedValue{
576			Ordinal: i + 1,
577			Value:   v,
578		}
579	}
580	return list
581}
582
583func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
584	return s.queryContext(context.Background(), convertOldArgs(args))
585}
586
587func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
588	if !s.c.connectionGood {
589		return nil, driver.ErrBadConn
590	}
591	if err = s.sendQuery(args); err != nil {
592		return nil, s.c.checkBadConn(err)
593	}
594	return s.processQueryResponse(ctx)
595}
596
597func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
598	tokchan := make(chan tokenStruct, 5)
599	ctx, cancel := context.WithCancel(ctx)
600	go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
601	s.c.clearOuts()
602	// process metadata
603	var cols []columnStruct
604loop:
605	for tok := range tokchan {
606		switch token := tok.(type) {
607		// By ignoring DONE token we effectively
608		// skip empty result-sets.
609		// This improves results in queries like that:
610		// set nocount on; select 1
611		// see TestIgnoreEmptyResults test
612		//case doneStruct:
613		//break loop
614		case []columnStruct:
615			cols = token
616			break loop
617		case doneStruct:
618			if token.isError() {
619				cancel()
620				return nil, s.c.checkBadConn(token.getError())
621			}
622		case ReturnStatus:
623			s.c.setReturnStatus(token)
624		case error:
625			cancel()
626			return nil, s.c.checkBadConn(token)
627		}
628	}
629	res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
630	return
631}
632
633func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
634	return s.exec(context.Background(), convertOldArgs(args))
635}
636
637func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
638	if !s.c.connectionGood {
639		return nil, driver.ErrBadConn
640	}
641	if err = s.sendQuery(args); err != nil {
642		return nil, s.c.checkBadConn(err)
643	}
644	if res, err = s.processExec(ctx); err != nil {
645		return nil, s.c.checkBadConn(err)
646	}
647	return
648}
649
650func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
651	tokchan := make(chan tokenStruct, 5)
652	go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
653	s.c.clearOuts()
654	var rowCount int64
655	for token := range tokchan {
656		switch token := token.(type) {
657		case doneInProcStruct:
658			if token.Status&doneCount != 0 {
659				rowCount += int64(token.RowCount)
660			}
661		case doneStruct:
662			if token.Status&doneCount != 0 {
663				rowCount += int64(token.RowCount)
664			}
665			if token.isError() {
666				return nil, token.getError()
667			}
668		case ReturnStatus:
669			s.c.setReturnStatus(token)
670		case error:
671			return nil, token
672		}
673	}
674	return &Result{s.c, rowCount}, nil
675}
676
677type Rows struct {
678	stmt    *Stmt
679	cols    []columnStruct
680	tokchan chan tokenStruct
681
682	nextCols []columnStruct
683
684	cancel func()
685}
686
687func (rc *Rows) Close() error {
688	rc.cancel()
689	for _ = range rc.tokchan {
690	}
691	rc.tokchan = nil
692	return nil
693}
694
695func (rc *Rows) Columns() (res []string) {
696	res = make([]string, len(rc.cols))
697	for i, col := range rc.cols {
698		res[i] = col.ColName
699	}
700	return
701}
702
703func (rc *Rows) Next(dest []driver.Value) error {
704	if !rc.stmt.c.connectionGood {
705		return driver.ErrBadConn
706	}
707	if rc.nextCols != nil {
708		return io.EOF
709	}
710	for tok := range rc.tokchan {
711		switch tokdata := tok.(type) {
712		case []columnStruct:
713			rc.nextCols = tokdata
714			return io.EOF
715		case []interface{}:
716			for i := range dest {
717				dest[i] = tokdata[i]
718			}
719			return nil
720		case doneStruct:
721			if tokdata.isError() {
722				return rc.stmt.c.checkBadConn(tokdata.getError())
723			}
724		case ReturnStatus:
725			rc.stmt.c.setReturnStatus(tokdata)
726		case error:
727			return rc.stmt.c.checkBadConn(tokdata)
728		}
729	}
730	return io.EOF
731}
732
733func (rc *Rows) HasNextResultSet() bool {
734	return rc.nextCols != nil
735}
736
737func (rc *Rows) NextResultSet() error {
738	rc.cols = rc.nextCols
739	rc.nextCols = nil
740	if rc.cols == nil {
741		return io.EOF
742	}
743	return nil
744}
745
746// It should return
747// the value type that can be used to scan types into. For example, the database
748// column type "bigint" this should return "reflect.TypeOf(int64(0))".
749func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
750	return makeGoLangScanType(r.cols[index].ti)
751}
752
753// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
754// database system type name without the length. Type names should be uppercase.
755// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
756// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
757// "TIMESTAMP".
758func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
759	return makeGoLangTypeName(r.cols[index].ti)
760}
761
762// RowsColumnTypeLength may be implemented by Rows. It should return the length
763// of the column type if the column is a variable length type. If the column is
764// not a variable length type ok should return false.
765// If length is not limited other than system limits, it should return math.MaxInt64.
766// The following are examples of returned values for various types:
767//   TEXT          (math.MaxInt64, true)
768//   varchar(10)   (10, true)
769//   nvarchar(10)  (10, true)
770//   decimal       (0, false)
771//   int           (0, false)
772//   bytea(30)     (30, true)
773func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
774	return makeGoLangTypeLength(r.cols[index].ti)
775}
776
777// It should return
778// the precision and scale for decimal types. If not applicable, ok should be false.
779// The following are examples of returned values for various types:
780//   decimal(38, 4)    (38, 4, true)
781//   int               (0, 0, false)
782//   decimal           (math.MaxInt64, math.MaxInt64, true)
783func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
784	return makeGoLangTypePrecisionScale(r.cols[index].ti)
785}
786
787// The nullable value should
788// be true if it is known the column may be null, or false if the column is known
789// to be not nullable.
790// If the column nullability is unknown, ok should be false.
791func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
792	nullable = r.cols[index].Flags&colFlagNullable != 0
793	ok = true
794	return
795}
796
797func makeStrParam(val string) (res param) {
798	res.ti.TypeId = typeNVarChar
799	res.buffer = str2ucs2(val)
800	res.ti.Size = len(res.buffer)
801	return
802}
803
804func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
805	if val == nil {
806		res.ti.TypeId = typeNull
807		res.buffer = nil
808		res.ti.Size = 0
809		return
810	}
811	switch val := val.(type) {
812	case int64:
813		res.ti.TypeId = typeIntN
814		res.buffer = make([]byte, 8)
815		res.ti.Size = 8
816		binary.LittleEndian.PutUint64(res.buffer, uint64(val))
817	case sql.NullInt64:
818		// only null values should be getting here
819		res.ti.TypeId = typeIntN
820		res.ti.Size = 8
821		res.buffer = []byte{}
822
823	case float64:
824		res.ti.TypeId = typeFltN
825		res.ti.Size = 8
826		res.buffer = make([]byte, 8)
827		binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
828	case sql.NullFloat64:
829		// only null values should be getting here
830		res.ti.TypeId = typeFltN
831		res.ti.Size = 8
832		res.buffer = []byte{}
833
834	case []byte:
835		res.ti.TypeId = typeBigVarBin
836		res.ti.Size = len(val)
837		res.buffer = val
838	case string:
839		res = makeStrParam(val)
840	case sql.NullString:
841		// only null values should be getting here
842		res.ti.TypeId = typeNVarChar
843		res.buffer = nil
844		res.ti.Size = 8000
845	case bool:
846		res.ti.TypeId = typeBitN
847		res.ti.Size = 1
848		res.buffer = make([]byte, 1)
849		if val {
850			res.buffer[0] = 1
851		}
852	case sql.NullBool:
853		// only null values should be getting here
854		res.ti.TypeId = typeBitN
855		res.ti.Size = 1
856		res.buffer = []byte{}
857
858	case time.Time:
859		if s.c.sess.loginAck.TDSVersion >= verTDS73 {
860			res.ti.TypeId = typeDateTimeOffsetN
861			res.ti.Scale = 7
862			res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale))
863			res.ti.Size = len(res.buffer)
864		} else {
865			res.ti.TypeId = typeDateTimeN
866			res.buffer = encodeDateTime(val)
867			res.ti.Size = len(res.buffer)
868		}
869	default:
870		return s.makeParamExtra(val)
871	}
872	return
873}
874
875type Result struct {
876	c            *Conn
877	rowsAffected int64
878}
879
880func (r *Result) RowsAffected() (int64, error) {
881	return r.rowsAffected, nil
882}
883
884var _ driver.Pinger = &Conn{}
885
886// Ping is used to check if the remote server is available and satisfies the Pinger interface.
887func (c *Conn) Ping(ctx context.Context) error {
888	if !c.connectionGood {
889		return driver.ErrBadConn
890	}
891	stmt := &Stmt{c, `select 1;`, 0, nil}
892	_, err := stmt.ExecContext(ctx, nil)
893	return err
894}
895
896var _ driver.ConnBeginTx = &Conn{}
897
898// BeginTx satisfies ConnBeginTx.
899func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
900	if !c.connectionGood {
901		return nil, driver.ErrBadConn
902	}
903	if opts.ReadOnly {
904		return nil, errors.New("Read-only transactions are not supported")
905	}
906
907	var tdsIsolation isoLevel
908	switch sql.IsolationLevel(opts.Isolation) {
909	case sql.LevelDefault:
910		tdsIsolation = isolationUseCurrent
911	case sql.LevelReadUncommitted:
912		tdsIsolation = isolationReadUncommited
913	case sql.LevelReadCommitted:
914		tdsIsolation = isolationReadCommited
915	case sql.LevelWriteCommitted:
916		return nil, errors.New("LevelWriteCommitted isolation level is not supported")
917	case sql.LevelRepeatableRead:
918		tdsIsolation = isolationRepeatableRead
919	case sql.LevelSnapshot:
920		tdsIsolation = isolationSnapshot
921	case sql.LevelSerializable:
922		tdsIsolation = isolationSerializable
923	case sql.LevelLinearizable:
924		return nil, errors.New("LevelLinearizable isolation level is not supported")
925	default:
926		return nil, errors.New("Isolation level is not supported or unknown")
927	}
928	return c.begin(ctx, tdsIsolation)
929}
930
931func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
932	if !c.connectionGood {
933		return nil, driver.ErrBadConn
934	}
935	if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
936		return c.prepareCopyIn(ctx, query)
937	}
938
939	return c.prepareContext(ctx, query)
940}
941
942func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
943	if !s.c.connectionGood {
944		return nil, driver.ErrBadConn
945	}
946	list := make([]namedValue, len(args))
947	for i, nv := range args {
948		list[i] = namedValue(nv)
949	}
950	return s.queryContext(ctx, list)
951}
952
953func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
954	if !s.c.connectionGood {
955		return nil, driver.ErrBadConn
956	}
957	list := make([]namedValue, len(args))
958	for i, nv := range args {
959		list[i] = namedValue(nv)
960	}
961	return s.exec(ctx, list)
962}
963