1// Copyright (C) MongoDB, Inc. 2014-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package connstring 8 9import ( 10 "errors" 11 "fmt" 12 "net" 13 "net/url" 14 "runtime" 15 "strconv" 16 "strings" 17 "time" 18) 19 20// Parse parses the provided uri and returns a URI object. 21func ParseURIConnectionString(s string) (ConnString, error) { 22 var p parser 23 err := p.parse(s) 24 if err != nil { 25 err = fmt.Errorf("error parsing uri (%s): %s", s, err) 26 } 27 return p.ConnString, err 28} 29 30// ConnString represents a connection string to mongodb. 31type ConnString struct { 32 Original string 33 AppName string 34 AuthMechanism string 35 AuthMechanismProperties map[string]string 36 AuthSource string 37 Connect ConnectMode 38 ConnectTimeout time.Duration 39 Database string 40 FSync bool 41 HeartbeatInterval time.Duration 42 Hosts []string 43 Journal bool 44 KerberosService string 45 KerberosServiceHost string 46 MaxConnIdleTime time.Duration 47 MaxConnLifeTime time.Duration 48 MaxConnsPerHost uint16 49 MaxConnsPerHostSet bool 50 MaxIdleConnsPerHost uint16 51 MaxIdleConnsPerHostSet bool 52 Password string 53 PasswordSet bool 54 ReadPreference string 55 ReadPreferenceTagSets []map[string]string 56 ReplicaSet string 57 ServerSelectionTimeout time.Duration 58 SocketTimeout time.Duration 59 Username string 60 UseSSL bool 61 UseSSLSeen bool 62 W string 63 WTimeout time.Duration 64 65 UsingSRV bool 66 67 Options map[string][]string 68 UnknownOptions map[string][]string 69} 70 71func (u *ConnString) String() string { 72 return u.Original 73} 74 75// ConnectMode informs the driver on how to connect 76// to the server. 77type ConnectMode uint8 78 79// ConnectMode constants. 80const ( 81 AutoConnect ConnectMode = iota 82 SingleConnect 83) 84 85type parser struct { 86 ConnString 87 88 haveWTimeoutMS bool 89} 90 91func (p *parser) parse(original string) error { 92 p.Original = original 93 uri := original 94 95 var err error 96 var isSRV bool 97 if strings.HasPrefix(uri, "mongodb+srv://") { 98 isSRV = true 99 100 p.UsingSRV = true 101 102 // SSL should be turned on by default when retrieving hosts from SRV 103 p.UseSSL = true 104 p.UseSSLSeen = true 105 106 // remove the scheme 107 uri = uri[14:] 108 } else if strings.HasPrefix(uri, "mongodb://") { 109 // remove the scheme 110 uri = uri[10:] 111 } else { 112 return fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"") 113 } 114 115 if idx := strings.Index(uri, "@"); idx != -1 { 116 userInfo := uri[:idx] 117 uri = uri[idx+1:] 118 119 username := userInfo 120 var password string 121 122 if idx := strings.Index(userInfo, ":"); idx != -1 { 123 username = userInfo[:idx] 124 password = userInfo[idx+1:] 125 p.PasswordSet = true 126 } 127 128 if len(username) > 1 { 129 if strings.Contains(username, "/") { 130 return fmt.Errorf("unescaped slash in username") 131 } 132 } 133 134 p.Username, err = url.QueryUnescape(username) 135 if err != nil { 136 return fmt.Errorf("invalid username: %s", err) 137 } 138 if len(password) > 1 { 139 if strings.Contains(password, ":") { 140 return fmt.Errorf("unescaped colon in password") 141 } 142 if strings.Contains(password, "/") { 143 return fmt.Errorf("unescaped slash in password") 144 } 145 p.Password, err = url.QueryUnescape(password) 146 if err != nil { 147 return fmt.Errorf("invalid password: %s", err) 148 } 149 } 150 } 151 152 // fetch the hosts field 153 hosts := uri 154 if idx := strings.IndexAny(uri, "/?@"); idx != -1 { 155 if uri[idx] == '@' { 156 return fmt.Errorf("unescaped @ sign in user info") 157 } 158 if uri[idx] == '?' { 159 return fmt.Errorf("must have a / before the query ?") 160 } 161 hosts = uri[:idx] 162 } 163 164 var connectionArgsFromTXT []string 165 parsedHosts := strings.Split(hosts, ",") 166 167 if isSRV { 168 parsedHosts = strings.Split(hosts, ",") 169 if len(parsedHosts) != 1 { 170 return fmt.Errorf("URI with SRV must include one and only one hostname") 171 } 172 parsedHosts, err = fetchSeedlistFromSRV(parsedHosts[0]) 173 if err != nil { 174 return err 175 } 176 177 // error ignored because finding a TXT record should not be 178 // considered an error. 179 recordsFromTXT, _ := net.LookupTXT(hosts) 180 181 // This is a temporary fix to get around bug https://github.com/golang/go/issues/21472. 182 // It will currently incorrectly concatenate multiple TXT records to one 183 // on windows. 184 if runtime.GOOS == "windows" { 185 recordsFromTXT = []string{strings.Join(recordsFromTXT, "")} 186 } 187 188 if len(recordsFromTXT) > 1 { 189 return errors.New("multiple records from TXT not supported") 190 } 191 if len(recordsFromTXT) > 0 { 192 connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' }) 193 194 err := validateTXTResult(connectionArgsFromTXT) 195 if err != nil { 196 return err 197 } 198 199 } 200 } 201 202 for _, host := range parsedHosts { 203 err = p.addHost(host) 204 if err != nil { 205 return fmt.Errorf("invalid host \"%s\": %s", host, err) 206 } 207 } 208 if len(p.Hosts) == 0 { 209 return fmt.Errorf("must have at least 1 host") 210 } 211 212 uri = uri[len(hosts):] 213 214 extractedDatabase, err := extractDatabaseFromURI(uri) 215 if err != nil { 216 return err 217 } 218 219 uri = extractedDatabase.uri 220 p.Database = extractedDatabase.db 221 222 connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri) 223 connectionArgPairs := append(connectionArgsFromTXT, connectionArgsFromQueryString...) 224 225 for _, pair := range connectionArgPairs { 226 err = p.addOption(pair) 227 if err != nil { 228 return err 229 } 230 } 231 232 return nil 233} 234 235func fetchSeedlistFromSRV(host string) ([]string, error) { 236 var err error 237 238 _, _, err = net.SplitHostPort(host) 239 240 if err == nil { 241 // we were able to successfully extract a port from the host, 242 // but should not be able to when using SRV 243 return nil, fmt.Errorf("URI with srv must not include a port number") 244 } 245 246 _, addresses, err := net.LookupSRV("mongodb", "tcp", host) 247 if err != nil { 248 return nil, err 249 } 250 parsedHosts := make([]string, len(addresses)) 251 for i, address := range addresses { 252 trimmedAddressTarget := strings.TrimSuffix(address.Target, ".") 253 err := validateSRVResult(trimmedAddressTarget, host) 254 if err != nil { 255 return nil, err 256 } 257 parsedHosts[i] = fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port) 258 } 259 260 return parsedHosts, nil 261} 262 263func (p *parser) addHost(host string) error { 264 if host == "" { 265 return nil 266 } 267 host, err := url.QueryUnescape(host) 268 if err != nil { 269 return fmt.Errorf("invalid host \"%s\": %s", host, err) 270 } 271 272 _, port, err := net.SplitHostPort(host) 273 // this is unfortunate that SplitHostPort actually requires 274 // a port to exist. 275 if err != nil { 276 if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" { 277 return err 278 } 279 } 280 281 if port != "" { 282 d, err := strconv.Atoi(port) 283 if err != nil { 284 return fmt.Errorf("port must be an integer: %s", err) 285 } 286 if d <= 0 || d >= 65536 { 287 return fmt.Errorf("port must be in the range [1, 65535]") 288 } 289 } 290 p.Hosts = append(p.Hosts, host) 291 return nil 292} 293 294func (p *parser) addOption(pair string) error { 295 kv := strings.SplitN(pair, "=", 2) 296 if len(kv) != 2 || kv[0] == "" { 297 return fmt.Errorf("invalid option") 298 } 299 300 key, err := url.QueryUnescape(kv[0]) 301 if err != nil { 302 return fmt.Errorf("invalid option key \"%s\": %s", kv[0], err) 303 } 304 305 value, err := url.QueryUnescape(kv[1]) 306 if err != nil { 307 return fmt.Errorf("invalid option value \"%s\": %s", kv[1], err) 308 } 309 310 lowerKey := strings.ToLower(key) 311 switch lowerKey { 312 case "appname": 313 p.AppName = value 314 case "authmechanism": 315 p.AuthMechanism = value 316 case "authmechanismproperties": 317 p.AuthMechanismProperties = make(map[string]string) 318 pairs := strings.Split(value, ",") 319 for _, pair := range pairs { 320 kv := strings.SplitN(pair, ":", 2) 321 if len(kv) != 2 || kv[0] == "" { 322 return fmt.Errorf("invalid authMechanism property") 323 } 324 p.AuthMechanismProperties[kv[0]] = kv[1] 325 } 326 case "authsource": 327 p.AuthSource = value 328 case "connect": 329 switch strings.ToLower(value) { 330 case "auto", "automatic": 331 p.Connect = AutoConnect 332 case "direct", "single": 333 p.Connect = SingleConnect 334 default: 335 return fmt.Errorf("invalid 'connect' value: %s", value) 336 } 337 case "connecttimeoutms": 338 n, err := strconv.Atoi(value) 339 if err != nil || n < 0 { 340 return fmt.Errorf("invalid value for %s: %s", key, value) 341 } 342 p.ConnectTimeout = time.Duration(n) * time.Millisecond 343 case "heartbeatintervalms", "heartbeatfrequencyms": 344 n, err := strconv.Atoi(value) 345 if err != nil || n < 0 { 346 return fmt.Errorf("invalid value for %s: %s", key, value) 347 } 348 p.HeartbeatInterval = time.Duration(n) * time.Millisecond 349 case "fsync": 350 f, err := strconv.ParseBool(value) 351 if err != nil { 352 return fmt.Errorf("invalid value for %s: %s", key, value) 353 } 354 p.FSync = f 355 case "j": 356 j, err := strconv.ParseBool(value) 357 if err != nil { 358 return fmt.Errorf("invalid value for %s: %s", key, value) 359 } 360 p.Journal = j 361 case "gssapiservicename": 362 p.KerberosService = value 363 case "gssapihostname": 364 p.KerberosServiceHost = value 365 case "maxconnsperhost": 366 n, err := strconv.Atoi(value) 367 if err != nil || n < 0 { 368 return fmt.Errorf("invalid value for %s: %s", key, value) 369 } 370 p.MaxConnsPerHost = uint16(n) 371 p.MaxConnsPerHostSet = true 372 case "maxidleconnsperhost": 373 n, err := strconv.Atoi(value) 374 if err != nil || n < 0 { 375 return fmt.Errorf("invalid value for %s: %s", key, value) 376 } 377 p.MaxIdleConnsPerHost = uint16(n) 378 p.MaxIdleConnsPerHostSet = true 379 case "maxidletimems": 380 n, err := strconv.Atoi(value) 381 if err != nil || n < 0 { 382 return fmt.Errorf("invalid value for %s: %s", key, value) 383 } 384 p.MaxConnIdleTime = time.Duration(n) * time.Millisecond 385 case "maxlifetimems": 386 n, err := strconv.Atoi(value) 387 if err != nil || n < 0 { 388 return fmt.Errorf("invalid value for %s: %s", key, value) 389 } 390 p.MaxConnLifeTime = time.Duration(n) * time.Millisecond 391 case "maxpoolsize": 392 n, err := strconv.Atoi(value) 393 if err != nil || n < 0 { 394 return fmt.Errorf("invalid value for %s: %s", key, value) 395 } 396 p.MaxConnsPerHost = uint16(n) 397 p.MaxConnsPerHostSet = true 398 p.MaxIdleConnsPerHost = uint16(n) 399 p.MaxIdleConnsPerHostSet = true 400 case "readpreference": 401 p.ReadPreference = value 402 case "readpreferencetags": 403 tags := make(map[string]string) 404 items := strings.Split(value, ",") 405 for _, item := range items { 406 parts := strings.Split(item, ":") 407 if len(parts) != 2 { 408 return fmt.Errorf("invalid value for %s: %s", key, value) 409 } 410 tags[parts[0]] = parts[1] 411 } 412 p.ReadPreferenceTagSets = append(p.ReadPreferenceTagSets, tags) 413 case "replicaset": 414 p.ReplicaSet = value 415 case "serverselectiontimeoutms": 416 n, err := strconv.Atoi(value) 417 if err != nil || n < 0 { 418 return fmt.Errorf("invalid value for %s: %s", key, value) 419 } 420 p.ServerSelectionTimeout = time.Duration(n) * time.Millisecond 421 case "sockettimeoutms": 422 n, err := strconv.Atoi(value) 423 if err != nil || n < 0 { 424 return fmt.Errorf("invalid value for %s: %s", key, value) 425 } 426 p.SocketTimeout = time.Duration(n) * time.Millisecond 427 case "ssl": 428 b, err := strconv.ParseBool(value) 429 if err != nil { 430 return fmt.Errorf("invalid value for %s: %s", key, value) 431 } 432 p.UseSSL = b 433 p.UseSSLSeen = true 434 case "w": 435 p.W = value 436 case "wtimeoutms": 437 n, err := strconv.Atoi(value) 438 if err != nil || n < 0 { 439 return fmt.Errorf("invalid value for %s: %s", key, value) 440 } 441 p.WTimeout = time.Duration(n) * time.Millisecond 442 p.haveWTimeoutMS = true 443 case "wtimeout": 444 if p.haveWTimeoutMS { 445 // use wtimeoutMS if it exists 446 break 447 } 448 n, err := strconv.Atoi(value) 449 if err != nil || n < 0 { 450 return fmt.Errorf("invalid value for %s: %s", key, value) 451 } 452 p.WTimeout = time.Duration(n) * time.Millisecond 453 default: 454 if p.UnknownOptions == nil { 455 p.UnknownOptions = make(map[string][]string) 456 } 457 p.UnknownOptions[lowerKey] = append(p.UnknownOptions[lowerKey], value) 458 } 459 460 if p.Options == nil { 461 p.Options = make(map[string][]string) 462 } 463 p.Options[lowerKey] = append(p.Options[lowerKey], value) 464 465 return nil 466} 467 468func validateSRVResult(recordFromSRV, inputHostName string) error { 469 separatedInputDomain := strings.Split(inputHostName, ".") 470 separatedRecord := strings.Split(recordFromSRV, ".") 471 if len(separatedRecord) < 2 { 472 return errors.New("DNS name must contain at least 2 labels") 473 } 474 if len(separatedRecord) < len(separatedInputDomain) { 475 return errors.New("Domain suffix from SRV record not matched input domain") 476 } 477 478 inputDomainSuffix := separatedInputDomain[1:] 479 domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1) 480 481 recordDomainSuffix := separatedRecord[domainSuffixOffset:] 482 for ix, label := range inputDomainSuffix { 483 if label != recordDomainSuffix[ix] { 484 return errors.New("Domain suffix from SRV record not matched input domain") 485 } 486 } 487 return nil 488} 489 490var allowedTXTOptions = map[string]struct{}{ 491 "authsource": {}, 492 "replicaset": {}, 493} 494 495func validateTXTResult(paramsFromTXT []string) error { 496 for _, param := range paramsFromTXT { 497 kv := strings.SplitN(param, "=", 2) 498 if len(kv) != 2 { 499 return errors.New("Invalid TXT record") 500 } 501 key := strings.ToLower(kv[0]) 502 if _, ok := allowedTXTOptions[key]; !ok { 503 return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0]) 504 } 505 } 506 return nil 507} 508 509func extractQueryArgsFromURI(uri string) ([]string, error) { 510 if len(uri) == 0 { 511 return nil, nil 512 } 513 514 if uri[0] != '?' { 515 return nil, errors.New("must have a ? separator between path and query") 516 } 517 518 uri = uri[1:] 519 if len(uri) == 0 { 520 return nil, nil 521 } 522 return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil 523 524} 525 526type extractedDatabase struct { 527 uri string 528 db string 529} 530 531// extractDatabaseFromURI is a helper function to retrieve information about 532// the database from the passed in URI. It accepts as an argument the currently 533// parsed URI and returns the remainder of the uri, the database it found, 534// and any error it encounters while parsing. 535func extractDatabaseFromURI(uri string) (extractedDatabase, error) { 536 if len(uri) == 0 { 537 return extractedDatabase{}, nil 538 } 539 540 if uri[0] != '/' { 541 return extractedDatabase{}, errors.New("must have a / separator between hosts and path") 542 } 543 544 uri = uri[1:] 545 if len(uri) == 0 { 546 return extractedDatabase{}, nil 547 } 548 549 database := uri 550 if idx := strings.IndexRune(uri, '?'); idx != -1 { 551 database = uri[:idx] 552 } 553 554 escapedDatabase, err := url.QueryUnescape(database) 555 if err != nil { 556 return extractedDatabase{}, fmt.Errorf("invalid database \"%s\": %s", database, err) 557 } 558 559 uri = uri[len(database):] 560 561 return extractedDatabase{ 562 uri: uri, 563 db: escapedDatabase, 564 }, nil 565} 566