1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"bytes"
7	"context"
8	"database/sql"
9	"database/sql/driver"
10	"encoding/json"
11	"fmt"
12	"io"
13	"net/http"
14	"net/url"
15	"regexp"
16	"strconv"
17	"strings"
18	"sync"
19	"sync/atomic"
20	"time"
21
22	"github.com/google/uuid"
23)
24
25const (
26	httpHeaderContentType      = "Content-Type"
27	httpHeaderAccept           = "accept"
28	httpHeaderUserAgent        = "User-Agent"
29	httpHeaderServiceName      = "X-Snowflake-Service"
30	httpHeaderContentLength    = "Content-Length"
31	httpHeaderHost             = "Host"
32	httpHeaderValueOctetStream = "application/octet-stream"
33	httpHeaderContentEncoding  = "Content-Encoding"
34)
35
36const (
37	statementTypeIDMulti = int64(0x1000)
38
39	statementTypeIDDml              = int64(0x3000)
40	statementTypeIDInsert           = statementTypeIDDml + int64(0x100)
41	statementTypeIDUpdate           = statementTypeIDDml + int64(0x200)
42	statementTypeIDDelete           = statementTypeIDDml + int64(0x300)
43	statementTypeIDMerge            = statementTypeIDDml + int64(0x400)
44	statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500)
45)
46
47const (
48	sessionClientSessionKeepAlive          = "client_session_keep_alive"
49	sessionClientValidateDefaultParameters = "CLIENT_VALIDATE_DEFAULT_PARAMETERS"
50	sessionArrayBindStageThreshold         = "client_stage_array_binding_threshold"
51	serviceName                            = "service_name"
52)
53
54type resultType string
55
56const (
57	snowflakeResultType contextKey = "snowflakeResultType"
58	execResultType      resultType = "exec"
59	queryResultType     resultType = "query"
60)
61
62type snowflakeConn struct {
63	ctx             context.Context
64	cfg             *Config
65	rest            *snowflakeRestful
66	SequenceCounter uint64
67	QueryID         string
68	SQLState        string
69	telemetry       *snowflakeTelemetry
70	internal        InternalClient
71}
72
73var queryIDPattern = `[\w\-_]+`
74var queryIDRegexp = regexp.MustCompile(queryIDPattern)
75
76const (
77	urlQueriesResultFmt string = "/queries/%s/result"
78)
79
80// isDml returns true if the statement type code is in the range of DML.
81func (sc *snowflakeConn) isDml(v int64) bool {
82	return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert
83}
84
85// isMultiStmt returns true if the statement type code is of type multistatement
86// Note that the statement type code is also equivalent to type INSERT, so an additional check of the name is required
87func (sc *snowflakeConn) isMultiStmt(data *execResponseData) bool {
88	return data.StatementTypeID == statementTypeIDMulti && data.RowType[0].Name == "multiple statement execution"
89}
90
91func (sc *snowflakeConn) exec(
92	ctx context.Context,
93	query string,
94	noResult bool,
95	isInternal bool,
96	describeOnly bool,
97	bindings []driver.NamedValue) (
98	*execResponse, error) {
99	var err error
100	counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter
101
102	req := execRequest{
103		SQLText:      query,
104		AsyncExec:    noResult,
105		Parameters:   map[string]interface{}{},
106		IsInternal:   isInternal,
107		DescribeOnly: describeOnly,
108		SequenceID:   counter,
109	}
110	if key := ctx.Value(multiStatementCount); key != nil {
111		req.Parameters[string(multiStatementCount)] = key
112	}
113	logger.WithContext(ctx).Infof("parameters: %v", req.Parameters)
114
115	requestID := getOrGenerateRequestIDFromContext(ctx)
116	if len(bindings) > 0 {
117		arrayBindThreshold := sc.getArrayBindStageThreshold()
118		numBinds := arrayBindValueCount(bindings)
119		if 0 < arrayBindThreshold && arrayBindThreshold <= numBinds && !describeOnly && isArrayBind(bindings) {
120			// bulk array insert binding
121			uploader := bindUploader{
122				sc:        sc,
123				ctx:       ctx,
124				stagePath: "@" + bindStageName + "/" + requestID.String(),
125			}
126			uploader.upload(bindings)
127			req.Bindings = nil
128			req.BindStage = uploader.stagePath
129		} else {
130			// variable or array binding
131			req.Bindings, err = getBindValues(bindings)
132			if err != nil {
133				return nil, err
134			}
135			req.BindStage = ""
136		}
137	}
138	logger.WithContext(ctx).Infof("bindings: %v", req.Bindings)
139
140	headers := getHeaders()
141	if isFileTransfer(query) {
142		headers[httpHeaderAccept] = headerContentTypeApplicationJSON
143	}
144	if serviceName, ok := sc.cfg.Params[serviceName]; ok {
145		headers[httpHeaderServiceName] = *serviceName
146	}
147
148	jsonBody, err := json.Marshal(req)
149	if err != nil {
150		return nil, err
151	}
152
153	var data *execResponse
154	data, err = sc.rest.FuncPostQuery(ctx, sc.rest, &url.Values{}, headers, jsonBody, sc.rest.RequestTimeout, requestID, sc.cfg)
155	if err != nil {
156		return data, err
157	}
158	var code int
159	if data.Code != "" {
160		code, err = strconv.Atoi(data.Code)
161		if err != nil {
162			code = -1
163			return data, err
164		}
165	} else {
166		code = -1
167	}
168	logger.WithContext(ctx).Infof("Success: %v, Code: %v", data.Success, code)
169	if !data.Success {
170		return nil, &SnowflakeError{
171			Number:   code,
172			SQLState: data.Data.SQLState,
173			Message:  data.Message,
174			QueryID:  data.Data.QueryID,
175		}
176	}
177	if isFileTransfer(query) {
178		sfa := snowflakeFileTransferAgent{
179			sc:      sc,
180			data:    &data.Data,
181			command: query,
182			options: new(SnowflakeFileTransferOptions),
183		}
184		if fs := getFileStream(ctx); fs != nil {
185			sfa.sourceStream = fs
186			if isInternal {
187				sfa.data.AutoCompress = false
188			}
189		}
190		if op := getFileTransferOptions(ctx); op != nil {
191			sfa.options = op
192		}
193		if sfa.options.multiPartThreshold == 0 {
194			sfa.options.multiPartThreshold = dataSizeThreshold
195		}
196		if err = sfa.execute(); err != nil {
197			return nil, err
198		}
199		data, err = sfa.result()
200		if err != nil {
201			return nil, err
202		}
203	}
204
205	logger.WithContext(ctx).Info("Exec/Query SUCCESS")
206	sc.cfg.Database = data.Data.FinalDatabaseName
207	sc.cfg.Schema = data.Data.FinalSchemaName
208	sc.cfg.Role = data.Data.FinalRoleName
209	sc.cfg.Warehouse = data.Data.FinalWarehouseName
210	sc.QueryID = data.Data.QueryID
211	sc.SQLState = data.Data.SQLState
212	sc.populateSessionParameters(data.Data.Parameters)
213	return data, err
214}
215
216func (sc *snowflakeConn) Begin() (driver.Tx, error) {
217	return sc.BeginTx(sc.ctx, driver.TxOptions{})
218}
219
220func (sc *snowflakeConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
221	logger.WithContext(ctx).Info("BeginTx")
222	if opts.ReadOnly {
223		return nil, &SnowflakeError{
224			Number:   ErrNoReadOnlyTransaction,
225			SQLState: SQLStateFeatureNotSupported,
226			Message:  errMsgNoReadOnlyTransaction,
227		}
228	}
229	if int(opts.Isolation) != int(sql.LevelDefault) {
230		return nil, &SnowflakeError{
231			Number:   ErrNoDefaultTransactionIsolationLevel,
232			SQLState: SQLStateFeatureNotSupported,
233			Message:  errMsgNoDefaultTransactionIsolationLevel,
234		}
235	}
236	if sc.rest == nil {
237		return nil, driver.ErrBadConn
238	}
239	isDesc := isDescribeOnly(ctx)
240	_, err := sc.exec(ctx, "BEGIN", false /* noResult */, false /* isInternal */, isDesc, nil)
241	if err != nil {
242		return nil, err
243	}
244	return &snowflakeTx{sc}, err
245}
246
247func (sc *snowflakeConn) cleanup() {
248	// must flush log buffer while the process is running.
249	sc.rest = nil
250	sc.cfg = nil
251}
252
253func (sc *snowflakeConn) Close() (err error) {
254	logger.WithContext(sc.ctx).Infoln("Close")
255	sc.telemetry.sendBatch()
256	sc.stopHeartBeat()
257
258	if !sc.cfg.KeepSessionAlive {
259		err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout)
260		if err != nil {
261			logger.Error(err)
262		}
263	}
264	sc.cleanup()
265	return nil
266}
267
268func (sc *snowflakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
269	logger.WithContext(sc.ctx).Infoln("Prepare")
270	if sc.rest == nil {
271		return nil, driver.ErrBadConn
272	}
273	noResult := isAsyncMode(ctx)
274	data, err := sc.exec(ctx, query, noResult, false /* isInternal */, true /* describeOnly */, []driver.NamedValue{})
275	if err != nil {
276		if data != nil {
277			code, err := strconv.Atoi(data.Code)
278			if err != nil {
279				return nil, err
280			}
281			return nil, &SnowflakeError{
282				Number:   code,
283				SQLState: data.Data.SQLState,
284				Message:  err.Error(),
285				QueryID:  data.Data.QueryID,
286			}
287		}
288		return nil, err
289	}
290	stmt := &snowflakeStmt{
291		sc:    sc,
292		query: query,
293	}
294	return stmt, nil
295}
296
297func (sc *snowflakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
298	logger.WithContext(ctx).Infof("Exec: %#v, %v", query, args)
299	if sc.rest == nil {
300		return nil, driver.ErrBadConn
301	}
302	noResult := isAsyncMode(ctx)
303	isDesc := isDescribeOnly(ctx)
304	// TODO handle isInternal
305	ctx = setResultType(ctx, execResultType)
306	data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args)
307	if err != nil {
308		logger.WithContext(ctx).Infof("error: %v", err)
309		if data != nil {
310			code, err := strconv.Atoi(data.Code)
311			if err != nil {
312				return nil, err
313			}
314			return nil, &SnowflakeError{
315				Number:   code,
316				SQLState: data.Data.SQLState,
317				Message:  err.Error(),
318				QueryID:  data.Data.QueryID}
319		}
320		return nil, err
321	}
322
323	// if async exec, return result object right away
324	if noResult {
325		return data.Data.AsyncResult, nil
326	}
327
328	if sc.isDml(data.Data.StatementTypeID) {
329		// collects all values from the returned row sets
330		updatedRows, err := updateRows(data.Data)
331		if err != nil {
332			return nil, err
333		}
334		logger.WithContext(ctx).Debugf("number of updated rows: %#v", updatedRows)
335		return &snowflakeResult{
336			affectedRows: updatedRows,
337			insertID:     -1,
338			queryID:      sc.QueryID,
339		}, nil // last insert id is not supported by Snowflake
340	} else if sc.isMultiStmt(&data.Data) {
341		return sc.handleMultiExec(ctx, data.Data)
342	}
343	logger.Debug("DDL")
344	return driver.ResultNoRows, nil
345}
346
347func (sc *snowflakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
348	qid, err := getResumeQueryID(ctx)
349	if err != nil {
350		return nil, err
351	}
352	if qid == "" {
353		return sc.queryContextInternal(ctx, query, args)
354	}
355
356	// first we will check the status of this particular query to find out if there is result to fetch
357	err = sc.checkQueryStatus(ctx, qid)
358	if err == nil || (err != nil && err.(*SnowflakeError).Number == ErrQueryIsRunning) {
359		// the query is running. Rows object will be returned from here.
360		return sc.buildRowsForRunningQuery(ctx, qid)
361	}
362
363	return nil, err
364}
365
366func (sc *snowflakeConn) queryContextInternal(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
367	logger.WithContext(ctx).Infof("Query: %#v, %v", query, args)
368	if sc.rest == nil {
369		return nil, driver.ErrBadConn
370	}
371
372	noResult := isAsyncMode(ctx)
373	isDesc := isDescribeOnly(ctx)
374	ctx = setResultType(ctx, queryResultType)
375	// TODO: handle isInternal
376	data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args)
377	if err != nil {
378		logger.WithContext(ctx).Errorf("error: %v", err)
379		if data != nil {
380			code, err := strconv.Atoi(data.Code)
381			if err != nil {
382				return nil, err
383			}
384			return nil, &SnowflakeError{
385				Number:   code,
386				SQLState: data.Data.SQLState,
387				Message:  err.Error(),
388				QueryID:  data.Data.QueryID}
389		}
390		return nil, err
391	}
392
393	// if async query, return row object right away
394	if noResult {
395		return data.Data.AsyncRows, nil
396	}
397
398	rows := new(snowflakeRows)
399	rows.sc = sc
400	rows.queryID = sc.QueryID
401
402	if sc.isMultiStmt(&data.Data) {
403		// handleMultiQuery is responsible to fill rows with childResults
404		if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil {
405			return nil, err
406		}
407	} else {
408		rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data))
409	}
410
411	rows.ChunkDownloader.start()
412	return rows, err
413}
414
415func (sc *snowflakeConn) Prepare(query string) (driver.Stmt, error) {
416	return sc.PrepareContext(sc.ctx, query)
417}
418
419func (sc *snowflakeConn) Exec(
420	query string,
421	args []driver.Value) (
422	driver.Result, error) {
423	return sc.ExecContext(sc.ctx, query, toNamedValues(args))
424}
425
426func (sc *snowflakeConn) Query(
427	query string,
428	args []driver.Value) (
429	driver.Rows, error) {
430	return sc.QueryContext(sc.ctx, query, toNamedValues(args))
431}
432
433func (sc *snowflakeConn) Ping(ctx context.Context) error {
434	logger.WithContext(ctx).Infoln("Ping")
435	if sc.rest == nil {
436		return driver.ErrBadConn
437	}
438	noResult := isAsyncMode(ctx)
439	isDesc := isDescribeOnly(ctx)
440	// TODO: handle isInternal
441	_, err := sc.exec(ctx, "SELECT 1", noResult, false /* isInternal */, isDesc, []driver.NamedValue{})
442	return err
443}
444
445// CheckNamedValue determines which types are handled by this driver aside from
446// the instances captured by driver.Value
447func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error {
448	if supported := supportedArrayBind(nv); !supported {
449		return driver.ErrSkip
450	}
451	return nil
452}
453
454func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParameter) {
455	// other session parameters (not all)
456	logger.WithContext(sc.ctx).Infof("params: %#v", parameters)
457	for _, param := range parameters {
458		v := ""
459		switch param.Value.(type) {
460		case int64:
461			if vv, ok := param.Value.(int64); ok {
462				v = strconv.FormatInt(vv, 10)
463			}
464		case float64:
465			if vv, ok := param.Value.(float64); ok {
466				v = strconv.FormatFloat(vv, 'g', -1, 64)
467			}
468		case bool:
469			if vv, ok := param.Value.(bool); ok {
470				v = strconv.FormatBool(vv)
471			}
472		default:
473			if vv, ok := param.Value.(string); ok {
474				v = vv
475			}
476		}
477		logger.Debugf("parameter. name: %v, value: %v", param.Name, v)
478		sc.cfg.Params[strings.ToLower(param.Name)] = &v
479	}
480}
481
482func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool {
483	v, ok := sc.cfg.Params[sessionClientSessionKeepAlive]
484	if !ok {
485		return false
486	}
487	return strings.Compare(*v, "true") == 0
488}
489
490func (sc *snowflakeConn) getArrayBindStageThreshold() int {
491	v, ok := sc.cfg.Params[sessionArrayBindStageThreshold]
492	if !ok {
493		return 0
494	}
495	num, err := strconv.Atoi(*v)
496	if err != nil {
497		return 0
498	}
499	return num
500}
501
502func (sc *snowflakeConn) startHeartBeat() {
503	if !sc.isClientSessionKeepAliveEnabled() {
504		return
505	}
506	sc.rest.HeartBeat = &heartbeat{
507		restful: sc.rest,
508	}
509	sc.rest.HeartBeat.start()
510}
511
512func (sc *snowflakeConn) stopHeartBeat() {
513	if !sc.isClientSessionKeepAliveEnabled() {
514		return
515	}
516	sc.rest.HeartBeat.stop()
517}
518
519func (sc *snowflakeConn) handleMultiExec(ctx context.Context, data execResponseData) (driver.Result, error) {
520	var updatedRows int64
521	childResults := getChildResults(data.ResultIDs, data.ResultTypes)
522	for _, child := range childResults {
523		resultPath := fmt.Sprintf(urlQueriesResultFmt, child.id)
524		childData, err := sc.getQueryResultResp(ctx, resultPath)
525		if err != nil {
526			logger.Errorf("error: %v", err)
527			code, err := strconv.Atoi(childData.Code)
528			if err != nil {
529				return nil, err
530			}
531			if childData != nil {
532				return nil, &SnowflakeError{
533					Number:   code,
534					SQLState: childData.Data.SQLState,
535					Message:  err.Error(),
536					QueryID:  childData.Data.QueryID}
537			}
538			return nil, err
539		}
540		if sc.isDml(childData.Data.StatementTypeID) {
541			count, err := updateRows(childData.Data)
542			if err != nil {
543				logger.WithContext(ctx).Errorf("error: %v", err)
544				if childData != nil {
545					code, err := strconv.Atoi(childData.Code)
546					if err != nil {
547						return nil, err
548					}
549					return nil, &SnowflakeError{
550						Number:   code,
551						SQLState: childData.Data.SQLState,
552						Message:  err.Error(),
553						QueryID:  childData.Data.QueryID}
554				}
555				return nil, err
556			}
557			updatedRows += count
558		}
559	}
560	logger.WithContext(ctx).Infof("number of updated rows: %#v", updatedRows)
561	return &snowflakeResult{
562		affectedRows: updatedRows,
563		insertID:     -1,
564		queryID:      sc.QueryID,
565	}, nil
566}
567
568// Fill the correspondent rows and add chunk downloader into the rows when iterate the childResults
569func (sc *snowflakeConn) handleMultiQuery(ctx context.Context, data execResponseData, rows *snowflakeRows) error {
570	childResults := getChildResults(data.ResultIDs, data.ResultTypes)
571
572	for _, child := range childResults {
573		if err := sc.rowsForRunningQuery(ctx, child.id, rows); err != nil {
574			return err
575		}
576	}
577	return nil
578}
579
580func setResultType(ctx context.Context, resType resultType) context.Context {
581	return context.WithValue(ctx, snowflakeResultType, resType)
582}
583
584func getResultType(ctx context.Context) resultType {
585	return ctx.Value(snowflakeResultType).(resultType)
586}
587
588func updateRows(data execResponseData) (int64, error) {
589	var count int64
590	for i, n := 0, len(data.RowType); i < n; i++ {
591		v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64)
592		if err != nil {
593			return -1, err
594		}
595		count += v
596	}
597	return count, nil
598}
599
600type childResult struct {
601	id  string
602	typ string
603}
604
605func getChildResults(IDs string, types string) []childResult {
606	if IDs == "" {
607		return nil
608	}
609	queryIDs := strings.Split(IDs, ",")
610	resultTypes := strings.Split(types, ",")
611	res := make([]childResult, len(queryIDs))
612	for i, id := range queryIDs {
613		res[i] = childResult{id, resultTypes[i]}
614	}
615	return res
616}
617
618func (sc *snowflakeConn) getQueryResultResp(ctx context.Context, resultPath string) (*execResponse, error) {
619	headers := getHeaders()
620	if serviceName, ok := sc.cfg.Params[serviceName]; ok {
621		headers[httpHeaderServiceName] = *serviceName
622	}
623	param := make(url.Values)
624	param.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
625	param.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10))
626	param.Add(requestGUIDKey, uuid.New().String())
627	token, _, _ := sc.rest.TokenAccessor.GetTokens()
628	if token != "" {
629		headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
630	}
631	url := sc.rest.getFullURL(resultPath, &param)
632	res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout)
633	if err != nil {
634		logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
635		return nil, err
636	}
637	var respd *execResponse
638	err = json.NewDecoder(res.Body).Decode(&respd)
639	if err != nil {
640		logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
641		return nil, err
642	}
643	return respd, nil
644}
645
646// checkQueryStatus returns the status given the query ID. If successful,
647// the error will be nil, indicating there is a complete query result to fetch.
648// Other than nil, there are three error types that can be returned:
649// 1. ErrQueryStatus, if GS cannot return any kind of status due to any reason,
650// i.e. connection, permission, if a query was just submitted, etc.
651// 2, ErrQueryReportedError, if the requested query was terminated or aborted
652// and GS returned an error status included in query. SFQueryFailedWithError
653// 3, ErrQueryIsRunning, if the requested query is still running and might have
654// a complete result later, these statuses were listed in query. SFQueryRunning
655func (sc *snowflakeConn) checkQueryStatus(ctx context.Context, qid string) error {
656	headers := make(map[string]string)
657	param := make(url.Values)
658	param.Add(requestGUIDKey, uuid.New().String())
659	if tok, _, _ := sc.rest.TokenAccessor.GetTokens(); tok != "" {
660		headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, tok)
661	}
662	resultPath := fmt.Sprintf("/monitoring/queries/%s", qid)
663	url := sc.rest.getFullURL(resultPath, &param)
664
665	res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout)
666	if err != nil {
667		logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
668		return err
669	}
670	var statusResp = statusResponse{}
671
672	err = json.NewDecoder(res.Body).Decode(&statusResp)
673	if err != nil {
674		logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
675		return err
676	}
677
678	if !statusResp.Success || len(statusResp.Data.Queries) == 0 {
679		logger.WithContext(ctx).Errorf("status query returned not-success or no status returned.")
680		return &SnowflakeError{
681			Number:  ErrQueryStatus,
682			Message: "status query returned not-success or no status returned. Please retry"}
683	}
684
685	var queryRet = statusResp.Data.Queries[0]
686	if queryRet.ErrorCode != 0 {
687		return &SnowflakeError{
688			Number: ErrQueryStatus,
689			Message: fmt.Sprintf("server ErrorCode=%d, ErrorMessage=%s",
690				queryRet.ErrorCode, queryRet.ErrorMessage),
691			IncludeQueryID: true,
692			QueryID:        qid,
693		}
694	}
695
696	// returned errorCode is 0. Now check what is the returned status of the query.
697	var qstatus = strToSFQueryStatus(queryRet.Status)
698	if sfqStatusIsAnError(qstatus) {
699		return &SnowflakeError{
700			Number: ErrQueryReportedError,
701			Message: fmt.Sprintf("%s: status from server: [%s]",
702				queryRet.ErrorMessage, queryRet.Status),
703			IncludeQueryID: true,
704			QueryID:        qid,
705		}
706	}
707
708	if sfqStatusIsStillRunning(qstatus) {
709		return &SnowflakeError{
710			Number: ErrQueryIsRunning,
711			Message: fmt.Sprintf("%s: status from server: [%s]",
712				queryRet.ErrorMessage, queryRet.Status),
713			IncludeQueryID: true,
714			QueryID:        qid,
715		}
716	}
717
718	//success
719	return nil
720}
721
722// Fetch query result for a query id from /queries/<qid>/result endpoint.
723func (sc *snowflakeConn) rowsForRunningQuery(ctx context.Context, qid string, rows *snowflakeRows) error {
724	resultPath := fmt.Sprintf(urlQueriesResultFmt, qid)
725	resp, err := sc.getQueryResultResp(ctx, resultPath)
726	if err != nil {
727		logger.WithContext(ctx).Errorf("error: %v", err)
728		if resp != nil {
729			code, err := strconv.Atoi(resp.Code)
730			if err != nil {
731				return err
732			}
733			return &SnowflakeError{
734				Number:   code,
735				SQLState: resp.Data.SQLState,
736				Message:  err.Error(),
737				QueryID:  resp.Data.QueryID}
738		}
739		return err
740	}
741	rows.addDownloader(populateChunkDownloader(ctx, sc, resp.Data))
742	return nil
743}
744
745// prepare a Rows object to return for query of 'qid'
746func (sc *snowflakeConn) buildRowsForRunningQuery(ctx context.Context, qid string) (driver.Rows, error) {
747	rows := new(snowflakeRows)
748	rows.sc = sc
749	rows.queryID = qid
750	err := sc.rowsForRunningQuery(ctx, qid, rows)
751	if err != nil {
752		return nil, err
753	}
754	rows.ChunkDownloader.start()
755	return rows, err
756}
757
758func isAsyncMode(ctx context.Context) bool {
759	val := ctx.Value(asyncMode)
760	if val == nil {
761		return false
762	}
763	a, ok := val.(bool)
764	return ok && a
765}
766
767func getResumeQueryID(ctx context.Context) (string, error) {
768	val := ctx.Value(fetchResultByID)
769	if val == nil {
770		return "", nil
771	}
772	strVal, ok := val.(string)
773	if !ok {
774		return "", fmt.Errorf("failed to cast val %+v to string", val)
775	}
776	// so there is a queryID in context for which we want to fetch the result
777	if !queryIDRegexp.MatchString(strVal) {
778		return strVal, &SnowflakeError{
779			Number:  ErrQueryIDFormat,
780			Message: "Invalid QID",
781			QueryID: strVal}
782	}
783	return strVal, nil
784}
785
786func getAsync(
787	ctx context.Context,
788	sr *snowflakeRestful,
789	headers map[string]string,
790	URL *url.URL,
791	timeout time.Duration,
792	res *snowflakeResult,
793	rows *snowflakeRows,
794	cfg *Config,
795) {
796	resType := getResultType(ctx)
797	var errChannel chan error
798	sfError := &SnowflakeError{
799		Number: -1,
800	}
801	if resType == execResultType {
802		errChannel = res.errChannel
803		sfError.QueryID = res.queryID
804	} else {
805		errChannel = rows.errChannel
806		sfError.QueryID = rows.queryID
807	}
808	defer close(errChannel)
809	token, _, _ := sr.TokenAccessor.GetTokens()
810	headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
811	resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout)
812	if err != nil {
813		logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
814		sfError.Message = err.Error()
815		errChannel <- sfError
816		close(errChannel)
817		return
818	}
819	if resp.Body != nil {
820		defer resp.Body.Close()
821	}
822
823	respd := execResponse{}
824	err = json.NewDecoder(resp.Body).Decode(&respd)
825	resp.Body.Close()
826	if err != nil {
827		logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
828		sfError.Message = err.Error()
829		errChannel <- sfError
830		close(errChannel)
831		return
832	}
833
834	sc := &snowflakeConn{rest: sr, cfg: cfg}
835	if respd.Success {
836		if resType == execResultType {
837			res.insertID = -1
838			if sc.isDml(respd.Data.StatementTypeID) {
839				res.affectedRows, _ = updateRows(respd.Data)
840			} else if sc.isMultiStmt(&respd.Data) {
841				r, err := sc.handleMultiExec(ctx, respd.Data)
842				if err != nil {
843					res.errChannel <- err
844					close(errChannel)
845					return
846				}
847				res.affectedRows, err = r.RowsAffected()
848				if err != nil {
849					res.errChannel <- err
850					close(errChannel)
851					return
852				}
853			}
854			res.queryID = respd.Data.QueryID
855			res.errChannel <- nil // mark exec status complete
856		} else {
857			rows.sc = sc
858			rows.queryID = respd.Data.QueryID
859			if sc.isMultiStmt(&respd.Data) {
860				err = sc.handleMultiQuery(ctx, respd.Data, rows)
861				if err != nil {
862					rows.errChannel <- err
863					close(errChannel)
864					return
865				}
866			} else {
867				rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data))
868			}
869			rows.ChunkDownloader.start()
870			rows.errChannel <- nil // mark query status complete
871		}
872	} else {
873		var code int
874		if respd.Code != "" {
875			code, err = strconv.Atoi(respd.Code)
876			if err != nil {
877				code = -1
878			}
879		} else {
880			code = -1
881		}
882		errChannel <- &SnowflakeError{
883			Number:   code,
884			SQLState: respd.Data.SQLState,
885			Message:  respd.Message,
886			QueryID:  respd.Data.QueryID,
887		}
888	}
889}
890
891func getQueryIDChan(ctx context.Context) chan<- string {
892	v := ctx.Value(queryIDChannel)
893	if v == nil {
894		return nil
895	}
896	c, _ := v.(chan<- string)
897	return c
898}
899
900func getFileStream(ctx context.Context) *bytes.Buffer {
901	s := ctx.Value(fileStreamFile)
902	r, ok := s.(io.Reader)
903	if !ok {
904		return nil
905	}
906	buf := new(bytes.Buffer)
907	buf.ReadFrom(r)
908	return buf
909}
910
911func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions {
912	v := ctx.Value(fileTransferOptions)
913	if v == nil {
914		return nil
915	}
916	o, ok := v.(*SnowflakeFileTransferOptions)
917	if !ok {
918		return nil
919	}
920	return o
921}
922
923func isDescribeOnly(ctx context.Context) bool {
924	v := ctx.Value(describeOnly)
925	if v == nil {
926		return false
927	}
928	d, ok := v.(bool)
929	return ok && d
930}
931
932// returns snowflake chunk downloader by default or stream based chunk
933// downloader if option provided through context
934func populateChunkDownloader(ctx context.Context, sc *snowflakeConn, data execResponseData) chunkDownloader {
935	if useStreamDownloader(ctx) && data.QueryResultFormat == "json" {
936		// stream chunk downloading only works for row based data formats, i.e. json
937		fetcher := &httpStreamChunkFetcher{
938			ctx:      ctx,
939			client:   sc.rest.Client,
940			clientIP: sc.cfg.ClientIP,
941			headers:  data.ChunkHeaders,
942			qrmk:     data.Qrmk,
943		}
944		return newStreamChunkDownloader(ctx, fetcher, data.Total, data.RowType, data.RowSet, data.Chunks)
945	}
946
947	return &snowflakeChunkDownloader{
948		sc:                 sc,
949		ctx:                ctx,
950		CurrentChunk:       make([]chunkRowType, len(data.RowSet)),
951		ChunkMetas:         data.Chunks,
952		Total:              data.Total,
953		TotalRowIndex:      int64(-1),
954		CellCount:          len(data.RowType),
955		Qrmk:               data.Qrmk,
956		QueryResultFormat:  data.QueryResultFormat,
957		ChunkHeader:        data.ChunkHeaders,
958		FuncDownload:       downloadChunk,
959		FuncDownloadHelper: downloadChunkHelper,
960		FuncGet:            getChunk,
961		RowSet: rowSetType{
962			RowType:      data.RowType,
963			JSON:         data.RowSet,
964			RowSetBase64: data.RowSetBase64,
965		},
966	}
967}
968
969func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) {
970	sc := &snowflakeConn{
971		SequenceCounter: 0,
972		ctx:             ctx,
973		cfg:             &config,
974	}
975	var st http.RoundTripper = SnowflakeTransport
976	if sc.cfg.Transporter == nil {
977		if sc.cfg.InsecureMode {
978			// no revocation check with OCSP. Think twice when you want to enable this option.
979			st = snowflakeInsecureTransport
980		} else {
981			// set OCSP fail open mode
982			ocspResponseCacheLock.Lock()
983			atomic.StoreUint32((*uint32)(&ocspFailOpen), uint32(sc.cfg.OCSPFailOpen))
984			ocspResponseCacheLock.Unlock()
985		}
986	} else {
987		// use the custom transport
988		st = sc.cfg.Transporter
989	}
990	var tokenAccessor TokenAccessor
991	if sc.cfg.TokenAccessor != nil {
992		tokenAccessor = sc.cfg.TokenAccessor
993	} else {
994		tokenAccessor = getSimpleTokenAccessor()
995	}
996	if sc.cfg.DisableTelemetry {
997		sc.telemetry.enabled = false
998	}
999
1000	// authenticate
1001	sc.rest = &snowflakeRestful{
1002		Host:     sc.cfg.Host,
1003		Port:     sc.cfg.Port,
1004		Protocol: sc.cfg.Protocol,
1005		Client: &http.Client{
1006			// request timeout including reading response body
1007			Timeout:   sc.cfg.ClientTimeout,
1008			Transport: st,
1009		},
1010		TokenAccessor:       tokenAccessor,
1011		LoginTimeout:        sc.cfg.LoginTimeout,
1012		RequestTimeout:      sc.cfg.RequestTimeout,
1013		FuncPost:            postRestful,
1014		FuncGet:             getRestful,
1015		FuncPostQuery:       postRestfulQuery,
1016		FuncPostQueryHelper: postRestfulQueryHelper,
1017		FuncRenewSession:    renewRestfulSession,
1018		FuncPostAuth:        postAuth,
1019		FuncCloseSession:    closeSession,
1020		FuncCancelQuery:     cancelQuery,
1021		FuncPostAuthSAML:    postAuthSAML,
1022		FuncPostAuthOKTA:    postAuthOKTA,
1023		FuncGetSSO:          getSSO,
1024	}
1025	sc.telemetry = &snowflakeTelemetry{
1026		flushSize: defaultFlushSize,
1027		sr:        sc.rest,
1028		mutex:     &sync.Mutex{},
1029		enabled:   true,
1030	}
1031	return sc, nil
1032}
1033