1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "context" 7 "database/sql/driver" 8 "io" 9 "math/rand" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/google/uuid" 15) 16 17type contextKey string 18 19const ( 20 // multiStatementCount controls the number of queries to execute in a single API call 21 multiStatementCount contextKey = "MULTI_STATEMENT_COUNT" 22 // asyncMode tells the server to not block the request on executing the entire query 23 asyncMode contextKey = "ASYNC_MODE_QUERY" 24 // queryIDChannel is the channel to receive the query ID from 25 queryIDChannel contextKey = "QUERY_ID_CHANNEL" 26 // snowflakeRequestIDKey is optional context key to specify request id 27 snowflakeRequestIDKey contextKey = "SNOWFLAKE_REQUEST_ID" 28 // fetchResultByID the queryID of query result to fetch 29 fetchResultByID contextKey = "SF_FETCH_RESULT_BY_ID" 30 // fileStreamFile is the address of the file to be uploaded via PUT 31 fileStreamFile contextKey = "STREAMING_PUT_FILE" 32 // fileTransferOptions allows the user to pass in custom 33 fileTransferOptions contextKey = "FILE_TRANSFER_OPTIONS" 34 // enableHigherPrecision returns numbers with higher precision in a *big format 35 enableHigherPrecision contextKey = "ENABLE_HIGHER_PRECISION" 36) 37 38const ( 39 describeOnly contextKey = "DESCRIBE_ONLY" 40 cancelRetry contextKey = "CANCEL_RETRY" 41 streamChunkDownload contextKey = "STREAM_CHUNK_DOWNLOAD" 42) 43 44// WithMultiStatement returns a context that allows the user to execute the desired number of sql queries in one query 45func WithMultiStatement(ctx context.Context, num int) (context.Context, error) { 46 return context.WithValue(ctx, multiStatementCount, num), nil 47} 48 49// WithAsyncMode returns a context that allows execution of query in async mode 50func WithAsyncMode(ctx context.Context) context.Context { 51 return context.WithValue(ctx, asyncMode, true) 52} 53 54// WithQueryIDChan returns a context that contains the channel to receive the query ID 55func WithQueryIDChan(ctx context.Context, c chan<- string) context.Context { 56 return context.WithValue(ctx, queryIDChannel, c) 57} 58 59// WithRequestID returns a new context with the specified snowflake request id 60func WithRequestID(ctx context.Context, requestID uuid.UUID) context.Context { 61 return context.WithValue(ctx, snowflakeRequestIDKey, requestID) 62} 63 64// WithStreamDownloader returns a context that allows the use of a stream based chunk downloader 65func WithStreamDownloader(ctx context.Context) context.Context { 66 return context.WithValue(ctx, streamChunkDownload, true) 67} 68 69// WithFetchResultByID returns a context that allows retrieving the result by query ID 70func WithFetchResultByID(ctx context.Context, queryID string) context.Context { 71 return context.WithValue(ctx, fetchResultByID, queryID) 72} 73 74// WithFileStream returns a context that contains the address of the file stream to be PUT 75func WithFileStream(ctx context.Context, reader io.Reader) context.Context { 76 return context.WithValue(ctx, fileStreamFile, reader) 77} 78 79// WithFileTransferOptions returns a context that contains the address of file transfer options 80func WithFileTransferOptions(ctx context.Context, options *SnowflakeFileTransferOptions) context.Context { 81 return context.WithValue(ctx, fileTransferOptions, options) 82} 83 84// WithDescribeOnly returns a context that enables a describe only query 85func WithDescribeOnly(ctx context.Context) context.Context { 86 return context.WithValue(ctx, describeOnly, true) 87} 88 89// WithHigherPrecision returns a context that enables higher precision by 90// returning a *big.Int or *big.Float variable when querying rows for column 91// types with numbers that don't fit into its native Golang counterpart 92func WithHigherPrecision(ctx context.Context) context.Context { 93 return context.WithValue(ctx, enableHigherPrecision, true) 94} 95 96// Get the request ID from the context if specified, otherwise generate one 97func getOrGenerateRequestIDFromContext(ctx context.Context) uuid.UUID { 98 requestID, ok := ctx.Value(snowflakeRequestIDKey).(uuid.UUID) 99 if ok && requestID != uuid.Nil { 100 return requestID 101 } 102 return uuid.New() 103} 104 105// integer min 106func intMin(a, b int) int { 107 if a < b { 108 return a 109 } 110 return b 111} 112 113// integer max 114func intMax(a, b int) int { 115 if a > b { 116 return a 117 } 118 return b 119} 120 121func int64Max(a, b int64) int64 { 122 if a > b { 123 return a 124 } 125 return b 126} 127 128func getMin(arr []int) int { 129 if len(arr) == 0 { 130 return -1 131 } 132 min := arr[0] 133 for _, v := range arr { 134 if v <= min { 135 min = v 136 } 137 } 138 return min 139} 140 141// time.Duration max 142func durationMax(d1, d2 time.Duration) time.Duration { 143 if d1-d2 > 0 { 144 return d1 145 } 146 return d2 147} 148 149// time.Duration min 150func durationMin(d1, d2 time.Duration) time.Duration { 151 if d1-d2 < 0 { 152 return d1 153 } 154 return d2 155} 156 157// toNamedValues converts a slice of driver.Value to a slice of driver.NamedValue for Go 1.8 SQL package 158func toNamedValues(values []driver.Value) []driver.NamedValue { 159 namedValues := make([]driver.NamedValue, len(values)) 160 for idx, value := range values { 161 namedValues[idx] = driver.NamedValue{Name: "", Ordinal: idx + 1, Value: value} 162 } 163 return namedValues 164} 165 166// TokenAccessor manages the session token and master token 167type TokenAccessor interface { 168 GetTokens() (token string, masterToken string, sessionID int64) 169 SetTokens(token string, masterToken string, sessionID int64) 170 Lock() error 171 Unlock() 172} 173 174type simpleTokenAccessor struct { 175 token string 176 masterToken string 177 sessionID int64 178 accessorLock sync.Mutex // Used to implement accessor's Lock and Unlock 179 tokenLock sync.RWMutex // Used to synchronize SetTokens and GetTokens 180} 181 182func getSimpleTokenAccessor() TokenAccessor { 183 return &simpleTokenAccessor{sessionID: -1} 184} 185 186func (sta *simpleTokenAccessor) Lock() error { 187 sta.accessorLock.Lock() 188 return nil 189} 190 191func (sta *simpleTokenAccessor) Unlock() { 192 sta.accessorLock.Unlock() 193} 194 195func (sta *simpleTokenAccessor) GetTokens() (token string, masterToken string, sessionID int64) { 196 sta.tokenLock.RLock() 197 defer sta.tokenLock.RUnlock() 198 return sta.token, sta.masterToken, sta.sessionID 199} 200 201func (sta *simpleTokenAccessor) SetTokens(token string, masterToken string, sessionID int64) { 202 sta.tokenLock.Lock() 203 defer sta.tokenLock.Unlock() 204 sta.token = token 205 sta.masterToken = masterToken 206 sta.sessionID = sessionID 207} 208 209func escapeForCSV(value string) string { 210 if value == "" { 211 return "\"\"" 212 } 213 if strings.Contains(value, "\"") || strings.Contains(value, "\n") || 214 strings.Contains(value, ",") || strings.Contains(value, "\\") { 215 return "\"" + strings.ReplaceAll(value, "\"", "\"\"") + "\"" 216 } 217 return value 218} 219 220func randomString(n int) string { 221 rand.Seed(time.Now().UnixNano()) 222 alpha := []rune("abcdefghijklmnopqrstuvwxyz") 223 b := make([]rune, n) 224 for i := range b { 225 b[i] = alpha[rand.Intn(len(alpha))] 226 } 227 return string(b) 228} 229