1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "bytes" 7 "context" 8 "database/sql" 9 "database/sql/driver" 10 "encoding/json" 11 "fmt" 12 "io" 13 "net/http" 14 "net/url" 15 "regexp" 16 "strconv" 17 "strings" 18 "sync" 19 "sync/atomic" 20 "time" 21 22 "github.com/google/uuid" 23) 24 25const ( 26 httpHeaderContentType = "Content-Type" 27 httpHeaderAccept = "accept" 28 httpHeaderUserAgent = "User-Agent" 29 httpHeaderServiceName = "X-Snowflake-Service" 30 httpHeaderContentLength = "Content-Length" 31 httpHeaderHost = "Host" 32 httpHeaderValueOctetStream = "application/octet-stream" 33 httpHeaderContentEncoding = "Content-Encoding" 34) 35 36const ( 37 statementTypeIDMulti = int64(0x1000) 38 39 statementTypeIDDml = int64(0x3000) 40 statementTypeIDInsert = statementTypeIDDml + int64(0x100) 41 statementTypeIDUpdate = statementTypeIDDml + int64(0x200) 42 statementTypeIDDelete = statementTypeIDDml + int64(0x300) 43 statementTypeIDMerge = statementTypeIDDml + int64(0x400) 44 statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500) 45) 46 47const ( 48 sessionClientSessionKeepAlive = "client_session_keep_alive" 49 sessionClientValidateDefaultParameters = "CLIENT_VALIDATE_DEFAULT_PARAMETERS" 50 sessionArrayBindStageThreshold = "client_stage_array_binding_threshold" 51 serviceName = "service_name" 52) 53 54type resultType string 55 56const ( 57 snowflakeResultType contextKey = "snowflakeResultType" 58 execResultType resultType = "exec" 59 queryResultType resultType = "query" 60) 61 62type snowflakeConn struct { 63 ctx context.Context 64 cfg *Config 65 rest *snowflakeRestful 66 SequenceCounter uint64 67 QueryID string 68 SQLState string 69 telemetry *snowflakeTelemetry 70 internal InternalClient 71} 72 73var queryIDPattern = `[\w\-_]+` 74var queryIDRegexp = regexp.MustCompile(queryIDPattern) 75 76const ( 77 urlQueriesResultFmt string = "/queries/%s/result" 78) 79 80// isDml returns true if the statement type code is in the range of DML. 81func (sc *snowflakeConn) isDml(v int64) bool { 82 return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert 83} 84 85// isMultiStmt returns true if the statement type code is of type multistatement 86// Note that the statement type code is also equivalent to type INSERT, so an additional check of the name is required 87func (sc *snowflakeConn) isMultiStmt(data *execResponseData) bool { 88 return data.StatementTypeID == statementTypeIDMulti && data.RowType[0].Name == "multiple statement execution" 89} 90 91func (sc *snowflakeConn) exec( 92 ctx context.Context, 93 query string, 94 noResult bool, 95 isInternal bool, 96 describeOnly bool, 97 bindings []driver.NamedValue) ( 98 *execResponse, error) { 99 var err error 100 counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter 101 102 req := execRequest{ 103 SQLText: query, 104 AsyncExec: noResult, 105 Parameters: map[string]interface{}{}, 106 IsInternal: isInternal, 107 DescribeOnly: describeOnly, 108 SequenceID: counter, 109 } 110 if key := ctx.Value(multiStatementCount); key != nil { 111 req.Parameters[string(multiStatementCount)] = key 112 } 113 logger.WithContext(ctx).Infof("parameters: %v", req.Parameters) 114 115 requestID := getOrGenerateRequestIDFromContext(ctx) 116 if len(bindings) > 0 { 117 arrayBindThreshold := sc.getArrayBindStageThreshold() 118 numBinds := arrayBindValueCount(bindings) 119 if 0 < arrayBindThreshold && arrayBindThreshold <= numBinds && !describeOnly && isArrayBind(bindings) { 120 // bulk array insert binding 121 uploader := bindUploader{ 122 sc: sc, 123 ctx: ctx, 124 stagePath: "@" + bindStageName + "/" + requestID.String(), 125 } 126 uploader.upload(bindings) 127 req.Bindings = nil 128 req.BindStage = uploader.stagePath 129 } else { 130 // variable or array binding 131 req.Bindings, err = getBindValues(bindings) 132 if err != nil { 133 return nil, err 134 } 135 req.BindStage = "" 136 } 137 } 138 logger.WithContext(ctx).Infof("bindings: %v", req.Bindings) 139 140 headers := getHeaders() 141 if isFileTransfer(query) { 142 headers[httpHeaderAccept] = headerContentTypeApplicationJSON 143 } 144 if serviceName, ok := sc.cfg.Params[serviceName]; ok { 145 headers[httpHeaderServiceName] = *serviceName 146 } 147 148 jsonBody, err := json.Marshal(req) 149 if err != nil { 150 return nil, err 151 } 152 153 var data *execResponse 154 data, err = sc.rest.FuncPostQuery(ctx, sc.rest, &url.Values{}, headers, jsonBody, sc.rest.RequestTimeout, requestID, sc.cfg) 155 if err != nil { 156 return data, err 157 } 158 var code int 159 if data.Code != "" { 160 code, err = strconv.Atoi(data.Code) 161 if err != nil { 162 code = -1 163 return data, err 164 } 165 } else { 166 code = -1 167 } 168 logger.WithContext(ctx).Infof("Success: %v, Code: %v", data.Success, code) 169 if !data.Success { 170 return nil, &SnowflakeError{ 171 Number: code, 172 SQLState: data.Data.SQLState, 173 Message: data.Message, 174 QueryID: data.Data.QueryID, 175 } 176 } 177 if isFileTransfer(query) { 178 sfa := snowflakeFileTransferAgent{ 179 sc: sc, 180 data: &data.Data, 181 command: query, 182 options: new(SnowflakeFileTransferOptions), 183 } 184 if fs := getFileStream(ctx); fs != nil { 185 sfa.sourceStream = fs 186 if isInternal { 187 sfa.data.AutoCompress = false 188 } 189 } 190 if op := getFileTransferOptions(ctx); op != nil { 191 sfa.options = op 192 } 193 if sfa.options.multiPartThreshold == 0 { 194 sfa.options.multiPartThreshold = dataSizeThreshold 195 } 196 if err = sfa.execute(); err != nil { 197 return nil, err 198 } 199 data, err = sfa.result() 200 if err != nil { 201 return nil, err 202 } 203 } 204 205 logger.WithContext(ctx).Info("Exec/Query SUCCESS") 206 sc.cfg.Database = data.Data.FinalDatabaseName 207 sc.cfg.Schema = data.Data.FinalSchemaName 208 sc.cfg.Role = data.Data.FinalRoleName 209 sc.cfg.Warehouse = data.Data.FinalWarehouseName 210 sc.QueryID = data.Data.QueryID 211 sc.SQLState = data.Data.SQLState 212 sc.populateSessionParameters(data.Data.Parameters) 213 return data, err 214} 215 216func (sc *snowflakeConn) Begin() (driver.Tx, error) { 217 return sc.BeginTx(sc.ctx, driver.TxOptions{}) 218} 219 220func (sc *snowflakeConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 221 logger.WithContext(ctx).Info("BeginTx") 222 if opts.ReadOnly { 223 return nil, &SnowflakeError{ 224 Number: ErrNoReadOnlyTransaction, 225 SQLState: SQLStateFeatureNotSupported, 226 Message: errMsgNoReadOnlyTransaction, 227 } 228 } 229 if int(opts.Isolation) != int(sql.LevelDefault) { 230 return nil, &SnowflakeError{ 231 Number: ErrNoDefaultTransactionIsolationLevel, 232 SQLState: SQLStateFeatureNotSupported, 233 Message: errMsgNoDefaultTransactionIsolationLevel, 234 } 235 } 236 if sc.rest == nil { 237 return nil, driver.ErrBadConn 238 } 239 isDesc := isDescribeOnly(ctx) 240 _, err := sc.exec(ctx, "BEGIN", false /* noResult */, false /* isInternal */, isDesc, nil) 241 if err != nil { 242 return nil, err 243 } 244 return &snowflakeTx{sc}, err 245} 246 247func (sc *snowflakeConn) cleanup() { 248 // must flush log buffer while the process is running. 249 sc.rest = nil 250 sc.cfg = nil 251} 252 253func (sc *snowflakeConn) Close() (err error) { 254 logger.WithContext(sc.ctx).Infoln("Close") 255 sc.telemetry.sendBatch() 256 sc.stopHeartBeat() 257 258 if !sc.cfg.KeepSessionAlive { 259 err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout) 260 if err != nil { 261 logger.Error(err) 262 } 263 } 264 sc.cleanup() 265 return nil 266} 267 268func (sc *snowflakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 269 logger.WithContext(sc.ctx).Infoln("Prepare") 270 if sc.rest == nil { 271 return nil, driver.ErrBadConn 272 } 273 noResult := isAsyncMode(ctx) 274 data, err := sc.exec(ctx, query, noResult, false /* isInternal */, true /* describeOnly */, []driver.NamedValue{}) 275 if err != nil { 276 if data != nil { 277 code, err := strconv.Atoi(data.Code) 278 if err != nil { 279 return nil, err 280 } 281 return nil, &SnowflakeError{ 282 Number: code, 283 SQLState: data.Data.SQLState, 284 Message: err.Error(), 285 QueryID: data.Data.QueryID, 286 } 287 } 288 return nil, err 289 } 290 stmt := &snowflakeStmt{ 291 sc: sc, 292 query: query, 293 } 294 return stmt, nil 295} 296 297func (sc *snowflakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 298 logger.WithContext(ctx).Infof("Exec: %#v, %v", query, args) 299 if sc.rest == nil { 300 return nil, driver.ErrBadConn 301 } 302 noResult := isAsyncMode(ctx) 303 isDesc := isDescribeOnly(ctx) 304 // TODO handle isInternal 305 ctx = setResultType(ctx, execResultType) 306 data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args) 307 if err != nil { 308 logger.WithContext(ctx).Infof("error: %v", err) 309 if data != nil { 310 code, err := strconv.Atoi(data.Code) 311 if err != nil { 312 return nil, err 313 } 314 return nil, &SnowflakeError{ 315 Number: code, 316 SQLState: data.Data.SQLState, 317 Message: err.Error(), 318 QueryID: data.Data.QueryID} 319 } 320 return nil, err 321 } 322 323 // if async exec, return result object right away 324 if noResult { 325 return data.Data.AsyncResult, nil 326 } 327 328 if sc.isDml(data.Data.StatementTypeID) { 329 // collects all values from the returned row sets 330 updatedRows, err := updateRows(data.Data) 331 if err != nil { 332 return nil, err 333 } 334 logger.WithContext(ctx).Debugf("number of updated rows: %#v", updatedRows) 335 return &snowflakeResult{ 336 affectedRows: updatedRows, 337 insertID: -1, 338 queryID: sc.QueryID, 339 }, nil // last insert id is not supported by Snowflake 340 } else if sc.isMultiStmt(&data.Data) { 341 return sc.handleMultiExec(ctx, data.Data) 342 } 343 logger.Debug("DDL") 344 return driver.ResultNoRows, nil 345} 346 347func (sc *snowflakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 348 qid, err := getResumeQueryID(ctx) 349 if err != nil { 350 return nil, err 351 } 352 if qid == "" { 353 return sc.queryContextInternal(ctx, query, args) 354 } 355 356 // first we will check the status of this particular query to find out if there is result to fetch 357 err = sc.checkQueryStatus(ctx, qid) 358 if err == nil || (err != nil && err.(*SnowflakeError).Number == ErrQueryIsRunning) { 359 // the query is running. Rows object will be returned from here. 360 return sc.buildRowsForRunningQuery(ctx, qid) 361 } 362 363 return nil, err 364} 365 366func (sc *snowflakeConn) queryContextInternal(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 367 logger.WithContext(ctx).Infof("Query: %#v, %v", query, args) 368 if sc.rest == nil { 369 return nil, driver.ErrBadConn 370 } 371 372 noResult := isAsyncMode(ctx) 373 isDesc := isDescribeOnly(ctx) 374 ctx = setResultType(ctx, queryResultType) 375 // TODO: handle isInternal 376 data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args) 377 if err != nil { 378 logger.WithContext(ctx).Errorf("error: %v", err) 379 if data != nil { 380 code, err := strconv.Atoi(data.Code) 381 if err != nil { 382 return nil, err 383 } 384 return nil, &SnowflakeError{ 385 Number: code, 386 SQLState: data.Data.SQLState, 387 Message: err.Error(), 388 QueryID: data.Data.QueryID} 389 } 390 return nil, err 391 } 392 393 // if async query, return row object right away 394 if noResult { 395 return data.Data.AsyncRows, nil 396 } 397 398 rows := new(snowflakeRows) 399 rows.sc = sc 400 rows.queryID = sc.QueryID 401 402 if sc.isMultiStmt(&data.Data) { 403 // handleMultiQuery is responsible to fill rows with childResults 404 if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil { 405 return nil, err 406 } 407 } else { 408 rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data)) 409 } 410 411 rows.ChunkDownloader.start() 412 return rows, err 413} 414 415func (sc *snowflakeConn) Prepare(query string) (driver.Stmt, error) { 416 return sc.PrepareContext(sc.ctx, query) 417} 418 419func (sc *snowflakeConn) Exec( 420 query string, 421 args []driver.Value) ( 422 driver.Result, error) { 423 return sc.ExecContext(sc.ctx, query, toNamedValues(args)) 424} 425 426func (sc *snowflakeConn) Query( 427 query string, 428 args []driver.Value) ( 429 driver.Rows, error) { 430 return sc.QueryContext(sc.ctx, query, toNamedValues(args)) 431} 432 433func (sc *snowflakeConn) Ping(ctx context.Context) error { 434 logger.WithContext(ctx).Infoln("Ping") 435 if sc.rest == nil { 436 return driver.ErrBadConn 437 } 438 noResult := isAsyncMode(ctx) 439 isDesc := isDescribeOnly(ctx) 440 // TODO: handle isInternal 441 _, err := sc.exec(ctx, "SELECT 1", noResult, false /* isInternal */, isDesc, []driver.NamedValue{}) 442 return err 443} 444 445// CheckNamedValue determines which types are handled by this driver aside from 446// the instances captured by driver.Value 447func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error { 448 if supported := supportedArrayBind(nv); !supported { 449 return driver.ErrSkip 450 } 451 return nil 452} 453 454func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParameter) { 455 // other session parameters (not all) 456 logger.WithContext(sc.ctx).Infof("params: %#v", parameters) 457 for _, param := range parameters { 458 v := "" 459 switch param.Value.(type) { 460 case int64: 461 if vv, ok := param.Value.(int64); ok { 462 v = strconv.FormatInt(vv, 10) 463 } 464 case float64: 465 if vv, ok := param.Value.(float64); ok { 466 v = strconv.FormatFloat(vv, 'g', -1, 64) 467 } 468 case bool: 469 if vv, ok := param.Value.(bool); ok { 470 v = strconv.FormatBool(vv) 471 } 472 default: 473 if vv, ok := param.Value.(string); ok { 474 v = vv 475 } 476 } 477 logger.Debugf("parameter. name: %v, value: %v", param.Name, v) 478 sc.cfg.Params[strings.ToLower(param.Name)] = &v 479 } 480} 481 482func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { 483 v, ok := sc.cfg.Params[sessionClientSessionKeepAlive] 484 if !ok { 485 return false 486 } 487 return strings.Compare(*v, "true") == 0 488} 489 490func (sc *snowflakeConn) getArrayBindStageThreshold() int { 491 v, ok := sc.cfg.Params[sessionArrayBindStageThreshold] 492 if !ok { 493 return 0 494 } 495 num, err := strconv.Atoi(*v) 496 if err != nil { 497 return 0 498 } 499 return num 500} 501 502func (sc *snowflakeConn) startHeartBeat() { 503 if !sc.isClientSessionKeepAliveEnabled() { 504 return 505 } 506 sc.rest.HeartBeat = &heartbeat{ 507 restful: sc.rest, 508 } 509 sc.rest.HeartBeat.start() 510} 511 512func (sc *snowflakeConn) stopHeartBeat() { 513 if !sc.isClientSessionKeepAliveEnabled() { 514 return 515 } 516 sc.rest.HeartBeat.stop() 517} 518 519func (sc *snowflakeConn) handleMultiExec(ctx context.Context, data execResponseData) (driver.Result, error) { 520 var updatedRows int64 521 childResults := getChildResults(data.ResultIDs, data.ResultTypes) 522 for _, child := range childResults { 523 resultPath := fmt.Sprintf(urlQueriesResultFmt, child.id) 524 childData, err := sc.getQueryResultResp(ctx, resultPath) 525 if err != nil { 526 logger.Errorf("error: %v", err) 527 code, err := strconv.Atoi(childData.Code) 528 if err != nil { 529 return nil, err 530 } 531 if childData != nil { 532 return nil, &SnowflakeError{ 533 Number: code, 534 SQLState: childData.Data.SQLState, 535 Message: err.Error(), 536 QueryID: childData.Data.QueryID} 537 } 538 return nil, err 539 } 540 if sc.isDml(childData.Data.StatementTypeID) { 541 count, err := updateRows(childData.Data) 542 if err != nil { 543 logger.WithContext(ctx).Errorf("error: %v", err) 544 if childData != nil { 545 code, err := strconv.Atoi(childData.Code) 546 if err != nil { 547 return nil, err 548 } 549 return nil, &SnowflakeError{ 550 Number: code, 551 SQLState: childData.Data.SQLState, 552 Message: err.Error(), 553 QueryID: childData.Data.QueryID} 554 } 555 return nil, err 556 } 557 updatedRows += count 558 } 559 } 560 logger.WithContext(ctx).Infof("number of updated rows: %#v", updatedRows) 561 return &snowflakeResult{ 562 affectedRows: updatedRows, 563 insertID: -1, 564 queryID: sc.QueryID, 565 }, nil 566} 567 568// Fill the correspondent rows and add chunk downloader into the rows when iterate the childResults 569func (sc *snowflakeConn) handleMultiQuery(ctx context.Context, data execResponseData, rows *snowflakeRows) error { 570 childResults := getChildResults(data.ResultIDs, data.ResultTypes) 571 572 for _, child := range childResults { 573 if err := sc.rowsForRunningQuery(ctx, child.id, rows); err != nil { 574 return err 575 } 576 } 577 return nil 578} 579 580func setResultType(ctx context.Context, resType resultType) context.Context { 581 return context.WithValue(ctx, snowflakeResultType, resType) 582} 583 584func getResultType(ctx context.Context) resultType { 585 return ctx.Value(snowflakeResultType).(resultType) 586} 587 588func updateRows(data execResponseData) (int64, error) { 589 var count int64 590 for i, n := 0, len(data.RowType); i < n; i++ { 591 v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64) 592 if err != nil { 593 return -1, err 594 } 595 count += v 596 } 597 return count, nil 598} 599 600type childResult struct { 601 id string 602 typ string 603} 604 605func getChildResults(IDs string, types string) []childResult { 606 if IDs == "" { 607 return nil 608 } 609 queryIDs := strings.Split(IDs, ",") 610 resultTypes := strings.Split(types, ",") 611 res := make([]childResult, len(queryIDs)) 612 for i, id := range queryIDs { 613 res[i] = childResult{id, resultTypes[i]} 614 } 615 return res 616} 617 618func (sc *snowflakeConn) getQueryResultResp(ctx context.Context, resultPath string) (*execResponse, error) { 619 headers := getHeaders() 620 if serviceName, ok := sc.cfg.Params[serviceName]; ok { 621 headers[httpHeaderServiceName] = *serviceName 622 } 623 param := make(url.Values) 624 param.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) 625 param.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10)) 626 param.Add(requestGUIDKey, uuid.New().String()) 627 token, _, _ := sc.rest.TokenAccessor.GetTokens() 628 if token != "" { 629 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 630 } 631 url := sc.rest.getFullURL(resultPath, ¶m) 632 res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout) 633 if err != nil { 634 logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) 635 return nil, err 636 } 637 var respd *execResponse 638 err = json.NewDecoder(res.Body).Decode(&respd) 639 if err != nil { 640 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 641 return nil, err 642 } 643 return respd, nil 644} 645 646// checkQueryStatus returns the status given the query ID. If successful, 647// the error will be nil, indicating there is a complete query result to fetch. 648// Other than nil, there are three error types that can be returned: 649// 1. ErrQueryStatus, if GS cannot return any kind of status due to any reason, 650// i.e. connection, permission, if a query was just submitted, etc. 651// 2, ErrQueryReportedError, if the requested query was terminated or aborted 652// and GS returned an error status included in query. SFQueryFailedWithError 653// 3, ErrQueryIsRunning, if the requested query is still running and might have 654// a complete result later, these statuses were listed in query. SFQueryRunning 655func (sc *snowflakeConn) checkQueryStatus(ctx context.Context, qid string) error { 656 headers := make(map[string]string) 657 param := make(url.Values) 658 param.Add(requestGUIDKey, uuid.New().String()) 659 if tok, _, _ := sc.rest.TokenAccessor.GetTokens(); tok != "" { 660 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, tok) 661 } 662 resultPath := fmt.Sprintf("/monitoring/queries/%s", qid) 663 url := sc.rest.getFullURL(resultPath, ¶m) 664 665 res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout) 666 if err != nil { 667 logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) 668 return err 669 } 670 var statusResp = statusResponse{} 671 672 err = json.NewDecoder(res.Body).Decode(&statusResp) 673 if err != nil { 674 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 675 return err 676 } 677 678 if !statusResp.Success || len(statusResp.Data.Queries) == 0 { 679 logger.WithContext(ctx).Errorf("status query returned not-success or no status returned.") 680 return &SnowflakeError{ 681 Number: ErrQueryStatus, 682 Message: "status query returned not-success or no status returned. Please retry"} 683 } 684 685 var queryRet = statusResp.Data.Queries[0] 686 if queryRet.ErrorCode != 0 { 687 return &SnowflakeError{ 688 Number: ErrQueryStatus, 689 Message: fmt.Sprintf("server ErrorCode=%d, ErrorMessage=%s", 690 queryRet.ErrorCode, queryRet.ErrorMessage), 691 IncludeQueryID: true, 692 QueryID: qid, 693 } 694 } 695 696 // returned errorCode is 0. Now check what is the returned status of the query. 697 var qstatus = strToSFQueryStatus(queryRet.Status) 698 if sfqStatusIsAnError(qstatus) { 699 return &SnowflakeError{ 700 Number: ErrQueryReportedError, 701 Message: fmt.Sprintf("%s: status from server: [%s]", 702 queryRet.ErrorMessage, queryRet.Status), 703 IncludeQueryID: true, 704 QueryID: qid, 705 } 706 } 707 708 if sfqStatusIsStillRunning(qstatus) { 709 return &SnowflakeError{ 710 Number: ErrQueryIsRunning, 711 Message: fmt.Sprintf("%s: status from server: [%s]", 712 queryRet.ErrorMessage, queryRet.Status), 713 IncludeQueryID: true, 714 QueryID: qid, 715 } 716 } 717 718 //success 719 return nil 720} 721 722// Fetch query result for a query id from /queries/<qid>/result endpoint. 723func (sc *snowflakeConn) rowsForRunningQuery(ctx context.Context, qid string, rows *snowflakeRows) error { 724 resultPath := fmt.Sprintf(urlQueriesResultFmt, qid) 725 resp, err := sc.getQueryResultResp(ctx, resultPath) 726 if err != nil { 727 logger.WithContext(ctx).Errorf("error: %v", err) 728 if resp != nil { 729 code, err := strconv.Atoi(resp.Code) 730 if err != nil { 731 return err 732 } 733 return &SnowflakeError{ 734 Number: code, 735 SQLState: resp.Data.SQLState, 736 Message: err.Error(), 737 QueryID: resp.Data.QueryID} 738 } 739 return err 740 } 741 rows.addDownloader(populateChunkDownloader(ctx, sc, resp.Data)) 742 return nil 743} 744 745// prepare a Rows object to return for query of 'qid' 746func (sc *snowflakeConn) buildRowsForRunningQuery(ctx context.Context, qid string) (driver.Rows, error) { 747 rows := new(snowflakeRows) 748 rows.sc = sc 749 rows.queryID = qid 750 err := sc.rowsForRunningQuery(ctx, qid, rows) 751 if err != nil { 752 return nil, err 753 } 754 rows.ChunkDownloader.start() 755 return rows, err 756} 757 758func isAsyncMode(ctx context.Context) bool { 759 val := ctx.Value(asyncMode) 760 if val == nil { 761 return false 762 } 763 a, ok := val.(bool) 764 return ok && a 765} 766 767func getResumeQueryID(ctx context.Context) (string, error) { 768 val := ctx.Value(fetchResultByID) 769 if val == nil { 770 return "", nil 771 } 772 strVal, ok := val.(string) 773 if !ok { 774 return "", fmt.Errorf("failed to cast val %+v to string", val) 775 } 776 // so there is a queryID in context for which we want to fetch the result 777 if !queryIDRegexp.MatchString(strVal) { 778 return strVal, &SnowflakeError{ 779 Number: ErrQueryIDFormat, 780 Message: "Invalid QID", 781 QueryID: strVal} 782 } 783 return strVal, nil 784} 785 786func getAsync( 787 ctx context.Context, 788 sr *snowflakeRestful, 789 headers map[string]string, 790 URL *url.URL, 791 timeout time.Duration, 792 res *snowflakeResult, 793 rows *snowflakeRows, 794 cfg *Config, 795) { 796 resType := getResultType(ctx) 797 var errChannel chan error 798 sfError := &SnowflakeError{ 799 Number: -1, 800 } 801 if resType == execResultType { 802 errChannel = res.errChannel 803 sfError.QueryID = res.queryID 804 } else { 805 errChannel = rows.errChannel 806 sfError.QueryID = rows.queryID 807 } 808 defer close(errChannel) 809 token, _, _ := sr.TokenAccessor.GetTokens() 810 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 811 resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout) 812 if err != nil { 813 logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) 814 sfError.Message = err.Error() 815 errChannel <- sfError 816 close(errChannel) 817 return 818 } 819 if resp.Body != nil { 820 defer resp.Body.Close() 821 } 822 823 respd := execResponse{} 824 err = json.NewDecoder(resp.Body).Decode(&respd) 825 resp.Body.Close() 826 if err != nil { 827 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 828 sfError.Message = err.Error() 829 errChannel <- sfError 830 close(errChannel) 831 return 832 } 833 834 sc := &snowflakeConn{rest: sr, cfg: cfg} 835 if respd.Success { 836 if resType == execResultType { 837 res.insertID = -1 838 if sc.isDml(respd.Data.StatementTypeID) { 839 res.affectedRows, _ = updateRows(respd.Data) 840 } else if sc.isMultiStmt(&respd.Data) { 841 r, err := sc.handleMultiExec(ctx, respd.Data) 842 if err != nil { 843 res.errChannel <- err 844 close(errChannel) 845 return 846 } 847 res.affectedRows, err = r.RowsAffected() 848 if err != nil { 849 res.errChannel <- err 850 close(errChannel) 851 return 852 } 853 } 854 res.queryID = respd.Data.QueryID 855 res.errChannel <- nil // mark exec status complete 856 } else { 857 rows.sc = sc 858 rows.queryID = respd.Data.QueryID 859 if sc.isMultiStmt(&respd.Data) { 860 err = sc.handleMultiQuery(ctx, respd.Data, rows) 861 if err != nil { 862 rows.errChannel <- err 863 close(errChannel) 864 return 865 } 866 } else { 867 rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data)) 868 } 869 rows.ChunkDownloader.start() 870 rows.errChannel <- nil // mark query status complete 871 } 872 } else { 873 var code int 874 if respd.Code != "" { 875 code, err = strconv.Atoi(respd.Code) 876 if err != nil { 877 code = -1 878 } 879 } else { 880 code = -1 881 } 882 errChannel <- &SnowflakeError{ 883 Number: code, 884 SQLState: respd.Data.SQLState, 885 Message: respd.Message, 886 QueryID: respd.Data.QueryID, 887 } 888 } 889} 890 891func getQueryIDChan(ctx context.Context) chan<- string { 892 v := ctx.Value(queryIDChannel) 893 if v == nil { 894 return nil 895 } 896 c, _ := v.(chan<- string) 897 return c 898} 899 900func getFileStream(ctx context.Context) *bytes.Buffer { 901 s := ctx.Value(fileStreamFile) 902 r, ok := s.(io.Reader) 903 if !ok { 904 return nil 905 } 906 buf := new(bytes.Buffer) 907 buf.ReadFrom(r) 908 return buf 909} 910 911func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions { 912 v := ctx.Value(fileTransferOptions) 913 if v == nil { 914 return nil 915 } 916 o, ok := v.(*SnowflakeFileTransferOptions) 917 if !ok { 918 return nil 919 } 920 return o 921} 922 923func isDescribeOnly(ctx context.Context) bool { 924 v := ctx.Value(describeOnly) 925 if v == nil { 926 return false 927 } 928 d, ok := v.(bool) 929 return ok && d 930} 931 932// returns snowflake chunk downloader by default or stream based chunk 933// downloader if option provided through context 934func populateChunkDownloader(ctx context.Context, sc *snowflakeConn, data execResponseData) chunkDownloader { 935 if useStreamDownloader(ctx) && data.QueryResultFormat == "json" { 936 // stream chunk downloading only works for row based data formats, i.e. json 937 fetcher := &httpStreamChunkFetcher{ 938 ctx: ctx, 939 client: sc.rest.Client, 940 clientIP: sc.cfg.ClientIP, 941 headers: data.ChunkHeaders, 942 qrmk: data.Qrmk, 943 } 944 return newStreamChunkDownloader(ctx, fetcher, data.Total, data.RowType, data.RowSet, data.Chunks) 945 } 946 947 return &snowflakeChunkDownloader{ 948 sc: sc, 949 ctx: ctx, 950 CurrentChunk: make([]chunkRowType, len(data.RowSet)), 951 ChunkMetas: data.Chunks, 952 Total: data.Total, 953 TotalRowIndex: int64(-1), 954 CellCount: len(data.RowType), 955 Qrmk: data.Qrmk, 956 QueryResultFormat: data.QueryResultFormat, 957 ChunkHeader: data.ChunkHeaders, 958 FuncDownload: downloadChunk, 959 FuncDownloadHelper: downloadChunkHelper, 960 FuncGet: getChunk, 961 RowSet: rowSetType{ 962 RowType: data.RowType, 963 JSON: data.RowSet, 964 RowSetBase64: data.RowSetBase64, 965 }, 966 } 967} 968 969func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) { 970 sc := &snowflakeConn{ 971 SequenceCounter: 0, 972 ctx: ctx, 973 cfg: &config, 974 } 975 var st http.RoundTripper = SnowflakeTransport 976 if sc.cfg.Transporter == nil { 977 if sc.cfg.InsecureMode { 978 // no revocation check with OCSP. Think twice when you want to enable this option. 979 st = snowflakeInsecureTransport 980 } else { 981 // set OCSP fail open mode 982 ocspResponseCacheLock.Lock() 983 atomic.StoreUint32((*uint32)(&ocspFailOpen), uint32(sc.cfg.OCSPFailOpen)) 984 ocspResponseCacheLock.Unlock() 985 } 986 } else { 987 // use the custom transport 988 st = sc.cfg.Transporter 989 } 990 var tokenAccessor TokenAccessor 991 if sc.cfg.TokenAccessor != nil { 992 tokenAccessor = sc.cfg.TokenAccessor 993 } else { 994 tokenAccessor = getSimpleTokenAccessor() 995 } 996 if sc.cfg.DisableTelemetry { 997 sc.telemetry.enabled = false 998 } 999 1000 // authenticate 1001 sc.rest = &snowflakeRestful{ 1002 Host: sc.cfg.Host, 1003 Port: sc.cfg.Port, 1004 Protocol: sc.cfg.Protocol, 1005 Client: &http.Client{ 1006 // request timeout including reading response body 1007 Timeout: sc.cfg.ClientTimeout, 1008 Transport: st, 1009 }, 1010 TokenAccessor: tokenAccessor, 1011 LoginTimeout: sc.cfg.LoginTimeout, 1012 RequestTimeout: sc.cfg.RequestTimeout, 1013 FuncPost: postRestful, 1014 FuncGet: getRestful, 1015 FuncPostQuery: postRestfulQuery, 1016 FuncPostQueryHelper: postRestfulQueryHelper, 1017 FuncRenewSession: renewRestfulSession, 1018 FuncPostAuth: postAuth, 1019 FuncCloseSession: closeSession, 1020 FuncCancelQuery: cancelQuery, 1021 FuncPostAuthSAML: postAuthSAML, 1022 FuncPostAuthOKTA: postAuthOKTA, 1023 FuncGetSSO: getSSO, 1024 } 1025 sc.telemetry = &snowflakeTelemetry{ 1026 flushSize: defaultFlushSize, 1027 sr: sc.rest, 1028 mutex: &sync.Mutex{}, 1029 enabled: true, 1030 } 1031 return sc, nil 1032} 1033