1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"bytes"
13	"crypto/rsa"
14	"crypto/tls"
15	"errors"
16	"fmt"
17	"net"
18	"net/url"
19	"sort"
20	"strconv"
21	"strings"
22	"time"
23)
24
25var (
26	errInvalidDSNUnescaped       = errors.New("invalid DSN: did you forget to escape a param value?")
27	errInvalidDSNAddr            = errors.New("invalid DSN: network address not terminated (missing closing brace)")
28	errInvalidDSNNoSlash         = errors.New("invalid DSN: missing the slash separating the database name")
29	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
30)
31
32// Config is a configuration parsed from a DSN string.
33// If a new Config is created instead of being parsed from a DSN string,
34// the NewConfig function should be used, which sets default values.
35type Config struct {
36	User             string            // Username
37	Passwd           string            // Password (requires User)
38	Net              string            // Network type
39	Addr             string            // Network address (requires Net)
40	DBName           string            // Database name
41	Params           map[string]string // Connection parameters
42	Collation        string            // Connection collation
43	Loc              *time.Location    // Location for time.Time values
44	MaxAllowedPacket int               // Max packet size allowed
45	ServerPubKey     string            // Server public key name
46	pubKey           *rsa.PublicKey    // Server public key
47	TLSConfig        string            // TLS configuration name
48	tls              *tls.Config       // TLS configuration
49	Timeout          time.Duration     // Dial timeout
50	ReadTimeout      time.Duration     // I/O read timeout
51	WriteTimeout     time.Duration     // I/O write timeout
52
53	AllowAllFiles           bool // Allow all files to be used with LOAD DATA LOCAL INFILE
54	AllowCleartextPasswords bool // Allows the cleartext client side plugin
55	AllowNativePasswords    bool // Allows the native password authentication method
56	AllowOldPasswords       bool // Allows the old insecure password method
57	ClientFoundRows         bool // Return number of matching rows instead of rows changed
58	ColumnsWithAlias        bool // Prepend table alias to column names
59	InterpolateParams       bool // Interpolate placeholders into query string
60	MultiStatements         bool // Allow multiple statements in one query
61	ParseTime               bool // Parse time values to time.Time
62	RejectReadOnly          bool // Reject read-only connections
63}
64
65// NewConfig creates a new Config and sets default values.
66func NewConfig() *Config {
67	return &Config{
68		Collation:            defaultCollation,
69		Loc:                  time.UTC,
70		MaxAllowedPacket:     defaultMaxAllowedPacket,
71		AllowNativePasswords: true,
72	}
73}
74
75func (cfg *Config) normalize() error {
76	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
77		return errInvalidDSNUnsafeCollation
78	}
79
80	// Set default network if empty
81	if cfg.Net == "" {
82		cfg.Net = "tcp"
83	}
84
85	// Set default address if empty
86	if cfg.Addr == "" {
87		switch cfg.Net {
88		case "tcp":
89			cfg.Addr = "127.0.0.1:3306"
90		case "unix":
91			cfg.Addr = "/tmp/mysql.sock"
92		default:
93			return errors.New("default addr for network '" + cfg.Net + "' unknown")
94		}
95
96	} else if cfg.Net == "tcp" {
97		cfg.Addr = ensureHavePort(cfg.Addr)
98	}
99
100	if cfg.tls != nil {
101		if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
102			host, _, err := net.SplitHostPort(cfg.Addr)
103			if err == nil {
104				cfg.tls.ServerName = host
105			}
106		}
107	}
108
109	return nil
110}
111
112// FormatDSN formats the given Config into a DSN string which can be passed to
113// the driver.
114func (cfg *Config) FormatDSN() string {
115	var buf bytes.Buffer
116
117	// [username[:password]@]
118	if len(cfg.User) > 0 {
119		buf.WriteString(cfg.User)
120		if len(cfg.Passwd) > 0 {
121			buf.WriteByte(':')
122			buf.WriteString(cfg.Passwd)
123		}
124		buf.WriteByte('@')
125	}
126
127	// [protocol[(address)]]
128	if len(cfg.Net) > 0 {
129		buf.WriteString(cfg.Net)
130		if len(cfg.Addr) > 0 {
131			buf.WriteByte('(')
132			buf.WriteString(cfg.Addr)
133			buf.WriteByte(')')
134		}
135	}
136
137	// /dbname
138	buf.WriteByte('/')
139	buf.WriteString(cfg.DBName)
140
141	// [?param1=value1&...&paramN=valueN]
142	hasParam := false
143
144	if cfg.AllowAllFiles {
145		hasParam = true
146		buf.WriteString("?allowAllFiles=true")
147	}
148
149	if cfg.AllowCleartextPasswords {
150		if hasParam {
151			buf.WriteString("&allowCleartextPasswords=true")
152		} else {
153			hasParam = true
154			buf.WriteString("?allowCleartextPasswords=true")
155		}
156	}
157
158	if !cfg.AllowNativePasswords {
159		if hasParam {
160			buf.WriteString("&allowNativePasswords=false")
161		} else {
162			hasParam = true
163			buf.WriteString("?allowNativePasswords=false")
164		}
165	}
166
167	if cfg.AllowOldPasswords {
168		if hasParam {
169			buf.WriteString("&allowOldPasswords=true")
170		} else {
171			hasParam = true
172			buf.WriteString("?allowOldPasswords=true")
173		}
174	}
175
176	if cfg.ClientFoundRows {
177		if hasParam {
178			buf.WriteString("&clientFoundRows=true")
179		} else {
180			hasParam = true
181			buf.WriteString("?clientFoundRows=true")
182		}
183	}
184
185	if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
186		if hasParam {
187			buf.WriteString("&collation=")
188		} else {
189			hasParam = true
190			buf.WriteString("?collation=")
191		}
192		buf.WriteString(col)
193	}
194
195	if cfg.ColumnsWithAlias {
196		if hasParam {
197			buf.WriteString("&columnsWithAlias=true")
198		} else {
199			hasParam = true
200			buf.WriteString("?columnsWithAlias=true")
201		}
202	}
203
204	if cfg.InterpolateParams {
205		if hasParam {
206			buf.WriteString("&interpolateParams=true")
207		} else {
208			hasParam = true
209			buf.WriteString("?interpolateParams=true")
210		}
211	}
212
213	if cfg.Loc != time.UTC && cfg.Loc != nil {
214		if hasParam {
215			buf.WriteString("&loc=")
216		} else {
217			hasParam = true
218			buf.WriteString("?loc=")
219		}
220		buf.WriteString(url.QueryEscape(cfg.Loc.String()))
221	}
222
223	if cfg.MultiStatements {
224		if hasParam {
225			buf.WriteString("&multiStatements=true")
226		} else {
227			hasParam = true
228			buf.WriteString("?multiStatements=true")
229		}
230	}
231
232	if cfg.ParseTime {
233		if hasParam {
234			buf.WriteString("&parseTime=true")
235		} else {
236			hasParam = true
237			buf.WriteString("?parseTime=true")
238		}
239	}
240
241	if cfg.ReadTimeout > 0 {
242		if hasParam {
243			buf.WriteString("&readTimeout=")
244		} else {
245			hasParam = true
246			buf.WriteString("?readTimeout=")
247		}
248		buf.WriteString(cfg.ReadTimeout.String())
249	}
250
251	if cfg.RejectReadOnly {
252		if hasParam {
253			buf.WriteString("&rejectReadOnly=true")
254		} else {
255			hasParam = true
256			buf.WriteString("?rejectReadOnly=true")
257		}
258	}
259
260	if len(cfg.ServerPubKey) > 0 {
261		if hasParam {
262			buf.WriteString("&serverPubKey=")
263		} else {
264			hasParam = true
265			buf.WriteString("?serverPubKey=")
266		}
267		buf.WriteString(url.QueryEscape(cfg.ServerPubKey))
268	}
269
270	if cfg.Timeout > 0 {
271		if hasParam {
272			buf.WriteString("&timeout=")
273		} else {
274			hasParam = true
275			buf.WriteString("?timeout=")
276		}
277		buf.WriteString(cfg.Timeout.String())
278	}
279
280	if len(cfg.TLSConfig) > 0 {
281		if hasParam {
282			buf.WriteString("&tls=")
283		} else {
284			hasParam = true
285			buf.WriteString("?tls=")
286		}
287		buf.WriteString(url.QueryEscape(cfg.TLSConfig))
288	}
289
290	if cfg.WriteTimeout > 0 {
291		if hasParam {
292			buf.WriteString("&writeTimeout=")
293		} else {
294			hasParam = true
295			buf.WriteString("?writeTimeout=")
296		}
297		buf.WriteString(cfg.WriteTimeout.String())
298	}
299
300	if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
301		if hasParam {
302			buf.WriteString("&maxAllowedPacket=")
303		} else {
304			hasParam = true
305			buf.WriteString("?maxAllowedPacket=")
306		}
307		buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))
308
309	}
310
311	// other params
312	if cfg.Params != nil {
313		var params []string
314		for param := range cfg.Params {
315			params = append(params, param)
316		}
317		sort.Strings(params)
318		for _, param := range params {
319			if hasParam {
320				buf.WriteByte('&')
321			} else {
322				hasParam = true
323				buf.WriteByte('?')
324			}
325
326			buf.WriteString(param)
327			buf.WriteByte('=')
328			buf.WriteString(url.QueryEscape(cfg.Params[param]))
329		}
330	}
331
332	return buf.String()
333}
334
335// ParseDSN parses the DSN string to a Config
336func ParseDSN(dsn string) (cfg *Config, err error) {
337	// New config with some default values
338	cfg = NewConfig()
339
340	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
341	// Find the last '/' (since the password or the net addr might contain a '/')
342	foundSlash := false
343	for i := len(dsn) - 1; i >= 0; i-- {
344		if dsn[i] == '/' {
345			foundSlash = true
346			var j, k int
347
348			// left part is empty if i <= 0
349			if i > 0 {
350				// [username[:password]@][protocol[(address)]]
351				// Find the last '@' in dsn[:i]
352				for j = i; j >= 0; j-- {
353					if dsn[j] == '@' {
354						// username[:password]
355						// Find the first ':' in dsn[:j]
356						for k = 0; k < j; k++ {
357							if dsn[k] == ':' {
358								cfg.Passwd = dsn[k+1 : j]
359								break
360							}
361						}
362						cfg.User = dsn[:k]
363
364						break
365					}
366				}
367
368				// [protocol[(address)]]
369				// Find the first '(' in dsn[j+1:i]
370				for k = j + 1; k < i; k++ {
371					if dsn[k] == '(' {
372						// dsn[i-1] must be == ')' if an address is specified
373						if dsn[i-1] != ')' {
374							if strings.ContainsRune(dsn[k+1:i], ')') {
375								return nil, errInvalidDSNUnescaped
376							}
377							return nil, errInvalidDSNAddr
378						}
379						cfg.Addr = dsn[k+1 : i-1]
380						break
381					}
382				}
383				cfg.Net = dsn[j+1 : k]
384			}
385
386			// dbname[?param1=value1&...&paramN=valueN]
387			// Find the first '?' in dsn[i+1:]
388			for j = i + 1; j < len(dsn); j++ {
389				if dsn[j] == '?' {
390					if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
391						return
392					}
393					break
394				}
395			}
396			cfg.DBName = dsn[i+1 : j]
397
398			break
399		}
400	}
401
402	if !foundSlash && len(dsn) > 0 {
403		return nil, errInvalidDSNNoSlash
404	}
405
406	if err = cfg.normalize(); err != nil {
407		return nil, err
408	}
409	return
410}
411
412// parseDSNParams parses the DSN "query string"
413// Values must be url.QueryEscape'ed
414func parseDSNParams(cfg *Config, params string) (err error) {
415	for _, v := range strings.Split(params, "&") {
416		param := strings.SplitN(v, "=", 2)
417		if len(param) != 2 {
418			continue
419		}
420
421		// cfg params
422		switch value := param[1]; param[0] {
423		// Disable INFILE whitelist / enable all files
424		case "allowAllFiles":
425			var isBool bool
426			cfg.AllowAllFiles, isBool = readBool(value)
427			if !isBool {
428				return errors.New("invalid bool value: " + value)
429			}
430
431		// Use cleartext authentication mode (MySQL 5.5.10+)
432		case "allowCleartextPasswords":
433			var isBool bool
434			cfg.AllowCleartextPasswords, isBool = readBool(value)
435			if !isBool {
436				return errors.New("invalid bool value: " + value)
437			}
438
439		// Use native password authentication
440		case "allowNativePasswords":
441			var isBool bool
442			cfg.AllowNativePasswords, isBool = readBool(value)
443			if !isBool {
444				return errors.New("invalid bool value: " + value)
445			}
446
447		// Use old authentication mode (pre MySQL 4.1)
448		case "allowOldPasswords":
449			var isBool bool
450			cfg.AllowOldPasswords, isBool = readBool(value)
451			if !isBool {
452				return errors.New("invalid bool value: " + value)
453			}
454
455		// Switch "rowsAffected" mode
456		case "clientFoundRows":
457			var isBool bool
458			cfg.ClientFoundRows, isBool = readBool(value)
459			if !isBool {
460				return errors.New("invalid bool value: " + value)
461			}
462
463		// Collation
464		case "collation":
465			cfg.Collation = value
466			break
467
468		case "columnsWithAlias":
469			var isBool bool
470			cfg.ColumnsWithAlias, isBool = readBool(value)
471			if !isBool {
472				return errors.New("invalid bool value: " + value)
473			}
474
475		// Compression
476		case "compress":
477			return errors.New("compression not implemented yet")
478
479		// Enable client side placeholder substitution
480		case "interpolateParams":
481			var isBool bool
482			cfg.InterpolateParams, isBool = readBool(value)
483			if !isBool {
484				return errors.New("invalid bool value: " + value)
485			}
486
487		// Time Location
488		case "loc":
489			if value, err = url.QueryUnescape(value); err != nil {
490				return
491			}
492			cfg.Loc, err = time.LoadLocation(value)
493			if err != nil {
494				return
495			}
496
497		// multiple statements in one query
498		case "multiStatements":
499			var isBool bool
500			cfg.MultiStatements, isBool = readBool(value)
501			if !isBool {
502				return errors.New("invalid bool value: " + value)
503			}
504
505		// time.Time parsing
506		case "parseTime":
507			var isBool bool
508			cfg.ParseTime, isBool = readBool(value)
509			if !isBool {
510				return errors.New("invalid bool value: " + value)
511			}
512
513		// I/O read Timeout
514		case "readTimeout":
515			cfg.ReadTimeout, err = time.ParseDuration(value)
516			if err != nil {
517				return
518			}
519
520		// Reject read-only connections
521		case "rejectReadOnly":
522			var isBool bool
523			cfg.RejectReadOnly, isBool = readBool(value)
524			if !isBool {
525				return errors.New("invalid bool value: " + value)
526			}
527
528		// Server public key
529		case "serverPubKey":
530			name, err := url.QueryUnescape(value)
531			if err != nil {
532				return fmt.Errorf("invalid value for server pub key name: %v", err)
533			}
534
535			if pubKey := getServerPubKey(name); pubKey != nil {
536				cfg.ServerPubKey = name
537				cfg.pubKey = pubKey
538			} else {
539				return errors.New("invalid value / unknown server pub key name: " + name)
540			}
541
542		// Strict mode
543		case "strict":
544			panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
545
546		// Dial Timeout
547		case "timeout":
548			cfg.Timeout, err = time.ParseDuration(value)
549			if err != nil {
550				return
551			}
552
553		// TLS-Encryption
554		case "tls":
555			boolValue, isBool := readBool(value)
556			if isBool {
557				if boolValue {
558					cfg.TLSConfig = "true"
559					cfg.tls = &tls.Config{}
560				} else {
561					cfg.TLSConfig = "false"
562				}
563			} else if vl := strings.ToLower(value); vl == "skip-verify" {
564				cfg.TLSConfig = vl
565				cfg.tls = &tls.Config{InsecureSkipVerify: true}
566			} else {
567				name, err := url.QueryUnescape(value)
568				if err != nil {
569					return fmt.Errorf("invalid value for TLS config name: %v", err)
570				}
571
572				if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
573					cfg.TLSConfig = name
574					cfg.tls = tlsConfig
575				} else {
576					return errors.New("invalid value / unknown config name: " + name)
577				}
578			}
579
580		// I/O write Timeout
581		case "writeTimeout":
582			cfg.WriteTimeout, err = time.ParseDuration(value)
583			if err != nil {
584				return
585			}
586		case "maxAllowedPacket":
587			cfg.MaxAllowedPacket, err = strconv.Atoi(value)
588			if err != nil {
589				return
590			}
591		default:
592			// lazy init
593			if cfg.Params == nil {
594				cfg.Params = make(map[string]string)
595			}
596
597			if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
598				return
599			}
600		}
601	}
602
603	return
604}
605
606func ensureHavePort(addr string) string {
607	if _, _, err := net.SplitHostPort(addr); err != nil {
608		return net.JoinHostPort(addr, "3306")
609	}
610	return addr
611}
612