1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package connstring // import "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
8
9import (
10	"errors"
11	"fmt"
12	"net"
13	"net/url"
14	"strconv"
15	"strings"
16	"time"
17
18	"go.mongodb.org/mongo-driver/internal"
19	"go.mongodb.org/mongo-driver/mongo/writeconcern"
20	"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
21	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
22)
23
24// Parse parses the provided uri and returns a URI object.
25func Parse(s string) (ConnString, error) {
26	p := parser{dnsResolver: dns.DefaultResolver}
27	err := p.parse(s)
28	if err != nil {
29		err = internal.WrapErrorf(err, "error parsing uri")
30	}
31	return p.ConnString, err
32}
33
34// ConnString represents a connection string to mongodb.
35type ConnString struct {
36	Original                           string
37	AppName                            string
38	AuthMechanism                      string
39	AuthMechanismProperties            map[string]string
40	AuthSource                         string
41	Compressors                        []string
42	Connect                            ConnectMode
43	ConnectSet                         bool
44	ConnectTimeout                     time.Duration
45	ConnectTimeoutSet                  bool
46	Database                           string
47	HeartbeatInterval                  time.Duration
48	HeartbeatIntervalSet               bool
49	Hosts                              []string
50	J                                  bool
51	JSet                               bool
52	LocalThreshold                     time.Duration
53	LocalThresholdSet                  bool
54	MaxConnIdleTime                    time.Duration
55	MaxConnIdleTimeSet                 bool
56	MaxPoolSize                        uint64
57	MaxPoolSizeSet                     bool
58	MinPoolSize                        uint64
59	MinPoolSizeSet                     bool
60	Password                           string
61	PasswordSet                        bool
62	ReadConcernLevel                   string
63	ReadPreference                     string
64	ReadPreferenceTagSets              []map[string]string
65	RetryWrites                        bool
66	RetryWritesSet                     bool
67	RetryReads                         bool
68	RetryReadsSet                      bool
69	MaxStaleness                       time.Duration
70	MaxStalenessSet                    bool
71	ReplicaSet                         string
72	Scheme                             string
73	ServerSelectionTimeout             time.Duration
74	ServerSelectionTimeoutSet          bool
75	SocketTimeout                      time.Duration
76	SocketTimeoutSet                   bool
77	SSL                                bool
78	SSLSet                             bool
79	SSLClientCertificateKeyFile        string
80	SSLClientCertificateKeyFileSet     bool
81	SSLClientCertificateKeyPassword    func() string
82	SSLClientCertificateKeyPasswordSet bool
83	SSLInsecure                        bool
84	SSLInsecureSet                     bool
85	SSLCaFile                          string
86	SSLCaFileSet                       bool
87	WString                            string
88	WNumber                            int
89	WNumberSet                         bool
90	Username                           string
91	ZlibLevel                          int
92	ZlibLevelSet                       bool
93	ZstdLevel                          int
94	ZstdLevelSet                       bool
95
96	WTimeout              time.Duration
97	WTimeoutSet           bool
98	WTimeoutSetFromOption bool
99
100	Options        map[string][]string
101	UnknownOptions map[string][]string
102}
103
104func (u *ConnString) String() string {
105	return u.Original
106}
107
108// ConnectMode informs the driver on how to connect
109// to the server.
110type ConnectMode uint8
111
112// ConnectMode constants.
113const (
114	AutoConnect ConnectMode = iota
115	SingleConnect
116)
117
118// Scheme constants
119const (
120	SchemeMongoDB    = "mongodb"
121	SchemeMongoDBSRV = "mongodb+srv"
122)
123
124type parser struct {
125	ConnString
126
127	dnsResolver *dns.Resolver
128	tlsssl      *bool // used to determine if tls and ssl options are both specified and set differently.
129}
130
131func (p *parser) parse(original string) error {
132	p.Original = original
133	uri := original
134
135	var err error
136	if strings.HasPrefix(uri, SchemeMongoDBSRV+"://") {
137		p.Scheme = SchemeMongoDBSRV
138		// remove the scheme
139		uri = uri[len(SchemeMongoDBSRV)+3:]
140	} else if strings.HasPrefix(uri, SchemeMongoDB+"://") {
141		p.Scheme = SchemeMongoDB
142		// remove the scheme
143		uri = uri[len(SchemeMongoDB)+3:]
144	} else {
145		return fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"")
146	}
147
148	if idx := strings.Index(uri, "@"); idx != -1 {
149		userInfo := uri[:idx]
150		uri = uri[idx+1:]
151
152		username := userInfo
153		var password string
154
155		if idx := strings.Index(userInfo, ":"); idx != -1 {
156			username = userInfo[:idx]
157			password = userInfo[idx+1:]
158			p.PasswordSet = true
159		}
160
161		if len(username) > 1 {
162			if strings.Contains(username, "/") {
163				return fmt.Errorf("unescaped slash in username")
164			}
165		}
166
167		p.Username, err = url.QueryUnescape(username)
168		if err != nil {
169			return internal.WrapErrorf(err, "invalid username")
170		}
171		if len(password) > 1 {
172			if strings.Contains(password, ":") {
173				return fmt.Errorf("unescaped colon in password")
174			}
175			if strings.Contains(password, "/") {
176				return fmt.Errorf("unescaped slash in password")
177			}
178			p.Password, err = url.QueryUnescape(password)
179			if err != nil {
180				return internal.WrapErrorf(err, "invalid password")
181			}
182		}
183	}
184
185	// fetch the hosts field
186	hosts := uri
187	if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
188		if uri[idx] == '@' {
189			return fmt.Errorf("unescaped @ sign in user info")
190		}
191		if uri[idx] == '?' {
192			return fmt.Errorf("must have a / before the query ?")
193		}
194		hosts = uri[:idx]
195	}
196
197	var connectionArgsFromTXT []string
198	parsedHosts := strings.Split(hosts, ",")
199
200	if p.Scheme == SchemeMongoDBSRV {
201		parsedHosts, err = p.dnsResolver.ParseHosts(hosts, true)
202		if err != nil {
203			return err
204		}
205		connectionArgsFromTXT, err = p.dnsResolver.GetConnectionArgsFromTXT(hosts)
206		if err != nil {
207			return err
208		}
209
210		// SSL is enabled by default for SRV, but can be manually disabled with "ssl=false".
211		p.SSL = true
212		p.SSLSet = true
213	}
214
215	for _, host := range parsedHosts {
216		err = p.addHost(host)
217		if err != nil {
218			return internal.WrapErrorf(err, "invalid host \"%s\"", host)
219		}
220	}
221	if len(p.Hosts) == 0 {
222		return fmt.Errorf("must have at least 1 host")
223	}
224
225	uri = uri[len(hosts):]
226
227	extractedDatabase, err := extractDatabaseFromURI(uri)
228	if err != nil {
229		return err
230	}
231
232	uri = extractedDatabase.uri
233	p.Database = extractedDatabase.db
234
235	connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
236	connectionArgPairs := append(connectionArgsFromTXT, connectionArgsFromQueryString...)
237
238	for _, pair := range connectionArgPairs {
239		err = p.addOption(pair)
240		if err != nil {
241			return err
242		}
243	}
244
245	err = p.setDefaultAuthParams(extractedDatabase.db)
246	if err != nil {
247		return err
248	}
249
250	err = p.validateAuth()
251	if err != nil {
252		return err
253	}
254
255	// Check for invalid write concern (i.e. w=0 and j=true)
256	if p.WNumberSet && p.WNumber == 0 && p.JSet && p.J {
257		return writeconcern.ErrInconsistent
258	}
259
260	// If WTimeout was set from manual options passed in, set WTImeoutSet to true.
261	if p.WTimeoutSetFromOption {
262		p.WTimeoutSet = true
263	}
264
265	return nil
266}
267
268func (p *parser) setDefaultAuthParams(dbName string) error {
269	switch strings.ToLower(p.AuthMechanism) {
270	case "plain":
271		if p.AuthSource == "" {
272			p.AuthSource = dbName
273			if p.AuthSource == "" {
274				p.AuthSource = "$external"
275			}
276		}
277	case "gssapi":
278		if p.AuthMechanismProperties == nil {
279			p.AuthMechanismProperties = map[string]string{
280				"SERVICE_NAME": "mongodb",
281			}
282		} else if v, ok := p.AuthMechanismProperties["SERVICE_NAME"]; !ok || v == "" {
283			p.AuthMechanismProperties["SERVICE_NAME"] = "mongodb"
284		}
285		fallthrough
286	case "mongodb-x509":
287		if p.AuthSource == "" {
288			p.AuthSource = "$external"
289		} else if p.AuthSource != "$external" {
290			return fmt.Errorf("auth source must be $external")
291		}
292	case "mongodb-cr":
293		fallthrough
294	case "scram-sha-1":
295		fallthrough
296	case "scram-sha-256":
297		if p.AuthSource == "" {
298			p.AuthSource = dbName
299			if p.AuthSource == "" {
300				p.AuthSource = "admin"
301			}
302		}
303	case "":
304		if p.AuthSource == "" && (p.AuthMechanismProperties != nil || p.Username != "" || p.PasswordSet) {
305			p.AuthSource = dbName
306			if p.AuthSource == "" {
307				p.AuthSource = "admin"
308			}
309		}
310	default:
311		return fmt.Errorf("invalid auth mechanism")
312	}
313	return nil
314}
315
316func (p *parser) validateAuth() error {
317	switch strings.ToLower(p.AuthMechanism) {
318	case "mongodb-cr":
319		if p.Username == "" {
320			return fmt.Errorf("username required for MONGO-CR")
321		}
322		if p.Password == "" {
323			return fmt.Errorf("password required for MONGO-CR")
324		}
325		if p.AuthMechanismProperties != nil {
326			return fmt.Errorf("MONGO-CR cannot have mechanism properties")
327		}
328	case "mongodb-x509":
329		if p.Password != "" {
330			return fmt.Errorf("password cannot be specified for MONGO-X509")
331		}
332		if p.AuthMechanismProperties != nil {
333			return fmt.Errorf("MONGO-X509 cannot have mechanism properties")
334		}
335	case "gssapi":
336		if p.Username == "" {
337			return fmt.Errorf("username required for GSSAPI")
338		}
339		for k := range p.AuthMechanismProperties {
340			if k != "SERVICE_NAME" && k != "CANONICALIZE_HOST_NAME" && k != "SERVICE_REALM" {
341				return fmt.Errorf("invalid auth property for GSSAPI")
342			}
343		}
344	case "plain":
345		if p.Username == "" {
346			return fmt.Errorf("username required for PLAIN")
347		}
348		if p.Password == "" {
349			return fmt.Errorf("password required for PLAIN")
350		}
351		if p.AuthMechanismProperties != nil {
352			return fmt.Errorf("PLAIN cannot have mechanism properties")
353		}
354	case "scram-sha-1":
355		if p.Username == "" {
356			return fmt.Errorf("username required for SCRAM-SHA-1")
357		}
358		if p.Password == "" {
359			return fmt.Errorf("password required for SCRAM-SHA-1")
360		}
361		if p.AuthMechanismProperties != nil {
362			return fmt.Errorf("SCRAM-SHA-1 cannot have mechanism properties")
363		}
364	case "scram-sha-256":
365		if p.Username == "" {
366			return fmt.Errorf("username required for SCRAM-SHA-256")
367		}
368		if p.Password == "" {
369			return fmt.Errorf("password required for SCRAM-SHA-256")
370		}
371		if p.AuthMechanismProperties != nil {
372			return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties")
373		}
374	case "":
375		if p.Username == "" && p.AuthSource != "" {
376			return fmt.Errorf("authsource without username is invalid")
377		}
378	default:
379		return fmt.Errorf("invalid auth mechanism")
380	}
381	return nil
382}
383
384func (p *parser) addHost(host string) error {
385	if host == "" {
386		return nil
387	}
388	host, err := url.QueryUnescape(host)
389	if err != nil {
390		return internal.WrapErrorf(err, "invalid host \"%s\"", host)
391	}
392
393	_, port, err := net.SplitHostPort(host)
394	// this is unfortunate that SplitHostPort actually requires
395	// a port to exist.
396	if err != nil {
397		if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" {
398			return err
399		}
400	}
401
402	if port != "" {
403		d, err := strconv.Atoi(port)
404		if err != nil {
405			return internal.WrapErrorf(err, "port must be an integer")
406		}
407		if d <= 0 || d >= 65536 {
408			return fmt.Errorf("port must be in the range [1, 65535]")
409		}
410	}
411	p.Hosts = append(p.Hosts, host)
412	return nil
413}
414
415func (p *parser) addOption(pair string) error {
416	kv := strings.SplitN(pair, "=", 2)
417	if len(kv) != 2 || kv[0] == "" {
418		return fmt.Errorf("invalid option")
419	}
420
421	key, err := url.QueryUnescape(kv[0])
422	if err != nil {
423		return internal.WrapErrorf(err, "invalid option key \"%s\"", kv[0])
424	}
425
426	value, err := url.QueryUnescape(kv[1])
427	if err != nil {
428		return internal.WrapErrorf(err, "invalid option value \"%s\"", kv[1])
429	}
430
431	lowerKey := strings.ToLower(key)
432	switch lowerKey {
433	case "appname":
434		p.AppName = value
435	case "authmechanism":
436		p.AuthMechanism = value
437	case "authmechanismproperties":
438		p.AuthMechanismProperties = make(map[string]string)
439		pairs := strings.Split(value, ",")
440		for _, pair := range pairs {
441			kv := strings.SplitN(pair, ":", 2)
442			if len(kv) != 2 || kv[0] == "" {
443				return fmt.Errorf("invalid authMechanism property")
444			}
445			p.AuthMechanismProperties[kv[0]] = kv[1]
446		}
447	case "authsource":
448		p.AuthSource = value
449	case "compressors":
450		compressors := strings.Split(value, ",")
451		if len(compressors) < 1 {
452			return fmt.Errorf("must have at least 1 compressor")
453		}
454		p.Compressors = compressors
455	case "connect":
456		switch strings.ToLower(value) {
457		case "automatic":
458		case "direct":
459			p.Connect = SingleConnect
460		default:
461			return fmt.Errorf("invalid 'connect' value: %s", value)
462		}
463
464		p.ConnectSet = true
465	case "connecttimeoutms":
466		n, err := strconv.Atoi(value)
467		if err != nil || n < 0 {
468			return fmt.Errorf("invalid value for %s: %s", key, value)
469		}
470		p.ConnectTimeout = time.Duration(n) * time.Millisecond
471		p.ConnectTimeoutSet = true
472	case "heartbeatintervalms", "heartbeatfrequencyms":
473		n, err := strconv.Atoi(value)
474		if err != nil || n < 0 {
475			return fmt.Errorf("invalid value for %s: %s", key, value)
476		}
477		p.HeartbeatInterval = time.Duration(n) * time.Millisecond
478		p.HeartbeatIntervalSet = true
479	case "journal":
480		switch value {
481		case "true":
482			p.J = true
483		case "false":
484			p.J = false
485		default:
486			return fmt.Errorf("invalid value for %s: %s", key, value)
487		}
488
489		p.JSet = true
490	case "localthresholdms":
491		n, err := strconv.Atoi(value)
492		if err != nil || n < 0 {
493			return fmt.Errorf("invalid value for %s: %s", key, value)
494		}
495		p.LocalThreshold = time.Duration(n) * time.Millisecond
496		p.LocalThresholdSet = true
497	case "maxidletimems":
498		n, err := strconv.Atoi(value)
499		if err != nil || n < 0 {
500			return fmt.Errorf("invalid value for %s: %s", key, value)
501		}
502		p.MaxConnIdleTime = time.Duration(n) * time.Millisecond
503		p.MaxConnIdleTimeSet = true
504	case "maxpoolsize":
505		n, err := strconv.Atoi(value)
506		if err != nil || n < 0 {
507			return fmt.Errorf("invalid value for %s: %s", key, value)
508		}
509		p.MaxPoolSize = uint64(n)
510		p.MaxPoolSizeSet = true
511	case "minpoolsize":
512		n, err := strconv.Atoi(value)
513		if err != nil || n < 0 {
514			return fmt.Errorf("invalid value for %s: %s", key, value)
515		}
516		p.MinPoolSize = uint64(n)
517		p.MinPoolSizeSet = true
518	case "readconcernlevel":
519		p.ReadConcernLevel = value
520	case "readpreference":
521		p.ReadPreference = value
522	case "readpreferencetags":
523		if value == "" {
524			// for when readPreferenceTags= at end of URI
525			break
526		}
527
528		tags := make(map[string]string)
529		items := strings.Split(value, ",")
530		for _, item := range items {
531			parts := strings.Split(item, ":")
532			if len(parts) != 2 {
533				return fmt.Errorf("invalid value for %s: %s", key, value)
534			}
535			tags[parts[0]] = parts[1]
536		}
537		p.ReadPreferenceTagSets = append(p.ReadPreferenceTagSets, tags)
538	case "maxstaleness", "maxstalenessseconds":
539		n, err := strconv.Atoi(value)
540		if err != nil || n < 0 {
541			return fmt.Errorf("invalid value for %s: %s", key, value)
542		}
543		p.MaxStaleness = time.Duration(n) * time.Second
544		p.MaxStalenessSet = true
545	case "replicaset":
546		p.ReplicaSet = value
547	case "retrywrites":
548		switch value {
549		case "true":
550			p.RetryWrites = true
551		case "false":
552			p.RetryWrites = false
553		default:
554			return fmt.Errorf("invalid value for %s: %s", key, value)
555		}
556
557		p.RetryWritesSet = true
558	case "retryreads":
559		switch value {
560		case "true":
561			p.RetryReads = true
562		case "false":
563			p.RetryReads = false
564		default:
565			return fmt.Errorf("invalid value for %s: %s", key, value)
566		}
567
568		p.RetryReadsSet = true
569	case "serverselectiontimeoutms":
570		n, err := strconv.Atoi(value)
571		if err != nil || n < 0 {
572			return fmt.Errorf("invalid value for %s: %s", key, value)
573		}
574		p.ServerSelectionTimeout = time.Duration(n) * time.Millisecond
575		p.ServerSelectionTimeoutSet = true
576	case "sockettimeoutms":
577		n, err := strconv.Atoi(value)
578		if err != nil || n < 0 {
579			return fmt.Errorf("invalid value for %s: %s", key, value)
580		}
581		p.SocketTimeout = time.Duration(n) * time.Millisecond
582		p.SocketTimeoutSet = true
583	case "ssl", "tls":
584		switch value {
585		case "true":
586			p.SSL = true
587		case "false":
588			p.SSL = false
589		default:
590			return fmt.Errorf("invalid value for %s: %s", key, value)
591		}
592		if p.tlsssl != nil && *p.tlsssl != p.SSL {
593			return errors.New("tls and ssl options, when both specified, must be equivalent")
594		}
595
596		p.tlsssl = new(bool)
597		*p.tlsssl = p.SSL
598
599		p.SSLSet = true
600	case "sslclientcertificatekeyfile", "tlscertificatekeyfile":
601		p.SSL = true
602		p.SSLSet = true
603		p.SSLClientCertificateKeyFile = value
604		p.SSLClientCertificateKeyFileSet = true
605	case "sslclientcertificatekeypassword", "tlscertificatekeyfilepassword":
606		p.SSLClientCertificateKeyPassword = func() string { return value }
607		p.SSLClientCertificateKeyPasswordSet = true
608	case "sslinsecure", "tlsinsecure":
609		switch value {
610		case "true":
611			p.SSLInsecure = true
612		case "false":
613			p.SSLInsecure = false
614		default:
615			return fmt.Errorf("invalid value for %s: %s", key, value)
616		}
617
618		p.SSLInsecureSet = true
619	case "sslcertificateauthorityfile", "tlscafile":
620		p.SSL = true
621		p.SSLSet = true
622		p.SSLCaFile = value
623		p.SSLCaFileSet = true
624	case "w":
625		if w, err := strconv.Atoi(value); err == nil {
626			if w < 0 {
627				return fmt.Errorf("invalid value for %s: %s", key, value)
628			}
629
630			p.WNumber = w
631			p.WNumberSet = true
632			p.WString = ""
633			break
634		}
635
636		p.WString = value
637		p.WNumberSet = false
638
639	case "wtimeoutms":
640		n, err := strconv.Atoi(value)
641		if err != nil || n < 0 {
642			return fmt.Errorf("invalid value for %s: %s", key, value)
643		}
644		p.WTimeout = time.Duration(n) * time.Millisecond
645		p.WTimeoutSet = true
646	case "wtimeout":
647		// Defer to wtimeoutms, but not to a manually-set option.
648		if p.WTimeoutSet {
649			break
650		}
651		n, err := strconv.Atoi(value)
652		if err != nil || n < 0 {
653			return fmt.Errorf("invalid value for %s: %s", key, value)
654		}
655		p.WTimeout = time.Duration(n) * time.Millisecond
656	case "zlibcompressionlevel":
657		level, err := strconv.Atoi(value)
658		if err != nil || (level < -1 || level > 9) {
659			return fmt.Errorf("invalid value for %s: %s", key, value)
660		}
661
662		if level == -1 {
663			level = wiremessage.DefaultZlibLevel
664		}
665		p.ZlibLevel = level
666		p.ZlibLevelSet = true
667	case "zstdcompressionlevel":
668		const maxZstdLevel = 22 // https://github.com/facebook/zstd/blob/a880ca239b447968493dd2fed3850e766d6305cc/contrib/linux-kernel/lib/zstd/compress.c#L3291
669		level, err := strconv.Atoi(value)
670		if err != nil || (level < -1 || level > maxZstdLevel) {
671			return fmt.Errorf("invalid value for %s: %s", key, value)
672		}
673
674		if level == -1 {
675			level = wiremessage.DefaultZstdLevel
676		}
677		p.ZstdLevel = level
678		p.ZstdLevelSet = true
679	default:
680		if p.UnknownOptions == nil {
681			p.UnknownOptions = make(map[string][]string)
682		}
683		p.UnknownOptions[lowerKey] = append(p.UnknownOptions[lowerKey], value)
684	}
685
686	if p.Options == nil {
687		p.Options = make(map[string][]string)
688	}
689	p.Options[lowerKey] = append(p.Options[lowerKey], value)
690
691	return nil
692}
693
694func extractQueryArgsFromURI(uri string) ([]string, error) {
695	if len(uri) == 0 {
696		return nil, nil
697	}
698
699	if uri[0] != '?' {
700		return nil, errors.New("must have a ? separator between path and query")
701	}
702
703	uri = uri[1:]
704	if len(uri) == 0 {
705		return nil, nil
706	}
707	return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil
708
709}
710
711type extractedDatabase struct {
712	uri string
713	db  string
714}
715
716// extractDatabaseFromURI is a helper function to retrieve information about
717// the database from the passed in URI. It accepts as an argument the currently
718// parsed URI and returns the remainder of the uri, the database it found,
719// and any error it encounters while parsing.
720func extractDatabaseFromURI(uri string) (extractedDatabase, error) {
721	if len(uri) == 0 {
722		return extractedDatabase{}, nil
723	}
724
725	if uri[0] != '/' {
726		return extractedDatabase{}, errors.New("must have a / separator between hosts and path")
727	}
728
729	uri = uri[1:]
730	if len(uri) == 0 {
731		return extractedDatabase{}, nil
732	}
733
734	database := uri
735	if idx := strings.IndexRune(uri, '?'); idx != -1 {
736		database = uri[:idx]
737	}
738
739	escapedDatabase, err := url.QueryUnescape(database)
740	if err != nil {
741		return extractedDatabase{}, internal.WrapErrorf(err, "invalid database \"%s\"", database)
742	}
743
744	uri = uri[len(database):]
745
746	return extractedDatabase{
747		uri: uri,
748		db:  escapedDatabase,
749	}, nil
750}
751