1// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file.
5
6// +build go1.13,cgo
7
8package sqlite3
9
10import (
11	"context"
12	"database/sql"
13	"database/sql/driver"
14	"errors"
15	"os"
16	"testing"
17)
18
19func TestBeginTxCancel(t *testing.T) {
20	srcTempFilename := TempFilename(t)
21	defer os.Remove(srcTempFilename)
22
23	db, err := sql.Open("sqlite3", srcTempFilename)
24	if err != nil {
25		t.Fatal(err)
26	}
27
28	db.SetMaxOpenConns(10)
29	db.SetMaxIdleConns(5)
30
31	defer db.Close()
32	initDatabase(t, db, 100)
33
34	// create several go-routines to expose racy issue
35	for i := 0; i < 1000; i++ {
36		func() {
37			ctx, cancel := context.WithCancel(context.Background())
38			conn, err := db.Conn(ctx)
39			if err != nil {
40				t.Fatal(err)
41			}
42			defer func() {
43				if err := conn.Close(); err != nil {
44					t.Error(err)
45				}
46			}()
47
48			err = conn.Raw(func(driverConn interface{}) error {
49				d, ok := driverConn.(driver.ConnBeginTx)
50				if !ok {
51					t.Fatal("unexpected: wrong type")
52				}
53				// checks that conn.Raw can be used to get *SQLiteConn
54				if _, ok = driverConn.(*SQLiteConn); !ok {
55					t.Fatalf("conn.Raw() driverConn type=%T, expected *SQLiteConn", driverConn)
56				}
57
58				go cancel() // make it cancel concurrently with exec("BEGIN");
59				tx, err := d.BeginTx(ctx, driver.TxOptions{})
60				switch err {
61				case nil:
62					switch err := tx.Rollback(); err {
63					case nil, sql.ErrTxDone:
64					default:
65						return err
66					}
67				case context.Canceled:
68				default:
69					// must not fail with "cannot start a transaction within a transaction"
70					return err
71				}
72				return nil
73			})
74			if err != nil {
75				t.Fatal(err)
76			}
77		}()
78	}
79}
80
81func TestStmtReadonly(t *testing.T) {
82	db, err := sql.Open("sqlite3", ":memory:")
83	if err != nil {
84		t.Fatal(err)
85	}
86
87	_, err = db.Exec("CREATE TABLE t (count INT)")
88	if err != nil {
89		t.Fatal(err)
90	}
91
92	isRO := func(query string) bool {
93		c, err := db.Conn(context.Background())
94		if err != nil {
95			return false
96		}
97
98		var ro bool
99		c.Raw(func(dc interface{}) error {
100			stmt, err := dc.(*SQLiteConn).Prepare(query)
101			if err != nil {
102				return err
103			}
104			if stmt == nil {
105				return errors.New("stmt is nil")
106			}
107			ro = stmt.(*SQLiteStmt).Readonly()
108			return nil
109		})
110		return ro // On errors ro will remain false.
111	}
112
113	if !isRO(`select * from t`) {
114		t.Error("select not seen as read-only")
115	}
116	if isRO(`insert into t values (1), (2)`) {
117		t.Error("insert seen as read-only")
118	}
119}
120