1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2// 3// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. 4// 5// This Source Code Form is subject to the terms of the Mozilla Public 6// License, v. 2.0. If a copy of the MPL was not distributed with this file, 7// You can obtain one at http://mozilla.org/MPL/2.0/. 8 9package mysql 10 11import ( 12 "bytes" 13 "crypto/rsa" 14 "crypto/tls" 15 "errors" 16 "fmt" 17 "net" 18 "net/url" 19 "sort" 20 "strconv" 21 "strings" 22 "time" 23) 24 25var ( 26 errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?") 27 errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)") 28 errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name") 29 errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") 30) 31 32// Config is a configuration parsed from a DSN string. 33// If a new Config is created instead of being parsed from a DSN string, 34// the NewConfig function should be used, which sets default values. 35type Config struct { 36 User string // Username 37 Passwd string // Password (requires User) 38 Net string // Network type 39 Addr string // Network address (requires Net) 40 DBName string // Database name 41 Params map[string]string // Connection parameters 42 Collation string // Connection collation 43 Loc *time.Location // Location for time.Time values 44 MaxAllowedPacket int // Max packet size allowed 45 ServerPubKey string // Server public key name 46 pubKey *rsa.PublicKey // Server public key 47 TLSConfig string // TLS configuration name 48 tls *tls.Config // TLS configuration 49 Timeout time.Duration // Dial timeout 50 ReadTimeout time.Duration // I/O read timeout 51 WriteTimeout time.Duration // I/O write timeout 52 53 AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE 54 AllowCleartextPasswords bool // Allows the cleartext client side plugin 55 AllowNativePasswords bool // Allows the native password authentication method 56 AllowOldPasswords bool // Allows the old insecure password method 57 ClientFoundRows bool // Return number of matching rows instead of rows changed 58 ColumnsWithAlias bool // Prepend table alias to column names 59 InterpolateParams bool // Interpolate placeholders into query string 60 MultiStatements bool // Allow multiple statements in one query 61 ParseTime bool // Parse time values to time.Time 62 RejectReadOnly bool // Reject read-only connections 63} 64 65// NewConfig creates a new Config and sets default values. 66func NewConfig() *Config { 67 return &Config{ 68 Collation: defaultCollation, 69 Loc: time.UTC, 70 MaxAllowedPacket: defaultMaxAllowedPacket, 71 AllowNativePasswords: true, 72 } 73} 74 75func (cfg *Config) normalize() error { 76 if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { 77 return errInvalidDSNUnsafeCollation 78 } 79 80 // Set default network if empty 81 if cfg.Net == "" { 82 cfg.Net = "tcp" 83 } 84 85 // Set default address if empty 86 if cfg.Addr == "" { 87 switch cfg.Net { 88 case "tcp": 89 cfg.Addr = "127.0.0.1:3306" 90 case "unix": 91 cfg.Addr = "/tmp/mysql.sock" 92 default: 93 return errors.New("default addr for network '" + cfg.Net + "' unknown") 94 } 95 96 } else if cfg.Net == "tcp" { 97 cfg.Addr = ensureHavePort(cfg.Addr) 98 } 99 100 if cfg.tls != nil { 101 if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { 102 host, _, err := net.SplitHostPort(cfg.Addr) 103 if err == nil { 104 cfg.tls.ServerName = host 105 } 106 } 107 } 108 109 return nil 110} 111 112// FormatDSN formats the given Config into a DSN string which can be passed to 113// the driver. 114func (cfg *Config) FormatDSN() string { 115 var buf bytes.Buffer 116 117 // [username[:password]@] 118 if len(cfg.User) > 0 { 119 buf.WriteString(cfg.User) 120 if len(cfg.Passwd) > 0 { 121 buf.WriteByte(':') 122 buf.WriteString(cfg.Passwd) 123 } 124 buf.WriteByte('@') 125 } 126 127 // [protocol[(address)]] 128 if len(cfg.Net) > 0 { 129 buf.WriteString(cfg.Net) 130 if len(cfg.Addr) > 0 { 131 buf.WriteByte('(') 132 buf.WriteString(cfg.Addr) 133 buf.WriteByte(')') 134 } 135 } 136 137 // /dbname 138 buf.WriteByte('/') 139 buf.WriteString(cfg.DBName) 140 141 // [?param1=value1&...¶mN=valueN] 142 hasParam := false 143 144 if cfg.AllowAllFiles { 145 hasParam = true 146 buf.WriteString("?allowAllFiles=true") 147 } 148 149 if cfg.AllowCleartextPasswords { 150 if hasParam { 151 buf.WriteString("&allowCleartextPasswords=true") 152 } else { 153 hasParam = true 154 buf.WriteString("?allowCleartextPasswords=true") 155 } 156 } 157 158 if !cfg.AllowNativePasswords { 159 if hasParam { 160 buf.WriteString("&allowNativePasswords=false") 161 } else { 162 hasParam = true 163 buf.WriteString("?allowNativePasswords=false") 164 } 165 } 166 167 if cfg.AllowOldPasswords { 168 if hasParam { 169 buf.WriteString("&allowOldPasswords=true") 170 } else { 171 hasParam = true 172 buf.WriteString("?allowOldPasswords=true") 173 } 174 } 175 176 if cfg.ClientFoundRows { 177 if hasParam { 178 buf.WriteString("&clientFoundRows=true") 179 } else { 180 hasParam = true 181 buf.WriteString("?clientFoundRows=true") 182 } 183 } 184 185 if col := cfg.Collation; col != defaultCollation && len(col) > 0 { 186 if hasParam { 187 buf.WriteString("&collation=") 188 } else { 189 hasParam = true 190 buf.WriteString("?collation=") 191 } 192 buf.WriteString(col) 193 } 194 195 if cfg.ColumnsWithAlias { 196 if hasParam { 197 buf.WriteString("&columnsWithAlias=true") 198 } else { 199 hasParam = true 200 buf.WriteString("?columnsWithAlias=true") 201 } 202 } 203 204 if cfg.InterpolateParams { 205 if hasParam { 206 buf.WriteString("&interpolateParams=true") 207 } else { 208 hasParam = true 209 buf.WriteString("?interpolateParams=true") 210 } 211 } 212 213 if cfg.Loc != time.UTC && cfg.Loc != nil { 214 if hasParam { 215 buf.WriteString("&loc=") 216 } else { 217 hasParam = true 218 buf.WriteString("?loc=") 219 } 220 buf.WriteString(url.QueryEscape(cfg.Loc.String())) 221 } 222 223 if cfg.MultiStatements { 224 if hasParam { 225 buf.WriteString("&multiStatements=true") 226 } else { 227 hasParam = true 228 buf.WriteString("?multiStatements=true") 229 } 230 } 231 232 if cfg.ParseTime { 233 if hasParam { 234 buf.WriteString("&parseTime=true") 235 } else { 236 hasParam = true 237 buf.WriteString("?parseTime=true") 238 } 239 } 240 241 if cfg.ReadTimeout > 0 { 242 if hasParam { 243 buf.WriteString("&readTimeout=") 244 } else { 245 hasParam = true 246 buf.WriteString("?readTimeout=") 247 } 248 buf.WriteString(cfg.ReadTimeout.String()) 249 } 250 251 if cfg.RejectReadOnly { 252 if hasParam { 253 buf.WriteString("&rejectReadOnly=true") 254 } else { 255 hasParam = true 256 buf.WriteString("?rejectReadOnly=true") 257 } 258 } 259 260 if len(cfg.ServerPubKey) > 0 { 261 if hasParam { 262 buf.WriteString("&serverPubKey=") 263 } else { 264 hasParam = true 265 buf.WriteString("?serverPubKey=") 266 } 267 buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) 268 } 269 270 if cfg.Timeout > 0 { 271 if hasParam { 272 buf.WriteString("&timeout=") 273 } else { 274 hasParam = true 275 buf.WriteString("?timeout=") 276 } 277 buf.WriteString(cfg.Timeout.String()) 278 } 279 280 if len(cfg.TLSConfig) > 0 { 281 if hasParam { 282 buf.WriteString("&tls=") 283 } else { 284 hasParam = true 285 buf.WriteString("?tls=") 286 } 287 buf.WriteString(url.QueryEscape(cfg.TLSConfig)) 288 } 289 290 if cfg.WriteTimeout > 0 { 291 if hasParam { 292 buf.WriteString("&writeTimeout=") 293 } else { 294 hasParam = true 295 buf.WriteString("?writeTimeout=") 296 } 297 buf.WriteString(cfg.WriteTimeout.String()) 298 } 299 300 if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { 301 if hasParam { 302 buf.WriteString("&maxAllowedPacket=") 303 } else { 304 hasParam = true 305 buf.WriteString("?maxAllowedPacket=") 306 } 307 buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket)) 308 309 } 310 311 // other params 312 if cfg.Params != nil { 313 var params []string 314 for param := range cfg.Params { 315 params = append(params, param) 316 } 317 sort.Strings(params) 318 for _, param := range params { 319 if hasParam { 320 buf.WriteByte('&') 321 } else { 322 hasParam = true 323 buf.WriteByte('?') 324 } 325 326 buf.WriteString(param) 327 buf.WriteByte('=') 328 buf.WriteString(url.QueryEscape(cfg.Params[param])) 329 } 330 } 331 332 return buf.String() 333} 334 335// ParseDSN parses the DSN string to a Config 336func ParseDSN(dsn string) (cfg *Config, err error) { 337 // New config with some default values 338 cfg = NewConfig() 339 340 // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] 341 // Find the last '/' (since the password or the net addr might contain a '/') 342 foundSlash := false 343 for i := len(dsn) - 1; i >= 0; i-- { 344 if dsn[i] == '/' { 345 foundSlash = true 346 var j, k int 347 348 // left part is empty if i <= 0 349 if i > 0 { 350 // [username[:password]@][protocol[(address)]] 351 // Find the last '@' in dsn[:i] 352 for j = i; j >= 0; j-- { 353 if dsn[j] == '@' { 354 // username[:password] 355 // Find the first ':' in dsn[:j] 356 for k = 0; k < j; k++ { 357 if dsn[k] == ':' { 358 cfg.Passwd = dsn[k+1 : j] 359 break 360 } 361 } 362 cfg.User = dsn[:k] 363 364 break 365 } 366 } 367 368 // [protocol[(address)]] 369 // Find the first '(' in dsn[j+1:i] 370 for k = j + 1; k < i; k++ { 371 if dsn[k] == '(' { 372 // dsn[i-1] must be == ')' if an address is specified 373 if dsn[i-1] != ')' { 374 if strings.ContainsRune(dsn[k+1:i], ')') { 375 return nil, errInvalidDSNUnescaped 376 } 377 return nil, errInvalidDSNAddr 378 } 379 cfg.Addr = dsn[k+1 : i-1] 380 break 381 } 382 } 383 cfg.Net = dsn[j+1 : k] 384 } 385 386 // dbname[?param1=value1&...¶mN=valueN] 387 // Find the first '?' in dsn[i+1:] 388 for j = i + 1; j < len(dsn); j++ { 389 if dsn[j] == '?' { 390 if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { 391 return 392 } 393 break 394 } 395 } 396 cfg.DBName = dsn[i+1 : j] 397 398 break 399 } 400 } 401 402 if !foundSlash && len(dsn) > 0 { 403 return nil, errInvalidDSNNoSlash 404 } 405 406 if err = cfg.normalize(); err != nil { 407 return nil, err 408 } 409 return 410} 411 412// parseDSNParams parses the DSN "query string" 413// Values must be url.QueryEscape'ed 414func parseDSNParams(cfg *Config, params string) (err error) { 415 for _, v := range strings.Split(params, "&") { 416 param := strings.SplitN(v, "=", 2) 417 if len(param) != 2 { 418 continue 419 } 420 421 // cfg params 422 switch value := param[1]; param[0] { 423 // Disable INFILE whitelist / enable all files 424 case "allowAllFiles": 425 var isBool bool 426 cfg.AllowAllFiles, isBool = readBool(value) 427 if !isBool { 428 return errors.New("invalid bool value: " + value) 429 } 430 431 // Use cleartext authentication mode (MySQL 5.5.10+) 432 case "allowCleartextPasswords": 433 var isBool bool 434 cfg.AllowCleartextPasswords, isBool = readBool(value) 435 if !isBool { 436 return errors.New("invalid bool value: " + value) 437 } 438 439 // Use native password authentication 440 case "allowNativePasswords": 441 var isBool bool 442 cfg.AllowNativePasswords, isBool = readBool(value) 443 if !isBool { 444 return errors.New("invalid bool value: " + value) 445 } 446 447 // Use old authentication mode (pre MySQL 4.1) 448 case "allowOldPasswords": 449 var isBool bool 450 cfg.AllowOldPasswords, isBool = readBool(value) 451 if !isBool { 452 return errors.New("invalid bool value: " + value) 453 } 454 455 // Switch "rowsAffected" mode 456 case "clientFoundRows": 457 var isBool bool 458 cfg.ClientFoundRows, isBool = readBool(value) 459 if !isBool { 460 return errors.New("invalid bool value: " + value) 461 } 462 463 // Collation 464 case "collation": 465 cfg.Collation = value 466 break 467 468 case "columnsWithAlias": 469 var isBool bool 470 cfg.ColumnsWithAlias, isBool = readBool(value) 471 if !isBool { 472 return errors.New("invalid bool value: " + value) 473 } 474 475 // Compression 476 case "compress": 477 return errors.New("compression not implemented yet") 478 479 // Enable client side placeholder substitution 480 case "interpolateParams": 481 var isBool bool 482 cfg.InterpolateParams, isBool = readBool(value) 483 if !isBool { 484 return errors.New("invalid bool value: " + value) 485 } 486 487 // Time Location 488 case "loc": 489 if value, err = url.QueryUnescape(value); err != nil { 490 return 491 } 492 cfg.Loc, err = time.LoadLocation(value) 493 if err != nil { 494 return 495 } 496 497 // multiple statements in one query 498 case "multiStatements": 499 var isBool bool 500 cfg.MultiStatements, isBool = readBool(value) 501 if !isBool { 502 return errors.New("invalid bool value: " + value) 503 } 504 505 // time.Time parsing 506 case "parseTime": 507 var isBool bool 508 cfg.ParseTime, isBool = readBool(value) 509 if !isBool { 510 return errors.New("invalid bool value: " + value) 511 } 512 513 // I/O read Timeout 514 case "readTimeout": 515 cfg.ReadTimeout, err = time.ParseDuration(value) 516 if err != nil { 517 return 518 } 519 520 // Reject read-only connections 521 case "rejectReadOnly": 522 var isBool bool 523 cfg.RejectReadOnly, isBool = readBool(value) 524 if !isBool { 525 return errors.New("invalid bool value: " + value) 526 } 527 528 // Server public key 529 case "serverPubKey": 530 name, err := url.QueryUnescape(value) 531 if err != nil { 532 return fmt.Errorf("invalid value for server pub key name: %v", err) 533 } 534 535 if pubKey := getServerPubKey(name); pubKey != nil { 536 cfg.ServerPubKey = name 537 cfg.pubKey = pubKey 538 } else { 539 return errors.New("invalid value / unknown server pub key name: " + name) 540 } 541 542 // Strict mode 543 case "strict": 544 panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") 545 546 // Dial Timeout 547 case "timeout": 548 cfg.Timeout, err = time.ParseDuration(value) 549 if err != nil { 550 return 551 } 552 553 // TLS-Encryption 554 case "tls": 555 boolValue, isBool := readBool(value) 556 if isBool { 557 if boolValue { 558 cfg.TLSConfig = "true" 559 cfg.tls = &tls.Config{} 560 } else { 561 cfg.TLSConfig = "false" 562 } 563 } else if vl := strings.ToLower(value); vl == "skip-verify" { 564 cfg.TLSConfig = vl 565 cfg.tls = &tls.Config{InsecureSkipVerify: true} 566 } else { 567 name, err := url.QueryUnescape(value) 568 if err != nil { 569 return fmt.Errorf("invalid value for TLS config name: %v", err) 570 } 571 572 if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { 573 cfg.TLSConfig = name 574 cfg.tls = tlsConfig 575 } else { 576 return errors.New("invalid value / unknown config name: " + name) 577 } 578 } 579 580 // I/O write Timeout 581 case "writeTimeout": 582 cfg.WriteTimeout, err = time.ParseDuration(value) 583 if err != nil { 584 return 585 } 586 case "maxAllowedPacket": 587 cfg.MaxAllowedPacket, err = strconv.Atoi(value) 588 if err != nil { 589 return 590 } 591 default: 592 // lazy init 593 if cfg.Params == nil { 594 cfg.Params = make(map[string]string) 595 } 596 597 if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { 598 return 599 } 600 } 601 } 602 603 return 604} 605 606func ensureHavePort(addr string) string { 607 if _, _, err := net.SplitHostPort(addr); err != nil { 608 return net.JoinHostPort(addr, "3306") 609 } 610 return addr 611} 612