1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5//lint:file-ignore U1000 Ignore all unused code
6
7import (
8	"bytes"
9	"context"
10	"crypto/rsa"
11	"database/sql"
12	"flag"
13	"fmt"
14	"math/big"
15	"net/http"
16	"net/url"
17	"os"
18	"os/signal"
19	"reflect"
20	"strings"
21	"syscall"
22	"testing"
23	"time"
24
25	"github.com/google/uuid"
26)
27
28var (
29	user             string
30	pass             string
31	account          string
32	dbname           string
33	schemaname       string
34	warehouse        string
35	rolename         string
36	dsn              string
37	host             string
38	port             string
39	protocol         string
40	customPrivateKey bool            // Whether user has specified the private key path
41	testPrivKey      *rsa.PrivateKey // Valid private key used for all test cases
42)
43
44const (
45	selectNumberSQL       = "SELECT %s::NUMBER(%v, %v) AS C"
46	selectVariousTypes    = "SELECT 1.0::NUMBER(30,2) as C1, 2::NUMBER(38,0) AS C2, 't3' AS C3, 4.2::DOUBLE AS C4, 'abcd'::BINARY AS C5, true AS C6"
47	selectRandomGenerator = "SELECT SEQ8(), RANDSTR(1000, RANDOM()) FROM TABLE(GENERATOR(ROWCOUNT=>%v))"
48	PSTLocation           = "America/Los_Angeles"
49)
50
51// The tests require the following parameters in the environment variables.
52// SNOWFLAKE_TEST_USER, SNOWFLAKE_TEST_PASSWORD, SNOWFLAKE_TEST_ACCOUNT,
53// SNOWFLAKE_TEST_DATABASE, SNOWFLAKE_TEST_SCHEMA, SNOWFLAKE_TEST_WAREHOUSE.
54// Optionally you may specify SNOWFLAKE_TEST_PROTOCOL, SNOWFLAKE_TEST_HOST
55// and SNOWFLAKE_TEST_PORT to specify the endpoint.
56func init() {
57	// get environment variables
58	env := func(key, defaultValue string) string {
59		if value := os.Getenv(key); value != "" {
60			return value
61		}
62		return defaultValue
63	}
64	user = env("SNOWFLAKE_TEST_USER", "testuser")
65	pass = env("SNOWFLAKE_TEST_PASSWORD", "testpassword")
66	account = env("SNOWFLAKE_TEST_ACCOUNT", "testaccount")
67	dbname = env("SNOWFLAKE_TEST_DATABASE", "testdb")
68	schemaname = env("SNOWFLAKE_TEST_SCHEMA", "public")
69	rolename = env("SNOWFLAKE_TEST_ROLE", "sysadmin")
70	warehouse = env("SNOWFLAKE_TEST_WAREHOUSE", "testwarehouse")
71
72	protocol = env("SNOWFLAKE_TEST_PROTOCOL", "https")
73	host = os.Getenv("SNOWFLAKE_TEST_HOST")
74	port = env("SNOWFLAKE_TEST_PORT", "443")
75	if host == "" {
76		host = fmt.Sprintf("%s.snowflakecomputing.com", account)
77	} else {
78		host = fmt.Sprintf("%s:%s", host, port)
79	}
80
81	setupPrivateKey()
82
83	createDSN("UTC")
84}
85
86func createDSN(timezone string) {
87	dsn = fmt.Sprintf("%s:%s@%s/%s/%s", user, pass, host, dbname, schemaname)
88
89	parameters := url.Values{}
90	parameters.Add("timezone", timezone)
91	if protocol != "" {
92		parameters.Add("protocol", protocol)
93	}
94	if account != "" {
95		parameters.Add("account", account)
96	}
97	if warehouse != "" {
98		parameters.Add("warehouse", warehouse)
99	}
100	if rolename != "" {
101		parameters.Add("role", rolename)
102	}
103
104	if len(parameters) > 0 {
105		dsn += "?" + parameters.Encode()
106	}
107}
108
109// setup creates a test schema so that all tests can run in the same schema
110func setup() (string, error) {
111	env := func(key, defaultValue string) string {
112		if value := os.Getenv(key); value != "" {
113			return value
114		}
115		return defaultValue
116	}
117
118	orgSchemaname := schemaname
119	if env("GITHUB_WORKFLOW", "") != "" {
120		githubRunnerID := env("RUNNER_TRACKING_ID", "GITHUB_RUNNER_ID")
121		githubRunnerID = strings.ReplaceAll(githubRunnerID, "-", "_")
122		githubSha := env("GITHUB_SHA", "GITHUB_SHA")
123		schemaname = fmt.Sprintf("%v_%v", githubRunnerID, githubSha)
124	} else {
125		schemaname = fmt.Sprintf("golang_%v", time.Now().UnixNano())
126	}
127	var db *sql.DB
128	var err error
129	if db, err = sql.Open("snowflake", dsn); err != nil {
130		return "", fmt.Errorf("failed to open db. %v, err: %v", dsn, err)
131	}
132	defer db.Close()
133	if _, err = db.Exec(fmt.Sprintf("CREATE OR REPLACE SCHEMA %v", schemaname)); err != nil {
134		return "", fmt.Errorf("failed to create schema. %v", err)
135	}
136	createDSN("UTC")
137	return orgSchemaname, nil
138}
139
140// teardown drops the test schema
141func teardown() error {
142	var db *sql.DB
143	var err error
144	if db, err = sql.Open("snowflake", dsn); err != nil {
145		return fmt.Errorf("failed to open db. %v, err: %v", dsn, err)
146	}
147	defer db.Close()
148	if _, err = db.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %v", schemaname)); err != nil {
149		return fmt.Errorf("failed to create schema. %v", err)
150	}
151	return nil
152}
153
154func TestMain(m *testing.M) {
155	flag.Parse()
156	signal.Ignore(syscall.SIGQUIT)
157	if value := os.Getenv("SKIP_SETUP"); value != "" {
158		os.Exit(m.Run())
159	}
160
161	if _, err := setup(); err != nil {
162		panic(err)
163	}
164	ret := m.Run()
165	if err := teardown(); err != nil {
166		panic(err)
167	}
168	os.Exit(ret)
169}
170
171type DBTest struct {
172	*testing.T
173	db *sql.DB
174}
175
176type RowsExtended struct {
177	rows      *sql.Rows
178	closeChan *chan bool
179}
180
181func (rs *RowsExtended) Close() error {
182	*rs.closeChan <- true
183	close(*rs.closeChan)
184	return rs.rows.Close()
185}
186
187func (rs *RowsExtended) ColumnTypes() ([]*sql.ColumnType, error) {
188	return rs.rows.ColumnTypes()
189}
190
191func (rs *RowsExtended) Columns() ([]string, error) {
192	return rs.rows.Columns()
193}
194
195func (rs *RowsExtended) Err() error {
196	return rs.rows.Err()
197}
198
199func (rs *RowsExtended) Next() bool {
200	return rs.rows.Next()
201}
202
203func (rs *RowsExtended) NextResultSet() bool {
204	return rs.rows.NextResultSet()
205}
206
207func (rs *RowsExtended) Scan(dest ...interface{}) error {
208	return rs.rows.Scan(dest...)
209}
210
211func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExtended) {
212	// handler interrupt signal
213	ctx, cancel := context.WithCancel(context.Background())
214	c := make(chan os.Signal, 1)
215	c0 := make(chan bool, 1)
216	signal.Notify(c, os.Interrupt)
217	defer func() {
218		signal.Stop(c)
219	}()
220	go func() {
221		select {
222		case <-c:
223			fmt.Println("Caught signal, canceling...")
224			cancel()
225		case <-ctx.Done():
226			fmt.Println("Done")
227		case <-c0:
228		}
229		close(c)
230	}()
231
232	rs, err := dbt.db.QueryContext(ctx, query, args...)
233	if err != nil {
234		dbt.fail("query", query, err)
235	}
236	return &RowsExtended{
237		rows:      rs,
238		closeChan: &c0,
239	}
240}
241
242func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...interface{}) (rows *RowsExtended) {
243	// handler interrupt signal
244	ctx, cancel := context.WithCancel(ctx)
245	c := make(chan os.Signal, 1)
246	c0 := make(chan bool, 1)
247	signal.Notify(c, os.Interrupt)
248	defer func() {
249		signal.Stop(c)
250	}()
251	go func() {
252		select {
253		case <-c:
254			fmt.Println("Caught signal, canceling...")
255			cancel()
256		case <-ctx.Done():
257			fmt.Println("Done")
258		case <-c0:
259		}
260		close(c)
261	}()
262
263	rs, err := dbt.db.QueryContext(ctx, query, args...)
264	if err != nil {
265		dbt.fail("query", query, err)
266	}
267	return &RowsExtended{
268		rows:      rs,
269		closeChan: &c0,
270	}
271}
272
273func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...interface{}) {
274	rows := dbt.mustQuery(query, args...)
275	cnt := 0
276	for rows.Next() {
277		cnt++
278	}
279	if cnt != expected {
280		dbt.Fatalf("expected %v, got %v", expected, cnt)
281	}
282}
283
284func (dbt *DBTest) fail(method, query string, err error) {
285	if len(query) > 300 {
286		query = "[query too large to print]"
287	}
288	dbt.Fatalf("error on %s [%s]: %s", method, query, err.Error())
289}
290
291func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
292	res, err := dbt.db.Exec(query, args...)
293	if err != nil {
294		dbt.fail("exec", query, err)
295	}
296	return res
297}
298
299func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result) {
300	res, err := dbt.db.ExecContext(ctx, query, args...)
301	if err != nil {
302		dbt.fail("exec context", query, err)
303	}
304	return res
305}
306
307func (dbt *DBTest) mustDecimalSize(ct *sql.ColumnType) (pr int64, sc int64) {
308	var ok bool
309	pr, sc, ok = ct.DecimalSize()
310	if !ok {
311		dbt.Fatalf("failed to get decimal size. %v", ct)
312	}
313	return pr, sc
314}
315
316func (dbt *DBTest) mustFailDecimalSize(ct *sql.ColumnType) {
317	var ok bool
318	if _, _, ok = ct.DecimalSize(); ok {
319		dbt.Fatalf("should not return decimal size. %v", ct)
320	}
321}
322
323func (dbt *DBTest) mustLength(ct *sql.ColumnType) (cLen int64) {
324	var ok bool
325	cLen, ok = ct.Length()
326	if !ok {
327		dbt.Fatalf("failed to get length. %v", ct)
328	}
329	return cLen
330}
331
332func (dbt *DBTest) mustFailLength(ct *sql.ColumnType) {
333	var ok bool
334	if _, ok = ct.Length(); ok {
335		dbt.Fatalf("should not return length. %v", ct)
336	}
337}
338
339func (dbt *DBTest) mustNullable(ct *sql.ColumnType) (canNull bool) {
340	var ok bool
341	canNull, ok = ct.Nullable()
342	if !ok {
343		dbt.Fatalf("failed to get length. %v", ct)
344	}
345	return canNull
346}
347
348func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) {
349	stmt, err := dbt.db.Prepare(query)
350	if err != nil {
351		dbt.fail("prepare", query, err)
352	}
353	return stmt
354}
355
356func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
357	db, err := sql.Open("snowflake", dsn)
358	if err != nil {
359		t.Fatalf("error connecting: %s", err.Error())
360	}
361	defer db.Close()
362
363	if _, err = db.Exec("DROP TABLE IF EXISTS test"); err != nil {
364		t.Fatalf("failed to drop table: %v", err)
365	}
366
367	dbt := &DBTest{t, db}
368	for _, test := range tests {
369		test(dbt)
370		dbt.db.Exec("DROP TABLE IF EXISTS test")
371	}
372}
373
374func runningOnGithubAction() bool {
375	return os.Getenv("GITHUB_ACTIONS") != ""
376}
377
378func runningOnAWS() bool {
379	return os.Getenv("CLOUD_PROVIDER") == "AWS"
380}
381
382func runningOnAzure() bool {
383	return os.Getenv("CLOUD_PROVIDER") == "AZURE"
384}
385
386func runningOnGCP() bool {
387	return os.Getenv("CLOUD_PROVIDER") == "GCP"
388}
389
390func TestBogusUserPasswordParameters(t *testing.T) {
391	invalidDNS := fmt.Sprintf("%s:%s@%s", "bogus", pass, host)
392	invalidUserPassErrorTests(invalidDNS, t)
393	invalidDNS = fmt.Sprintf("%s:%s@%s", user, "INVALID_PASSWORD", host)
394	invalidUserPassErrorTests(invalidDNS, t)
395}
396
397func invalidUserPassErrorTests(invalidDNS string, t *testing.T) {
398	parameters := url.Values{}
399	if protocol != "" {
400		parameters.Add("protocol", protocol)
401	}
402	if account != "" {
403		parameters.Add("account", account)
404	}
405	invalidDNS += "?" + parameters.Encode()
406	db, err := sql.Open("snowflake", invalidDNS)
407	if err != nil {
408		t.Fatalf("error creating a connection object: %s", err.Error())
409	}
410	// actual connection won't happen until run a query
411	defer db.Close()
412	if _, err = db.Exec("SELECT 1"); err == nil {
413		t.Fatal("should cause an error.")
414	}
415	if driverErr, ok := err.(*SnowflakeError); ok {
416		if driverErr.Number != 390100 {
417			t.Fatalf("wrong error code: %v", driverErr)
418		}
419		if !strings.Contains(driverErr.Error(), "390100") {
420			t.Fatalf("error message should included the error code. got: %v", driverErr.Error())
421		}
422	} else {
423		t.Fatalf("wrong error code: %v", err)
424	}
425}
426
427func TestBogusHostNameParameters(t *testing.T) {
428	invalidDNS := fmt.Sprintf("%s:%s@%s", user, pass, "INVALID_HOST:1234")
429	invalidHostErrorTests(invalidDNS, []string{"no such host", "verify account name is correct", "HTTP Status: 403", "Temporary failure in name resolution"}, t)
430	invalidDNS = fmt.Sprintf("%s:%s@%s", user, pass, "INVALID_HOST")
431	invalidHostErrorTests(invalidDNS, []string{"read: connection reset by peer", "EOF", "verify account name is correct", "HTTP Status: 403", "Temporary failure in name resolution"}, t)
432}
433
434func invalidHostErrorTests(invalidDNS string, mstr []string, t *testing.T) {
435	parameters := url.Values{}
436	if protocol != "" {
437		parameters.Add("protocol", protocol)
438	}
439	if account != "" {
440		parameters.Add("account", account)
441	}
442	parameters.Add("loginTimeout", "10")
443	invalidDNS += "?" + parameters.Encode()
444	db, err := sql.Open("snowflake", invalidDNS)
445	if err != nil {
446		t.Fatalf("error creating a connection object: %s", err.Error())
447	}
448	// actual connection won't happen until run a query
449	defer db.Close()
450	if _, err = db.Exec("SELECT 1"); err == nil {
451		t.Fatal("should cause an error.")
452	}
453	found := false
454	for _, m := range mstr {
455		if strings.Contains(err.Error(), m) {
456			found = true
457		}
458	}
459	if !found {
460		t.Fatalf("wrong error: %v", err)
461	}
462}
463
464func TestCommentOnlyQuery(t *testing.T) {
465	runTests(t, dsn, func(dbt *DBTest) {
466		query := "--"
467		// just a comment, no query
468		rows, err := dbt.db.Query(query)
469		if err == nil {
470			rows.Close()
471			dbt.fail("query", query, err)
472		}
473		if driverErr, ok := err.(*SnowflakeError); ok {
474			if driverErr.Number != 900 { // syntax error
475				dbt.fail("query", query, err)
476			}
477		}
478	})
479}
480
481func TestEmptyQuery(t *testing.T) {
482	runTests(t, dsn, func(dbt *DBTest) {
483		query := "select 1 from dual where 1=0"
484		// just a comment, no query
485		rows := dbt.db.QueryRow(query)
486		var v1 interface{}
487		if err := rows.Scan(&v1); err != sql.ErrNoRows {
488			dbt.Errorf("should fail. err: %v", err)
489		}
490		rows = dbt.db.QueryRowContext(context.Background(), query)
491		if err := rows.Scan(&v1); err != sql.ErrNoRows {
492			dbt.Errorf("should fail. err: %v", err)
493		}
494	})
495}
496
497func TestEmptyQueryWithRequestID(t *testing.T) {
498	runTests(t, dsn, func(dbt *DBTest) {
499		query := "select 1"
500		ctx := WithRequestID(context.Background(), uuid.New())
501		rows := dbt.db.QueryRowContext(ctx, query)
502		var v1 interface{}
503		if err := rows.Scan(&v1); err != nil {
504			dbt.Errorf("should not have failed with valid request id. err: %v", err)
505		}
506	})
507}
508
509func TestCRUD(t *testing.T) {
510	runTests(t, dsn, func(dbt *DBTest) {
511		// Create Table
512		dbt.mustExec("CREATE TABLE test (value BOOLEAN)")
513
514		// Test for unexpected Data
515		var out bool
516		rows := dbt.mustQuery("SELECT * FROM test")
517		defer rows.Close()
518		if rows.Next() {
519			dbt.Error("unexpected Data in empty table")
520		}
521
522		// Create Data
523		res := dbt.mustExec("INSERT INTO test VALUES (true)")
524		count, err := res.RowsAffected()
525		if err != nil {
526			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
527		}
528		if count != 1 {
529			dbt.Fatalf("expected 1 affected row, got %d", count)
530		}
531
532		id, err := res.LastInsertId()
533		if err != nil {
534			dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
535		}
536		if id != -1 {
537			dbt.Fatalf(
538				"expected InsertId -1, got %d. Snowflake doesn't support last insert ID", id)
539		}
540
541		// Read
542		rows = dbt.mustQuery("SELECT value FROM test")
543		defer rows.Close()
544		if rows.Next() {
545			rows.Scan(&out)
546			if !out {
547				dbt.Errorf("%t should be true", out)
548			}
549			if rows.Next() {
550				dbt.Error("unexpected Data")
551			}
552		} else {
553			dbt.Error("no Data")
554		}
555
556		// Update
557		res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
558		count, err = res.RowsAffected()
559		if err != nil {
560			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
561		}
562		if count != 1 {
563			dbt.Fatalf("expected 1 affected row, got %d", count)
564		}
565
566		// Check Update
567		rows = dbt.mustQuery("SELECT value FROM test")
568		defer rows.Close()
569		if rows.Next() {
570			rows.Scan(&out)
571			if out {
572				dbt.Errorf("%t should be true", out)
573			}
574			if rows.Next() {
575				dbt.Error("unexpected Data")
576			}
577		} else {
578			dbt.Error("no Data")
579		}
580
581		// Delete
582		res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
583		count, err = res.RowsAffected()
584		if err != nil {
585			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
586		}
587		if count != 1 {
588			dbt.Fatalf("expected 1 affected row, got %d", count)
589		}
590
591		// Check for unexpected rows
592		res = dbt.mustExec("DELETE FROM test")
593		count, err = res.RowsAffected()
594		if err != nil {
595			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
596		}
597		if count != 0 {
598			dbt.Fatalf("expected 0 affected row, got %d", count)
599		}
600	})
601}
602
603func TestInt(t *testing.T) {
604	testInt(t, false)
605}
606
607func testInt(t *testing.T, json bool) {
608	runTests(t, dsn, func(dbt *DBTest) {
609		types := []string{"INT", "INTEGER"}
610		in := int64(42)
611		var out int64
612		var rows *RowsExtended
613
614		// SIGNED
615		for _, v := range types {
616			if json {
617				dbt.mustExec(forceJSON)
618			}
619			dbt.mustExec("CREATE TABLE test (value " + v + ")")
620			dbt.mustExec("INSERT INTO test VALUES (?)", in)
621			rows = dbt.mustQuery("SELECT value FROM test")
622			defer rows.Close()
623			if rows.Next() {
624				rows.Scan(&out)
625				if in != out {
626					dbt.Errorf("%s: %d != %d", v, in, out)
627				}
628			} else {
629				dbt.Errorf("%s: no data", v)
630			}
631
632			dbt.mustExec("DROP TABLE IF EXISTS test")
633		}
634	})
635}
636
637func TestArrowBigInt(t *testing.T) {
638	db := openDB(t)
639	dbt := &DBTest{t, db}
640
641	testcases := []struct {
642		num  string
643		prec int
644		sc   int
645	}{
646		{"10000000000000000000000000000000000000", 38, 0},
647		{"-10000000000000000000000000000000000000", 38, 0},
648		{"12345678901234567890123456789012345678", 38, 0},
649		{"-12345678901234567890123456789012345678", 38, 0},
650		{"99999999999999999999999999999999999999", 38, 0},
651		{"-99999999999999999999999999999999999999", 38, 0},
652	}
653
654	for _, tc := range testcases {
655		rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()),
656			fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
657		if !rows.Next() {
658			dbt.Error("failed to query")
659		}
660		defer rows.Close()
661		var v *big.Int
662		if err := rows.Scan(&v); err != nil {
663			dbt.Errorf("failed to scan. %#v", err)
664		}
665
666		b, ok := new(big.Int).SetString(tc.num, 10)
667		if !ok {
668			dbt.Errorf("failed to convert %v big.Int.", tc.num)
669		}
670		if v.Cmp(b) != 0 {
671			dbt.Errorf("big.Int value mismatch: expected %v, got %v", b, v)
672		}
673	}
674}
675
676func TestFloat32(t *testing.T) {
677	testFloat32(t, false)
678}
679
680func testFloat32(t *testing.T, json bool) {
681	runTests(t, dsn, func(dbt *DBTest) {
682		types := [2]string{"FLOAT", "DOUBLE"}
683		in := float32(42.23)
684		var out float32
685		var rows *RowsExtended
686		for _, v := range types {
687			if json {
688				dbt.mustExec(forceJSON)
689			}
690			dbt.mustExec("CREATE TABLE test (value " + v + ")")
691			dbt.mustExec("INSERT INTO test VALUES (?)", in)
692			rows = dbt.mustQuery("SELECT value FROM test")
693			defer rows.Close()
694			if rows.Next() {
695				err := rows.Scan(&out)
696				if err != nil {
697					dbt.Errorf("failed to scan data: %v", err)
698				}
699				if in != out {
700					dbt.Errorf("%s: %g != %g", v, in, out)
701				}
702			} else {
703				dbt.Errorf("%s: no data", v)
704			}
705			dbt.mustExec("DROP TABLE IF EXISTS test")
706		}
707	})
708}
709
710func TestFloat64(t *testing.T) {
711	testFloat64(t, false)
712}
713
714func testFloat64(t *testing.T, json bool) {
715	runTests(t, dsn, func(dbt *DBTest) {
716		types := [2]string{"FLOAT", "DOUBLE"}
717		expected := 42.23
718		var out float64
719		var rows *RowsExtended
720		for _, v := range types {
721			if json {
722				dbt.mustExec(forceJSON)
723			}
724			dbt.mustExec("CREATE TABLE test (value " + v + ")")
725			dbt.mustExec("INSERT INTO test VALUES (42.23)")
726			rows = dbt.mustQuery("SELECT value FROM test")
727			defer rows.Close()
728			if rows.Next() {
729				rows.Scan(&out)
730				if expected != out {
731					dbt.Errorf("%s: %g != %g", v, expected, out)
732				}
733			} else {
734				dbt.Errorf("%s: no data", v)
735			}
736			dbt.mustExec("DROP TABLE IF EXISTS test")
737		}
738	})
739}
740
741func TestArrowBigFloat(t *testing.T) {
742	db := openDB(t)
743	dbt := &DBTest{t, db}
744
745	testcases := []struct {
746		num  string
747		prec int
748		sc   int
749	}{
750		{"1.23", 30, 2},
751		{"1.0000000000000000000000000000000000000", 38, 37},
752		{"-1.0000000000000000000000000000000000000", 38, 37},
753		{"1.2345678901234567890123456789012345678", 38, 37},
754		{"-1.2345678901234567890123456789012345678", 38, 37},
755		{"9.9999999999999999999999999999999999999", 38, 37},
756		{"-9.9999999999999999999999999999999999999", 38, 37},
757	}
758
759	for _, tc := range testcases {
760		rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()),
761			fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
762		if !rows.Next() {
763			dbt.Error("failed to query")
764		}
765		defer rows.Close()
766		var v *big.Float
767		if err := rows.Scan(&v); err != nil {
768			dbt.Errorf("failed to scan. %#v", err)
769		}
770
771		prec := v.Prec()
772		b, ok := new(big.Float).SetPrec(prec).SetString(tc.num)
773		if !ok {
774			dbt.Errorf("failed to convert %v to big.Float.", tc.num)
775		}
776		if v.Cmp(b) != 0 {
777			dbt.Errorf("big.Float value mismatch: expected %v, got %v", b, v)
778		}
779	}
780}
781
782func TestArrowIntPrecision(t *testing.T) {
783	db := openDB(t)
784	dbt := &DBTest{t, db}
785
786	intTestcases := []struct {
787		num  string
788		prec int
789		sc   int
790	}{
791		{"10000000000000000000000000000000000000", 38, 0},
792		{"-10000000000000000000000000000000000000", 38, 0},
793		{"12345678901234567890123456789012345678", 38, 0},
794		{"-12345678901234567890123456789012345678", 38, 0},
795		{"99999999999999999999999999999999999999", 38, 0},
796		{"-99999999999999999999999999999999999999", 38, 0},
797	}
798
799	t.Run("arrow_disabled_scan_int64", func(t *testing.T) {
800		for _, tc := range intTestcases {
801			dbt.mustExec(forceJSON)
802			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
803			if !rows.Next() {
804				dbt.Error("failed to query")
805			}
806			defer rows.Close()
807			var v int64
808			if err := rows.Scan(&v); err == nil {
809				dbt.Error("should fail to scan")
810			}
811		}
812	})
813	t.Run("arrow_disabled_scan_string", func(t *testing.T) {
814		for _, tc := range intTestcases {
815			dbt.mustExec(forceJSON)
816			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
817			if !rows.Next() {
818				dbt.Error("failed to query")
819			}
820			defer rows.Close()
821			var v int64
822			if err := rows.Scan(&v); err == nil {
823				dbt.Error("should fail to scan")
824			}
825		}
826	})
827	t.Run("arrow_enabled_scan_big_int", func(t *testing.T) {
828		for _, tc := range intTestcases {
829			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
830			if !rows.Next() {
831				dbt.Error("failed to query")
832			}
833			defer rows.Close()
834			var v string
835			if err := rows.Scan(&v); err != nil {
836				dbt.Errorf("failed to scan. %#v", err)
837			}
838			if !strings.EqualFold(v, tc.num) {
839				dbt.Errorf("int value mismatch: expected %v, got %v", tc.num, v)
840			}
841		}
842	})
843	t.Run("arrow_high_precision_enabled_scan_big_int", func(t *testing.T) {
844		for _, tc := range intTestcases {
845			rows := dbt.mustQueryContext(
846				WithHigherPrecision(context.Background()),
847				fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
848			if !rows.Next() {
849				dbt.Error("failed to query")
850			}
851			defer rows.Close()
852			var v *big.Int
853			if err := rows.Scan(&v); err != nil {
854				dbt.Errorf("failed to scan. %#v", err)
855			}
856
857			b, ok := new(big.Int).SetString(tc.num, 10)
858			if !ok {
859				dbt.Errorf("failed to convert %v big.Int.", tc.num)
860			}
861			if v.Cmp(b) != 0 {
862				dbt.Errorf("big.Int value mismatch: expected %v, got %v", b, v)
863			}
864		}
865	})
866}
867
868// TestArrowFloatPrecision tests the different variable types allowed in the
869// rows.Scan() method. Note that for lower precision types we do not attempt
870// to check the value as precision could be lost.
871func TestArrowFloatPrecision(t *testing.T) {
872	db := openDB(t)
873	dbt := &DBTest{t, db}
874
875	fltTestcases := []struct {
876		num  string
877		prec int
878		sc   int
879	}{
880		{"1.23", 30, 2},
881		{"1.0000000000000000000000000000000000000", 38, 37},
882		{"-1.0000000000000000000000000000000000000", 38, 37},
883		{"1.2345678901234567890123456789012345678", 38, 37},
884		{"-1.2345678901234567890123456789012345678", 38, 37},
885		{"9.9999999999999999999999999999999999999", 38, 37},
886		{"-9.9999999999999999999999999999999999999", 38, 37},
887	}
888
889	t.Run("arrow_disabled_scan_float64", func(t *testing.T) {
890		for _, tc := range fltTestcases {
891			dbt.mustExec(forceJSON)
892			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
893			if !rows.Next() {
894				dbt.Error("failed to query")
895			}
896			defer rows.Close()
897			var v float64
898			if err := rows.Scan(&v); err != nil {
899				dbt.Errorf("failed to scan. %#v", err)
900			}
901		}
902	})
903	t.Run("arrow_disabled_scan_float32", func(t *testing.T) {
904		for _, tc := range fltTestcases {
905			dbt.mustExec(forceJSON)
906			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
907			if !rows.Next() {
908				dbt.Error("failed to query")
909			}
910			defer rows.Close()
911			var v float32
912			if err := rows.Scan(&v); err != nil {
913				dbt.Errorf("failed to scan. %#v", err)
914			}
915		}
916	})
917	t.Run("arrow_disabled_scan_string", func(t *testing.T) {
918		for _, tc := range fltTestcases {
919			dbt.mustExec(forceJSON)
920			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
921			if !rows.Next() {
922				dbt.Error("failed to query")
923			}
924			defer rows.Close()
925			var v string
926			if err := rows.Scan(&v); err != nil {
927				dbt.Errorf("failed to scan. %#v", err)
928			}
929			if !strings.EqualFold(v, tc.num) {
930				dbt.Errorf("int value mismatch: expected %v, got %v", tc.num, v)
931			}
932		}
933	})
934	t.Run("arrow_enabled_scan_float64", func(t *testing.T) {
935		for _, tc := range fltTestcases {
936			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
937			if !rows.Next() {
938				dbt.Error("failed to query")
939			}
940			defer rows.Close()
941			var v float64
942			if err := rows.Scan(&v); err != nil {
943				dbt.Errorf("failed to scan. %#v", err)
944			}
945		}
946	})
947	t.Run("arrow_enabled_scan_float32", func(t *testing.T) {
948		for _, tc := range fltTestcases {
949			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
950			if !rows.Next() {
951				dbt.Error("failed to query")
952			}
953			defer rows.Close()
954			var v float32
955			if err := rows.Scan(&v); err != nil {
956				dbt.Errorf("failed to scan. %#v", err)
957			}
958		}
959	})
960	t.Run("arrow_enabled_scan_string", func(t *testing.T) {
961		for _, tc := range fltTestcases {
962			rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
963			if !rows.Next() {
964				dbt.Error("failed to query")
965			}
966			defer rows.Close()
967			var v string
968			if err := rows.Scan(&v); err != nil {
969				dbt.Errorf("failed to scan. %#v", err)
970			}
971		}
972	})
973	t.Run("arrow_high_precision_enabled_scan_big_float", func(t *testing.T) {
974		for _, tc := range fltTestcases {
975			rows := dbt.mustQueryContext(
976				WithHigherPrecision(context.Background()),
977				fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc))
978			if !rows.Next() {
979				dbt.Error("failed to query")
980			}
981			defer rows.Close()
982			var v *big.Float
983			if err := rows.Scan(&v); err != nil {
984				dbt.Errorf("failed to scan. %#v", err)
985			}
986
987			prec := v.Prec()
988			b, ok := new(big.Float).SetPrec(prec).SetString(tc.num)
989			if !ok {
990				dbt.Errorf("failed to convert %v to big.Float.", tc.num)
991			}
992			if v.Cmp(b) != 0 {
993				dbt.Errorf("big.Float value mismatch: expected %v, got %v", b, v)
994			}
995		}
996	})
997}
998
999func TestArrowVariousTypes(t *testing.T) {
1000	runTests(t, dsn, func(dbt *DBTest) {
1001		rows := dbt.mustQueryContext(
1002			WithHigherPrecision(context.Background()), selectVariousTypes)
1003		defer rows.Close()
1004		if !rows.Next() {
1005			dbt.Error("failed to query")
1006		}
1007		cc, err := rows.Columns()
1008		if err != nil {
1009			dbt.Errorf("columns: %v", cc)
1010		}
1011		ct, err := rows.ColumnTypes()
1012		if err != nil {
1013			dbt.Errorf("column types: %v", ct)
1014		}
1015		var v1 *big.Float
1016		var v2 int
1017		var v3 string
1018		var v4 float64
1019		var v5 []byte
1020		var v6 bool
1021		if err = rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil {
1022			dbt.Errorf("failed to scan: %#v", err)
1023		}
1024		if v1.Cmp(big.NewFloat(1.0)) != 0 {
1025			dbt.Errorf("failed to scan. %#v", *v1)
1026		}
1027		if ct[0].Name() != "C1" || ct[1].Name() != "C2" || ct[2].Name() != "C3" || ct[3].Name() != "C4" || ct[4].Name() != "C5" || ct[5].Name() != "C6" {
1028			dbt.Errorf("failed to get column names: %#v", ct)
1029		}
1030		if ct[0].ScanType() != reflect.TypeOf(float64(0)) {
1031			dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeOf(float64(0)), ct[0].ScanType())
1032		}
1033		if ct[1].ScanType() != reflect.TypeOf(int64(0)) {
1034			dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeOf(int64(0)), ct[1].ScanType())
1035		}
1036		var pr, sc int64
1037		var cLen int64
1038		pr, sc = dbt.mustDecimalSize(ct[0])
1039		if pr != 30 || sc != 2 {
1040			dbt.Errorf("failed to get precision and scale. %#v", ct[0])
1041		}
1042		dbt.mustFailLength(ct[0])
1043		if canNull := dbt.mustNullable(ct[0]); canNull {
1044			dbt.Errorf("failed to get nullable. %#v", ct[0])
1045		}
1046		if cLen != 0 {
1047			dbt.Errorf("failed to get length. %#v", ct[0])
1048		}
1049		if v2 != 2 {
1050			dbt.Errorf("failed to scan. %#v", v2)
1051		}
1052		pr, sc = dbt.mustDecimalSize(ct[1])
1053		if pr != 38 || sc != 0 {
1054			dbt.Errorf("failed to get precision and scale. %#v", ct[1])
1055		}
1056		dbt.mustFailLength(ct[1])
1057		if canNull := dbt.mustNullable(ct[1]); canNull {
1058			dbt.Errorf("failed to get nullable. %#v", ct[1])
1059		}
1060		if v3 != "t3" {
1061			dbt.Errorf("failed to scan. %#v", v3)
1062		}
1063		dbt.mustFailDecimalSize(ct[2])
1064		if cLen = dbt.mustLength(ct[2]); cLen != 2 {
1065			dbt.Errorf("failed to get length. %#v", ct[2])
1066		}
1067		if canNull := dbt.mustNullable(ct[2]); canNull {
1068			dbt.Errorf("failed to get nullable. %#v", ct[2])
1069		}
1070		if v4 != 4.2 {
1071			dbt.Errorf("failed to scan. %#v", v4)
1072		}
1073		dbt.mustFailDecimalSize(ct[3])
1074		dbt.mustFailLength(ct[3])
1075		if canNull := dbt.mustNullable(ct[3]); canNull {
1076			dbt.Errorf("failed to get nullable. %#v", ct[3])
1077		}
1078		if !bytes.Equal(v5, []byte{0xab, 0xcd}) {
1079			dbt.Errorf("failed to scan. %#v", v5)
1080		}
1081		dbt.mustFailDecimalSize(ct[4])
1082		if cLen = dbt.mustLength(ct[4]); cLen != 8388608 { // BINARY
1083			dbt.Errorf("failed to get length. %#v", ct[4])
1084		}
1085		if canNull := dbt.mustNullable(ct[4]); canNull {
1086			dbt.Errorf("failed to get nullable. %#v", ct[4])
1087		}
1088		if !v6 {
1089			dbt.Errorf("failed to scan. %#v", v6)
1090		}
1091		dbt.mustFailDecimalSize(ct[5])
1092		dbt.mustFailLength(ct[5])
1093		/*canNull = dbt.mustNullable(ct[5])
1094		if canNull {
1095			dbt.Errorf("failed to get nullable. %#v", ct[5])
1096		}*/
1097	})
1098}
1099
1100func TestString(t *testing.T) {
1101	testString(t, false)
1102}
1103
1104func testString(t *testing.T, json bool) {
1105	runTests(t, dsn, func(dbt *DBTest) {
1106		if json {
1107			dbt.mustExec(forceJSON)
1108		}
1109		types := []string{"CHAR(255)", "VARCHAR(255)", "TEXT", "STRING"}
1110		in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах  น่าฟังเอย"
1111		var out string
1112		var rows *RowsExtended
1113
1114		for _, v := range types {
1115			dbt.mustExec("CREATE TABLE test (value " + v + ")")
1116			dbt.mustExec("INSERT INTO test VALUES (?)", in)
1117
1118			rows = dbt.mustQuery("SELECT value FROM test")
1119			defer rows.Close()
1120			if rows.Next() {
1121				rows.Scan(&out)
1122				if in != out {
1123					dbt.Errorf("%s: %s != %s", v, in, out)
1124				}
1125			} else {
1126				dbt.Errorf("%s: no data", v)
1127			}
1128			dbt.mustExec("DROP TABLE IF EXISTS test")
1129		}
1130
1131		// BLOB (Snowflake doesn't support BLOB type but STRING covers large text data)
1132		dbt.mustExec("CREATE TABLE test (id int, value STRING)")
1133
1134		id := 2
1135		in = `Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam
1136			nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam
1137			erat, sed diam voluptua. At vero eos et accusam et justo duo
1138			dolores et ea rebum. Stet clita kasd gubergren, no sea takimata
1139			sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet,
1140			consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt
1141			ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero
1142			eos et accusam et justo duo dolores et ea rebum. Stet clita kasd
1143			gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.`
1144		dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in)
1145
1146		if err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil {
1147			dbt.Fatalf("Error on BLOB-Query: %s", err.Error())
1148		} else if out != in {
1149			dbt.Errorf("BLOB: %s != %s", in, out)
1150		}
1151	})
1152}
1153
1154type tcDateTimeTimestamp struct {
1155	dbtype  string
1156	tlayout string
1157	tests   []timeTest
1158}
1159
1160type timeTest struct {
1161	s string    // source date time string
1162	t time.Time // expected fetched data
1163}
1164
1165func (tt timeTest) genQuery() string {
1166	return "SELECT '%s'::%s"
1167}
1168
1169func (tt timeTest) run(t *testing.T, dbt *DBTest, dbtype, tlayout string) {
1170	var rows *RowsExtended
1171	query := fmt.Sprintf(tt.genQuery(), tt.s, dbtype)
1172	rows = dbt.mustQuery(query)
1173	defer rows.Close()
1174	var err error
1175	if !rows.Next() {
1176		err = rows.Err()
1177		if err == nil {
1178			err = fmt.Errorf("no data")
1179		}
1180		dbt.Errorf("%s: %s", dbtype, err)
1181		return
1182	}
1183
1184	var dst interface{}
1185	if err = rows.Scan(&dst); err != nil {
1186		dbt.Errorf("%s: %s", dbtype, err)
1187		return
1188	}
1189	switch val := dst.(type) {
1190	case []uint8:
1191		str := string(val)
1192		if str == tt.s {
1193			return
1194		}
1195		dbt.Errorf("%s to string: expected %q, got %q",
1196			dbtype,
1197			tt.s,
1198			str,
1199		)
1200	case time.Time:
1201		if val.UnixNano() == tt.t.UnixNano() {
1202			return
1203		}
1204		t.Logf("source:%v, expected: %v, got:%v", tt.s, tt.t, val)
1205		dbt.Errorf("%s to string: expected %q, got %q",
1206			dbtype,
1207			tt.s,
1208			val.Format(tlayout),
1209		)
1210	default:
1211		dbt.Errorf("%s: unhandled type %T (is '%v')",
1212			dbtype, val, val,
1213		)
1214	}
1215}
1216
1217func TestSimpleDateTimeTimestampFetch(t *testing.T) {
1218	testSimpleDateTimeTimestampFetch(t, false)
1219}
1220
1221func testSimpleDateTimeTimestampFetch(t *testing.T, json bool) {
1222	var scan = func(rows *RowsExtended, cd interface{}, ct interface{}, cts interface{}) {
1223		if err := rows.Scan(cd, ct, cts); err != nil {
1224			t.Fatal(err)
1225		}
1226	}
1227	var fetchTypes = []func(*RowsExtended){
1228		func(rows *RowsExtended) {
1229			var cd, ct, cts time.Time
1230			scan(rows, &cd, &ct, &cts)
1231		},
1232		func(rows *RowsExtended) {
1233			var cd, ct, cts time.Time
1234			scan(rows, &cd, &ct, &cts)
1235		},
1236	}
1237	runTests(t, dsn, func(dbt *DBTest) {
1238		if json {
1239			dbt.mustExec(forceJSON)
1240		}
1241		for _, f := range fetchTypes {
1242			rows := dbt.mustQuery("SELECT CURRENT_DATE(), CURRENT_TIME(), CURRENT_TIMESTAMP()")
1243			defer rows.Close()
1244			if rows.Next() {
1245				f(rows)
1246			} else {
1247				t.Fatal("no results")
1248			}
1249		}
1250	})
1251}
1252
1253func TestDateTime(t *testing.T) {
1254	testDateTime(t, false)
1255}
1256
1257func testDateTime(t *testing.T, json bool) {
1258	afterTime := func(t time.Time, d string) time.Time {
1259		dur, err := time.ParseDuration(d)
1260		if err != nil {
1261			panic(err)
1262		}
1263		return t.Add(dur)
1264	}
1265	t0 := time.Time{}
1266	tstr0 := "0000-00-00 00:00:00.000000000"
1267	testcases := []tcDateTimeTimestamp{
1268		{"DATE", format[:10], []timeTest{
1269			{t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)},
1270			{t: time.Date(2, 8, 2, 0, 0, 0, 0, time.UTC), s: "0002-08-02"},
1271		}},
1272		{"TIME", format[11:19], []timeTest{
1273			{t: afterTime(t0, "12345s")},
1274			{t: t0, s: tstr0[11:19]},
1275		}},
1276		{"TIME(0)", format[11:19], []timeTest{
1277			{t: afterTime(t0, "12345s")},
1278			{t: t0, s: tstr0[11:19]},
1279		}},
1280		{"TIME(1)", format[11:21], []timeTest{
1281			{t: afterTime(t0, "12345600ms")},
1282			{t: t0, s: tstr0[11:21]},
1283		}},
1284		{"TIME(6)", format[11:], []timeTest{
1285			{t: t0, s: tstr0[11:]},
1286		}},
1287		{"DATETIME", format[:19], []timeTest{
1288			{t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
1289		}},
1290		{"DATETIME(0)", format[:21], []timeTest{
1291			{t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
1292		}},
1293		{"DATETIME(1)", format[:21], []timeTest{
1294			{t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)},
1295		}},
1296		{"DATETIME(6)", format, []timeTest{
1297			{t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)},
1298		}},
1299		{"DATETIME(9)", format, []timeTest{
1300			{t: time.Date(2011, 11, 20, 21, 27, 37, 123456789, time.UTC)},
1301		}},
1302	}
1303	runTests(t, dsn, func(dbt *DBTest) {
1304		if json {
1305			dbt.mustExec(forceJSON)
1306		}
1307		for _, setups := range testcases {
1308			for _, setup := range setups.tests {
1309				if setup.s == "" {
1310					// fill time string wherever Go can reliable produce it
1311					setup.s = setup.t.Format(setups.tlayout)
1312				}
1313				setup.run(t, dbt, setups.dbtype, setups.tlayout)
1314			}
1315		}
1316	})
1317}
1318
1319func TestTimestampLTZ(t *testing.T) {
1320	testTimestampLTZ(t, false)
1321}
1322
1323func testTimestampLTZ(t *testing.T, json bool) {
1324	// Set session time zone in Los Angeles, same as machine
1325	createDSN(PSTLocation)
1326	location, err := time.LoadLocation(PSTLocation)
1327	if err != nil {
1328		t.Error(err)
1329	}
1330	testcases := []tcDateTimeTimestamp{
1331		{
1332			dbtype:  "TIMESTAMP_LTZ(9)",
1333			tlayout: format,
1334			tests: []timeTest{
1335				{
1336					s: "2016-12-30 05:02:03",
1337					t: time.Date(2016, 12, 30, 5, 2, 3, 0, location),
1338				},
1339				{
1340					s: "2016-12-30 05:02:03 -00:00",
1341					t: time.Date(2016, 12, 30, 5, 2, 3, 0, time.UTC),
1342				},
1343				{
1344					s: "2017-05-12 00:51:42",
1345					t: time.Date(2017, 5, 12, 0, 51, 42, 0, location),
1346				},
1347				{
1348					s: "2017-03-12 01:00:00",
1349					t: time.Date(2017, 3, 12, 1, 0, 0, 0, location),
1350				},
1351				{
1352					s: "2017-03-13 04:00:00",
1353					t: time.Date(2017, 3, 13, 4, 0, 0, 0, location),
1354				},
1355				{
1356					s: "2017-03-13 04:00:00.123456789",
1357					t: time.Date(2017, 3, 13, 4, 0, 0, 123456789, location),
1358				},
1359			},
1360		},
1361		{
1362			dbtype:  "TIMESTAMP_LTZ(8)",
1363			tlayout: format,
1364			tests: []timeTest{
1365				{
1366					s: "2017-03-13 04:00:00.123456789",
1367					t: time.Date(2017, 3, 13, 4, 0, 0, 123456780, location),
1368				},
1369			},
1370		},
1371	}
1372	runTests(t, dsn, func(dbt *DBTest) {
1373		if json {
1374			dbt.mustExec(forceJSON)
1375		}
1376		for _, setups := range testcases {
1377			for _, setup := range setups.tests {
1378				if setup.s == "" {
1379					// fill time string wherever Go can reliable produce it
1380					setup.s = setup.t.Format(setups.tlayout)
1381				}
1382				setup.run(t, dbt, setups.dbtype, setups.tlayout)
1383			}
1384		}
1385	})
1386	// Revert timezone to UTC, which is default for the test suit
1387	createDSN("UTC")
1388}
1389
1390func TestTimestampTZ(t *testing.T) {
1391	testTimestampTZ(t, false)
1392}
1393
1394func testTimestampTZ(t *testing.T, json bool) {
1395	sflo := func(offsets string) (loc *time.Location) {
1396		r, err := LocationWithOffsetString(offsets)
1397		if err != nil {
1398			return time.UTC
1399		}
1400		return r
1401	}
1402	testcases := []tcDateTimeTimestamp{
1403		{
1404			dbtype:  "TIMESTAMP_TZ(9)",
1405			tlayout: format,
1406			tests: []timeTest{
1407				{
1408					s: "2016-12-30 05:02:03 +07:00",
1409					t: time.Date(2016, 12, 30, 5, 2, 3, 0,
1410						sflo("+0700")),
1411				},
1412				{
1413					s: "2017-05-23 03:56:41 -09:00",
1414					t: time.Date(2017, 5, 23, 3, 56, 41, 0,
1415						sflo("-0900")),
1416				},
1417			},
1418		},
1419	}
1420	runTests(t, dsn, func(dbt *DBTest) {
1421		if json {
1422			dbt.mustExec(forceJSON)
1423		}
1424		for _, setups := range testcases {
1425			for _, setup := range setups.tests {
1426				if setup.s == "" {
1427					// fill time string wherever Go can reliable produce it
1428					setup.s = setup.t.Format(setups.tlayout)
1429				}
1430				setup.run(t, dbt, setups.dbtype, setups.tlayout)
1431			}
1432		}
1433	})
1434}
1435
1436func TestNULL(t *testing.T) {
1437	testNULL(t, false)
1438}
1439
1440func testNULL(t *testing.T, json bool) {
1441	runTests(t, dsn, func(dbt *DBTest) {
1442		if json {
1443			dbt.mustExec(forceJSON)
1444		}
1445		nullStmt, err := dbt.db.Prepare("SELECT NULL")
1446		if err != nil {
1447			dbt.Fatal(err)
1448		}
1449		defer nullStmt.Close()
1450
1451		nonNullStmt, err := dbt.db.Prepare("SELECT 1")
1452		if err != nil {
1453			dbt.Fatal(err)
1454		}
1455		defer nonNullStmt.Close()
1456
1457		// NullBool
1458		var nb sql.NullBool
1459		// Invalid
1460		if err = nullStmt.QueryRow().Scan(&nb); err != nil {
1461			dbt.Fatal(err)
1462		}
1463		if nb.Valid {
1464			dbt.Error("valid NullBool which should be invalid")
1465		}
1466		// Valid
1467		if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
1468			dbt.Fatal(err)
1469		}
1470		if !nb.Valid {
1471			dbt.Error("invalid NullBool which should be valid")
1472		} else if !nb.Bool {
1473			dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool)
1474		}
1475
1476		// NullFloat64
1477		var nf sql.NullFloat64
1478		// Invalid
1479		if err = nullStmt.QueryRow().Scan(&nf); err != nil {
1480			dbt.Fatal(err)
1481		}
1482		if nf.Valid {
1483			dbt.Error("valid NullFloat64 which should be invalid")
1484		}
1485		// Valid
1486		if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
1487			dbt.Fatal(err)
1488		}
1489		if !nf.Valid {
1490			dbt.Error("invalid NullFloat64 which should be valid")
1491		} else if nf.Float64 != float64(1) {
1492			dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
1493		}
1494
1495		// NullInt64
1496		var ni sql.NullInt64
1497		// Invalid
1498		if err = nullStmt.QueryRow().Scan(&ni); err != nil {
1499			dbt.Fatal(err)
1500		}
1501		if ni.Valid {
1502			dbt.Error("valid NullInt64 which should be invalid")
1503		}
1504		// Valid
1505		if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
1506			dbt.Fatal(err)
1507		}
1508		if !ni.Valid {
1509			dbt.Error("invalid NullInt64 which should be valid")
1510		} else if ni.Int64 != int64(1) {
1511			dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64)
1512		}
1513
1514		// NullString
1515		var ns sql.NullString
1516		// Invalid
1517		if err = nullStmt.QueryRow().Scan(&ns); err != nil {
1518			dbt.Fatal(err)
1519		}
1520		if ns.Valid {
1521			dbt.Error("valid NullString which should be invalid")
1522		}
1523		// Valid
1524		if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
1525			dbt.Fatal(err)
1526		}
1527		if !ns.Valid {
1528			dbt.Error("invalid NullString which should be valid")
1529		} else if ns.String != `1` {
1530			dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)")
1531		}
1532
1533		// nil-bytes
1534		var b []byte
1535		// Read nil
1536		if err = nullStmt.QueryRow().Scan(&b); err != nil {
1537			dbt.Fatal(err)
1538		}
1539		if b != nil {
1540			dbt.Error("non-nil []byte which should be nil")
1541		}
1542		// Read non-nil
1543		if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
1544			dbt.Fatal(err)
1545		}
1546		if b == nil {
1547			dbt.Error("nil []byte which should be non-nil")
1548		}
1549		// Insert nil
1550		b = nil
1551		success := false
1552		if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil {
1553			dbt.Fatal(err)
1554		}
1555		if !success {
1556			dbt.Error("inserting []byte(nil) as NULL failed")
1557			t.Fatal("stopping")
1558		}
1559		// Check input==output with input==nil
1560		b = nil
1561		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
1562			dbt.Fatal(err)
1563		}
1564		if b != nil {
1565			dbt.Error("non-nil echo from nil input")
1566		}
1567		// Check input==output with input!=nil
1568		b = []byte("")
1569		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
1570			dbt.Fatal(err)
1571		}
1572		if b == nil {
1573			dbt.Error("nil echo from non-nil input")
1574		}
1575
1576		// Insert NULL
1577		dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")
1578		dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2)
1579
1580		var out interface{}
1581		rows := dbt.mustQuery("SELECT * FROM test")
1582		defer rows.Close()
1583		if rows.Next() {
1584			rows.Scan(&out)
1585			if out != nil {
1586				dbt.Errorf("%v != nil", out)
1587			}
1588		} else {
1589			dbt.Error("no data")
1590		}
1591	})
1592}
1593
1594func TestVariant(t *testing.T) {
1595	testVariant(t, false)
1596}
1597
1598func testVariant(t *testing.T, json bool) {
1599	runTests(t, dsn, func(dbt *DBTest) {
1600		if json {
1601			dbt.mustExec(forceJSON)
1602		}
1603		rows := dbt.mustQuery(`select parse_json('[{"id":1, "name":"test1"},{"id":2, "name":"test2"}]')`)
1604		defer rows.Close()
1605		var v string
1606		if rows.Next() {
1607			if err := rows.Scan(&v); err != nil {
1608				t.Fatal(err)
1609			}
1610		} else {
1611			t.Fatal("no rows")
1612		}
1613	})
1614}
1615
1616func TestArray(t *testing.T) {
1617	testArray(t, false)
1618}
1619
1620func testArray(t *testing.T, json bool) {
1621	runTests(t, dsn, func(dbt *DBTest) {
1622		if json {
1623			dbt.mustExec(forceJSON)
1624		}
1625		rows := dbt.mustQuery(`select as_array(parse_json('[{"id":1, "name":"test1"},{"id":2, "name":"test2"}]'))`)
1626		defer rows.Close()
1627		var v string
1628		if rows.Next() {
1629			if err := rows.Scan(&v); err != nil {
1630				t.Fatal(err)
1631			}
1632		} else {
1633			t.Fatal("no rows")
1634		}
1635	})
1636}
1637
1638func TestLargeSetResult(t *testing.T) {
1639	CustomJSONDecoderEnabled = false
1640	testLargeSetResult(t, 100000, false)
1641}
1642
1643func testLargeSetResult(t *testing.T, numrows int, json bool) {
1644	runTests(t, dsn, func(dbt *DBTest) {
1645		if json {
1646			dbt.mustExec(forceJSON)
1647		}
1648		rows := dbt.mustQuery(fmt.Sprintf(selectRandomGenerator, numrows))
1649		defer rows.Close()
1650		cnt := 0
1651		var idx int
1652		var v string
1653		for rows.Next() {
1654			if err := rows.Scan(&idx, &v); err != nil {
1655				t.Fatal(err)
1656			}
1657			cnt++
1658		}
1659		logger.Infof("NextResultSet: %v", rows.NextResultSet())
1660
1661		if cnt != numrows {
1662			dbt.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt)
1663		}
1664	})
1665}
1666
1667func TestPingpongQuery(t *testing.T) {
1668	runTests(t, dsn, func(dbt *DBTest) {
1669		numrows := 1
1670		rows := dbt.mustQuery("SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 60))")
1671		defer rows.Close()
1672		cnt := 0
1673		for rows.Next() {
1674			cnt++
1675		}
1676		if cnt != numrows {
1677			dbt.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt)
1678		}
1679	})
1680}
1681
1682func TestDML(t *testing.T) {
1683	runTests(t, dsn, func(dbt *DBTest) {
1684		dbt.mustExec("CREATE OR REPLACE TABLE test(c1 int, c2 string)")
1685		if err := insertData(dbt, false); err != nil {
1686			dbt.Fatalf("failed to insert data: %v", err)
1687		}
1688		results, err := queryTest(dbt)
1689		if err != nil {
1690			dbt.Fatalf("failed to query test table: %v", err)
1691		}
1692		if len(*results) != 0 {
1693			dbt.Fatalf("number of returned data didn't match. expected 0, got: %v", len(*results))
1694		}
1695		if err = insertData(dbt, true); err != nil {
1696			dbt.Fatalf("failed to insert data: %v", err)
1697		}
1698		results, err = queryTest(dbt)
1699		if err != nil {
1700			dbt.Fatalf("failed to query test table: %v", err)
1701		}
1702		if len(*results) != 2 {
1703			dbt.Fatalf("number of returned data didn't match. expected 2, got: %v", len(*results))
1704		}
1705	})
1706}
1707
1708func insertData(dbt *DBTest, commit bool) error {
1709	tx, err := dbt.db.Begin()
1710	if err != nil {
1711		dbt.Fatalf("failed to begin transaction: %v", err)
1712	}
1713	res, err := tx.Exec("INSERT INTO test VALUES(1, 'test1'), (2, 'test2')")
1714	if err != nil {
1715		dbt.Fatalf("failed to insert value into test: %v", err)
1716	}
1717	n, err := res.RowsAffected()
1718	if err != nil {
1719		dbt.Fatalf("failed to rows affected: %v", err)
1720	}
1721	if n != 2 {
1722		dbt.Fatalf("failed to insert value into test. expected: 2, got: %v", n)
1723	}
1724	results, err := queryTestTx(tx)
1725	if err != nil {
1726		dbt.Fatalf("failed to query test table: %v", err)
1727	}
1728	if len(*results) != 2 {
1729		dbt.Fatalf("number of returned data didn't match. expected 2, got: %v", len(*results))
1730	}
1731	if commit {
1732		if err = tx.Commit(); err != nil {
1733			return err
1734		}
1735	} else {
1736		if err = tx.Rollback(); err != nil {
1737			return err
1738		}
1739	}
1740	return err
1741}
1742
1743func queryTestTx(tx *sql.Tx) (*map[int]string, error) {
1744	var c1 int
1745	var c2 string
1746	rows, err := tx.Query("SELECT c1, c2 FROM test")
1747	if err != nil {
1748		return nil, err
1749	}
1750	defer rows.Close()
1751
1752	results := make(map[int]string, 2)
1753	for rows.Next() {
1754		if err = rows.Scan(&c1, &c2); err != nil {
1755			return nil, err
1756		}
1757		results[c1] = c2
1758	}
1759	return &results, nil
1760}
1761
1762func queryTest(dbt *DBTest) (*map[int]string, error) {
1763	var c1 int
1764	var c2 string
1765	rows, err := dbt.db.Query("SELECT c1, c2 FROM test")
1766	if err != nil {
1767		return nil, err
1768	}
1769	defer rows.Close()
1770	results := make(map[int]string, 2)
1771	for rows.Next() {
1772		if err = rows.Scan(&c1, &c2); err != nil {
1773			return nil, err
1774		}
1775		results[c1] = c2
1776	}
1777	return &results, nil
1778}
1779
1780// Special cases where rows are already closed
1781func TestRowsClose(t *testing.T) {
1782	runTests(t, dsn, func(dbt *DBTest) {
1783		rows, err := dbt.db.Query("SELECT 1")
1784		if err != nil {
1785			dbt.Fatal(err)
1786		}
1787		if err = rows.Close(); err != nil {
1788			dbt.Fatal(err)
1789		}
1790
1791		if rows.Next() {
1792			dbt.Fatal("unexpected row after rows.Close()")
1793		}
1794		if err = rows.Err(); err != nil {
1795			dbt.Fatal(err)
1796		}
1797	})
1798}
1799
1800func TestResultNoRows(t *testing.T) {
1801	// DDL
1802	runTests(t, dsn, func(dbt *DBTest) {
1803		row, err := dbt.db.Exec("CREATE OR REPLACE TABLE test(c1 int)")
1804		if err != nil {
1805			t.Fatalf("failed to execute DDL. err: %v", err)
1806		}
1807		if _, err = row.RowsAffected(); err == nil {
1808			t.Fatal("should have failed to get RowsAffected")
1809		}
1810		if _, err = row.LastInsertId(); err == nil {
1811			t.Fatal("should have failed to get LastInsertID")
1812		}
1813	})
1814}
1815
1816func TestCancelQuery(t *testing.T) {
1817	runTests(t, dsn, func(dbt *DBTest) {
1818		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
1819		defer cancel()
1820
1821		_, err := dbt.db.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))")
1822		if err == nil {
1823			dbt.Fatal("No timeout error returned")
1824		}
1825		if err.Error() != "context deadline exceeded" {
1826			dbt.Fatalf("Timeout error mismatch: expect %v, receive %v", context.DeadlineExceeded, err.Error())
1827		}
1828	})
1829}
1830
1831func TestInvalidConnection(t *testing.T) {
1832	db := openDB(t)
1833	if err := db.Close(); err != nil {
1834		t.Error("should not cause error in Close")
1835	}
1836	if err := db.Close(); err != nil {
1837		t.Error("should not cause error in the second call of Close")
1838	}
1839	if _, err := db.Exec("CREATE TABLE OR REPLACE test0(c1 int)"); err == nil {
1840		t.Error("should fail to run Exec")
1841	}
1842	if _, err := db.Query("SELECT CURRENT_TIMESTAMP()"); err == nil {
1843		t.Error("should fail to run Query")
1844	}
1845	if _, err := db.Begin(); err == nil {
1846		t.Error("should fail to run Begin")
1847	}
1848}
1849
1850func TestPing(t *testing.T) {
1851	db := openDB(t)
1852	if err := db.Ping(); err != nil {
1853		t.Fatalf("failed to ping. %v, err: %v", dsn, err)
1854	}
1855	if err := db.PingContext(context.Background()); err != nil {
1856		t.Fatalf("failed to ping with context. %v, err: %v", dsn, err)
1857	}
1858	if err := db.Close(); err != nil {
1859		t.Fatalf("failed to close db. %v, err: %v", dsn, err)
1860	}
1861	if err := db.Ping(); err == nil {
1862		t.Fatal("should have failed to ping")
1863	}
1864	if err := db.PingContext(context.Background()); err == nil {
1865		t.Fatal("should have failed to ping with context")
1866	}
1867}
1868
1869func TestDoubleDollar(t *testing.T) {
1870	// no escape is required for dollar signs
1871	runTests(t, dsn, func(dbt *DBTest) {
1872		sql := `create or replace function dateErr(I double) returns date
1873language javascript strict
1874as $$
1875  var x = [
1876    0, "1400000000000",
1877    "2013-04-05",
1878    [], [1400000000000],
1879    "x1234",
1880    Number.NaN, null, undefined,
1881    {},
1882    [1400000000000,1500000000000]
1883  ];
1884  return x[I];
1885$$
1886;`
1887		dbt.mustExec(sql)
1888	})
1889}
1890
1891func TestTransactionOptions(t *testing.T) {
1892	var tx *sql.Tx
1893	var err error
1894
1895	db := openDB(t)
1896	defer db.Close()
1897
1898	tx, err = db.BeginTx(context.Background(), &sql.TxOptions{})
1899	if err != nil {
1900		t.Fatal("failed to start transaction.")
1901	}
1902	if err = tx.Rollback(); err != nil {
1903		t.Fatal("failed to rollback")
1904	}
1905	if _, err = db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}); err == nil {
1906		t.Fatal("should have failed.")
1907	}
1908	if driverErr, ok := err.(*SnowflakeError); !ok || driverErr.Number != ErrNoReadOnlyTransaction {
1909		t.Fatalf("should have returned Snowflake Error: %v", errMsgNoReadOnlyTransaction)
1910	}
1911	if _, err = db.BeginTx(context.Background(), &sql.TxOptions{Isolation: 100}); err == nil {
1912		t.Fatal("should have failed.")
1913	}
1914	if driverErr, ok := err.(*SnowflakeError); !ok || driverErr.Number != ErrNoDefaultTransactionIsolationLevel {
1915		t.Fatalf("should have returned Snowflake Error: %v", errMsgNoDefaultTransactionIsolationLevel)
1916	}
1917}
1918
1919func TestTimezoneSessionParameter(t *testing.T) {
1920	createDSN(PSTLocation)
1921	db := openDB(t)
1922	defer db.Close()
1923
1924	rows, err := db.Query("SHOW PARAMETERS LIKE 'TIMEZONE'")
1925	if err != nil {
1926		t.Errorf("failed to run show parameters. err: %v", err)
1927	}
1928	defer rows.Close()
1929	if !rows.Next() {
1930		t.Fatal("failed to get timezone.")
1931	}
1932
1933	p, err := ScanSnowflakeParameter(rows)
1934	if err != nil {
1935		t.Errorf("failed to run get timezone value. err: %v", err)
1936	}
1937	if p.Value != PSTLocation {
1938		t.Errorf("failed to get an expected timezone. got: %v", p.Value)
1939	}
1940	createDSN("UTC")
1941}
1942
1943func TestLargeSetResultCancel(t *testing.T) {
1944	runTests(t, dsn, func(dbt *DBTest) {
1945		c := make(chan error)
1946		ctx, cancel := context.WithCancel(context.Background())
1947		go func() {
1948			// attempt to run a 100 seconds query, but it should be canceled in 1 second
1949			timelimit := 100
1950			rows, err := dbt.db.QueryContext(
1951				ctx,
1952				fmt.Sprintf("SELECT COUNT(*) FROM TABLE(GENERATOR(timelimit=>%v))", timelimit))
1953			if err != nil {
1954				c <- err
1955				return
1956			}
1957			defer rows.Close()
1958			c <- nil
1959		}()
1960		// cancel after 1 second
1961		time.Sleep(time.Second)
1962		cancel()
1963		ret := <-c
1964		if ret.Error() != "context canceled" {
1965			t.Fatalf("failed to cancel. err: %v", ret)
1966		}
1967		close(c)
1968	})
1969}
1970
1971func TestValidateDatabaseParameter(t *testing.T) {
1972	baseDSN := fmt.Sprintf("%s:%s@%s", user, pass, host)
1973	testcases := []struct {
1974		dsn       string
1975		params    map[string]string
1976		errorCode int
1977	}{
1978		{
1979			dsn:       baseDSN + fmt.Sprintf("/%s/%s", "NOT_EXISTS", "NOT_EXISTS"),
1980			errorCode: ErrObjectNotExistOrAuthorized,
1981		},
1982		{
1983			dsn:       baseDSN + fmt.Sprintf("/%s/%s", dbname, "NOT_EXISTS"),
1984			errorCode: ErrObjectNotExistOrAuthorized,
1985		},
1986		{
1987			dsn: baseDSN + fmt.Sprintf("/%s/%s", dbname, schemaname),
1988			params: map[string]string{
1989				"warehouse": "NOT_EXIST",
1990			},
1991			errorCode: ErrObjectNotExistOrAuthorized,
1992		},
1993		{
1994			dsn: baseDSN + fmt.Sprintf("/%s/%s", dbname, schemaname),
1995			params: map[string]string{
1996				"role": "NOT_EXIST",
1997			},
1998			errorCode: ErrRoleNotExist,
1999		},
2000	}
2001	for idx, tc := range testcases {
2002		newDSN := tc.dsn
2003		parameters := url.Values{}
2004		if protocol != "" {
2005			parameters.Add("protocol", protocol)
2006		}
2007		if account != "" {
2008			parameters.Add("account", account)
2009		}
2010		for k, v := range tc.params {
2011			parameters.Add(k, v)
2012		}
2013		newDSN += "?" + parameters.Encode()
2014		db, err := sql.Open("snowflake", newDSN)
2015		// actual connection won't happen until run a query
2016		if err != nil {
2017			t.Fatalf("error creating a connection object: %s", err.Error())
2018		}
2019		defer db.Close()
2020		if _, err = db.Exec("SELECT 1"); err == nil {
2021			t.Fatal("should cause an error.")
2022		}
2023		if driverErr, ok := err.(*SnowflakeError); ok {
2024			if driverErr.Number != tc.errorCode { // not exist error
2025				t.Errorf("got unexpected error: %v in %v", err, idx)
2026			}
2027		}
2028	}
2029}
2030
2031func TestSpecifyWarehouseDatabase(t *testing.T) {
2032	dsn := fmt.Sprintf("%s:%s@%s/%s", user, pass, host, dbname)
2033	parameters := url.Values{}
2034	parameters.Add("account", account)
2035	parameters.Add("warehouse", warehouse)
2036	// parameters.Add("role", "nopublic") TODO: create nopublic role for test
2037	if protocol != "" {
2038		parameters.Add("protocol", protocol)
2039	}
2040	db, err := sql.Open("snowflake", dsn+"?"+parameters.Encode())
2041	if err != nil {
2042		t.Fatalf("error creating a connection object: %s", err.Error())
2043	}
2044	defer db.Close()
2045	if _, err = db.Exec("SELECT 1"); err != nil {
2046		t.Fatalf("failed to execute a select 1: %v", err)
2047	}
2048}
2049
2050func TestFetchNil(t *testing.T) {
2051	runTests(t, dsn, func(dbt *DBTest) {
2052		rows := dbt.mustQuery("SELECT * FROM values(3,4),(null, 5) order by 2")
2053		defer rows.Close()
2054		var c1 sql.NullInt64
2055		var c2 sql.NullInt64
2056
2057		var results []sql.NullInt64
2058		for rows.Next() {
2059			if err := rows.Scan(&c1, &c2); err != nil {
2060				dbt.Fatal(err)
2061			}
2062			results = append(results, c1)
2063		}
2064		if results[1].Valid {
2065			t.Errorf("First element of second row must be nil (NULL). %v", results)
2066		}
2067	})
2068}
2069
2070func TestPingInvalidHost(t *testing.T) {
2071	config := Config{
2072		Account:      "NOT_EXISTS",
2073		User:         "BOGUS_USER",
2074		Password:     "barbar",
2075		LoginTimeout: 10 * time.Second,
2076	}
2077
2078	testURL, err := DSN(&config)
2079	if err != nil {
2080		t.Fatalf("failed to parse config. config: %v, err: %v", config, err)
2081	}
2082
2083	db, err := sql.Open("snowflake", testURL)
2084	if err != nil {
2085		t.Fatalf("failed to initalize the connetion. err: %v", err)
2086	}
2087	ctx := context.Background()
2088	if err = db.PingContext(ctx); err == nil {
2089		t.Fatal("should cause an error")
2090	}
2091	if driverErr, ok := err.(*SnowflakeError); !ok || ok && driverErr.Number != ErrCodeFailedToConnect {
2092		// Failed to connect error
2093		t.Fatalf("error didn't match")
2094	}
2095}
2096
2097func TestOpenWithConfig(t *testing.T) {
2098	config, err := ParseDSN(dsn)
2099	if err != nil {
2100		t.Fatalf("failed to parse dsn. dsn: %v, err: %v", dsn, err)
2101	}
2102	driver := SnowflakeDriver{}
2103	db, err := driver.OpenWithConfig(context.Background(), *config)
2104	if err != nil {
2105		t.Fatalf("failed to open with config. config: %v, err: %v", config, err)
2106	}
2107	db.Close()
2108}
2109
2110type CountingTransport struct {
2111	requests int
2112}
2113
2114func (t *CountingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
2115	t.requests++
2116	return snowflakeInsecureTransport.RoundTrip(r)
2117}
2118
2119func TestOpenWithTransport(t *testing.T) {
2120	config, err := ParseDSN(dsn)
2121	if err != nil {
2122		t.Fatalf("failed to parse dsn. dsn: %v, err: %v", dsn, err)
2123	}
2124	countingTransport := CountingTransport{}
2125	var transport http.RoundTripper = &countingTransport
2126	config.Transporter = transport
2127	driver := SnowflakeDriver{}
2128	db, err := driver.OpenWithConfig(context.Background(), *config)
2129	if err != nil {
2130		t.Fatalf("failed to open with config. config: %v, err: %v", config, err)
2131	}
2132	conn := db.(*snowflakeConn)
2133	if conn.rest.Client.Transport != transport {
2134		t.Fatal("transport doesn't match")
2135	}
2136	db.Close()
2137	if countingTransport.requests == 0 {
2138		t.Fatal("transport did not receive any requests")
2139	}
2140
2141	// Test that transport override also works in insecure mode
2142	countingTransport.requests = 0
2143	config.InsecureMode = true
2144	db, err = driver.OpenWithConfig(context.Background(), *config)
2145	if err != nil {
2146		t.Fatalf("failed to open with config. config: %v, err: %v", config, err)
2147	}
2148	conn = db.(*snowflakeConn)
2149	if conn.rest.Client.Transport != transport {
2150		t.Fatal("transport doesn't match")
2151	}
2152	db.Close()
2153	if countingTransport.requests == 0 {
2154		t.Fatal("transport did not receive any requests")
2155	}
2156}
2157
2158func createDSNWithClientSessionKeepAlive() {
2159	dsn = fmt.Sprintf("%s:%s@%s/%s/%s", user, pass, host, dbname, schemaname)
2160
2161	parameters := url.Values{}
2162	parameters.Add("client_session_keep_alive", "true")
2163	if protocol != "" {
2164		parameters.Add("protocol", protocol)
2165	}
2166	if account != "" {
2167		parameters.Add("account", account)
2168	}
2169	if warehouse != "" {
2170		parameters.Add("warehouse", warehouse)
2171	}
2172	if rolename != "" {
2173		parameters.Add("role", rolename)
2174	}
2175	if len(parameters) > 0 {
2176		dsn += "?" + parameters.Encode()
2177	}
2178}
2179
2180func TestClientSessionKeepAliveParameter(t *testing.T) {
2181	// This test doesn't really validate the CLIENT_SESSION_KEEP_ALIVE functionality but simply checks
2182	// the session parameter.
2183	createDSNWithClientSessionKeepAlive()
2184	runTests(t, dsn, func(dbt *DBTest) {
2185		rows := dbt.mustQuery("SHOW PARAMETERS LIKE 'CLIENT_SESSION_KEEP_ALIVE'")
2186		if !rows.Next() {
2187			t.Fatal("failed to get timezone.")
2188		}
2189
2190		p, err := ScanSnowflakeParameter(rows.rows)
2191		if err != nil {
2192			t.Errorf("failed to run get client_session_keep_alive value. err: %v", err)
2193		}
2194		if p.Value != "true" {
2195			t.Fatalf("failed to get an expected client_session_keep_alive. got: %v", p.Value)
2196		}
2197
2198		rows = dbt.mustQuery("select count(*) from table(generator(timelimit=>30))")
2199		defer rows.Close()
2200	})
2201}
2202
2203func TestTimePrecision(t *testing.T) {
2204	runTests(t, dsn, func(dbt *DBTest) {
2205		dbt.mustExec("create or replace table z3 (t1 time(5))")
2206		rows := dbt.mustQuery("select * from z3")
2207		cols, _ := rows.ColumnTypes()
2208		if pres, _, _ := cols[0].DecimalSize(); pres != 5 {
2209			t.Fatalf("Wrong value returned. Got %v instead of 5.", pres)
2210		}
2211	})
2212}
2213