1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"crypto/rsa"
7	"encoding/base64"
8	"fmt"
9	"net"
10	"net/http"
11	"net/url"
12	"strconv"
13	"strings"
14	"time"
15)
16
17const (
18	defaultClientTimeout  = 900 * time.Second // Timeout for network round trip + read out http response
19	defaultLoginTimeout   = 60 * time.Second  // Timeout for retry for login EXCLUDING clientTimeout
20	defaultRequestTimeout = 0 * time.Second   // Timeout for retry for request EXCLUDING clientTimeout
21	defaultJWTTimeout     = 60 * time.Second
22	defaultDomain         = ".snowflakecomputing.com"
23)
24
25// ConfigBool is a type to represent true or false in the Config
26type ConfigBool uint8
27
28const (
29	configBoolNotSet ConfigBool = iota // Reserved for unset to let default value fall into this category
30	// ConfigBoolTrue represents true for the config field
31	ConfigBoolTrue
32	// ConfigBoolFalse represents false for the config field
33	ConfigBoolFalse
34)
35
36// Config is a set of configuration parameters
37type Config struct {
38	Account   string // Account name
39	User      string // Username
40	Password  string // Password (requires User)
41	Database  string // Database name
42	Schema    string // Schema
43	Warehouse string // Warehouse
44	Role      string // Role
45	Region    string // Region
46
47	// ValidateDefaultParameters disable the validation checks for Database, Schema, Warehouse and Role
48	// at the time a connection is established
49	ValidateDefaultParameters ConfigBool
50
51	Params map[string]*string // other connection parameters
52
53	ClientIP net.IP // IP address for network check
54	Protocol string // http or https (optional)
55	Host     string // hostname (optional)
56	Port     int    // port (optional)
57
58	Authenticator AuthType // The authenticator type
59
60	Passcode           string
61	PasscodeInPassword bool
62
63	OktaURL *url.URL
64
65	LoginTimeout     time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response
66	RequestTimeout   time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response
67	JWTExpireTimeout time.Duration // JWT expire after timeout
68	ClientTimeout    time.Duration // Timeout for network round trip + read out http response
69
70	Application  string           // application name.
71	InsecureMode bool             // driver doesn't check certificate revocation status
72	OCSPFailOpen OCSPFailOpenMode // OCSP Fail Open
73
74	Token            string        // Token to use for OAuth other forms of token based auth
75	TokenAccessor    TokenAccessor // Optional token accessor to use
76	KeepSessionAlive bool          // Enables the session to persist even after the connection is closed
77
78	PrivateKey *rsa.PrivateKey // Private key used to sign JWT
79
80	Transporter http.RoundTripper // RoundTripper to intercept HTTP requests and responses
81
82	DisableTelemetry bool // indicates whether to disable telemetry
83}
84
85// ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED
86func (c *Config) ocspMode() string {
87	if c.InsecureMode {
88		return ocspModeInsecure
89	} else if c.OCSPFailOpen == ocspFailOpenNotSet || c.OCSPFailOpen == OCSPFailOpenTrue {
90		// by default or set to true
91		return ocspModeFailOpen
92	}
93	return ocspModeFailClosed
94}
95
96// DSN constructs a DSN for Snowflake db.
97func DSN(cfg *Config) (dsn string, err error) {
98	hasHost := true
99	if cfg.Host == "" {
100		hasHost = false
101		if cfg.Region == "us-west-2" {
102			cfg.Region = ""
103		}
104		if cfg.Region == "" {
105			cfg.Host = cfg.Account + defaultDomain
106		} else {
107			cfg.Host = cfg.Account + "." + cfg.Region + defaultDomain
108		}
109	}
110	// in case account includes region
111	posDot := strings.Index(cfg.Account, ".")
112	if posDot > 0 {
113		if cfg.Region != "" {
114			return "", ErrInvalidRegion
115		}
116		cfg.Region = cfg.Account[posDot+1:]
117		cfg.Account = cfg.Account[:posDot]
118	}
119	err = fillMissingConfigParameters(cfg)
120	if err != nil {
121		return "", err
122	}
123	params := &url.Values{}
124	if hasHost && cfg.Account != "" {
125		// account may not be included in a Host string
126		params.Add("account", cfg.Account)
127	}
128	if cfg.Database != "" {
129		params.Add("database", cfg.Database)
130	}
131	if cfg.Schema != "" {
132		params.Add("schema", cfg.Schema)
133	}
134	if cfg.Warehouse != "" {
135		params.Add("warehouse", cfg.Warehouse)
136	}
137	if cfg.Role != "" {
138		params.Add("role", cfg.Role)
139	}
140	if cfg.Region != "" {
141		params.Add("region", cfg.Region)
142	}
143	if cfg.Authenticator != AuthTypeSnowflake {
144		if cfg.Authenticator == AuthTypeOkta {
145			params.Add("authenticator", strings.ToLower(cfg.OktaURL.String()))
146		} else {
147			params.Add("authenticator", strings.ToLower(cfg.Authenticator.String()))
148		}
149	}
150	if cfg.Passcode != "" {
151		params.Add("passcode", cfg.Passcode)
152	}
153	if cfg.PasscodeInPassword {
154		params.Add("passcodeInPassword", strconv.FormatBool(cfg.PasscodeInPassword))
155	}
156	if cfg.LoginTimeout != defaultLoginTimeout {
157		params.Add("loginTimeout", strconv.FormatInt(int64(cfg.LoginTimeout/time.Second), 10))
158	}
159	if cfg.RequestTimeout != defaultRequestTimeout {
160		params.Add("requestTimeout", strconv.FormatInt(int64(cfg.RequestTimeout/time.Second), 10))
161	}
162	if cfg.JWTExpireTimeout != defaultJWTTimeout {
163		params.Add("jwtTimeout", strconv.FormatInt(int64(cfg.JWTExpireTimeout/time.Second), 10))
164	}
165	if cfg.Application != clientType {
166		params.Add("application", cfg.Application)
167	}
168	if cfg.Protocol != "" && cfg.Protocol != "https" {
169		params.Add("protocol", cfg.Protocol)
170	}
171	if cfg.Token != "" {
172		params.Add("token", cfg.Token)
173	}
174	if cfg.Params != nil {
175		for k, v := range cfg.Params {
176			params.Add(k, *v)
177		}
178	}
179	if cfg.PrivateKey != nil {
180		privateKeyInBytes, err := marshalPKCS8PrivateKey(cfg.PrivateKey)
181		if err != nil {
182			return "", err
183		}
184		keyBase64 := base64.URLEncoding.EncodeToString(privateKeyInBytes)
185		params.Add("privateKey", keyBase64)
186	}
187	if cfg.InsecureMode {
188		params.Add("insecureMode", strconv.FormatBool(cfg.InsecureMode))
189	}
190
191	params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse))
192
193	params.Add("validateDefaultParameters", strconv.FormatBool(cfg.ValidateDefaultParameters != ConfigBoolFalse))
194
195	dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port)
196	if params.Encode() != "" {
197		dsn += "?" + params.Encode()
198	}
199	return
200}
201
202// ParseDSN parses the DSN string to a Config.
203func ParseDSN(dsn string) (cfg *Config, err error) {
204	// New config with some default values
205	cfg = &Config{
206		Params:        make(map[string]*string),
207		Authenticator: AuthTypeSnowflake, // Default to snowflake
208	}
209
210	// user[:password]@account/database/schema[?param1=value1&paramN=valueN]
211	// or
212	// user[:password]@account/database[?param1=value1&paramN=valueN]
213	// or
214	// user[:password]@host:port/database/schema?account=user_account[?param1=value1&paramN=valueN]
215	// or
216	// host:port/database/schema?account=user_account[?param1=value1&paramN=valueN]
217
218	foundSlash := false
219	secondSlash := false
220	done := false
221	var i int
222	posQuestion := len(dsn)
223	for i = len(dsn) - 1; i >= 0; i-- {
224		switch {
225		case dsn[i] == '/':
226			foundSlash = true
227
228			// left part is empty if i <= 0
229			var j int
230			posSecondSlash := i
231			if i > 0 {
232				for j = i - 1; j >= 0; j-- {
233					switch {
234					case dsn[j] == '/':
235						// second slash
236						secondSlash = true
237						posSecondSlash = j
238					case dsn[j] == '@':
239						// username[:password]@...
240						cfg.User, cfg.Password = parseUserPassword(j, dsn)
241					}
242					if dsn[j] == '@' {
243						break
244					}
245				}
246
247				// account or host:port
248				err = parseAccountHostPort(cfg, j, posSecondSlash, dsn)
249				if err != nil {
250					return nil, err
251				}
252			}
253			// [?param1=value1&...&paramN=valueN]
254			// Find the first '?' in dsn[i+1:]
255			err = parseParams(cfg, i, dsn)
256			if err != nil {
257				return
258			}
259			if secondSlash {
260				cfg.Database = dsn[posSecondSlash+1 : i]
261				cfg.Schema = dsn[i+1 : posQuestion]
262			} else {
263				cfg.Database = dsn[posSecondSlash+1 : posQuestion]
264			}
265			done = true
266		case dsn[i] == '?':
267			posQuestion = i
268		}
269		if done {
270			break
271		}
272	}
273	if !foundSlash {
274		// no db or schema is specified
275		var j int
276		for j = len(dsn) - 1; j >= 0; j-- {
277			switch {
278			case dsn[j] == '@':
279				cfg.User, cfg.Password = parseUserPassword(j, dsn)
280			case dsn[j] == '?':
281				posQuestion = j
282			}
283			if dsn[j] == '@' {
284				break
285			}
286		}
287		err = parseAccountHostPort(cfg, j, posQuestion, dsn)
288		if err != nil {
289			return nil, err
290		}
291		err = parseParams(cfg, posQuestion-1, dsn)
292		if err != nil {
293			return
294		}
295	}
296	if cfg.Account == "" && strings.HasSuffix(cfg.Host, defaultDomain) {
297		posDot := strings.Index(cfg.Host, ".")
298		if posDot > 0 {
299			cfg.Account = cfg.Host[:posDot]
300		}
301	}
302	posDot := strings.Index(cfg.Account, ".")
303	if posDot >= 0 {
304		cfg.Account = cfg.Account[:posDot]
305	}
306
307	err = fillMissingConfigParameters(cfg)
308	if err != nil {
309		return nil, err
310	}
311
312	// unescape parameters
313	var s string
314	s, err = url.QueryUnescape(cfg.User)
315	if err != nil {
316		return nil, err
317	}
318	cfg.User = s
319	s, err = url.QueryUnescape(cfg.Password)
320	if err != nil {
321		return nil, err
322	}
323	cfg.Password = s
324	s, err = url.QueryUnescape(cfg.Database)
325	if err != nil {
326		return nil, err
327	}
328	cfg.Database = s
329	s, err = url.QueryUnescape(cfg.Schema)
330	if err != nil {
331		return nil, err
332	}
333	cfg.Schema = s
334	s, err = url.QueryUnescape(cfg.Role)
335	if err != nil {
336		return nil, err
337	}
338	cfg.Role = s
339	s, err = url.QueryUnescape(cfg.Warehouse)
340	if err != nil {
341		return nil, err
342	}
343	cfg.Warehouse = s
344	return cfg, nil
345}
346
347func fillMissingConfigParameters(cfg *Config) error {
348	posDash := strings.LastIndex(cfg.Account, "-")
349	if posDash > 0 {
350		if strings.Contains(cfg.Host, ".global.") {
351			cfg.Account = cfg.Account[:posDash]
352		}
353	}
354	if strings.Trim(cfg.Account, " ") == "" {
355		return ErrEmptyAccount
356	}
357
358	if cfg.Authenticator != AuthTypeOAuth && strings.Trim(cfg.User, " ") == "" {
359		// oauth does not require a username
360		return ErrEmptyUsername
361	}
362
363	if cfg.Authenticator != AuthTypeExternalBrowser &&
364		cfg.Authenticator != AuthTypeOAuth &&
365		cfg.Authenticator != AuthTypeJwt &&
366		strings.Trim(cfg.Password, " ") == "" {
367		// no password parameter is required for EXTERNALBROWSER, OAUTH or JWT.
368		return ErrEmptyPassword
369	}
370	if strings.Trim(cfg.Protocol, " ") == "" {
371		cfg.Protocol = "https"
372	}
373	if cfg.Port == 0 {
374		cfg.Port = 443
375	}
376
377	cfg.Region = strings.Trim(cfg.Region, " ")
378	if cfg.Region != "" {
379		// region is specified but not included in Host
380		i := strings.Index(cfg.Host, defaultDomain)
381		if i >= 1 {
382			hostPrefix := cfg.Host[0:i]
383			if !strings.HasSuffix(hostPrefix, cfg.Region) {
384				cfg.Host = hostPrefix + "." + cfg.Region + defaultDomain
385			}
386		}
387	}
388	if cfg.Host == "" {
389		if cfg.Region != "" {
390			cfg.Host = cfg.Account + "." + cfg.Region + defaultDomain
391		} else {
392			cfg.Host = cfg.Account + defaultDomain
393		}
394	}
395	if cfg.LoginTimeout == 0 {
396		cfg.LoginTimeout = defaultLoginTimeout
397	}
398	if cfg.RequestTimeout == 0 {
399		cfg.RequestTimeout = defaultRequestTimeout
400	}
401	if cfg.JWTExpireTimeout == 0 {
402		cfg.JWTExpireTimeout = defaultJWTTimeout
403	}
404	if cfg.ClientTimeout == 0 {
405		cfg.ClientTimeout = defaultClientTimeout
406	}
407	if strings.Trim(cfg.Application, " ") == "" {
408		cfg.Application = clientType
409	}
410
411	if cfg.OCSPFailOpen == ocspFailOpenNotSet {
412		cfg.OCSPFailOpen = OCSPFailOpenTrue
413	}
414
415	if cfg.ValidateDefaultParameters == configBoolNotSet {
416		cfg.ValidateDefaultParameters = ConfigBoolTrue
417	}
418
419	if strings.HasSuffix(cfg.Host, defaultDomain) && len(cfg.Host) == len(defaultDomain) {
420		return &SnowflakeError{
421			Number:      ErrCodeFailedToParseHost,
422			Message:     errMsgFailedToParseHost,
423			MessageArgs: []interface{}{cfg.Host},
424		}
425	}
426	return nil
427}
428
429// transformAccountToHost transforms host to account name
430func transformAccountToHost(cfg *Config) (err error) {
431	if cfg.Port == 0 && !strings.HasSuffix(cfg.Host, defaultDomain) && cfg.Host != "" {
432		// account name is specified instead of host:port
433		cfg.Account = cfg.Host
434		cfg.Host = cfg.Account + defaultDomain
435		cfg.Port = 443
436		posDot := strings.Index(cfg.Account, ".")
437		if posDot > 0 {
438			cfg.Region = cfg.Account[posDot+1:]
439			cfg.Account = cfg.Account[:posDot]
440		}
441	}
442	return nil
443}
444
445// parseAccountHostPort parses the DSN string to attempt to get account or host and port.
446func parseAccountHostPort(cfg *Config, posAt, posSlash int, dsn string) (err error) {
447	// account or host:port
448	var k int
449	for k = posAt + 1; k < posSlash; k++ {
450		if dsn[k] == ':' {
451			cfg.Port, err = strconv.Atoi(dsn[k+1 : posSlash])
452			if err != nil {
453				err = &SnowflakeError{
454					Number:      ErrCodeFailedToParsePort,
455					Message:     errMsgFailedToParsePort,
456					MessageArgs: []interface{}{dsn[k+1 : posSlash]},
457				}
458				return
459			}
460			break
461		}
462	}
463	cfg.Host = dsn[posAt+1 : k]
464	return transformAccountToHost(cfg)
465}
466
467// parseUserPassword parses the DSN string for username and password
468func parseUserPassword(posAt int, dsn string) (user, password string) {
469	var k int
470	for k = 0; k < posAt; k++ {
471		if dsn[k] == ':' {
472			password = dsn[k+1 : posAt]
473			break
474		}
475	}
476	user = dsn[:k]
477	return
478}
479
480// parseParams parse parameters
481func parseParams(cfg *Config, posQuestion int, dsn string) (err error) {
482	for j := posQuestion + 1; j < len(dsn); j++ {
483		if dsn[j] == '?' {
484			if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
485				return
486			}
487			break
488		}
489	}
490	return
491}
492
493// parseDSNParams parses the DSN "query string". Values must be url.QueryEscape'ed
494func parseDSNParams(cfg *Config, params string) (err error) {
495	logger.Infof("Query String: %v\n", params)
496	for _, v := range strings.Split(params, "&") {
497		param := strings.SplitN(v, "=", 2)
498		if len(param) != 2 {
499			continue
500		}
501		var value string
502		value, err = url.QueryUnescape(param[1])
503		if err != nil {
504			return err
505		}
506		switch param[0] {
507		// Disable INFILE whitelist / enable all files
508		case "account":
509			cfg.Account = value
510		case "warehouse":
511			cfg.Warehouse = value
512		case "database":
513			cfg.Database = value
514		case "schema":
515			cfg.Schema = value
516		case "role":
517			cfg.Role = value
518		case "region":
519			cfg.Region = value
520		case "protocol":
521			cfg.Protocol = value
522		case "passcode":
523			cfg.Passcode = value
524		case "passcodeInPassword":
525			var vv bool
526			vv, err = strconv.ParseBool(value)
527			if err != nil {
528				return
529			}
530			cfg.PasscodeInPassword = vv
531		case "loginTimeout":
532			cfg.LoginTimeout, err = parseTimeout(value)
533			if err != nil {
534				return
535			}
536		case "requestTimeout":
537			cfg.RequestTimeout, err = parseTimeout(value)
538			if err != nil {
539				return
540			}
541		case "jwtTimeout":
542			cfg.JWTExpireTimeout, err = parseTimeout(value)
543			if err != nil {
544				return err
545			}
546		case "application":
547			cfg.Application = value
548		case "authenticator":
549			err := determineAuthenticatorType(cfg, value)
550			if err != nil {
551				return err
552			}
553		case "insecureMode":
554			var vv bool
555			vv, err = strconv.ParseBool(value)
556			if err != nil {
557				return
558			}
559			cfg.InsecureMode = vv
560		case "ocspFailOpen":
561			var vv bool
562			vv, err = strconv.ParseBool(value)
563			if err != nil {
564				return
565			}
566			if vv {
567				cfg.OCSPFailOpen = OCSPFailOpenTrue
568			} else {
569				cfg.OCSPFailOpen = OCSPFailOpenFalse
570			}
571
572		case "token":
573			cfg.Token = value
574		case "privateKey":
575			var decodeErr error
576			block, decodeErr := base64.URLEncoding.DecodeString(value)
577			if decodeErr != nil {
578				err = &SnowflakeError{
579					Number:  ErrCodePrivateKeyParseError,
580					Message: "Base64 decode failed",
581				}
582				return
583			}
584			cfg.PrivateKey, err = parsePKCS8PrivateKey(block)
585			if err != nil {
586				return err
587			}
588		case "validateDefaultParameters":
589			var vv bool
590			vv, err = strconv.ParseBool(value)
591			if err != nil {
592				return
593			}
594			if vv {
595				cfg.ValidateDefaultParameters = ConfigBoolTrue
596			} else {
597				cfg.ValidateDefaultParameters = ConfigBoolFalse
598			}
599		default:
600			if cfg.Params == nil {
601				cfg.Params = make(map[string]*string)
602			}
603			cfg.Params[param[0]] = &value
604		}
605	}
606	return
607}
608
609func parseTimeout(value string) (time.Duration, error) {
610	var vv int64
611	var err error
612	vv, err = strconv.ParseInt(value, 10, 64)
613	if err != nil {
614		return time.Duration(0), err
615	}
616	return time.Duration(vv * int64(time.Second)), nil
617}
618