1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"context"
7	"database/sql/driver"
8	"io"
9	"math/rand"
10	"strings"
11	"sync"
12	"time"
13
14	"github.com/google/uuid"
15)
16
17type contextKey string
18
19const (
20	// multiStatementCount controls the number of queries to execute in a single API call
21	multiStatementCount contextKey = "MULTI_STATEMENT_COUNT"
22	// asyncMode tells the server to not block the request on executing the entire query
23	asyncMode contextKey = "ASYNC_MODE_QUERY"
24	// queryIDChannel is the channel to receive the query ID from
25	queryIDChannel contextKey = "QUERY_ID_CHANNEL"
26	// snowflakeRequestIDKey is optional context key to specify request id
27	snowflakeRequestIDKey contextKey = "SNOWFLAKE_REQUEST_ID"
28	// fetchResultByID the queryID of query result to fetch
29	fetchResultByID contextKey = "SF_FETCH_RESULT_BY_ID"
30	// fileStreamFile is the address of the file to be uploaded via PUT
31	fileStreamFile contextKey = "STREAMING_PUT_FILE"
32	// fileTransferOptions allows the user to pass in custom
33	fileTransferOptions contextKey = "FILE_TRANSFER_OPTIONS"
34	// enableHigherPrecision returns numbers with higher precision in a *big format
35	enableHigherPrecision contextKey = "ENABLE_HIGHER_PRECISION"
36)
37
38const (
39	describeOnly        contextKey = "DESCRIBE_ONLY"
40	cancelRetry         contextKey = "CANCEL_RETRY"
41	streamChunkDownload contextKey = "STREAM_CHUNK_DOWNLOAD"
42)
43
44// WithMultiStatement returns a context that allows the user to execute the desired number of sql queries in one query
45func WithMultiStatement(ctx context.Context, num int) (context.Context, error) {
46	return context.WithValue(ctx, multiStatementCount, num), nil
47}
48
49// WithAsyncMode returns a context that allows execution of query in async mode
50func WithAsyncMode(ctx context.Context) context.Context {
51	return context.WithValue(ctx, asyncMode, true)
52}
53
54// WithQueryIDChan returns a context that contains the channel to receive the query ID
55func WithQueryIDChan(ctx context.Context, c chan<- string) context.Context {
56	return context.WithValue(ctx, queryIDChannel, c)
57}
58
59// WithRequestID returns a new context with the specified snowflake request id
60func WithRequestID(ctx context.Context, requestID uuid.UUID) context.Context {
61	return context.WithValue(ctx, snowflakeRequestIDKey, requestID)
62}
63
64// WithStreamDownloader returns a context that allows the use of a stream based chunk downloader
65func WithStreamDownloader(ctx context.Context) context.Context {
66	return context.WithValue(ctx, streamChunkDownload, true)
67}
68
69// WithFetchResultByID returns a context that allows retrieving the result by query ID
70func WithFetchResultByID(ctx context.Context, queryID string) context.Context {
71	return context.WithValue(ctx, fetchResultByID, queryID)
72}
73
74// WithFileStream returns a context that contains the address of the file stream to be PUT
75func WithFileStream(ctx context.Context, reader io.Reader) context.Context {
76	return context.WithValue(ctx, fileStreamFile, reader)
77}
78
79// WithFileTransferOptions returns a context that contains the address of file transfer options
80func WithFileTransferOptions(ctx context.Context, options *SnowflakeFileTransferOptions) context.Context {
81	return context.WithValue(ctx, fileTransferOptions, options)
82}
83
84// WithDescribeOnly returns a context that enables a describe only query
85func WithDescribeOnly(ctx context.Context) context.Context {
86	return context.WithValue(ctx, describeOnly, true)
87}
88
89// WithHigherPrecision returns a context that enables higher precision by
90// returning a *big.Int or *big.Float variable when querying rows for column
91// types with numbers that don't fit into its native Golang counterpart
92func WithHigherPrecision(ctx context.Context) context.Context {
93	return context.WithValue(ctx, enableHigherPrecision, true)
94}
95
96// Get the request ID from the context if specified, otherwise generate one
97func getOrGenerateRequestIDFromContext(ctx context.Context) uuid.UUID {
98	requestID, ok := ctx.Value(snowflakeRequestIDKey).(uuid.UUID)
99	if ok && requestID != uuid.Nil {
100		return requestID
101	}
102	return uuid.New()
103}
104
105// integer min
106func intMin(a, b int) int {
107	if a < b {
108		return a
109	}
110	return b
111}
112
113// integer max
114func intMax(a, b int) int {
115	if a > b {
116		return a
117	}
118	return b
119}
120
121func int64Max(a, b int64) int64 {
122	if a > b {
123		return a
124	}
125	return b
126}
127
128func getMin(arr []int) int {
129	if len(arr) == 0 {
130		return -1
131	}
132	min := arr[0]
133	for _, v := range arr {
134		if v <= min {
135			min = v
136		}
137	}
138	return min
139}
140
141// time.Duration max
142func durationMax(d1, d2 time.Duration) time.Duration {
143	if d1-d2 > 0 {
144		return d1
145	}
146	return d2
147}
148
149// time.Duration min
150func durationMin(d1, d2 time.Duration) time.Duration {
151	if d1-d2 < 0 {
152		return d1
153	}
154	return d2
155}
156
157// toNamedValues converts a slice of driver.Value to a slice of driver.NamedValue for Go 1.8 SQL package
158func toNamedValues(values []driver.Value) []driver.NamedValue {
159	namedValues := make([]driver.NamedValue, len(values))
160	for idx, value := range values {
161		namedValues[idx] = driver.NamedValue{Name: "", Ordinal: idx + 1, Value: value}
162	}
163	return namedValues
164}
165
166// TokenAccessor manages the session token and master token
167type TokenAccessor interface {
168	GetTokens() (token string, masterToken string, sessionID int64)
169	SetTokens(token string, masterToken string, sessionID int64)
170	Lock() error
171	Unlock()
172}
173
174type simpleTokenAccessor struct {
175	token        string
176	masterToken  string
177	sessionID    int64
178	accessorLock sync.Mutex   // Used to implement accessor's Lock and Unlock
179	tokenLock    sync.RWMutex // Used to synchronize SetTokens and GetTokens
180}
181
182func getSimpleTokenAccessor() TokenAccessor {
183	return &simpleTokenAccessor{sessionID: -1}
184}
185
186func (sta *simpleTokenAccessor) Lock() error {
187	sta.accessorLock.Lock()
188	return nil
189}
190
191func (sta *simpleTokenAccessor) Unlock() {
192	sta.accessorLock.Unlock()
193}
194
195func (sta *simpleTokenAccessor) GetTokens() (token string, masterToken string, sessionID int64) {
196	sta.tokenLock.RLock()
197	defer sta.tokenLock.RUnlock()
198	return sta.token, sta.masterToken, sta.sessionID
199}
200
201func (sta *simpleTokenAccessor) SetTokens(token string, masterToken string, sessionID int64) {
202	sta.tokenLock.Lock()
203	defer sta.tokenLock.Unlock()
204	sta.token = token
205	sta.masterToken = masterToken
206	sta.sessionID = sessionID
207}
208
209func escapeForCSV(value string) string {
210	if value == "" {
211		return "\"\""
212	}
213	if strings.Contains(value, "\"") || strings.Contains(value, "\n") ||
214		strings.Contains(value, ",") || strings.Contains(value, "\\") {
215		return "\"" + strings.ReplaceAll(value, "\"", "\"\"") + "\""
216	}
217	return value
218}
219
220func randomString(n int) string {
221	rand.Seed(time.Now().UnixNano())
222	alpha := []rune("abcdefghijklmnopqrstuvwxyz")
223	b := make([]rune, n)
224	for i := range b {
225		b[i] = alpha[rand.Intn(len(alpha))]
226	}
227	return string(b)
228}
229