1// Copyright (c) 2017-2019 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"context"
7	"crypto/sha256"
8	"crypto/x509"
9	"encoding/base64"
10	"encoding/json"
11	"fmt"
12	"io/ioutil"
13	"net/http"
14	"net/url"
15	"runtime"
16	"strconv"
17	"strings"
18	"time"
19
20	"github.com/form3tech-oss/jwt-go"
21	"github.com/google/uuid"
22)
23
24const (
25	clientType = "Go"
26)
27
28// AuthType indicates the type of authentication in Snowflake
29type AuthType int
30
31const (
32	// AuthTypeSnowflake is the general username password authentication
33	AuthTypeSnowflake AuthType = iota
34	// AuthTypeOAuth is the OAuth authentication
35	AuthTypeOAuth
36	// AuthTypeExternalBrowser is to use a browser to access an Fed and perform SSO authentication
37	AuthTypeExternalBrowser
38	// AuthTypeOkta is to use a native okta URL to perform SSO authentication on Okta
39	AuthTypeOkta
40	// AuthTypeJwt is to use Jwt to perform authentication
41	AuthTypeJwt
42	// AuthTypeTokenAccessor is to use the provided token accessor and bypass authentication
43	AuthTypeTokenAccessor
44)
45
46func determineAuthenticatorType(cfg *Config, value string) error {
47	upperCaseValue := strings.ToUpper(value)
48	lowerCaseValue := strings.ToLower(value)
49	if strings.Trim(value, " ") == "" || upperCaseValue == AuthTypeSnowflake.String() {
50		cfg.Authenticator = AuthTypeSnowflake
51		return nil
52	} else if upperCaseValue == AuthTypeOAuth.String() {
53		cfg.Authenticator = AuthTypeOAuth
54		return nil
55	} else if upperCaseValue == AuthTypeJwt.String() {
56		cfg.Authenticator = AuthTypeJwt
57		return nil
58	} else if upperCaseValue == AuthTypeExternalBrowser.String() {
59		cfg.Authenticator = AuthTypeExternalBrowser
60		return nil
61	} else {
62		// possibly Okta case
63		oktaURLString, err := url.QueryUnescape(lowerCaseValue)
64		if err != nil {
65			return &SnowflakeError{
66				Number:      ErrCodeFailedToParseAuthenticator,
67				Message:     errMsgFailedToParseAuthenticator,
68				MessageArgs: []interface{}{lowerCaseValue},
69			}
70		}
71
72		oktaURL, err := url.Parse(oktaURLString)
73		if err != nil {
74			return &SnowflakeError{
75				Number:      ErrCodeFailedToParseAuthenticator,
76				Message:     errMsgFailedToParseAuthenticator,
77				MessageArgs: []interface{}{oktaURLString},
78			}
79		}
80
81		if oktaURL.Scheme != "https" || !strings.HasSuffix(oktaURL.Host, "okta.com") {
82			return &SnowflakeError{
83				Number:      ErrCodeFailedToParseAuthenticator,
84				Message:     errMsgFailedToParseAuthenticator,
85				MessageArgs: []interface{}{oktaURLString},
86			}
87		}
88		cfg.OktaURL = oktaURL
89		cfg.Authenticator = AuthTypeOkta
90	}
91	return nil
92}
93
94func (authType AuthType) String() string {
95	switch authType {
96	case AuthTypeSnowflake:
97		return "SNOWFLAKE"
98	case AuthTypeOAuth:
99		return "OAUTH"
100	case AuthTypeExternalBrowser:
101		return "EXTERNALBROWSER"
102	case AuthTypeOkta:
103		return "OKTA"
104	case AuthTypeJwt:
105		return "SNOWFLAKE_JWT"
106	case AuthTypeTokenAccessor:
107		return "TOKENACCESSOR"
108	default:
109		return "UNKNOWN"
110	}
111}
112
113// platform consists of compiler and architecture type in string
114var platform = fmt.Sprintf("%v-%v", runtime.Compiler, runtime.GOARCH)
115
116// operatingSystem is the runtime operating system.
117var operatingSystem = runtime.GOOS
118
119// userAgent shows up in User-Agent HTTP header
120var userAgent = fmt.Sprintf("%v/%v (%v-%v) %v/%v",
121	clientType,
122	SnowflakeGoDriverVersion,
123	operatingSystem,
124	runtime.GOARCH,
125	runtime.Compiler,
126	runtime.Version())
127
128type authRequestClientEnvironment struct {
129	Application string `json:"APPLICATION"`
130	Os          string `json:"OS"`
131	OsVersion   string `json:"OS_VERSION"`
132	OCSPMode    string `json:"OCSP_MODE"`
133}
134type authRequestData struct {
135	ClientAppID             string                       `json:"CLIENT_APP_ID"`
136	ClientAppVersion        string                       `json:"CLIENT_APP_VERSION"`
137	SvnRevision             string                       `json:"SVN_REVISION"`
138	AccountName             string                       `json:"ACCOUNT_NAME"`
139	LoginName               string                       `json:"LOGIN_NAME,omitempty"`
140	Password                string                       `json:"PASSWORD,omitempty"`
141	RawSAMLResponse         string                       `json:"RAW_SAML_RESPONSE,omitempty"`
142	ExtAuthnDuoMethod       string                       `json:"EXT_AUTHN_DUO_METHOD,omitempty"`
143	Passcode                string                       `json:"PASSCODE,omitempty"`
144	Authenticator           string                       `json:"AUTHENTICATOR,omitempty"`
145	SessionParameters       map[string]interface{}       `json:"SESSION_PARAMETERS,omitempty"`
146	ClientEnvironment       authRequestClientEnvironment `json:"CLIENT_ENVIRONMENT"`
147	BrowserModeRedirectPort string                       `json:"BROWSER_MODE_REDIRECT_PORT,omitempty"`
148	ProofKey                string                       `json:"PROOF_KEY,omitempty"`
149	Token                   string                       `json:"TOKEN,omitempty"`
150}
151type authRequest struct {
152	Data authRequestData `json:"data"`
153}
154
155type nameValueParameter struct {
156	Name  string      `json:"name"`
157	Value interface{} `json:"value"`
158}
159
160type authResponseSessionInfo struct {
161	DatabaseName  string `json:"databaseName"`
162	SchemaName    string `json:"schemaName"`
163	WarehouseName string `json:"warehouseName"`
164	RoleName      string `json:"roleName"`
165}
166
167type authResponseMain struct {
168	Token               string                  `json:"token,omitempty"`
169	Validity            time.Duration           `json:"validityInSeconds,omitempty"`
170	MasterToken         string                  `json:"masterToken,omitempty"`
171	MasterValidity      time.Duration           `json:"masterValidityInSeconds"`
172	DisplayUserName     string                  `json:"displayUserName"`
173	ServerVersion       string                  `json:"serverVersion"`
174	FirstLogin          bool                    `json:"firstLogin"`
175	RemMeToken          string                  `json:"remMeToken"`
176	RemMeValidity       time.Duration           `json:"remMeValidityInSeconds"`
177	HealthCheckInterval time.Duration           `json:"healthCheckInterval"`
178	NewClientForUpgrade string                  `json:"newClientForUpgrade"`
179	SessionID           int64                   `json:"sessionId"`
180	Parameters          []nameValueParameter    `json:"parameters"`
181	SessionInfo         authResponseSessionInfo `json:"sessionInfo"`
182	TokenURL            string                  `json:"tokenUrl,omitempty"`
183	SSOURL              string                  `json:"ssoUrl,omitempty"`
184	ProofKey            string                  `json:"proofKey,omitempty"`
185}
186type authResponse struct {
187	Data    authResponseMain `json:"data"`
188	Message string           `json:"message"`
189	Code    string           `json:"code"`
190	Success bool             `json:"success"`
191}
192
193func postAuth(
194	ctx context.Context,
195	sr *snowflakeRestful,
196	params *url.Values,
197	headers map[string]string,
198	body []byte,
199	timeout time.Duration) (
200	data *authResponse, err error) {
201	params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
202	params.Add(requestGUIDKey, uuid.New().String())
203
204	fullURL := sr.getFullURL(loginRequestPath, params)
205	logger.Infof("full URL: %v", fullURL)
206	resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true)
207	if err != nil {
208		return nil, err
209	}
210	defer resp.Body.Close()
211	if resp.StatusCode == http.StatusOK {
212		var respd authResponse
213		err = json.NewDecoder(resp.Body).Decode(&respd)
214		if err != nil {
215			logger.Error("failed to decode JSON. err: %v", err)
216			return nil, err
217		}
218		return &respd, nil
219	}
220	switch resp.StatusCode {
221	case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
222		// service availability or connectivity issue. Most likely server side issue.
223		return nil, &SnowflakeError{
224			Number:      ErrCodeServiceUnavailable,
225			SQLState:    SQLStateConnectionWasNotEstablished,
226			Message:     errMsgServiceUnavailable,
227			MessageArgs: []interface{}{resp.StatusCode, fullURL},
228		}
229	case http.StatusUnauthorized, http.StatusForbidden:
230		// failed to connect to db. account name may be wrong
231		return nil, &SnowflakeError{
232			Number:      ErrCodeFailedToConnect,
233			SQLState:    SQLStateConnectionRejected,
234			Message:     errMsgFailedToConnect,
235			MessageArgs: []interface{}{resp.StatusCode, fullURL},
236		}
237	}
238	b, err := ioutil.ReadAll(resp.Body)
239	if err != nil {
240		logger.Errorf("failed to extract HTTP response body. err: %v", err)
241		return nil, err
242	}
243	logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
244	logger.Infof("Header: %v", resp.Header)
245	return nil, &SnowflakeError{
246		Number:      ErrFailedToAuth,
247		SQLState:    SQLStateConnectionRejected,
248		Message:     errMsgFailedToAuth,
249		MessageArgs: []interface{}{resp.StatusCode, fullURL},
250	}
251}
252
253// Generates a map of headers needed to authenticate
254// with Snowflake.
255func getHeaders() map[string]string {
256	headers := make(map[string]string)
257	headers[httpHeaderContentType] = headerContentTypeApplicationJSON
258	headers[httpHeaderAccept] = headerAcceptTypeApplicationSnowflake
259	headers[httpHeaderUserAgent] = userAgent
260	return headers
261}
262
263// Used to authenticate the user with Snowflake.
264func authenticate(
265	ctx context.Context,
266	sc *snowflakeConn,
267	samlResponse []byte,
268	proofKey []byte,
269) (resp *authResponseMain, err error) {
270	headers := getHeaders()
271	clientEnvironment := authRequestClientEnvironment{
272		Application: sc.cfg.Application,
273		Os:          operatingSystem,
274		OsVersion:   platform,
275		OCSPMode:    sc.cfg.ocspMode(),
276	}
277
278	sessionParameters := make(map[string]interface{})
279	for k, v := range sc.cfg.Params {
280		// upper casing to normalize keys
281		sessionParameters[strings.ToUpper(k)] = *v
282	}
283
284	sessionParameters[sessionClientValidateDefaultParameters] = sc.cfg.ValidateDefaultParameters != ConfigBoolFalse
285
286	requestMain := authRequestData{
287		ClientAppID:       clientType,
288		ClientAppVersion:  SnowflakeGoDriverVersion,
289		AccountName:       sc.cfg.Account,
290		SessionParameters: sessionParameters,
291		ClientEnvironment: clientEnvironment,
292	}
293
294	switch sc.cfg.Authenticator {
295	case AuthTypeExternalBrowser:
296		requestMain.ProofKey = string(proofKey)
297		requestMain.Token = string(samlResponse)
298		requestMain.LoginName = sc.cfg.User
299		requestMain.Authenticator = AuthTypeExternalBrowser.String()
300	case AuthTypeOAuth:
301		requestMain.LoginName = sc.cfg.User
302		requestMain.Authenticator = AuthTypeOAuth.String()
303		requestMain.Token = sc.cfg.Token
304	case AuthTypeOkta:
305		requestMain.RawSAMLResponse = string(samlResponse)
306	case AuthTypeJwt:
307		requestMain.Authenticator = AuthTypeJwt.String()
308
309		jwtTokenString, err := prepareJWTToken(sc.cfg)
310		if err != nil {
311			return nil, err
312		}
313		requestMain.Token = jwtTokenString
314	case AuthTypeSnowflake:
315		logger.Info("Username and password")
316		requestMain.LoginName = sc.cfg.User
317		requestMain.Password = sc.cfg.Password
318		switch {
319		case sc.cfg.PasscodeInPassword:
320			requestMain.ExtAuthnDuoMethod = "passcode"
321		case sc.cfg.Passcode != "":
322			requestMain.Passcode = sc.cfg.Passcode
323			requestMain.ExtAuthnDuoMethod = "passcode"
324		}
325	case AuthTypeTokenAccessor:
326		logger.Info("Bypass authentication using existing token from token accessor")
327		sessionInfo := authResponseSessionInfo{
328			DatabaseName:  sc.cfg.Database,
329			SchemaName:    sc.cfg.Schema,
330			WarehouseName: sc.cfg.Warehouse,
331			RoleName:      sc.cfg.Role,
332		}
333		token, masterToken, sessionID := sc.cfg.TokenAccessor.GetTokens()
334		return &authResponseMain{
335			Token:       token,
336			MasterToken: masterToken,
337			SessionID:   sessionID,
338			SessionInfo: sessionInfo,
339		}, nil
340	}
341
342	authRequest := authRequest{
343		Data: requestMain,
344	}
345	params := &url.Values{}
346	if sc.cfg.Database != "" {
347		params.Add("databaseName", sc.cfg.Database)
348	}
349	if sc.cfg.Schema != "" {
350		params.Add("schemaName", sc.cfg.Schema)
351	}
352	if sc.cfg.Warehouse != "" {
353		params.Add("warehouse", sc.cfg.Warehouse)
354	}
355	if sc.cfg.Role != "" {
356		params.Add("roleName", sc.cfg.Role)
357	}
358
359	jsonBody, err := json.Marshal(authRequest)
360	if err != nil {
361		return
362	}
363
364	logger.WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v",
365		params, sc.rest.Protocol, sc.rest.Host, sc.rest.Port, sc.rest.LoginTimeout, sc.cfg.Authenticator.String())
366
367	respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, params, headers, jsonBody, sc.rest.LoginTimeout)
368	if err != nil {
369		return nil, err
370	}
371	if !respd.Success {
372		logger.Errorln("Authentication FAILED")
373		sc.rest.TokenAccessor.SetTokens("", "", -1)
374		code, err := strconv.Atoi(respd.Code)
375		if err != nil {
376			code = -1
377			return nil, err
378		}
379		return nil, &SnowflakeError{
380			Number:   code,
381			SQLState: SQLStateConnectionRejected,
382			Message:  respd.Message,
383		}
384	}
385	logger.Info("Authentication SUCCESS")
386	sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
387	return &respd.Data, nil
388}
389
390// Generate a JWT token in string given the configuration
391func prepareJWTToken(config *Config) (string, error) {
392	pubBytes, err := x509.MarshalPKIXPublicKey(config.PrivateKey.Public())
393	if err != nil {
394		return "", err
395	}
396	hash := sha256.Sum256(pubBytes)
397
398	accountName := strings.ToUpper(config.Account)
399	userName := strings.ToUpper(config.User)
400
401	issueAtTime := time.Now().UTC()
402	token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
403		"iss": fmt.Sprintf("%s.%s.%s", accountName, userName, "SHA256:"+base64.StdEncoding.EncodeToString(hash[:])),
404		"sub": fmt.Sprintf("%s.%s", accountName, userName),
405		"iat": issueAtTime.Unix(),
406		"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
407		"exp": issueAtTime.Add(config.JWTExpireTimeout).Unix(),
408	})
409
410	tokenString, err := token.SignedString(config.PrivateKey)
411
412	if err != nil {
413		return "", err
414	}
415
416	return tokenString, err
417}
418
419// Authenticate with sc.cfg
420func authenticateWithConfig(sc *snowflakeConn) error {
421	var authData *authResponseMain
422	var samlResponse []byte
423	var proofKey []byte
424	var err error
425	logger.Infof("Authenticating via %v", sc.cfg.Authenticator.String())
426	switch sc.cfg.Authenticator {
427	case AuthTypeExternalBrowser:
428		samlResponse, proofKey, err = authenticateByExternalBrowser(
429			sc.ctx,
430			sc.rest,
431			sc.cfg.Authenticator.String(),
432			sc.cfg.Application,
433			sc.cfg.Account,
434			sc.cfg.User,
435			sc.cfg.Password)
436		if err != nil {
437			sc.cleanup()
438			return err
439		}
440	case AuthTypeOkta:
441		samlResponse, err = authenticateBySAML(
442			sc.ctx,
443			sc.rest,
444			sc.cfg.OktaURL,
445			sc.cfg.Application,
446			sc.cfg.Account,
447			sc.cfg.User,
448			sc.cfg.Password)
449		if err != nil {
450			sc.cleanup()
451			return err
452		}
453	}
454	authData, err = authenticate(
455		sc.ctx,
456		sc,
457		samlResponse,
458		proofKey)
459	if err != nil {
460		sc.cleanup()
461		return err
462	}
463	sc.populateSessionParameters(authData.Parameters)
464	sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
465	return nil
466}
467