1// Copyright (C) MongoDB, Inc. 2014-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package connstring
8
9import (
10	"errors"
11	"fmt"
12	"net"
13	"net/url"
14	"runtime"
15	"strconv"
16	"strings"
17	"time"
18)
19
20// Parse parses the provided uri and returns a URI object.
21func ParseURIConnectionString(s string) (ConnString, error) {
22	var p parser
23	err := p.parse(s)
24	if err != nil {
25		err = fmt.Errorf("error parsing uri (%s): %s", s, err)
26	}
27	return p.ConnString, err
28}
29
30// ConnString represents a connection string to mongodb.
31type ConnString struct {
32	Original                string
33	AppName                 string
34	AuthMechanism           string
35	AuthMechanismProperties map[string]string
36	AuthSource              string
37	Connect                 ConnectMode
38	ConnectTimeout          time.Duration
39	Database                string
40	FSync                   bool
41	HeartbeatInterval       time.Duration
42	Hosts                   []string
43	Journal                 bool
44	KerberosService         string
45	KerberosServiceHost     string
46	MaxConnIdleTime         time.Duration
47	MaxConnLifeTime         time.Duration
48	MaxConnsPerHost         uint16
49	MaxConnsPerHostSet      bool
50	MaxIdleConnsPerHost     uint16
51	MaxIdleConnsPerHostSet  bool
52	Password                string
53	PasswordSet             bool
54	ReadPreference          string
55	ReadPreferenceTagSets   []map[string]string
56	ReplicaSet              string
57	ServerSelectionTimeout  time.Duration
58	SocketTimeout           time.Duration
59	Username                string
60	UseSSL                  bool
61	UseSSLSeen              bool
62	W                       string
63	WTimeout                time.Duration
64
65	UsingSRV bool
66
67	Options        map[string][]string
68	UnknownOptions map[string][]string
69}
70
71func (u *ConnString) String() string {
72	return u.Original
73}
74
75// ConnectMode informs the driver on how to connect
76// to the server.
77type ConnectMode uint8
78
79// ConnectMode constants.
80const (
81	AutoConnect ConnectMode = iota
82	SingleConnect
83)
84
85type parser struct {
86	ConnString
87
88	haveWTimeoutMS bool
89}
90
91func (p *parser) parse(original string) error {
92	p.Original = original
93	uri := original
94
95	var err error
96	var isSRV bool
97	if strings.HasPrefix(uri, "mongodb+srv://") {
98		isSRV = true
99
100		p.UsingSRV = true
101
102		// SSL should be turned on by default when retrieving hosts from SRV
103		p.UseSSL = true
104		p.UseSSLSeen = true
105
106		// remove the scheme
107		uri = uri[14:]
108	} else if strings.HasPrefix(uri, "mongodb://") {
109		// remove the scheme
110		uri = uri[10:]
111	} else {
112		return fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"")
113	}
114
115	if idx := strings.Index(uri, "@"); idx != -1 {
116		userInfo := uri[:idx]
117		uri = uri[idx+1:]
118
119		username := userInfo
120		var password string
121
122		if idx := strings.Index(userInfo, ":"); idx != -1 {
123			username = userInfo[:idx]
124			password = userInfo[idx+1:]
125			p.PasswordSet = true
126		}
127
128		if len(username) > 1 {
129			if strings.Contains(username, "/") {
130				return fmt.Errorf("unescaped slash in username")
131			}
132		}
133
134		p.Username, err = url.QueryUnescape(username)
135		if err != nil {
136			return fmt.Errorf("invalid username: %s", err)
137		}
138		if len(password) > 1 {
139			if strings.Contains(password, ":") {
140				return fmt.Errorf("unescaped colon in password")
141			}
142			if strings.Contains(password, "/") {
143				return fmt.Errorf("unescaped slash in password")
144			}
145			p.Password, err = url.QueryUnescape(password)
146			if err != nil {
147				return fmt.Errorf("invalid password: %s", err)
148			}
149		}
150	}
151
152	// fetch the hosts field
153	hosts := uri
154	if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
155		if uri[idx] == '@' {
156			return fmt.Errorf("unescaped @ sign in user info")
157		}
158		if uri[idx] == '?' {
159			return fmt.Errorf("must have a / before the query ?")
160		}
161		hosts = uri[:idx]
162	}
163
164	var connectionArgsFromTXT []string
165	parsedHosts := strings.Split(hosts, ",")
166
167	if isSRV {
168		parsedHosts = strings.Split(hosts, ",")
169		if len(parsedHosts) != 1 {
170			return fmt.Errorf("URI with SRV must include one and only one hostname")
171		}
172		parsedHosts, err = fetchSeedlistFromSRV(parsedHosts[0])
173		if err != nil {
174			return err
175		}
176
177		// error ignored because finding a TXT record should not be
178		// considered an error.
179		recordsFromTXT, _ := net.LookupTXT(hosts)
180
181		// This is a temporary fix to get around bug https://github.com/golang/go/issues/21472.
182		// It will currently incorrectly concatenate multiple TXT records to one
183		// on windows.
184		if runtime.GOOS == "windows" {
185			recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
186		}
187
188		if len(recordsFromTXT) > 1 {
189			return errors.New("multiple records from TXT not supported")
190		}
191		if len(recordsFromTXT) > 0 {
192			connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })
193
194			err := validateTXTResult(connectionArgsFromTXT)
195			if err != nil {
196				return err
197			}
198
199		}
200	}
201
202	for _, host := range parsedHosts {
203		err = p.addHost(host)
204		if err != nil {
205			return fmt.Errorf("invalid host \"%s\": %s", host, err)
206		}
207	}
208	if len(p.Hosts) == 0 {
209		return fmt.Errorf("must have at least 1 host")
210	}
211
212	uri = uri[len(hosts):]
213
214	extractedDatabase, err := extractDatabaseFromURI(uri)
215	if err != nil {
216		return err
217	}
218
219	uri = extractedDatabase.uri
220	p.Database = extractedDatabase.db
221
222	connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
223	connectionArgPairs := append(connectionArgsFromTXT, connectionArgsFromQueryString...)
224
225	for _, pair := range connectionArgPairs {
226		err = p.addOption(pair)
227		if err != nil {
228			return err
229		}
230	}
231
232	return nil
233}
234
235func fetchSeedlistFromSRV(host string) ([]string, error) {
236	var err error
237
238	_, _, err = net.SplitHostPort(host)
239
240	if err == nil {
241		// we were able to successfully extract a port from the host,
242		// but should not be able to when using SRV
243		return nil, fmt.Errorf("URI with srv must not include a port number")
244	}
245
246	_, addresses, err := net.LookupSRV("mongodb", "tcp", host)
247	if err != nil {
248		return nil, err
249	}
250	parsedHosts := make([]string, len(addresses))
251	for i, address := range addresses {
252		trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
253		err := validateSRVResult(trimmedAddressTarget, host)
254		if err != nil {
255			return nil, err
256		}
257		parsedHosts[i] = fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port)
258	}
259
260	return parsedHosts, nil
261}
262
263func (p *parser) addHost(host string) error {
264	if host == "" {
265		return nil
266	}
267	host, err := url.QueryUnescape(host)
268	if err != nil {
269		return fmt.Errorf("invalid host \"%s\": %s", host, err)
270	}
271
272	_, port, err := net.SplitHostPort(host)
273	// this is unfortunate that SplitHostPort actually requires
274	// a port to exist.
275	if err != nil {
276		if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" {
277			return err
278		}
279	}
280
281	if port != "" {
282		d, err := strconv.Atoi(port)
283		if err != nil {
284			return fmt.Errorf("port must be an integer: %s", err)
285		}
286		if d <= 0 || d >= 65536 {
287			return fmt.Errorf("port must be in the range [1, 65535]")
288		}
289	}
290	p.Hosts = append(p.Hosts, host)
291	return nil
292}
293
294func (p *parser) addOption(pair string) error {
295	kv := strings.SplitN(pair, "=", 2)
296	if len(kv) != 2 || kv[0] == "" {
297		return fmt.Errorf("invalid option")
298	}
299
300	key, err := url.QueryUnescape(kv[0])
301	if err != nil {
302		return fmt.Errorf("invalid option key \"%s\": %s", kv[0], err)
303	}
304
305	value, err := url.QueryUnescape(kv[1])
306	if err != nil {
307		return fmt.Errorf("invalid option value \"%s\": %s", kv[1], err)
308	}
309
310	lowerKey := strings.ToLower(key)
311	switch lowerKey {
312	case "appname":
313		p.AppName = value
314	case "authmechanism":
315		p.AuthMechanism = value
316	case "authmechanismproperties":
317		p.AuthMechanismProperties = make(map[string]string)
318		pairs := strings.Split(value, ",")
319		for _, pair := range pairs {
320			kv := strings.SplitN(pair, ":", 2)
321			if len(kv) != 2 || kv[0] == "" {
322				return fmt.Errorf("invalid authMechanism property")
323			}
324			p.AuthMechanismProperties[kv[0]] = kv[1]
325		}
326	case "authsource":
327		p.AuthSource = value
328	case "connect":
329		switch strings.ToLower(value) {
330		case "auto", "automatic":
331			p.Connect = AutoConnect
332		case "direct", "single":
333			p.Connect = SingleConnect
334		default:
335			return fmt.Errorf("invalid 'connect' value: %s", value)
336		}
337	case "connecttimeoutms":
338		n, err := strconv.Atoi(value)
339		if err != nil || n < 0 {
340			return fmt.Errorf("invalid value for %s: %s", key, value)
341		}
342		p.ConnectTimeout = time.Duration(n) * time.Millisecond
343	case "heartbeatintervalms", "heartbeatfrequencyms":
344		n, err := strconv.Atoi(value)
345		if err != nil || n < 0 {
346			return fmt.Errorf("invalid value for %s: %s", key, value)
347		}
348		p.HeartbeatInterval = time.Duration(n) * time.Millisecond
349	case "fsync":
350		f, err := strconv.ParseBool(value)
351		if err != nil {
352			return fmt.Errorf("invalid value for %s: %s", key, value)
353		}
354		p.FSync = f
355	case "j":
356		j, err := strconv.ParseBool(value)
357		if err != nil {
358			return fmt.Errorf("invalid value for %s: %s", key, value)
359		}
360		p.Journal = j
361	case "gssapiservicename":
362		p.KerberosService = value
363	case "gssapihostname":
364		p.KerberosServiceHost = value
365	case "maxconnsperhost":
366		n, err := strconv.Atoi(value)
367		if err != nil || n < 0 {
368			return fmt.Errorf("invalid value for %s: %s", key, value)
369		}
370		p.MaxConnsPerHost = uint16(n)
371		p.MaxConnsPerHostSet = true
372	case "maxidleconnsperhost":
373		n, err := strconv.Atoi(value)
374		if err != nil || n < 0 {
375			return fmt.Errorf("invalid value for %s: %s", key, value)
376		}
377		p.MaxIdleConnsPerHost = uint16(n)
378		p.MaxIdleConnsPerHostSet = true
379	case "maxidletimems":
380		n, err := strconv.Atoi(value)
381		if err != nil || n < 0 {
382			return fmt.Errorf("invalid value for %s: %s", key, value)
383		}
384		p.MaxConnIdleTime = time.Duration(n) * time.Millisecond
385	case "maxlifetimems":
386		n, err := strconv.Atoi(value)
387		if err != nil || n < 0 {
388			return fmt.Errorf("invalid value for %s: %s", key, value)
389		}
390		p.MaxConnLifeTime = time.Duration(n) * time.Millisecond
391	case "maxpoolsize":
392		n, err := strconv.Atoi(value)
393		if err != nil || n < 0 {
394			return fmt.Errorf("invalid value for %s: %s", key, value)
395		}
396		p.MaxConnsPerHost = uint16(n)
397		p.MaxConnsPerHostSet = true
398		p.MaxIdleConnsPerHost = uint16(n)
399		p.MaxIdleConnsPerHostSet = true
400	case "readpreference":
401		p.ReadPreference = value
402	case "readpreferencetags":
403		tags := make(map[string]string)
404		items := strings.Split(value, ",")
405		for _, item := range items {
406			parts := strings.Split(item, ":")
407			if len(parts) != 2 {
408				return fmt.Errorf("invalid value for %s: %s", key, value)
409			}
410			tags[parts[0]] = parts[1]
411		}
412		p.ReadPreferenceTagSets = append(p.ReadPreferenceTagSets, tags)
413	case "replicaset":
414		p.ReplicaSet = value
415	case "serverselectiontimeoutms":
416		n, err := strconv.Atoi(value)
417		if err != nil || n < 0 {
418			return fmt.Errorf("invalid value for %s: %s", key, value)
419		}
420		p.ServerSelectionTimeout = time.Duration(n) * time.Millisecond
421	case "sockettimeoutms":
422		n, err := strconv.Atoi(value)
423		if err != nil || n < 0 {
424			return fmt.Errorf("invalid value for %s: %s", key, value)
425		}
426		p.SocketTimeout = time.Duration(n) * time.Millisecond
427	case "ssl":
428		b, err := strconv.ParseBool(value)
429		if err != nil {
430			return fmt.Errorf("invalid value for %s: %s", key, value)
431		}
432		p.UseSSL = b
433		p.UseSSLSeen = true
434	case "w":
435		p.W = value
436	case "wtimeoutms":
437		n, err := strconv.Atoi(value)
438		if err != nil || n < 0 {
439			return fmt.Errorf("invalid value for %s: %s", key, value)
440		}
441		p.WTimeout = time.Duration(n) * time.Millisecond
442		p.haveWTimeoutMS = true
443	case "wtimeout":
444		if p.haveWTimeoutMS {
445			// use wtimeoutMS if it exists
446			break
447		}
448		n, err := strconv.Atoi(value)
449		if err != nil || n < 0 {
450			return fmt.Errorf("invalid value for %s: %s", key, value)
451		}
452		p.WTimeout = time.Duration(n) * time.Millisecond
453	default:
454		if p.UnknownOptions == nil {
455			p.UnknownOptions = make(map[string][]string)
456		}
457		p.UnknownOptions[lowerKey] = append(p.UnknownOptions[lowerKey], value)
458	}
459
460	if p.Options == nil {
461		p.Options = make(map[string][]string)
462	}
463	p.Options[lowerKey] = append(p.Options[lowerKey], value)
464
465	return nil
466}
467
468func validateSRVResult(recordFromSRV, inputHostName string) error {
469	separatedInputDomain := strings.Split(inputHostName, ".")
470	separatedRecord := strings.Split(recordFromSRV, ".")
471	if len(separatedRecord) < 2 {
472		return errors.New("DNS name must contain at least 2 labels")
473	}
474	if len(separatedRecord) < len(separatedInputDomain) {
475		return errors.New("Domain suffix from SRV record not matched input domain")
476	}
477
478	inputDomainSuffix := separatedInputDomain[1:]
479	domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
480
481	recordDomainSuffix := separatedRecord[domainSuffixOffset:]
482	for ix, label := range inputDomainSuffix {
483		if label != recordDomainSuffix[ix] {
484			return errors.New("Domain suffix from SRV record not matched input domain")
485		}
486	}
487	return nil
488}
489
490var allowedTXTOptions = map[string]struct{}{
491	"authsource": {},
492	"replicaset": {},
493}
494
495func validateTXTResult(paramsFromTXT []string) error {
496	for _, param := range paramsFromTXT {
497		kv := strings.SplitN(param, "=", 2)
498		if len(kv) != 2 {
499			return errors.New("Invalid TXT record")
500		}
501		key := strings.ToLower(kv[0])
502		if _, ok := allowedTXTOptions[key]; !ok {
503			return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0])
504		}
505	}
506	return nil
507}
508
509func extractQueryArgsFromURI(uri string) ([]string, error) {
510	if len(uri) == 0 {
511		return nil, nil
512	}
513
514	if uri[0] != '?' {
515		return nil, errors.New("must have a ? separator between path and query")
516	}
517
518	uri = uri[1:]
519	if len(uri) == 0 {
520		return nil, nil
521	}
522	return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil
523
524}
525
526type extractedDatabase struct {
527	uri string
528	db  string
529}
530
531// extractDatabaseFromURI is a helper function to retrieve information about
532// the database from the passed in URI. It accepts as an argument the currently
533// parsed URI and returns the remainder of the uri, the database it found,
534// and any error it encounters while parsing.
535func extractDatabaseFromURI(uri string) (extractedDatabase, error) {
536	if len(uri) == 0 {
537		return extractedDatabase{}, nil
538	}
539
540	if uri[0] != '/' {
541		return extractedDatabase{}, errors.New("must have a / separator between hosts and path")
542	}
543
544	uri = uri[1:]
545	if len(uri) == 0 {
546		return extractedDatabase{}, nil
547	}
548
549	database := uri
550	if idx := strings.IndexRune(uri, '?'); idx != -1 {
551		database = uri[:idx]
552	}
553
554	escapedDatabase, err := url.QueryUnescape(database)
555	if err != nil {
556		return extractedDatabase{}, fmt.Errorf("invalid database \"%s\": %s", database, err)
557	}
558
559	uri = uri[len(database):]
560
561	return extractedDatabase{
562		uri: uri,
563		db:  escapedDatabase,
564	}, nil
565}
566