1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package connstring // import "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 8 9import ( 10 "errors" 11 "fmt" 12 "net" 13 "net/url" 14 "strconv" 15 "strings" 16 "time" 17 18 "go.mongodb.org/mongo-driver/internal" 19 "go.mongodb.org/mongo-driver/mongo/writeconcern" 20 "go.mongodb.org/mongo-driver/x/mongo/driver/dns" 21 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" 22) 23 24// Parse parses the provided uri and returns a URI object. 25func Parse(s string) (ConnString, error) { 26 p := parser{dnsResolver: dns.DefaultResolver} 27 err := p.parse(s) 28 if err != nil { 29 err = internal.WrapErrorf(err, "error parsing uri") 30 } 31 return p.ConnString, err 32} 33 34// ConnString represents a connection string to mongodb. 35type ConnString struct { 36 Original string 37 AppName string 38 AuthMechanism string 39 AuthMechanismProperties map[string]string 40 AuthSource string 41 Compressors []string 42 Connect ConnectMode 43 ConnectSet bool 44 ConnectTimeout time.Duration 45 ConnectTimeoutSet bool 46 Database string 47 HeartbeatInterval time.Duration 48 HeartbeatIntervalSet bool 49 Hosts []string 50 J bool 51 JSet bool 52 LocalThreshold time.Duration 53 LocalThresholdSet bool 54 MaxConnIdleTime time.Duration 55 MaxConnIdleTimeSet bool 56 MaxPoolSize uint64 57 MaxPoolSizeSet bool 58 MinPoolSize uint64 59 MinPoolSizeSet bool 60 Password string 61 PasswordSet bool 62 ReadConcernLevel string 63 ReadPreference string 64 ReadPreferenceTagSets []map[string]string 65 RetryWrites bool 66 RetryWritesSet bool 67 RetryReads bool 68 RetryReadsSet bool 69 MaxStaleness time.Duration 70 MaxStalenessSet bool 71 ReplicaSet string 72 Scheme string 73 ServerSelectionTimeout time.Duration 74 ServerSelectionTimeoutSet bool 75 SocketTimeout time.Duration 76 SocketTimeoutSet bool 77 SSL bool 78 SSLSet bool 79 SSLClientCertificateKeyFile string 80 SSLClientCertificateKeyFileSet bool 81 SSLClientCertificateKeyPassword func() string 82 SSLClientCertificateKeyPasswordSet bool 83 SSLInsecure bool 84 SSLInsecureSet bool 85 SSLCaFile string 86 SSLCaFileSet bool 87 WString string 88 WNumber int 89 WNumberSet bool 90 Username string 91 ZlibLevel int 92 ZlibLevelSet bool 93 ZstdLevel int 94 ZstdLevelSet bool 95 96 WTimeout time.Duration 97 WTimeoutSet bool 98 WTimeoutSetFromOption bool 99 100 Options map[string][]string 101 UnknownOptions map[string][]string 102} 103 104func (u *ConnString) String() string { 105 return u.Original 106} 107 108// ConnectMode informs the driver on how to connect 109// to the server. 110type ConnectMode uint8 111 112// ConnectMode constants. 113const ( 114 AutoConnect ConnectMode = iota 115 SingleConnect 116) 117 118// Scheme constants 119const ( 120 SchemeMongoDB = "mongodb" 121 SchemeMongoDBSRV = "mongodb+srv" 122) 123 124type parser struct { 125 ConnString 126 127 dnsResolver *dns.Resolver 128 tlsssl *bool // used to determine if tls and ssl options are both specified and set differently. 129} 130 131func (p *parser) parse(original string) error { 132 p.Original = original 133 uri := original 134 135 var err error 136 if strings.HasPrefix(uri, SchemeMongoDBSRV+"://") { 137 p.Scheme = SchemeMongoDBSRV 138 // remove the scheme 139 uri = uri[len(SchemeMongoDBSRV)+3:] 140 } else if strings.HasPrefix(uri, SchemeMongoDB+"://") { 141 p.Scheme = SchemeMongoDB 142 // remove the scheme 143 uri = uri[len(SchemeMongoDB)+3:] 144 } else { 145 return fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"") 146 } 147 148 if idx := strings.Index(uri, "@"); idx != -1 { 149 userInfo := uri[:idx] 150 uri = uri[idx+1:] 151 152 username := userInfo 153 var password string 154 155 if idx := strings.Index(userInfo, ":"); idx != -1 { 156 username = userInfo[:idx] 157 password = userInfo[idx+1:] 158 p.PasswordSet = true 159 } 160 161 if len(username) > 1 { 162 if strings.Contains(username, "/") { 163 return fmt.Errorf("unescaped slash in username") 164 } 165 } 166 167 p.Username, err = url.QueryUnescape(username) 168 if err != nil { 169 return internal.WrapErrorf(err, "invalid username") 170 } 171 if len(password) > 1 { 172 if strings.Contains(password, ":") { 173 return fmt.Errorf("unescaped colon in password") 174 } 175 if strings.Contains(password, "/") { 176 return fmt.Errorf("unescaped slash in password") 177 } 178 p.Password, err = url.QueryUnescape(password) 179 if err != nil { 180 return internal.WrapErrorf(err, "invalid password") 181 } 182 } 183 } 184 185 // fetch the hosts field 186 hosts := uri 187 if idx := strings.IndexAny(uri, "/?@"); idx != -1 { 188 if uri[idx] == '@' { 189 return fmt.Errorf("unescaped @ sign in user info") 190 } 191 if uri[idx] == '?' { 192 return fmt.Errorf("must have a / before the query ?") 193 } 194 hosts = uri[:idx] 195 } 196 197 var connectionArgsFromTXT []string 198 parsedHosts := strings.Split(hosts, ",") 199 200 if p.Scheme == SchemeMongoDBSRV { 201 parsedHosts, err = p.dnsResolver.ParseHosts(hosts, true) 202 if err != nil { 203 return err 204 } 205 connectionArgsFromTXT, err = p.dnsResolver.GetConnectionArgsFromTXT(hosts) 206 if err != nil { 207 return err 208 } 209 210 // SSL is enabled by default for SRV, but can be manually disabled with "ssl=false". 211 p.SSL = true 212 p.SSLSet = true 213 } 214 215 for _, host := range parsedHosts { 216 err = p.addHost(host) 217 if err != nil { 218 return internal.WrapErrorf(err, "invalid host \"%s\"", host) 219 } 220 } 221 if len(p.Hosts) == 0 { 222 return fmt.Errorf("must have at least 1 host") 223 } 224 225 uri = uri[len(hosts):] 226 227 extractedDatabase, err := extractDatabaseFromURI(uri) 228 if err != nil { 229 return err 230 } 231 232 uri = extractedDatabase.uri 233 p.Database = extractedDatabase.db 234 235 connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri) 236 connectionArgPairs := append(connectionArgsFromTXT, connectionArgsFromQueryString...) 237 238 for _, pair := range connectionArgPairs { 239 err = p.addOption(pair) 240 if err != nil { 241 return err 242 } 243 } 244 245 err = p.setDefaultAuthParams(extractedDatabase.db) 246 if err != nil { 247 return err 248 } 249 250 err = p.validateAuth() 251 if err != nil { 252 return err 253 } 254 255 // Check for invalid write concern (i.e. w=0 and j=true) 256 if p.WNumberSet && p.WNumber == 0 && p.JSet && p.J { 257 return writeconcern.ErrInconsistent 258 } 259 260 // If WTimeout was set from manual options passed in, set WTImeoutSet to true. 261 if p.WTimeoutSetFromOption { 262 p.WTimeoutSet = true 263 } 264 265 return nil 266} 267 268func (p *parser) setDefaultAuthParams(dbName string) error { 269 switch strings.ToLower(p.AuthMechanism) { 270 case "plain": 271 if p.AuthSource == "" { 272 p.AuthSource = dbName 273 if p.AuthSource == "" { 274 p.AuthSource = "$external" 275 } 276 } 277 case "gssapi": 278 if p.AuthMechanismProperties == nil { 279 p.AuthMechanismProperties = map[string]string{ 280 "SERVICE_NAME": "mongodb", 281 } 282 } else if v, ok := p.AuthMechanismProperties["SERVICE_NAME"]; !ok || v == "" { 283 p.AuthMechanismProperties["SERVICE_NAME"] = "mongodb" 284 } 285 fallthrough 286 case "mongodb-x509": 287 if p.AuthSource == "" { 288 p.AuthSource = "$external" 289 } else if p.AuthSource != "$external" { 290 return fmt.Errorf("auth source must be $external") 291 } 292 case "mongodb-cr": 293 fallthrough 294 case "scram-sha-1": 295 fallthrough 296 case "scram-sha-256": 297 if p.AuthSource == "" { 298 p.AuthSource = dbName 299 if p.AuthSource == "" { 300 p.AuthSource = "admin" 301 } 302 } 303 case "": 304 if p.AuthSource == "" && (p.AuthMechanismProperties != nil || p.Username != "" || p.PasswordSet) { 305 p.AuthSource = dbName 306 if p.AuthSource == "" { 307 p.AuthSource = "admin" 308 } 309 } 310 default: 311 return fmt.Errorf("invalid auth mechanism") 312 } 313 return nil 314} 315 316func (p *parser) validateAuth() error { 317 switch strings.ToLower(p.AuthMechanism) { 318 case "mongodb-cr": 319 if p.Username == "" { 320 return fmt.Errorf("username required for MONGO-CR") 321 } 322 if p.Password == "" { 323 return fmt.Errorf("password required for MONGO-CR") 324 } 325 if p.AuthMechanismProperties != nil { 326 return fmt.Errorf("MONGO-CR cannot have mechanism properties") 327 } 328 case "mongodb-x509": 329 if p.Password != "" { 330 return fmt.Errorf("password cannot be specified for MONGO-X509") 331 } 332 if p.AuthMechanismProperties != nil { 333 return fmt.Errorf("MONGO-X509 cannot have mechanism properties") 334 } 335 case "gssapi": 336 if p.Username == "" { 337 return fmt.Errorf("username required for GSSAPI") 338 } 339 for k := range p.AuthMechanismProperties { 340 if k != "SERVICE_NAME" && k != "CANONICALIZE_HOST_NAME" && k != "SERVICE_REALM" { 341 return fmt.Errorf("invalid auth property for GSSAPI") 342 } 343 } 344 case "plain": 345 if p.Username == "" { 346 return fmt.Errorf("username required for PLAIN") 347 } 348 if p.Password == "" { 349 return fmt.Errorf("password required for PLAIN") 350 } 351 if p.AuthMechanismProperties != nil { 352 return fmt.Errorf("PLAIN cannot have mechanism properties") 353 } 354 case "scram-sha-1": 355 if p.Username == "" { 356 return fmt.Errorf("username required for SCRAM-SHA-1") 357 } 358 if p.Password == "" { 359 return fmt.Errorf("password required for SCRAM-SHA-1") 360 } 361 if p.AuthMechanismProperties != nil { 362 return fmt.Errorf("SCRAM-SHA-1 cannot have mechanism properties") 363 } 364 case "scram-sha-256": 365 if p.Username == "" { 366 return fmt.Errorf("username required for SCRAM-SHA-256") 367 } 368 if p.Password == "" { 369 return fmt.Errorf("password required for SCRAM-SHA-256") 370 } 371 if p.AuthMechanismProperties != nil { 372 return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties") 373 } 374 case "": 375 if p.Username == "" && p.AuthSource != "" { 376 return fmt.Errorf("authsource without username is invalid") 377 } 378 default: 379 return fmt.Errorf("invalid auth mechanism") 380 } 381 return nil 382} 383 384func (p *parser) addHost(host string) error { 385 if host == "" { 386 return nil 387 } 388 host, err := url.QueryUnescape(host) 389 if err != nil { 390 return internal.WrapErrorf(err, "invalid host \"%s\"", host) 391 } 392 393 _, port, err := net.SplitHostPort(host) 394 // this is unfortunate that SplitHostPort actually requires 395 // a port to exist. 396 if err != nil { 397 if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" { 398 return err 399 } 400 } 401 402 if port != "" { 403 d, err := strconv.Atoi(port) 404 if err != nil { 405 return internal.WrapErrorf(err, "port must be an integer") 406 } 407 if d <= 0 || d >= 65536 { 408 return fmt.Errorf("port must be in the range [1, 65535]") 409 } 410 } 411 p.Hosts = append(p.Hosts, host) 412 return nil 413} 414 415func (p *parser) addOption(pair string) error { 416 kv := strings.SplitN(pair, "=", 2) 417 if len(kv) != 2 || kv[0] == "" { 418 return fmt.Errorf("invalid option") 419 } 420 421 key, err := url.QueryUnescape(kv[0]) 422 if err != nil { 423 return internal.WrapErrorf(err, "invalid option key \"%s\"", kv[0]) 424 } 425 426 value, err := url.QueryUnescape(kv[1]) 427 if err != nil { 428 return internal.WrapErrorf(err, "invalid option value \"%s\"", kv[1]) 429 } 430 431 lowerKey := strings.ToLower(key) 432 switch lowerKey { 433 case "appname": 434 p.AppName = value 435 case "authmechanism": 436 p.AuthMechanism = value 437 case "authmechanismproperties": 438 p.AuthMechanismProperties = make(map[string]string) 439 pairs := strings.Split(value, ",") 440 for _, pair := range pairs { 441 kv := strings.SplitN(pair, ":", 2) 442 if len(kv) != 2 || kv[0] == "" { 443 return fmt.Errorf("invalid authMechanism property") 444 } 445 p.AuthMechanismProperties[kv[0]] = kv[1] 446 } 447 case "authsource": 448 p.AuthSource = value 449 case "compressors": 450 compressors := strings.Split(value, ",") 451 if len(compressors) < 1 { 452 return fmt.Errorf("must have at least 1 compressor") 453 } 454 p.Compressors = compressors 455 case "connect": 456 switch strings.ToLower(value) { 457 case "automatic": 458 case "direct": 459 p.Connect = SingleConnect 460 default: 461 return fmt.Errorf("invalid 'connect' value: %s", value) 462 } 463 464 p.ConnectSet = true 465 case "connecttimeoutms": 466 n, err := strconv.Atoi(value) 467 if err != nil || n < 0 { 468 return fmt.Errorf("invalid value for %s: %s", key, value) 469 } 470 p.ConnectTimeout = time.Duration(n) * time.Millisecond 471 p.ConnectTimeoutSet = true 472 case "heartbeatintervalms", "heartbeatfrequencyms": 473 n, err := strconv.Atoi(value) 474 if err != nil || n < 0 { 475 return fmt.Errorf("invalid value for %s: %s", key, value) 476 } 477 p.HeartbeatInterval = time.Duration(n) * time.Millisecond 478 p.HeartbeatIntervalSet = true 479 case "journal": 480 switch value { 481 case "true": 482 p.J = true 483 case "false": 484 p.J = false 485 default: 486 return fmt.Errorf("invalid value for %s: %s", key, value) 487 } 488 489 p.JSet = true 490 case "localthresholdms": 491 n, err := strconv.Atoi(value) 492 if err != nil || n < 0 { 493 return fmt.Errorf("invalid value for %s: %s", key, value) 494 } 495 p.LocalThreshold = time.Duration(n) * time.Millisecond 496 p.LocalThresholdSet = true 497 case "maxidletimems": 498 n, err := strconv.Atoi(value) 499 if err != nil || n < 0 { 500 return fmt.Errorf("invalid value for %s: %s", key, value) 501 } 502 p.MaxConnIdleTime = time.Duration(n) * time.Millisecond 503 p.MaxConnIdleTimeSet = true 504 case "maxpoolsize": 505 n, err := strconv.Atoi(value) 506 if err != nil || n < 0 { 507 return fmt.Errorf("invalid value for %s: %s", key, value) 508 } 509 p.MaxPoolSize = uint64(n) 510 p.MaxPoolSizeSet = true 511 case "minpoolsize": 512 n, err := strconv.Atoi(value) 513 if err != nil || n < 0 { 514 return fmt.Errorf("invalid value for %s: %s", key, value) 515 } 516 p.MinPoolSize = uint64(n) 517 p.MinPoolSizeSet = true 518 case "readconcernlevel": 519 p.ReadConcernLevel = value 520 case "readpreference": 521 p.ReadPreference = value 522 case "readpreferencetags": 523 if value == "" { 524 // for when readPreferenceTags= at end of URI 525 break 526 } 527 528 tags := make(map[string]string) 529 items := strings.Split(value, ",") 530 for _, item := range items { 531 parts := strings.Split(item, ":") 532 if len(parts) != 2 { 533 return fmt.Errorf("invalid value for %s: %s", key, value) 534 } 535 tags[parts[0]] = parts[1] 536 } 537 p.ReadPreferenceTagSets = append(p.ReadPreferenceTagSets, tags) 538 case "maxstaleness", "maxstalenessseconds": 539 n, err := strconv.Atoi(value) 540 if err != nil || n < 0 { 541 return fmt.Errorf("invalid value for %s: %s", key, value) 542 } 543 p.MaxStaleness = time.Duration(n) * time.Second 544 p.MaxStalenessSet = true 545 case "replicaset": 546 p.ReplicaSet = value 547 case "retrywrites": 548 switch value { 549 case "true": 550 p.RetryWrites = true 551 case "false": 552 p.RetryWrites = false 553 default: 554 return fmt.Errorf("invalid value for %s: %s", key, value) 555 } 556 557 p.RetryWritesSet = true 558 case "retryreads": 559 switch value { 560 case "true": 561 p.RetryReads = true 562 case "false": 563 p.RetryReads = false 564 default: 565 return fmt.Errorf("invalid value for %s: %s", key, value) 566 } 567 568 p.RetryReadsSet = true 569 case "serverselectiontimeoutms": 570 n, err := strconv.Atoi(value) 571 if err != nil || n < 0 { 572 return fmt.Errorf("invalid value for %s: %s", key, value) 573 } 574 p.ServerSelectionTimeout = time.Duration(n) * time.Millisecond 575 p.ServerSelectionTimeoutSet = true 576 case "sockettimeoutms": 577 n, err := strconv.Atoi(value) 578 if err != nil || n < 0 { 579 return fmt.Errorf("invalid value for %s: %s", key, value) 580 } 581 p.SocketTimeout = time.Duration(n) * time.Millisecond 582 p.SocketTimeoutSet = true 583 case "ssl", "tls": 584 switch value { 585 case "true": 586 p.SSL = true 587 case "false": 588 p.SSL = false 589 default: 590 return fmt.Errorf("invalid value for %s: %s", key, value) 591 } 592 if p.tlsssl != nil && *p.tlsssl != p.SSL { 593 return errors.New("tls and ssl options, when both specified, must be equivalent") 594 } 595 596 p.tlsssl = new(bool) 597 *p.tlsssl = p.SSL 598 599 p.SSLSet = true 600 case "sslclientcertificatekeyfile", "tlscertificatekeyfile": 601 p.SSL = true 602 p.SSLSet = true 603 p.SSLClientCertificateKeyFile = value 604 p.SSLClientCertificateKeyFileSet = true 605 case "sslclientcertificatekeypassword", "tlscertificatekeyfilepassword": 606 p.SSLClientCertificateKeyPassword = func() string { return value } 607 p.SSLClientCertificateKeyPasswordSet = true 608 case "sslinsecure", "tlsinsecure": 609 switch value { 610 case "true": 611 p.SSLInsecure = true 612 case "false": 613 p.SSLInsecure = false 614 default: 615 return fmt.Errorf("invalid value for %s: %s", key, value) 616 } 617 618 p.SSLInsecureSet = true 619 case "sslcertificateauthorityfile", "tlscafile": 620 p.SSL = true 621 p.SSLSet = true 622 p.SSLCaFile = value 623 p.SSLCaFileSet = true 624 case "w": 625 if w, err := strconv.Atoi(value); err == nil { 626 if w < 0 { 627 return fmt.Errorf("invalid value for %s: %s", key, value) 628 } 629 630 p.WNumber = w 631 p.WNumberSet = true 632 p.WString = "" 633 break 634 } 635 636 p.WString = value 637 p.WNumberSet = false 638 639 case "wtimeoutms": 640 n, err := strconv.Atoi(value) 641 if err != nil || n < 0 { 642 return fmt.Errorf("invalid value for %s: %s", key, value) 643 } 644 p.WTimeout = time.Duration(n) * time.Millisecond 645 p.WTimeoutSet = true 646 case "wtimeout": 647 // Defer to wtimeoutms, but not to a manually-set option. 648 if p.WTimeoutSet { 649 break 650 } 651 n, err := strconv.Atoi(value) 652 if err != nil || n < 0 { 653 return fmt.Errorf("invalid value for %s: %s", key, value) 654 } 655 p.WTimeout = time.Duration(n) * time.Millisecond 656 case "zlibcompressionlevel": 657 level, err := strconv.Atoi(value) 658 if err != nil || (level < -1 || level > 9) { 659 return fmt.Errorf("invalid value for %s: %s", key, value) 660 } 661 662 if level == -1 { 663 level = wiremessage.DefaultZlibLevel 664 } 665 p.ZlibLevel = level 666 p.ZlibLevelSet = true 667 case "zstdcompressionlevel": 668 const maxZstdLevel = 22 // https://github.com/facebook/zstd/blob/a880ca239b447968493dd2fed3850e766d6305cc/contrib/linux-kernel/lib/zstd/compress.c#L3291 669 level, err := strconv.Atoi(value) 670 if err != nil || (level < -1 || level > maxZstdLevel) { 671 return fmt.Errorf("invalid value for %s: %s", key, value) 672 } 673 674 if level == -1 { 675 level = wiremessage.DefaultZstdLevel 676 } 677 p.ZstdLevel = level 678 p.ZstdLevelSet = true 679 default: 680 if p.UnknownOptions == nil { 681 p.UnknownOptions = make(map[string][]string) 682 } 683 p.UnknownOptions[lowerKey] = append(p.UnknownOptions[lowerKey], value) 684 } 685 686 if p.Options == nil { 687 p.Options = make(map[string][]string) 688 } 689 p.Options[lowerKey] = append(p.Options[lowerKey], value) 690 691 return nil 692} 693 694func extractQueryArgsFromURI(uri string) ([]string, error) { 695 if len(uri) == 0 { 696 return nil, nil 697 } 698 699 if uri[0] != '?' { 700 return nil, errors.New("must have a ? separator between path and query") 701 } 702 703 uri = uri[1:] 704 if len(uri) == 0 { 705 return nil, nil 706 } 707 return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil 708 709} 710 711type extractedDatabase struct { 712 uri string 713 db string 714} 715 716// extractDatabaseFromURI is a helper function to retrieve information about 717// the database from the passed in URI. It accepts as an argument the currently 718// parsed URI and returns the remainder of the uri, the database it found, 719// and any error it encounters while parsing. 720func extractDatabaseFromURI(uri string) (extractedDatabase, error) { 721 if len(uri) == 0 { 722 return extractedDatabase{}, nil 723 } 724 725 if uri[0] != '/' { 726 return extractedDatabase{}, errors.New("must have a / separator between hosts and path") 727 } 728 729 uri = uri[1:] 730 if len(uri) == 0 { 731 return extractedDatabase{}, nil 732 } 733 734 database := uri 735 if idx := strings.IndexRune(uri, '?'); idx != -1 { 736 database = uri[:idx] 737 } 738 739 escapedDatabase, err := url.QueryUnescape(database) 740 if err != nil { 741 return extractedDatabase{}, internal.WrapErrorf(err, "invalid database \"%s\"", database) 742 } 743 744 uri = uri[len(database):] 745 746 return extractedDatabase{ 747 uri: uri, 748 db: escapedDatabase, 749 }, nil 750} 751