1// Copyright (c) 2021 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "context" 7 "fmt" 8 "testing" 9) 10 11func TestAsyncMode(t *testing.T) { 12 ctx := WithAsyncMode(context.Background()) 13 numrows := 100000 14 cnt := 0 15 var idx int 16 var v string 17 18 runTests(t, dsn, func(dbt *DBTest) { 19 rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) 20 defer rows.Close() 21 22 // Next() will block and wait until results are available 23 for rows.Next() { 24 if err := rows.Scan(&idx, &v); err != nil { 25 t.Fatal(err) 26 } 27 cnt++ 28 } 29 logger.Infof("NextResultSet: %v", rows.NextResultSet()) 30 31 if cnt != numrows { 32 t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) 33 } 34 35 dbt.mustExec("create or replace table test_async_exec (value boolean)") 36 res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)") 37 count, err := res.RowsAffected() 38 if err != nil { 39 t.Fatalf("res.RowsAffected() returned error: %v", err) 40 } 41 if count != 1 { 42 t.Fatalf("expected 1 affected row, got %d", count) 43 } 44 }) 45} 46 47func TestAsyncQueryFail(t *testing.T) { 48 ctx := WithAsyncMode(context.Background()) 49 runTests(t, dsn, func(dbt *DBTest) { 50 rows := dbt.mustQueryContext(ctx, "selectt 1") 51 defer rows.Close() 52 53 if rows.Next() { 54 t.Fatal("should have no rows available") 55 } else { 56 if err := rows.Err(); err == nil { 57 t.Fatal("should return a syntax error") 58 } 59 } 60 }) 61} 62 63func TestMultipleAsyncQueries(t *testing.T) { 64 ctx := WithAsyncMode(context.Background()) 65 s1 := "foo" 66 s2 := "bar" 67 ch1 := make(chan string) 68 ch2 := make(chan string) 69 70 runTests(t, dsn, func(dbt *DBTest) { 71 rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30)) 72 defer rows1.Close() 73 rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10)) 74 defer rows2.Close() 75 76 go retrieveRows(rows1, ch1) 77 go retrieveRows(rows2, ch2) 78 select { 79 case res := <-ch1: 80 t.Fatalf("value %v should not have been called earlier.", res) 81 case res := <-ch2: 82 if res != s2 { 83 t.Fatalf("query failed. expected: %v, got: %v", s2, res) 84 } 85 } 86 }) 87} 88 89func retrieveRows(rows *RowsExtended, ch chan string) { 90 var s string 91 for rows.Next() { 92 if err := rows.Scan(&s); err != nil { 93 ch <- err.Error() 94 close(ch) 95 return 96 } 97 } 98 ch <- s 99 close(ch) 100} 101