1// Copyright (c) 2017-2019 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"bytes"
7	"context"
8	"encoding/json"
9	"fmt"
10	"html"
11	"io/ioutil"
12	"net/http"
13	"net/url"
14	"strconv"
15	"time"
16)
17
18type authOKTARequest struct {
19	Username string `json:"username"`
20	Password string `json:"password"`
21}
22
23type authOKTAResponse struct {
24	CookieToken string `json:"cookieToken"`
25}
26
27/*
28authenticateBySAML authenticates a user by SAML
29SAML Authentication
301.  query GS to obtain IDP token and SSO url
312.  IMPORTANT Client side validation:
32	validate both token url and sso url contains same prefix
33	(protocol + host + port) as the given authenticator url.
34	Explanation:
35	This provides a way for the user to 'authenticate' the IDP it is
36	sending his/her credentials to.  Without such a check, the user could
37	be coerced to provide credentials to an IDP impersonator.
383.  query IDP token url to authenticate and retrieve access token
394.  given access token, query IDP URL snowflake app to get SAML response
405.  IMPORTANT Client side validation:
41	validate the post back url come back with the SAML response
42	contains the same prefix as the Snowflake's server url, which is the
43	intended destination url to Snowflake.
44Explanation:
45	This emulates the behavior of IDP initiated login flow in the user
46	browser where the IDP instructs the browser to POST the SAML
47	assertion to the specific SP endpoint.  This is critical in
48	preventing a SAML assertion issued to one SP from being sent to
49	another SP.
50*/
51func authenticateBySAML(
52	ctx context.Context,
53	sr *snowflakeRestful,
54	oktaURL *url.URL,
55	application string,
56	account string,
57	user string,
58	password string,
59) (samlResponse []byte, err error) {
60	logger.WithContext(ctx).Info("step 1: query GS to obtain IDP token and SSO url")
61	headers := make(map[string]string)
62	headers[httpHeaderContentType] = headerContentTypeApplicationJSON
63	headers[httpHeaderAccept] = headerContentTypeApplicationJSON
64	headers[httpHeaderUserAgent] = userAgent
65
66	clientEnvironment := authRequestClientEnvironment{
67		Application: application,
68		Os:          operatingSystem,
69		OsVersion:   platform,
70	}
71	requestMain := authRequestData{
72		ClientAppID:       clientType,
73		ClientAppVersion:  SnowflakeGoDriverVersion,
74		AccountName:       account,
75		ClientEnvironment: clientEnvironment,
76		Authenticator:     oktaURL.String(),
77	}
78	authRequest := authRequest{
79		Data: requestMain,
80	}
81	params := &url.Values{}
82	jsonBody, err := json.Marshal(authRequest)
83	if err != nil {
84		return nil, err
85	}
86	logger.WithContext(ctx).Infof("PARAMS for Auth: %v, %v", params, sr)
87	respd, err := sr.FuncPostAuthSAML(ctx, sr, headers, jsonBody, sr.LoginTimeout)
88	if err != nil {
89		return nil, err
90	}
91	if !respd.Success {
92		logger.Errorln("Authentication FAILED")
93		sr.TokenAccessor.SetTokens("", "", -1)
94		code, err := strconv.Atoi(respd.Code)
95		if err != nil {
96			code = -1
97			return nil, err
98		}
99		return nil, &SnowflakeError{
100			Number:   code,
101			SQLState: SQLStateConnectionRejected,
102			Message:  respd.Message,
103		}
104	}
105	logger.WithContext(ctx).Info("step 2: validate Token and SSO URL has the same prefix as oktaURL")
106	var tokenURL *url.URL
107	var ssoURL *url.URL
108	if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil {
109		return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL)
110	}
111	if ssoURL, err = url.Parse(respd.Data.TokenURL); err != nil {
112		return nil, fmt.Errorf("failed to parse ssoURL URL. %v", respd.Data.SSOURL)
113	}
114	if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) {
115		return nil, &SnowflakeError{
116			Number:      ErrCodeIdpConnectionError,
117			SQLState:    SQLStateConnectionRejected,
118			Message:     errMsgIdpConnectionError,
119			MessageArgs: []interface{}{oktaURL, respd.Data.TokenURL, respd.Data.SSOURL},
120		}
121	}
122	logger.WithContext(ctx).Info("step 3: query IDP token url to authenticate and retrieve access token")
123	jsonBody, err = json.Marshal(authOKTARequest{
124		Username: user,
125		Password: password,
126	})
127	if err != nil {
128		return nil, err
129	}
130	respa, err := sr.FuncPostAuthOKTA(ctx, sr, headers, jsonBody, respd.Data.TokenURL, sr.LoginTimeout)
131	if err != nil {
132		return nil, err
133	}
134
135	logger.WithContext(ctx).Info("step 4: query IDP URL snowflake app to get SAML response")
136	params = &url.Values{}
137	params.Add("RelayState", "/some/deep/link")
138	params.Add("onetimetoken", respa.CookieToken)
139
140	headers = make(map[string]string)
141	headers[httpHeaderAccept] = "*/*"
142	bd, err := sr.FuncGetSSO(ctx, sr, params, headers, respd.Data.SSOURL, sr.LoginTimeout)
143	if err != nil {
144		return nil, err
145	}
146	logger.WithContext(ctx).Info("step 5: validate post_back_url matches Snowflake URL")
147	tgtURL, err := postBackURL(bd)
148	if err != nil {
149		return nil, err
150	}
151
152	fullURL := sr.getURL()
153	logger.WithContext(ctx).Infof("tgtURL: %v, origURL: %v", tgtURL, fullURL)
154	if !isPrefixEqual(tgtURL, fullURL) {
155		return nil, &SnowflakeError{
156			Number:      ErrCodeSSOURLNotMatch,
157			SQLState:    SQLStateConnectionRejected,
158			Message:     errMsgSSOURLNotMatch,
159			MessageArgs: []interface{}{tgtURL, fullURL},
160		}
161	}
162	return bd, nil
163}
164
165func postBackURL(htmlData []byte) (url *url.URL, err error) {
166	idx0 := bytes.Index(htmlData, []byte("<form"))
167	if idx0 < 0 {
168		return nil, fmt.Errorf("failed to find a form tag in HTML response: %v", htmlData)
169	}
170	idx := bytes.Index(htmlData[idx0:], []byte("action=\""))
171	if idx < 0 {
172		return nil, fmt.Errorf("failed to find action field in HTML response: %v", htmlData[idx0:])
173	}
174	idx += idx0
175	endIdx := bytes.Index(htmlData[idx+8:], []byte("\""))
176	if endIdx < 0 {
177		return nil, fmt.Errorf("failed to find the end of action field: %v", htmlData[idx+8:])
178	}
179	r := html.UnescapeString(string(htmlData[idx+8 : idx+8+endIdx]))
180	return url.Parse(r)
181}
182
183func isPrefixEqual(u1 *url.URL, u2 *url.URL) bool {
184	p1 := u1.Port()
185	if p1 == "" && u1.Scheme == "https" {
186		p1 = "443"
187	}
188	p2 := u1.Port()
189	if p2 == "" && u1.Scheme == "https" {
190		p2 = "443"
191	}
192	return u1.Hostname() == u2.Hostname() && p1 == p2 && u1.Scheme == u2.Scheme
193}
194
195// Makes a request to /session/authenticator-request to get SAML Information,
196// such as the IDP Url and Proof Key, depending on the authenticator
197func postAuthSAML(
198	ctx context.Context,
199	sr *snowflakeRestful,
200	headers map[string]string,
201	body []byte,
202	timeout time.Duration) (
203	data *authResponse, err error) {
204
205	params := &url.Values{}
206	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
207	fullURL := sr.getFullURL(authenticatorRequestPath, params)
208
209	logger.Infof("fullURL: %v", fullURL)
210	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true)
211	if err != nil {
212		return nil, err
213	}
214	defer resp.Body.Close()
215	if resp.StatusCode == http.StatusOK {
216		var respd authResponse
217		err = json.NewDecoder(resp.Body).Decode(&respd)
218		if err != nil {
219			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
220			return nil, err
221		}
222		return &respd, nil
223	}
224	switch resp.StatusCode {
225	case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
226		// service availability or connectivity issue. Most likely server side issue.
227		return nil, &SnowflakeError{
228			Number:      ErrCodeServiceUnavailable,
229			SQLState:    SQLStateConnectionWasNotEstablished,
230			Message:     errMsgServiceUnavailable,
231			MessageArgs: []interface{}{resp.StatusCode, fullURL},
232		}
233	case http.StatusUnauthorized, http.StatusForbidden:
234		// failed to connect to db. account name may be wrong
235		return nil, &SnowflakeError{
236			Number:      ErrCodeFailedToConnect,
237			SQLState:    SQLStateConnectionRejected,
238			Message:     errMsgFailedToConnect,
239			MessageArgs: []interface{}{resp.StatusCode, fullURL},
240		}
241	}
242	_, err = ioutil.ReadAll(resp.Body)
243	if err != nil {
244		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
245		return nil, err
246	}
247	return nil, &SnowflakeError{
248		Number:      ErrFailedToAuthSAML,
249		SQLState:    SQLStateConnectionRejected,
250		Message:     errMsgFailedToAuthSAML,
251		MessageArgs: []interface{}{resp.StatusCode, fullURL},
252	}
253}
254
255func postAuthOKTA(
256	ctx context.Context,
257	sr *snowflakeRestful,
258	headers map[string]string,
259	body []byte,
260	fullURL string,
261	timeout time.Duration) (
262	data *authOKTAResponse, err error) {
263	logger.Infof("fullURL: %v", fullURL)
264	targetURL, err := url.Parse(fullURL)
265	if err != nil {
266		return nil, err
267	}
268	resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false)
269	if err != nil {
270		return nil, err
271	}
272	defer resp.Body.Close()
273	if resp.StatusCode == http.StatusOK {
274		var respd authOKTAResponse
275		err = json.NewDecoder(resp.Body).Decode(&respd)
276		if err != nil {
277			logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
278			return nil, err
279		}
280		return &respd, nil
281	}
282	_, err = ioutil.ReadAll(resp.Body)
283	if err != nil {
284		logger.Errorf("failed to extract HTTP response body. err: %v", err)
285		return nil, err
286	}
287	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v", resp.StatusCode, fullURL)
288	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
289	return nil, &SnowflakeError{
290		Number:      ErrFailedToAuthOKTA,
291		SQLState:    SQLStateConnectionRejected,
292		Message:     errMsgFailedToAuthOKTA,
293		MessageArgs: []interface{}{resp.StatusCode, fullURL},
294	}
295}
296
297func getSSO(
298	ctx context.Context,
299	sr *snowflakeRestful,
300	params *url.Values,
301	headers map[string]string,
302	ssoURL string,
303	timeout time.Duration) (
304	bd []byte, err error) {
305	fullURL, err := url.Parse(ssoURL)
306	if err != nil {
307		return nil, err
308	}
309	fullURL.RawQuery = params.Encode()
310	logger.WithContext(ctx).Infof("fullURL: %v", fullURL)
311	resp, err := sr.FuncGet(ctx, sr, fullURL, headers, timeout)
312	if err != nil {
313		return nil, err
314	}
315	defer resp.Body.Close()
316	b, err := ioutil.ReadAll(resp.Body)
317	if err != nil {
318		logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
319		return nil, err
320	}
321	if resp.StatusCode == http.StatusOK {
322		return b, nil
323	}
324	logger.WithContext(ctx).Infof("HTTP: %v, URL: %v ", resp.StatusCode, fullURL)
325	logger.WithContext(ctx).Infof("Header: %v", resp.Header)
326	return nil, &SnowflakeError{
327		Number:      ErrFailedToGetSSO,
328		SQLState:    SQLStateConnectionRejected,
329		Message:     errMsgFailedToGetSSO,
330		MessageArgs: []interface{}{resp.StatusCode, fullURL},
331	}
332}
333