1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"context"
7	"encoding/json"
8	"fmt"
9	"io/ioutil"
10	"net/http"
11	"net/url"
12	"strconv"
13	"time"
14
15	"github.com/google/uuid"
16)
17
18// HTTP headers
19const (
20	headerSnowflakeToken   = "Snowflake Token=\"%v\""
21	headerAuthorizationKey = "Authorization"
22
23	headerContentTypeApplicationJSON     = "application/json"
24	headerAcceptTypeApplicationSnowflake = "application/snowflake"
25)
26
27// Snowflake Server Error code
28const (
29	queryInProgressCode      = "333333"
30	queryInProgressAsyncCode = "333334"
31	sessionExpiredCode       = "390112"
32	queryNotExecuting        = "000605"
33)
34
35// Snowflake Server Endpoints
36const (
37	loginRequestPath         = "/session/v1/login-request"
38	queryRequestPath         = "/queries/v1/query-request"
39	tokenRequestPath         = "/session/token-request"
40	abortRequestPath         = "/queries/v1/abort-request"
41	authenticatorRequestPath = "/session/authenticator-request"
42	sessionRequestPath       = "/session"
43	heartBeatPath            = "/session/heartbeat"
44)
45
46// FuncGetType httpclient GET method to return http.Response
47type FuncGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error)
48
49// FuncPostType httpclient POST method to return http.Response
50type FuncPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool) (*http.Response, error)
51
52type snowflakeRestful struct {
53	Host           string
54	Port           int
55	Protocol       string
56	LoginTimeout   time.Duration // Login timeout
57	RequestTimeout time.Duration // request timeout
58
59	Client        *http.Client
60	TokenAccessor TokenAccessor
61	HeartBeat     *heartbeat
62
63	Connection *snowflakeConn
64
65	FuncPostQuery       func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, uuid.UUID, *Config) (*execResponse, error)
66	FuncPostQueryHelper func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration, uuid.UUID, *Config) (*execResponse, error)
67	FuncPost            FuncPostType
68	FuncGet             FuncGetType
69	FuncRenewSession    func(context.Context, *snowflakeRestful, time.Duration) error
70	FuncPostAuth        func(context.Context, *snowflakeRestful, *url.Values, map[string]string, []byte, time.Duration) (*authResponse, error)
71	FuncCloseSession    func(context.Context, *snowflakeRestful, time.Duration) error
72	FuncCancelQuery     func(context.Context, *snowflakeRestful, uuid.UUID, time.Duration) error
73
74	FuncPostAuthSAML func(context.Context, *snowflakeRestful, map[string]string, []byte, time.Duration) (*authResponse, error)
75	FuncPostAuthOKTA func(context.Context, *snowflakeRestful, map[string]string, []byte, string, time.Duration) (*authOKTAResponse, error)
76	FuncGetSSO       func(context.Context, *snowflakeRestful, *url.Values, map[string]string, string, time.Duration) ([]byte, error)
77}
78
79func (sr *snowflakeRestful) getURL() *url.URL {
80	return &url.URL{
81		Scheme: sr.Protocol,
82		Host:   sr.Host + ":" + strconv.Itoa(sr.Port),
83	}
84}
85
86func (sr *snowflakeRestful) getFullURL(path string, params *url.Values) *url.URL {
87	ret := &url.URL{
88		Scheme: sr.Protocol,
89		Host:   sr.Host + ":" + strconv.Itoa(sr.Port),
90		Path:   path,
91	}
92	if params != nil {
93		ret.RawQuery = params.Encode()
94	}
95	return ret
96}
97
98// Renew the snowflake session if the current token is still the stale token specified
99func (sr *snowflakeRestful) renewExpiredSessionToken(ctx context.Context, timeout time.Duration, expiredToken string) error {
100	err := sr.TokenAccessor.Lock()
101	if err != nil {
102		return err
103	}
104	defer sr.TokenAccessor.Unlock()
105	currentToken, _, _ := sr.TokenAccessor.GetTokens()
106	if expiredToken == currentToken || currentToken == "" {
107		// Only renew the session if the current token is still the expired token or current token is empty
108		return sr.FuncRenewSession(ctx, sr, timeout)
109	}
110	return nil
111}
112
113type renewSessionResponse struct {
114	Data    renewSessionResponseMain `json:"data"`
115	Message string                   `json:"message"`
116	Code    string                   `json:"code"`
117	Success bool                     `json:"success"`
118}
119
120type renewSessionResponseMain struct {
121	SessionToken        string        `json:"sessionToken"`
122	ValidityInSecondsST time.Duration `json:"validityInSecondsST"`
123	MasterToken         string        `json:"masterToken"`
124	ValidityInSecondsMT time.Duration `json:"validityInSecondsMT"`
125	SessionID           int64         `json:"sessionId"`
126}
127
128type cancelQueryResponse struct {
129	Data    interface{} `json:"data"`
130	Message string      `json:"message"`
131	Code    string      `json:"code"`
132	Success bool        `json:"success"`
133}
134
135type telemetryResponse struct {
136	Data    interface{}       `json:"data,omitempty"`
137	Message string            `json:"message"`
138	Code    string            `json:"code"`
139	Success bool              `json:"success"`
140	Headers map[string]string `json:"headers,omitempty"`
141}
142
143func postRestful(
144	ctx context.Context,
145	sr *snowflakeRestful,
146	fullURL *url.URL,
147	headers map[string]string,
148	body []byte,
149	timeout time.Duration,
150	raise4XX bool) (
151	*http.Response, error) {
152	return newRetryHTTP(
153		ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).doPost().setBody(body).doRaise4XX(raise4XX).execute()
154}
155
156func getRestful(
157	ctx context.Context,
158	sr *snowflakeRestful,
159	fullURL *url.URL,
160	headers map[string]string,
161	timeout time.Duration) (
162	*http.Response, error) {
163	return newRetryHTTP(
164		ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).execute()
165}
166
167func postRestfulQuery(
168	ctx context.Context,
169	sr *snowflakeRestful,
170	params *url.Values,
171	headers map[string]string,
172	body []byte,
173	timeout time.Duration,
174	requestID uuid.UUID,
175	cfg *Config) (
176	data *execResponse, err error) {
177
178	data, err = sr.FuncPostQueryHelper(ctx, sr, params, headers, body, timeout, requestID, cfg)
179
180	// errors other than context timeout and cancel would be returned to upper layers
181	if err != context.Canceled && err != context.DeadlineExceeded {
182		return data, err
183	}
184
185	err = sr.FuncCancelQuery(context.TODO(), sr, requestID, timeout)
186	if err != nil {
187		return nil, err
188	}
189	return nil, ctx.Err()
190}
191
192func postRestfulQueryHelper(
193	ctx context.Context,
194	sr *snowflakeRestful,
195	params *url.Values,
196	headers map[string]string,
197	body []byte,
198	timeout time.Duration,
199	requestID uuid.UUID,
200	cfg *Config) (
201	data *execResponse, err error) {
202	logger.Infof("params: %v", params)
203	params.Add(requestIDKey, requestID.String())
204	params.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10))
205	params.Add(requestGUIDKey, uuid.New().String())
206	token, _, _ := sr.TokenAccessor.GetTokens()
207	if token != "" {
208		headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
209	}
210
211	var resp *http.Response
212	fullURL := sr.getFullURL(queryRequestPath, params)
213	resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true)
214
215	if err != nil {
216		return nil, err
217	}
218	defer resp.Body.Close()
219
220	if resp.StatusCode == http.StatusOK {
221		logger.WithContext(ctx).Infof("postQuery: resp: %v", resp)
222		var respd execResponse
223		err = json.NewDecoder(resp.Body).Decode(&respd)
224		if err != nil {
225			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
226			return nil, err
227		}
228		if respd.Code == sessionExpiredCode {
229			err = sr.renewExpiredSessionToken(ctx, timeout, token)
230			if err != nil {
231				return nil, err
232			}
233			return sr.FuncPostQuery(ctx, sr, params, headers, body, timeout, requestID, cfg)
234		}
235
236		if queryIDChan := getQueryIDChan(ctx); queryIDChan != nil {
237			queryIDChan <- respd.Data.QueryID
238			close(queryIDChan)
239			ctx = WithQueryIDChan(ctx, nil)
240		}
241
242		var resultURL string
243		isSessionRenewed := false
244		noResult := isAsyncMode(ctx)
245
246		// if asynchronous query in progress, kick off retrieval but return object
247		if respd.Code == queryInProgressAsyncCode && noResult {
248			// placeholder object to return to user while retrieving results
249			rows := new(snowflakeRows)
250			res := new(snowflakeResult)
251			switch resType := getResultType(ctx); resType {
252			case execResultType:
253				res.queryID = respd.Data.QueryID
254				res.status = QueryStatusInProgress
255				res.errChannel = make(chan error)
256				respd.Data.AsyncResult = res
257			case queryResultType:
258				rows.queryID = respd.Data.QueryID
259				rows.status = QueryStatusInProgress
260				rows.errChannel = make(chan error)
261				respd.Data.AsyncRows = rows
262			default:
263				return &respd, nil
264			}
265
266			// spawn goroutine to retrieve asynchronous results
267			go getAsync(ctx, sr, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg)
268			return &respd, nil
269		}
270		for isSessionRenewed || respd.Code == queryInProgressCode ||
271			respd.Code == queryInProgressAsyncCode {
272			if !isSessionRenewed {
273				resultURL = respd.Data.GetResultURL
274			}
275
276			logger.Info("ping pong")
277			token, _, _ := sr.TokenAccessor.GetTokens()
278			headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
279			fullURL := sr.getFullURL(resultURL, nil)
280
281			resp, err = sr.FuncGet(ctx, sr, fullURL, headers, timeout)
282			if err != nil {
283				logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
284				return nil, err
285			}
286			respd = execResponse{} // reset the response
287			err = json.NewDecoder(resp.Body).Decode(&respd)
288			resp.Body.Close()
289			if err != nil {
290				logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
291				return nil, err
292			}
293			if respd.Code == sessionExpiredCode {
294				err = sr.renewExpiredSessionToken(ctx, timeout, token)
295				if err != nil {
296					return nil, err
297				}
298				isSessionRenewed = true
299			} else {
300				isSessionRenewed = false
301			}
302		}
303		return &respd, nil
304	}
305	b, err := ioutil.ReadAll(resp.Body)
306	if err != nil {
307		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
308		return nil, err
309	}
310	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
311	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
312	return nil, &SnowflakeError{
313		Number:      ErrFailedToPostQuery,
314		SQLState:    SQLStateConnectionFailure,
315		Message:     errMsgFailedToPostQuery,
316		MessageArgs: []interface{}{resp.StatusCode, fullURL},
317	}
318}
319
320func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error {
321	logger.WithContext(ctx).Info("close session")
322	params := &url.Values{}
323	params.Add("delete", "true")
324	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
325	params.Add(requestGUIDKey, uuid.New().String())
326	fullURL := sr.getFullURL(sessionRequestPath, params)
327
328	headers := getHeaders()
329	token, _, _ := sr.TokenAccessor.GetTokens()
330	headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
331
332	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false)
333	if err != nil {
334		return err
335	}
336	defer resp.Body.Close()
337	if resp.StatusCode == http.StatusOK {
338		var respd renewSessionResponse
339		err = json.NewDecoder(resp.Body).Decode(&respd)
340		if err != nil {
341			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
342			return err
343		}
344		if !respd.Success && respd.Code != sessionExpiredCode {
345			c, err := strconv.Atoi(respd.Code)
346			if err != nil {
347				return err
348			}
349			return &SnowflakeError{
350				Number:  c,
351				Message: respd.Message,
352			}
353		}
354		return nil
355	}
356	b, err := ioutil.ReadAll(resp.Body)
357	if err != nil {
358		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
359		return err
360	}
361	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
362	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
363	return &SnowflakeError{
364		Number:      ErrFailedToCloseSession,
365		SQLState:    SQLStateConnectionFailure,
366		Message:     errMsgFailedToCloseSession,
367		MessageArgs: []interface{}{resp.StatusCode, fullURL},
368	}
369}
370
371func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time.Duration) error {
372	logger.WithContext(ctx).Info("start renew session")
373	params := &url.Values{}
374	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
375	params.Add(requestGUIDKey, uuid.New().String())
376	fullURL := sr.getFullURL(tokenRequestPath, params)
377
378	token, masterToken, _ := sr.TokenAccessor.GetTokens()
379	headers := getHeaders()
380	headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, masterToken)
381
382	body := make(map[string]string)
383	body["oldSessionToken"] = token
384	body["requestType"] = "RENEW"
385
386	var reqBody []byte
387	reqBody, err := json.Marshal(body)
388	if err != nil {
389		return err
390	}
391
392	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false)
393	if err != nil {
394		return err
395	}
396	defer resp.Body.Close()
397	if resp.StatusCode == http.StatusOK {
398		var respd renewSessionResponse
399		err = json.NewDecoder(resp.Body).Decode(&respd)
400		if err != nil {
401			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
402			return err
403		}
404		if !respd.Success {
405			c, err := strconv.Atoi(respd.Code)
406			if err != nil {
407				return err
408			}
409			return &SnowflakeError{
410				Number:  c,
411				Message: respd.Message,
412			}
413		}
414		sr.TokenAccessor.SetTokens(respd.Data.SessionToken, respd.Data.MasterToken, respd.Data.SessionID)
415		return nil
416	}
417	b, err := ioutil.ReadAll(resp.Body)
418	if err != nil {
419		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
420		return err
421	}
422	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
423	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
424	return &SnowflakeError{
425		Number:      ErrFailedToRenewSession,
426		SQLState:    SQLStateConnectionFailure,
427		Message:     errMsgFailedToRenew,
428		MessageArgs: []interface{}{resp.StatusCode, fullURL},
429	}
430}
431
432func getCancelRetry(ctx context.Context) int {
433	val := ctx.Value(cancelRetry)
434	if val == nil {
435		return 5
436	}
437	cnt, _ := val.(int)
438	return cnt
439}
440
441func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID uuid.UUID, timeout time.Duration) error {
442	logger.WithContext(ctx).Info("cancel query")
443	params := &url.Values{}
444	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
445	params.Add(requestGUIDKey, uuid.New().String())
446
447	fullURL := sr.getFullURL(abortRequestPath, params)
448
449	headers := getHeaders()
450	token, _, _ := sr.TokenAccessor.GetTokens()
451	headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
452
453	req := make(map[string]string)
454	req[requestIDKey] = requestID.String()
455
456	reqByte, err := json.Marshal(req)
457	if err != nil {
458		return err
459	}
460
461	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false)
462	if err != nil {
463		return err
464	}
465	defer resp.Body.Close()
466	if resp.StatusCode == http.StatusOK {
467		var respd cancelQueryResponse
468		err = json.NewDecoder(resp.Body).Decode(&respd)
469		if err != nil {
470			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
471			return err
472		}
473		ctxRetry := getCancelRetry(ctx)
474		if !respd.Success && respd.Code == sessionExpiredCode {
475			err := sr.FuncRenewSession(ctx, sr, timeout)
476			if err != nil {
477				return err
478			}
479			return sr.FuncCancelQuery(ctx, sr, requestID, timeout)
480		} else if !respd.Success && respd.Code == queryNotExecuting && ctxRetry != 0 {
481			return sr.FuncCancelQuery(context.WithValue(ctx, cancelRetry, ctxRetry-1), sr, requestID, timeout)
482		} else if respd.Success {
483			return nil
484		} else {
485			c, err := strconv.Atoi(respd.Code)
486			if err != nil {
487				return err
488			}
489			return &SnowflakeError{
490				Number:  c,
491				Message: respd.Message,
492			}
493		}
494	}
495	b, err := ioutil.ReadAll(resp.Body)
496	if err != nil {
497		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
498		return err
499	}
500	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
501	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
502	return &SnowflakeError{
503		Number:      ErrFailedToCancelQuery,
504		SQLState:    SQLStateConnectionFailure,
505		Message:     errMsgFailedToCancelQuery,
506		MessageArgs: []interface{}{resp.StatusCode, fullURL},
507	}
508}
509