1// Copyright (c) 2017-2021 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "crypto/rsa" 7 "encoding/base64" 8 "fmt" 9 "net" 10 "net/http" 11 "net/url" 12 "strconv" 13 "strings" 14 "time" 15) 16 17const ( 18 defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response 19 defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout 20 defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout 21 defaultJWTTimeout = 60 * time.Second 22 defaultDomain = ".snowflakecomputing.com" 23) 24 25// ConfigBool is a type to represent true or false in the Config 26type ConfigBool uint8 27 28const ( 29 configBoolNotSet ConfigBool = iota // Reserved for unset to let default value fall into this category 30 // ConfigBoolTrue represents true for the config field 31 ConfigBoolTrue 32 // ConfigBoolFalse represents false for the config field 33 ConfigBoolFalse 34) 35 36// Config is a set of configuration parameters 37type Config struct { 38 Account string // Account name 39 User string // Username 40 Password string // Password (requires User) 41 Database string // Database name 42 Schema string // Schema 43 Warehouse string // Warehouse 44 Role string // Role 45 Region string // Region 46 47 // ValidateDefaultParameters disable the validation checks for Database, Schema, Warehouse and Role 48 // at the time a connection is established 49 ValidateDefaultParameters ConfigBool 50 51 Params map[string]*string // other connection parameters 52 53 ClientIP net.IP // IP address for network check 54 Protocol string // http or https (optional) 55 Host string // hostname (optional) 56 Port int // port (optional) 57 58 Authenticator AuthType // The authenticator type 59 60 Passcode string 61 PasscodeInPassword bool 62 63 OktaURL *url.URL 64 65 LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response 66 RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response 67 JWTExpireTimeout time.Duration // JWT expire after timeout 68 ClientTimeout time.Duration // Timeout for network round trip + read out http response 69 70 Application string // application name. 71 InsecureMode bool // driver doesn't check certificate revocation status 72 OCSPFailOpen OCSPFailOpenMode // OCSP Fail Open 73 74 Token string // Token to use for OAuth other forms of token based auth 75 TokenAccessor TokenAccessor // Optional token accessor to use 76 KeepSessionAlive bool // Enables the session to persist even after the connection is closed 77 78 PrivateKey *rsa.PrivateKey // Private key used to sign JWT 79 80 Transporter http.RoundTripper // RoundTripper to intercept HTTP requests and responses 81 82 DisableTelemetry bool // indicates whether to disable telemetry 83} 84 85// ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED 86func (c *Config) ocspMode() string { 87 if c.InsecureMode { 88 return ocspModeInsecure 89 } else if c.OCSPFailOpen == ocspFailOpenNotSet || c.OCSPFailOpen == OCSPFailOpenTrue { 90 // by default or set to true 91 return ocspModeFailOpen 92 } 93 return ocspModeFailClosed 94} 95 96// DSN constructs a DSN for Snowflake db. 97func DSN(cfg *Config) (dsn string, err error) { 98 hasHost := true 99 if cfg.Host == "" { 100 hasHost = false 101 if cfg.Region == "us-west-2" { 102 cfg.Region = "" 103 } 104 if cfg.Region == "" { 105 cfg.Host = cfg.Account + defaultDomain 106 } else { 107 cfg.Host = cfg.Account + "." + cfg.Region + defaultDomain 108 } 109 } 110 // in case account includes region 111 posDot := strings.Index(cfg.Account, ".") 112 if posDot > 0 { 113 if cfg.Region != "" { 114 return "", ErrInvalidRegion 115 } 116 cfg.Region = cfg.Account[posDot+1:] 117 cfg.Account = cfg.Account[:posDot] 118 } 119 err = fillMissingConfigParameters(cfg) 120 if err != nil { 121 return "", err 122 } 123 params := &url.Values{} 124 if hasHost && cfg.Account != "" { 125 // account may not be included in a Host string 126 params.Add("account", cfg.Account) 127 } 128 if cfg.Database != "" { 129 params.Add("database", cfg.Database) 130 } 131 if cfg.Schema != "" { 132 params.Add("schema", cfg.Schema) 133 } 134 if cfg.Warehouse != "" { 135 params.Add("warehouse", cfg.Warehouse) 136 } 137 if cfg.Role != "" { 138 params.Add("role", cfg.Role) 139 } 140 if cfg.Region != "" { 141 params.Add("region", cfg.Region) 142 } 143 if cfg.Authenticator != AuthTypeSnowflake { 144 if cfg.Authenticator == AuthTypeOkta { 145 params.Add("authenticator", strings.ToLower(cfg.OktaURL.String())) 146 } else { 147 params.Add("authenticator", strings.ToLower(cfg.Authenticator.String())) 148 } 149 } 150 if cfg.Passcode != "" { 151 params.Add("passcode", cfg.Passcode) 152 } 153 if cfg.PasscodeInPassword { 154 params.Add("passcodeInPassword", strconv.FormatBool(cfg.PasscodeInPassword)) 155 } 156 if cfg.LoginTimeout != defaultLoginTimeout { 157 params.Add("loginTimeout", strconv.FormatInt(int64(cfg.LoginTimeout/time.Second), 10)) 158 } 159 if cfg.RequestTimeout != defaultRequestTimeout { 160 params.Add("requestTimeout", strconv.FormatInt(int64(cfg.RequestTimeout/time.Second), 10)) 161 } 162 if cfg.JWTExpireTimeout != defaultJWTTimeout { 163 params.Add("jwtTimeout", strconv.FormatInt(int64(cfg.JWTExpireTimeout/time.Second), 10)) 164 } 165 if cfg.Application != clientType { 166 params.Add("application", cfg.Application) 167 } 168 if cfg.Protocol != "" && cfg.Protocol != "https" { 169 params.Add("protocol", cfg.Protocol) 170 } 171 if cfg.Token != "" { 172 params.Add("token", cfg.Token) 173 } 174 if cfg.Params != nil { 175 for k, v := range cfg.Params { 176 params.Add(k, *v) 177 } 178 } 179 if cfg.PrivateKey != nil { 180 privateKeyInBytes, err := marshalPKCS8PrivateKey(cfg.PrivateKey) 181 if err != nil { 182 return "", err 183 } 184 keyBase64 := base64.URLEncoding.EncodeToString(privateKeyInBytes) 185 params.Add("privateKey", keyBase64) 186 } 187 if cfg.InsecureMode { 188 params.Add("insecureMode", strconv.FormatBool(cfg.InsecureMode)) 189 } 190 191 params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse)) 192 193 params.Add("validateDefaultParameters", strconv.FormatBool(cfg.ValidateDefaultParameters != ConfigBoolFalse)) 194 195 dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port) 196 if params.Encode() != "" { 197 dsn += "?" + params.Encode() 198 } 199 return 200} 201 202// ParseDSN parses the DSN string to a Config. 203func ParseDSN(dsn string) (cfg *Config, err error) { 204 // New config with some default values 205 cfg = &Config{ 206 Params: make(map[string]*string), 207 Authenticator: AuthTypeSnowflake, // Default to snowflake 208 } 209 210 // user[:password]@account/database/schema[?param1=value1¶mN=valueN] 211 // or 212 // user[:password]@account/database[?param1=value1¶mN=valueN] 213 // or 214 // user[:password]@host:port/database/schema?account=user_account[?param1=value1¶mN=valueN] 215 // or 216 // host:port/database/schema?account=user_account[?param1=value1¶mN=valueN] 217 218 foundSlash := false 219 secondSlash := false 220 done := false 221 var i int 222 posQuestion := len(dsn) 223 for i = len(dsn) - 1; i >= 0; i-- { 224 switch { 225 case dsn[i] == '/': 226 foundSlash = true 227 228 // left part is empty if i <= 0 229 var j int 230 posSecondSlash := i 231 if i > 0 { 232 for j = i - 1; j >= 0; j-- { 233 switch { 234 case dsn[j] == '/': 235 // second slash 236 secondSlash = true 237 posSecondSlash = j 238 case dsn[j] == '@': 239 // username[:password]@... 240 cfg.User, cfg.Password = parseUserPassword(j, dsn) 241 } 242 if dsn[j] == '@' { 243 break 244 } 245 } 246 247 // account or host:port 248 err = parseAccountHostPort(cfg, j, posSecondSlash, dsn) 249 if err != nil { 250 return nil, err 251 } 252 } 253 // [?param1=value1&...¶mN=valueN] 254 // Find the first '?' in dsn[i+1:] 255 err = parseParams(cfg, i, dsn) 256 if err != nil { 257 return 258 } 259 if secondSlash { 260 cfg.Database = dsn[posSecondSlash+1 : i] 261 cfg.Schema = dsn[i+1 : posQuestion] 262 } else { 263 cfg.Database = dsn[posSecondSlash+1 : posQuestion] 264 } 265 done = true 266 case dsn[i] == '?': 267 posQuestion = i 268 } 269 if done { 270 break 271 } 272 } 273 if !foundSlash { 274 // no db or schema is specified 275 var j int 276 for j = len(dsn) - 1; j >= 0; j-- { 277 switch { 278 case dsn[j] == '@': 279 cfg.User, cfg.Password = parseUserPassword(j, dsn) 280 case dsn[j] == '?': 281 posQuestion = j 282 } 283 if dsn[j] == '@' { 284 break 285 } 286 } 287 err = parseAccountHostPort(cfg, j, posQuestion, dsn) 288 if err != nil { 289 return nil, err 290 } 291 err = parseParams(cfg, posQuestion-1, dsn) 292 if err != nil { 293 return 294 } 295 } 296 if cfg.Account == "" && strings.HasSuffix(cfg.Host, defaultDomain) { 297 posDot := strings.Index(cfg.Host, ".") 298 if posDot > 0 { 299 cfg.Account = cfg.Host[:posDot] 300 } 301 } 302 posDot := strings.Index(cfg.Account, ".") 303 if posDot >= 0 { 304 cfg.Account = cfg.Account[:posDot] 305 } 306 307 err = fillMissingConfigParameters(cfg) 308 if err != nil { 309 return nil, err 310 } 311 312 // unescape parameters 313 var s string 314 s, err = url.QueryUnescape(cfg.User) 315 if err != nil { 316 return nil, err 317 } 318 cfg.User = s 319 s, err = url.QueryUnescape(cfg.Password) 320 if err != nil { 321 return nil, err 322 } 323 cfg.Password = s 324 s, err = url.QueryUnescape(cfg.Database) 325 if err != nil { 326 return nil, err 327 } 328 cfg.Database = s 329 s, err = url.QueryUnescape(cfg.Schema) 330 if err != nil { 331 return nil, err 332 } 333 cfg.Schema = s 334 s, err = url.QueryUnescape(cfg.Role) 335 if err != nil { 336 return nil, err 337 } 338 cfg.Role = s 339 s, err = url.QueryUnescape(cfg.Warehouse) 340 if err != nil { 341 return nil, err 342 } 343 cfg.Warehouse = s 344 return cfg, nil 345} 346 347func fillMissingConfigParameters(cfg *Config) error { 348 posDash := strings.LastIndex(cfg.Account, "-") 349 if posDash > 0 { 350 if strings.Contains(cfg.Host, ".global.") { 351 cfg.Account = cfg.Account[:posDash] 352 } 353 } 354 if strings.Trim(cfg.Account, " ") == "" { 355 return ErrEmptyAccount 356 } 357 358 if cfg.Authenticator != AuthTypeOAuth && strings.Trim(cfg.User, " ") == "" { 359 // oauth does not require a username 360 return ErrEmptyUsername 361 } 362 363 if cfg.Authenticator != AuthTypeExternalBrowser && 364 cfg.Authenticator != AuthTypeOAuth && 365 cfg.Authenticator != AuthTypeJwt && 366 strings.Trim(cfg.Password, " ") == "" { 367 // no password parameter is required for EXTERNALBROWSER, OAUTH or JWT. 368 return ErrEmptyPassword 369 } 370 if strings.Trim(cfg.Protocol, " ") == "" { 371 cfg.Protocol = "https" 372 } 373 if cfg.Port == 0 { 374 cfg.Port = 443 375 } 376 377 cfg.Region = strings.Trim(cfg.Region, " ") 378 if cfg.Region != "" { 379 // region is specified but not included in Host 380 i := strings.Index(cfg.Host, defaultDomain) 381 if i >= 1 { 382 hostPrefix := cfg.Host[0:i] 383 if !strings.HasSuffix(hostPrefix, cfg.Region) { 384 cfg.Host = hostPrefix + "." + cfg.Region + defaultDomain 385 } 386 } 387 } 388 if cfg.Host == "" { 389 if cfg.Region != "" { 390 cfg.Host = cfg.Account + "." + cfg.Region + defaultDomain 391 } else { 392 cfg.Host = cfg.Account + defaultDomain 393 } 394 } 395 if cfg.LoginTimeout == 0 { 396 cfg.LoginTimeout = defaultLoginTimeout 397 } 398 if cfg.RequestTimeout == 0 { 399 cfg.RequestTimeout = defaultRequestTimeout 400 } 401 if cfg.JWTExpireTimeout == 0 { 402 cfg.JWTExpireTimeout = defaultJWTTimeout 403 } 404 if cfg.ClientTimeout == 0 { 405 cfg.ClientTimeout = defaultClientTimeout 406 } 407 if strings.Trim(cfg.Application, " ") == "" { 408 cfg.Application = clientType 409 } 410 411 if cfg.OCSPFailOpen == ocspFailOpenNotSet { 412 cfg.OCSPFailOpen = OCSPFailOpenTrue 413 } 414 415 if cfg.ValidateDefaultParameters == configBoolNotSet { 416 cfg.ValidateDefaultParameters = ConfigBoolTrue 417 } 418 419 if strings.HasSuffix(cfg.Host, defaultDomain) && len(cfg.Host) == len(defaultDomain) { 420 return &SnowflakeError{ 421 Number: ErrCodeFailedToParseHost, 422 Message: errMsgFailedToParseHost, 423 MessageArgs: []interface{}{cfg.Host}, 424 } 425 } 426 return nil 427} 428 429// transformAccountToHost transforms host to account name 430func transformAccountToHost(cfg *Config) (err error) { 431 if cfg.Port == 0 && !strings.HasSuffix(cfg.Host, defaultDomain) && cfg.Host != "" { 432 // account name is specified instead of host:port 433 cfg.Account = cfg.Host 434 cfg.Host = cfg.Account + defaultDomain 435 cfg.Port = 443 436 posDot := strings.Index(cfg.Account, ".") 437 if posDot > 0 { 438 cfg.Region = cfg.Account[posDot+1:] 439 cfg.Account = cfg.Account[:posDot] 440 } 441 } 442 return nil 443} 444 445// parseAccountHostPort parses the DSN string to attempt to get account or host and port. 446func parseAccountHostPort(cfg *Config, posAt, posSlash int, dsn string) (err error) { 447 // account or host:port 448 var k int 449 for k = posAt + 1; k < posSlash; k++ { 450 if dsn[k] == ':' { 451 cfg.Port, err = strconv.Atoi(dsn[k+1 : posSlash]) 452 if err != nil { 453 err = &SnowflakeError{ 454 Number: ErrCodeFailedToParsePort, 455 Message: errMsgFailedToParsePort, 456 MessageArgs: []interface{}{dsn[k+1 : posSlash]}, 457 } 458 return 459 } 460 break 461 } 462 } 463 cfg.Host = dsn[posAt+1 : k] 464 return transformAccountToHost(cfg) 465} 466 467// parseUserPassword parses the DSN string for username and password 468func parseUserPassword(posAt int, dsn string) (user, password string) { 469 var k int 470 for k = 0; k < posAt; k++ { 471 if dsn[k] == ':' { 472 password = dsn[k+1 : posAt] 473 break 474 } 475 } 476 user = dsn[:k] 477 return 478} 479 480// parseParams parse parameters 481func parseParams(cfg *Config, posQuestion int, dsn string) (err error) { 482 for j := posQuestion + 1; j < len(dsn); j++ { 483 if dsn[j] == '?' { 484 if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { 485 return 486 } 487 break 488 } 489 } 490 return 491} 492 493// parseDSNParams parses the DSN "query string". Values must be url.QueryEscape'ed 494func parseDSNParams(cfg *Config, params string) (err error) { 495 logger.Infof("Query String: %v\n", params) 496 for _, v := range strings.Split(params, "&") { 497 param := strings.SplitN(v, "=", 2) 498 if len(param) != 2 { 499 continue 500 } 501 var value string 502 value, err = url.QueryUnescape(param[1]) 503 if err != nil { 504 return err 505 } 506 switch param[0] { 507 // Disable INFILE whitelist / enable all files 508 case "account": 509 cfg.Account = value 510 case "warehouse": 511 cfg.Warehouse = value 512 case "database": 513 cfg.Database = value 514 case "schema": 515 cfg.Schema = value 516 case "role": 517 cfg.Role = value 518 case "region": 519 cfg.Region = value 520 case "protocol": 521 cfg.Protocol = value 522 case "passcode": 523 cfg.Passcode = value 524 case "passcodeInPassword": 525 var vv bool 526 vv, err = strconv.ParseBool(value) 527 if err != nil { 528 return 529 } 530 cfg.PasscodeInPassword = vv 531 case "loginTimeout": 532 cfg.LoginTimeout, err = parseTimeout(value) 533 if err != nil { 534 return 535 } 536 case "requestTimeout": 537 cfg.RequestTimeout, err = parseTimeout(value) 538 if err != nil { 539 return 540 } 541 case "jwtTimeout": 542 cfg.JWTExpireTimeout, err = parseTimeout(value) 543 if err != nil { 544 return err 545 } 546 case "application": 547 cfg.Application = value 548 case "authenticator": 549 err := determineAuthenticatorType(cfg, value) 550 if err != nil { 551 return err 552 } 553 case "insecureMode": 554 var vv bool 555 vv, err = strconv.ParseBool(value) 556 if err != nil { 557 return 558 } 559 cfg.InsecureMode = vv 560 case "ocspFailOpen": 561 var vv bool 562 vv, err = strconv.ParseBool(value) 563 if err != nil { 564 return 565 } 566 if vv { 567 cfg.OCSPFailOpen = OCSPFailOpenTrue 568 } else { 569 cfg.OCSPFailOpen = OCSPFailOpenFalse 570 } 571 572 case "token": 573 cfg.Token = value 574 case "privateKey": 575 var decodeErr error 576 block, decodeErr := base64.URLEncoding.DecodeString(value) 577 if decodeErr != nil { 578 err = &SnowflakeError{ 579 Number: ErrCodePrivateKeyParseError, 580 Message: "Base64 decode failed", 581 } 582 return 583 } 584 cfg.PrivateKey, err = parsePKCS8PrivateKey(block) 585 if err != nil { 586 return err 587 } 588 case "validateDefaultParameters": 589 var vv bool 590 vv, err = strconv.ParseBool(value) 591 if err != nil { 592 return 593 } 594 if vv { 595 cfg.ValidateDefaultParameters = ConfigBoolTrue 596 } else { 597 cfg.ValidateDefaultParameters = ConfigBoolFalse 598 } 599 default: 600 if cfg.Params == nil { 601 cfg.Params = make(map[string]*string) 602 } 603 cfg.Params[param[0]] = &value 604 } 605 } 606 return 607} 608 609func parseTimeout(value string) (time.Duration, error) { 610 var vv int64 611 var err error 612 vv, err = strconv.ParseInt(value, 10, 64) 613 if err != nil { 614 return time.Duration(0), err 615 } 616 return time.Duration(vv * int64(time.Second)), nil 617} 618