1package pq
2
3import (
4	"context"
5	"database/sql"
6	"runtime"
7	"strings"
8	"testing"
9	"time"
10)
11
12func TestMultipleSimpleQuery(t *testing.T) {
13	db := openTestConn(t)
14	defer db.Close()
15
16	rows, err := db.Query("select 1; set time zone default; select 2; select 3")
17	if err != nil {
18		t.Fatal(err)
19	}
20	defer rows.Close()
21
22	var i int
23	for rows.Next() {
24		if err := rows.Scan(&i); err != nil {
25			t.Fatal(err)
26		}
27		if i != 1 {
28			t.Fatalf("expected 1, got %d", i)
29		}
30	}
31	if !rows.NextResultSet() {
32		t.Fatal("expected more result sets", rows.Err())
33	}
34	for rows.Next() {
35		if err := rows.Scan(&i); err != nil {
36			t.Fatal(err)
37		}
38		if i != 2 {
39			t.Fatalf("expected 2, got %d", i)
40		}
41	}
42
43	// Make sure that if we ignore a result we can still query.
44
45	rows, err = db.Query("select 4; select 5")
46	if err != nil {
47		t.Fatal(err)
48	}
49	defer rows.Close()
50
51	for rows.Next() {
52		if err := rows.Scan(&i); err != nil {
53			t.Fatal(err)
54		}
55		if i != 4 {
56			t.Fatalf("expected 4, got %d", i)
57		}
58	}
59	if !rows.NextResultSet() {
60		t.Fatal("expected more result sets", rows.Err())
61	}
62	for rows.Next() {
63		if err := rows.Scan(&i); err != nil {
64			t.Fatal(err)
65		}
66		if i != 5 {
67			t.Fatalf("expected 5, got %d", i)
68		}
69	}
70	if rows.NextResultSet() {
71		t.Fatal("unexpected result set")
72	}
73}
74
75const contextRaceIterations = 100
76
77func TestContextCancelExec(t *testing.T) {
78	db := openTestConn(t)
79	defer db.Close()
80
81	ctx, cancel := context.WithCancel(context.Background())
82
83	// Delay execution for just a bit until db.ExecContext has begun.
84	defer time.AfterFunc(time.Millisecond*10, cancel).Stop()
85
86	// Not canceled until after the exec has started.
87	if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil {
88		t.Fatal("expected error")
89	} else if err.Error() != "pq: canceling statement due to user request" {
90		t.Fatalf("unexpected error: %s", err)
91	}
92
93	// Context is already canceled, so error should come before execution.
94	if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil {
95		t.Fatal("expected error")
96	} else if err.Error() != "context canceled" {
97		t.Fatalf("unexpected error: %s", err)
98	}
99
100	for i := 0; i < contextRaceIterations; i++ {
101		func() {
102			ctx, cancel := context.WithCancel(context.Background())
103			defer cancel()
104			if _, err := db.ExecContext(ctx, "select 1"); err != nil {
105				t.Fatal(err)
106			}
107		}()
108
109		if _, err := db.Exec("select 1"); err != nil {
110			t.Fatal(err)
111		}
112	}
113}
114
115func TestContextCancelQuery(t *testing.T) {
116	db := openTestConn(t)
117	defer db.Close()
118
119	ctx, cancel := context.WithCancel(context.Background())
120
121	// Delay execution for just a bit until db.QueryContext has begun.
122	defer time.AfterFunc(time.Millisecond*10, cancel).Stop()
123
124	// Not canceled until after the exec has started.
125	if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil {
126		t.Fatal("expected error")
127	} else if err.Error() != "pq: canceling statement due to user request" {
128		t.Fatalf("unexpected error: %s", err)
129	}
130
131	// Context is already canceled, so error should come before execution.
132	if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil {
133		t.Fatal("expected error")
134	} else if err.Error() != "context canceled" {
135		t.Fatalf("unexpected error: %s", err)
136	}
137
138	for i := 0; i < contextRaceIterations; i++ {
139		func() {
140			ctx, cancel := context.WithCancel(context.Background())
141			rows, err := db.QueryContext(ctx, "select 1")
142			cancel()
143			if err != nil {
144				t.Fatal(err)
145			} else if err := rows.Close(); err != nil {
146				t.Fatal(err)
147			}
148		}()
149
150		if rows, err := db.Query("select 1"); err != nil {
151			t.Fatal(err)
152		} else if err := rows.Close(); err != nil {
153			t.Fatal(err)
154		}
155	}
156}
157
158// TestIssue617 tests that a failed query in QueryContext doesn't lead to a
159// goroutine leak.
160func TestIssue617(t *testing.T) {
161	db := openTestConn(t)
162	defer db.Close()
163
164	const N = 10
165
166	numGoroutineStart := runtime.NumGoroutine()
167	for i := 0; i < N; i++ {
168		func() {
169			ctx, cancel := context.WithCancel(context.Background())
170			defer cancel()
171			_, err := db.QueryContext(ctx, `SELECT * FROM DOESNOTEXIST`)
172			pqErr, _ := err.(*Error)
173			// Expecting "pq: relation \"doesnotexist\" does not exist" error.
174			if err == nil || pqErr == nil || pqErr.Code != "42P01" {
175				t.Fatalf("expected undefined table error, got %v", err)
176			}
177		}()
178	}
179	numGoroutineFinish := runtime.NumGoroutine()
180
181	// We use N/2 and not N because the GC and other actors may increase or
182	// decrease the number of goroutines.
183	if numGoroutineFinish-numGoroutineStart >= N/2 {
184		t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish)
185	}
186}
187
188func TestContextCancelBegin(t *testing.T) {
189	db := openTestConn(t)
190	defer db.Close()
191
192	ctx, cancel := context.WithCancel(context.Background())
193	tx, err := db.BeginTx(ctx, nil)
194	if err != nil {
195		t.Fatal(err)
196	}
197
198	// Delay execution for just a bit until tx.Exec has begun.
199	defer time.AfterFunc(time.Millisecond*10, cancel).Stop()
200
201	// Not canceled until after the exec has started.
202	if _, err := tx.Exec("select pg_sleep(1)"); err == nil {
203		t.Fatal("expected error")
204	} else if err.Error() != "pq: canceling statement due to user request" {
205		t.Fatalf("unexpected error: %s", err)
206	}
207
208	// Transaction is canceled, so expect an error.
209	if _, err := tx.Query("select pg_sleep(1)"); err == nil {
210		t.Fatal("expected error")
211	} else if err != sql.ErrTxDone {
212		t.Fatalf("unexpected error: %s", err)
213	}
214
215	// Context is canceled, so cannot begin a transaction.
216	if _, err := db.BeginTx(ctx, nil); err == nil {
217		t.Fatal("expected error")
218	} else if err.Error() != "context canceled" {
219		t.Fatalf("unexpected error: %s", err)
220	}
221
222	for i := 0; i < contextRaceIterations; i++ {
223		func() {
224			ctx, cancel := context.WithCancel(context.Background())
225			tx, err := db.BeginTx(ctx, nil)
226			cancel()
227			if err != nil {
228				t.Fatal(err)
229			} else if err := tx.Rollback(); err != nil &&
230				err.Error() != "pq: canceling statement due to user request" &&
231				err != sql.ErrTxDone {
232				t.Fatal(err)
233			}
234		}()
235
236		if tx, err := db.Begin(); err != nil {
237			t.Fatal(err)
238		} else if err := tx.Rollback(); err != nil {
239			t.Fatal(err)
240		}
241	}
242}
243
244func TestTxOptions(t *testing.T) {
245	db := openTestConn(t)
246	defer db.Close()
247	ctx := context.Background()
248
249	tests := []struct {
250		level     sql.IsolationLevel
251		isolation string
252	}{
253		{
254			level:     sql.LevelDefault,
255			isolation: "",
256		},
257		{
258			level:     sql.LevelReadUncommitted,
259			isolation: "read uncommitted",
260		},
261		{
262			level:     sql.LevelReadCommitted,
263			isolation: "read committed",
264		},
265		{
266			level:     sql.LevelRepeatableRead,
267			isolation: "repeatable read",
268		},
269		{
270			level:     sql.LevelSerializable,
271			isolation: "serializable",
272		},
273	}
274
275	for _, test := range tests {
276		for _, ro := range []bool{true, false} {
277			tx, err := db.BeginTx(ctx, &sql.TxOptions{
278				Isolation: test.level,
279				ReadOnly:  ro,
280			})
281			if err != nil {
282				t.Fatal(err)
283			}
284
285			var isolation string
286			err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&isolation)
287			if err != nil {
288				t.Fatal(err)
289			}
290
291			if test.isolation != "" && isolation != test.isolation {
292				t.Errorf("wrong isolation level: %s != %s", isolation, test.isolation)
293			}
294
295			var isRO string
296			err = tx.QueryRow("select current_setting('transaction_read_only')").Scan(&isRO)
297			if err != nil {
298				t.Fatal(err)
299			}
300
301			if ro != (isRO == "on") {
302				t.Errorf("read/[write,only] not set: %t != %s for level %s",
303					ro, isRO, test.isolation)
304			}
305
306			tx.Rollback()
307		}
308	}
309
310	_, err := db.BeginTx(ctx, &sql.TxOptions{
311		Isolation: sql.LevelLinearizable,
312	})
313	if err == nil {
314		t.Fatal("expected LevelLinearizable to fail")
315	}
316	if !strings.Contains(err.Error(), "isolation level not supported") {
317		t.Errorf("Expected error to mention isolation level, got %q", err)
318	}
319}
320