1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sql
6
7import (
8	"context"
9	"database/sql/driver"
10	"errors"
11)
12
13func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
14	if ciCtx, is := ci.(driver.ConnPrepareContext); is {
15		return ciCtx.PrepareContext(ctx, query)
16	}
17	si, err := ci.Prepare(query)
18	if err == nil {
19		select {
20		default:
21		case <-ctx.Done():
22			si.Close()
23			return nil, ctx.Err()
24		}
25	}
26	return si, err
27}
28
29func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
30	if execerCtx != nil {
31		return execerCtx.ExecContext(ctx, query, nvdargs)
32	}
33	dargs, err := namedValueToValue(nvdargs)
34	if err != nil {
35		return nil, err
36	}
37
38	select {
39	default:
40	case <-ctx.Done():
41		return nil, ctx.Err()
42	}
43	return execer.Exec(query, dargs)
44}
45
46func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
47	if queryerCtx != nil {
48		return queryerCtx.QueryContext(ctx, query, nvdargs)
49	}
50	dargs, err := namedValueToValue(nvdargs)
51	if err != nil {
52		return nil, err
53	}
54
55	select {
56	default:
57	case <-ctx.Done():
58		return nil, ctx.Err()
59	}
60	return queryer.Query(query, dargs)
61}
62
63func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
64	if siCtx, is := si.(driver.StmtExecContext); is {
65		return siCtx.ExecContext(ctx, nvdargs)
66	}
67	dargs, err := namedValueToValue(nvdargs)
68	if err != nil {
69		return nil, err
70	}
71
72	select {
73	default:
74	case <-ctx.Done():
75		return nil, ctx.Err()
76	}
77	return si.Exec(dargs)
78}
79
80func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
81	if siCtx, is := si.(driver.StmtQueryContext); is {
82		return siCtx.QueryContext(ctx, nvdargs)
83	}
84	dargs, err := namedValueToValue(nvdargs)
85	if err != nil {
86		return nil, err
87	}
88
89	select {
90	default:
91	case <-ctx.Done():
92		return nil, ctx.Err()
93	}
94	return si.Query(dargs)
95}
96
97var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
98
99func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (driver.Tx, error) {
100	if ciCtx, is := ci.(driver.ConnBeginTx); is {
101		dopts := driver.TxOptions{}
102		if opts != nil {
103			dopts.Isolation = driver.IsolationLevel(opts.Isolation)
104			dopts.ReadOnly = opts.ReadOnly
105		}
106		return ciCtx.BeginTx(ctx, dopts)
107	}
108
109	if opts != nil {
110		// Check the transaction level. If the transaction level is non-default
111		// then return an error here as the BeginTx driver value is not supported.
112		if opts.Isolation != LevelDefault {
113			return nil, errors.New("sql: driver does not support non-default isolation level")
114		}
115
116		// If a read-only transaction is requested return an error as the
117		// BeginTx driver value is not supported.
118		if opts.ReadOnly {
119			return nil, errors.New("sql: driver does not support read-only transactions")
120		}
121	}
122
123	if ctx.Done() == nil {
124		return ci.Begin()
125	}
126
127	txi, err := ci.Begin()
128	if err == nil {
129		select {
130		default:
131		case <-ctx.Done():
132			txi.Rollback()
133			return nil, ctx.Err()
134		}
135	}
136	return txi, err
137}
138
139func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
140	dargs := make([]driver.Value, len(named))
141	for n, param := range named {
142		if len(param.Name) > 0 {
143			return nil, errors.New("sql: driver does not support the use of Named Parameters")
144		}
145		dargs[n] = param.Value
146	}
147	return dargs, nil
148}
149