1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "context" 7 "encoding/json" 8 "fmt" 9 "io/ioutil" 10 "net/http" 11 "net/url" 12 "strconv" 13 "time" 14 15 "github.com/google/uuid" 16) 17 18// HTTP headers 19const ( 20 headerSnowflakeToken = "Snowflake Token=\"%v\"" 21 headerAuthorizationKey = "Authorization" 22 23 headerContentTypeApplicationJSON = "application/json" 24 headerAcceptTypeApplicationSnowflake = "application/snowflake" 25) 26 27// Snowflake Server Error code 28const ( 29 queryInProgressCode = "333333" 30 queryInProgressAsyncCode = "333334" 31 sessionExpiredCode = "390112" 32 queryNotExecuting = "000605" 33) 34 35// Snowflake Server Endpoints 36const ( 37 loginRequestPath = "/session/v1/login-request" 38 queryRequestPath = "/queries/v1/query-request" 39 tokenRequestPath = "/session/token-request" 40 abortRequestPath = "/queries/v1/abort-request" 41 authenticatorRequestPath = "/session/authenticator-request" 42 sessionRequestPath = "/session" 43 heartBeatPath = "/session/heartbeat" 44) 45 46// FuncGetType httpclient GET method to return http.Response 47type FuncGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error) 48 49// FuncPostType httpclient POST method to return http.Response 50type FuncPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool) (*http.Response, error) 51 52type snowflakeRestful struct { 53 Host string 54 Port int 55 Protocol string 56 LoginTimeout time.Duration // Login timeout 57 RequestTimeout time.Duration // request timeout 58 59 Client *http.Client 60 TokenAccessor TokenAccessor 61 HeartBeat *heartbeat 62 63 Connection *snowflakeConn 64 65 FuncPostQuery func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, uuid.UUID, *Config) (*execResponse, error) 66 FuncPostQueryHelper func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, uuid.UUID, *Config) (*execResponse, error) 67 FuncPost FuncPostType 68 FuncGet FuncGetType 69 FuncRenewSession func(context.Context, *snowflakeRestful, time.Duration) error 70 FuncPostAuth func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration) (*authResponse, error) 71 FuncCloseSession func(context.Context, *snowflakeRestful, time.Duration) error 72 FuncCancelQuery func(context.Context, *snowflakeRestful, uuid.UUID, time.Duration) error 73 74 FuncPostAuthSAML func(context.Context, *snowflakeRestful, map[string]string, []byte, time.Duration) (*authResponse, error) 75 FuncPostAuthOKTA func(context.Context, *snowflakeRestful, map[string]string, []byte, string, time.Duration) (*authOKTAResponse, error) 76 FuncGetSSO func(context.Context, *snowflakeRestful, *url.Values, map[string]string, string, time.Duration) ([]byte, error) 77} 78 79func (sr *snowflakeRestful) getURL() *url.URL { 80 return &url.URL{ 81 Scheme: sr.Protocol, 82 Host: sr.Host + ":" + strconv.Itoa(sr.Port), 83 } 84} 85 86func (sr *snowflakeRestful) getFullURL(path string, params *url.Values) *url.URL { 87 ret := &url.URL{ 88 Scheme: sr.Protocol, 89 Host: sr.Host + ":" + strconv.Itoa(sr.Port), 90 Path: path, 91 } 92 if params != nil { 93 ret.RawQuery = params.Encode() 94 } 95 return ret 96} 97 98// Renew the snowflake session if the current token is still the stale token specified 99func (sr *snowflakeRestful) renewExpiredSessionToken(ctx context.Context, timeout time.Duration, expiredToken string) error { 100 err := sr.TokenAccessor.Lock() 101 if err != nil { 102 return err 103 } 104 defer sr.TokenAccessor.Unlock() 105 currentToken, _, _ := sr.TokenAccessor.GetTokens() 106 if expiredToken == currentToken || currentToken == "" { 107 // Only renew the session if the current token is still the expired token or current token is empty 108 return sr.FuncRenewSession(ctx, sr, timeout) 109 } 110 return nil 111} 112 113type renewSessionResponse struct { 114 Data renewSessionResponseMain `json:"data"` 115 Message string `json:"message"` 116 Code string `json:"code"` 117 Success bool `json:"success"` 118} 119 120type renewSessionResponseMain struct { 121 SessionToken string `json:"sessionToken"` 122 ValidityInSecondsST time.Duration `json:"validityInSecondsST"` 123 MasterToken string `json:"masterToken"` 124 ValidityInSecondsMT time.Duration `json:"validityInSecondsMT"` 125 SessionID int64 `json:"sessionId"` 126} 127 128type cancelQueryResponse struct { 129 Data interface{} `json:"data"` 130 Message string `json:"message"` 131 Code string `json:"code"` 132 Success bool `json:"success"` 133} 134 135type telemetryResponse struct { 136 Data interface{} `json:"data,omitempty"` 137 Message string `json:"message"` 138 Code string `json:"code"` 139 Success bool `json:"success"` 140 Headers map[string]string `json:"headers,omitempty"` 141} 142 143func postRestful( 144 ctx context.Context, 145 sr *snowflakeRestful, 146 fullURL *url.URL, 147 headers map[string]string, 148 body []byte, 149 timeout time.Duration, 150 raise4XX bool) ( 151 *http.Response, error) { 152 return newRetryHTTP( 153 ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).doPost().setBody(body).doRaise4XX(raise4XX).execute() 154} 155 156func getRestful( 157 ctx context.Context, 158 sr *snowflakeRestful, 159 fullURL *url.URL, 160 headers map[string]string, 161 timeout time.Duration) ( 162 *http.Response, error) { 163 return newRetryHTTP( 164 ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).execute() 165} 166 167func postRestfulQuery( 168 ctx context.Context, 169 sr *snowflakeRestful, 170 params *url.Values, 171 headers map[string]string, 172 body []byte, 173 timeout time.Duration, 174 requestID uuid.UUID, 175 cfg *Config) ( 176 data *execResponse, err error) { 177 178 data, err = sr.FuncPostQueryHelper(ctx, sr, params, headers, body, timeout, requestID, cfg) 179 180 // errors other than context timeout and cancel would be returned to upper layers 181 if err != context.Canceled && err != context.DeadlineExceeded { 182 return data, err 183 } 184 185 err = sr.FuncCancelQuery(context.TODO(), sr, requestID, timeout) 186 if err != nil { 187 return nil, err 188 } 189 return nil, ctx.Err() 190} 191 192func postRestfulQueryHelper( 193 ctx context.Context, 194 sr *snowflakeRestful, 195 params *url.Values, 196 headers map[string]string, 197 body []byte, 198 timeout time.Duration, 199 requestID uuid.UUID, 200 cfg *Config) ( 201 data *execResponse, err error) { 202 logger.Infof("params: %v", params) 203 params.Add(requestIDKey, requestID.String()) 204 params.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10)) 205 params.Add(requestGUIDKey, uuid.New().String()) 206 token, _, _ := sr.TokenAccessor.GetTokens() 207 if token != "" { 208 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 209 } 210 211 var resp *http.Response 212 fullURL := sr.getFullURL(queryRequestPath, params) 213 resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true) 214 215 if err != nil { 216 return nil, err 217 } 218 defer resp.Body.Close() 219 220 if resp.StatusCode == http.StatusOK { 221 logger.WithContext(ctx).Infof("postQuery: resp: %v", resp) 222 var respd execResponse 223 err = json.NewDecoder(resp.Body).Decode(&respd) 224 if err != nil { 225 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 226 return nil, err 227 } 228 if respd.Code == sessionExpiredCode { 229 err = sr.renewExpiredSessionToken(ctx, timeout, token) 230 if err != nil { 231 return nil, err 232 } 233 return sr.FuncPostQuery(ctx, sr, params, headers, body, timeout, requestID, cfg) 234 } 235 236 if queryIDChan := getQueryIDChan(ctx); queryIDChan != nil { 237 queryIDChan <- respd.Data.QueryID 238 close(queryIDChan) 239 ctx = WithQueryIDChan(ctx, nil) 240 } 241 242 var resultURL string 243 isSessionRenewed := false 244 noResult := isAsyncMode(ctx) 245 246 // if asynchronous query in progress, kick off retrieval but return object 247 if respd.Code == queryInProgressAsyncCode && noResult { 248 // placeholder object to return to user while retrieving results 249 rows := new(snowflakeRows) 250 res := new(snowflakeResult) 251 switch resType := getResultType(ctx); resType { 252 case execResultType: 253 res.queryID = respd.Data.QueryID 254 res.status = QueryStatusInProgress 255 res.errChannel = make(chan error) 256 respd.Data.AsyncResult = res 257 case queryResultType: 258 rows.queryID = respd.Data.QueryID 259 rows.status = QueryStatusInProgress 260 rows.errChannel = make(chan error) 261 respd.Data.AsyncRows = rows 262 default: 263 return &respd, nil 264 } 265 266 // spawn goroutine to retrieve asynchronous results 267 go getAsync(ctx, sr, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg) 268 return &respd, nil 269 } 270 for isSessionRenewed || respd.Code == queryInProgressCode || 271 respd.Code == queryInProgressAsyncCode { 272 if !isSessionRenewed { 273 resultURL = respd.Data.GetResultURL 274 } 275 276 logger.Info("ping pong") 277 token, _, _ := sr.TokenAccessor.GetTokens() 278 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 279 fullURL := sr.getFullURL(resultURL, nil) 280 281 resp, err = sr.FuncGet(ctx, sr, fullURL, headers, timeout) 282 if err != nil { 283 logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) 284 return nil, err 285 } 286 respd = execResponse{} // reset the response 287 err = json.NewDecoder(resp.Body).Decode(&respd) 288 resp.Body.Close() 289 if err != nil { 290 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 291 return nil, err 292 } 293 if respd.Code == sessionExpiredCode { 294 err = sr.renewExpiredSessionToken(ctx, timeout, token) 295 if err != nil { 296 return nil, err 297 } 298 isSessionRenewed = true 299 } else { 300 isSessionRenewed = false 301 } 302 } 303 return &respd, nil 304 } 305 b, err := ioutil.ReadAll(resp.Body) 306 if err != nil { 307 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 308 return nil, err 309 } 310 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) 311 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 312 return nil, &SnowflakeError{ 313 Number: ErrFailedToPostQuery, 314 SQLState: SQLStateConnectionFailure, 315 Message: errMsgFailedToPostQuery, 316 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 317 } 318} 319 320func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error { 321 logger.WithContext(ctx).Info("close session") 322 params := &url.Values{} 323 params.Add("delete", "true") 324 params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) 325 params.Add(requestGUIDKey, uuid.New().String()) 326 fullURL := sr.getFullURL(sessionRequestPath, params) 327 328 headers := getHeaders() 329 token, _, _ := sr.TokenAccessor.GetTokens() 330 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 331 332 resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false) 333 if err != nil { 334 return err 335 } 336 defer resp.Body.Close() 337 if resp.StatusCode == http.StatusOK { 338 var respd renewSessionResponse 339 err = json.NewDecoder(resp.Body).Decode(&respd) 340 if err != nil { 341 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 342 return err 343 } 344 if !respd.Success && respd.Code != sessionExpiredCode { 345 c, err := strconv.Atoi(respd.Code) 346 if err != nil { 347 return err 348 } 349 return &SnowflakeError{ 350 Number: c, 351 Message: respd.Message, 352 } 353 } 354 return nil 355 } 356 b, err := ioutil.ReadAll(resp.Body) 357 if err != nil { 358 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 359 return err 360 } 361 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) 362 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 363 return &SnowflakeError{ 364 Number: ErrFailedToCloseSession, 365 SQLState: SQLStateConnectionFailure, 366 Message: errMsgFailedToCloseSession, 367 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 368 } 369} 370 371func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error { 372 logger.WithContext(ctx).Info("start renew session") 373 params := &url.Values{} 374 params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) 375 params.Add(requestGUIDKey, uuid.New().String()) 376 fullURL := sr.getFullURL(tokenRequestPath, params) 377 378 token, masterToken, _ := sr.TokenAccessor.GetTokens() 379 headers := getHeaders() 380 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, masterToken) 381 382 body := make(map[string]string) 383 body["oldSessionToken"] = token 384 body["requestType"] = "RENEW" 385 386 var reqBody []byte 387 reqBody, err := json.Marshal(body) 388 if err != nil { 389 return err 390 } 391 392 resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false) 393 if err != nil { 394 return err 395 } 396 defer resp.Body.Close() 397 if resp.StatusCode == http.StatusOK { 398 var respd renewSessionResponse 399 err = json.NewDecoder(resp.Body).Decode(&respd) 400 if err != nil { 401 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 402 return err 403 } 404 if !respd.Success { 405 c, err := strconv.Atoi(respd.Code) 406 if err != nil { 407 return err 408 } 409 return &SnowflakeError{ 410 Number: c, 411 Message: respd.Message, 412 } 413 } 414 sr.TokenAccessor.SetTokens(respd.Data.SessionToken, respd.Data.MasterToken, respd.Data.SessionID) 415 return nil 416 } 417 b, err := ioutil.ReadAll(resp.Body) 418 if err != nil { 419 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 420 return err 421 } 422 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) 423 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 424 return &SnowflakeError{ 425 Number: ErrFailedToRenewSession, 426 SQLState: SQLStateConnectionFailure, 427 Message: errMsgFailedToRenew, 428 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 429 } 430} 431 432func getCancelRetry(ctx context.Context) int { 433 val := ctx.Value(cancelRetry) 434 if val == nil { 435 return 5 436 } 437 cnt, _ := val.(int) 438 return cnt 439} 440 441func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID uuid.UUID, timeout time.Duration) error { 442 logger.WithContext(ctx).Info("cancel query") 443 params := &url.Values{} 444 params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) 445 params.Add(requestGUIDKey, uuid.New().String()) 446 447 fullURL := sr.getFullURL(abortRequestPath, params) 448 449 headers := getHeaders() 450 token, _, _ := sr.TokenAccessor.GetTokens() 451 headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) 452 453 req := make(map[string]string) 454 req[requestIDKey] = requestID.String() 455 456 reqByte, err := json.Marshal(req) 457 if err != nil { 458 return err 459 } 460 461 resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false) 462 if err != nil { 463 return err 464 } 465 defer resp.Body.Close() 466 if resp.StatusCode == http.StatusOK { 467 var respd cancelQueryResponse 468 err = json.NewDecoder(resp.Body).Decode(&respd) 469 if err != nil { 470 logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) 471 return err 472 } 473 ctxRetry := getCancelRetry(ctx) 474 if !respd.Success && respd.Code == sessionExpiredCode { 475 err := sr.FuncRenewSession(ctx, sr, timeout) 476 if err != nil { 477 return err 478 } 479 return sr.FuncCancelQuery(ctx, sr, requestID, timeout) 480 } else if !respd.Success && respd.Code == queryNotExecuting && ctxRetry != 0 { 481 return sr.FuncCancelQuery(context.WithValue(ctx, cancelRetry, ctxRetry-1), sr, requestID, timeout) 482 } else if respd.Success { 483 return nil 484 } else { 485 c, err := strconv.Atoi(respd.Code) 486 if err != nil { 487 return err 488 } 489 return &SnowflakeError{ 490 Number: c, 491 Message: respd.Message, 492 } 493 } 494 } 495 b, err := ioutil.ReadAll(resp.Body) 496 if err != nil { 497 logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) 498 return err 499 } 500 logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) 501 logger.WithContext(ctx).Infof("Header: %v", resp.Header) 502 return &SnowflakeError{ 503 Number: ErrFailedToCancelQuery, 504 SQLState: SQLStateConnectionFailure, 505 Message: errMsgFailedToCancelQuery, 506 MessageArgs: []interface{}{resp.StatusCode, fullURL}, 507 } 508} 509