1// Copyright (c) 2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"context"
7	"fmt"
8	"testing"
9)
10
11func TestAsyncMode(t *testing.T) {
12	ctx := WithAsyncMode(context.Background())
13	numrows := 100000
14	cnt := 0
15	var idx int
16	var v string
17
18	runTests(t, dsn, func(dbt *DBTest) {
19		rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows))
20		defer rows.Close()
21
22		// Next() will block and wait until results are available
23		for rows.Next() {
24			if err := rows.Scan(&idx, &v); err != nil {
25				t.Fatal(err)
26			}
27			cnt++
28		}
29		logger.Infof("NextResultSet: %v", rows.NextResultSet())
30
31		if cnt != numrows {
32			t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt)
33		}
34
35		dbt.mustExec("create or replace table test_async_exec (value boolean)")
36		res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)")
37		count, err := res.RowsAffected()
38		if err != nil {
39			t.Fatalf("res.RowsAffected() returned error: %v", err)
40		}
41		if count != 1 {
42			t.Fatalf("expected 1 affected row, got %d", count)
43		}
44	})
45}
46
47func TestAsyncQueryFail(t *testing.T) {
48	ctx := WithAsyncMode(context.Background())
49	runTests(t, dsn, func(dbt *DBTest) {
50		rows := dbt.mustQueryContext(ctx, "selectt 1")
51		defer rows.Close()
52
53		if rows.Next() {
54			t.Fatal("should have no rows available")
55		} else {
56			if err := rows.Err(); err == nil {
57				t.Fatal("should return a syntax error")
58			}
59		}
60	})
61}
62
63func TestMultipleAsyncQueries(t *testing.T) {
64	ctx := WithAsyncMode(context.Background())
65	s1 := "foo"
66	s2 := "bar"
67	ch1 := make(chan string)
68	ch2 := make(chan string)
69
70	runTests(t, dsn, func(dbt *DBTest) {
71		rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30))
72		defer rows1.Close()
73		rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10))
74		defer rows2.Close()
75
76		go retrieveRows(rows1, ch1)
77		go retrieveRows(rows2, ch2)
78		select {
79		case res := <-ch1:
80			t.Fatalf("value %v should not have been called earlier.", res)
81		case res := <-ch2:
82			if res != s2 {
83				t.Fatalf("query failed. expected: %v, got: %v", s2, res)
84			}
85		}
86	})
87}
88
89func retrieveRows(rows *RowsExtended, ch chan string) {
90	var s string
91	for rows.Next() {
92		if err := rows.Scan(&s); err != nil {
93			ch <- err.Error()
94			close(ch)
95			return
96		}
97	}
98	ch <- s
99	close(ch)
100}
101