1package mssql 2 3import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/binary" 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net" 13 "net/url" 14 "os" 15 "sort" 16 "strconv" 17 "strings" 18 "time" 19 "unicode" 20 "unicode/utf16" 21 "unicode/utf8" 22) 23 24func parseInstances(msg []byte) map[string]map[string]string { 25 results := map[string]map[string]string{} 26 if len(msg) > 3 && msg[0] == 5 { 27 out_s := string(msg[3:]) 28 tokens := strings.Split(out_s, ";") 29 instdict := map[string]string{} 30 got_name := false 31 var name string 32 for _, token := range tokens { 33 if got_name { 34 instdict[name] = token 35 got_name = false 36 } else { 37 name = token 38 if len(name) == 0 { 39 if len(instdict) == 0 { 40 break 41 } 42 results[strings.ToUpper(instdict["InstanceName"])] = instdict 43 instdict = map[string]string{} 44 continue 45 } 46 got_name = true 47 } 48 } 49 } 50 return results 51} 52 53func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) { 54 maxTime := 5 * time.Second 55 ctx, cancel := context.WithTimeout(ctx, maxTime) 56 defer cancel() 57 conn, err := d.DialContext(ctx, "udp", address+":1434") 58 if err != nil { 59 return nil, err 60 } 61 defer conn.Close() 62 conn.SetDeadline(time.Now().Add(maxTime)) 63 _, err = conn.Write([]byte{3}) 64 if err != nil { 65 return nil, err 66 } 67 var resp = make([]byte, 16*1024-1) 68 read, err := conn.Read(resp) 69 if err != nil { 70 return nil, err 71 } 72 return parseInstances(resp[:read]), nil 73} 74 75// tds versions 76const ( 77 verTDS70 = 0x70000000 78 verTDS71 = 0x71000000 79 verTDS71rev1 = 0x71000001 80 verTDS72 = 0x72090002 81 verTDS73A = 0x730A0003 82 verTDS73 = verTDS73A 83 verTDS73B = 0x730B0003 84 verTDS74 = 0x74000004 85) 86 87// packet types 88// https://msdn.microsoft.com/en-us/library/dd304214.aspx 89const ( 90 packSQLBatch packetType = 1 91 packRPCRequest = 3 92 packReply = 4 93 94 // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx 95 // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx 96 packAttention = 6 97 98 packBulkLoadBCP = 7 99 packTransMgrReq = 14 100 packNormal = 15 101 packLogin7 = 16 102 packSSPIMessage = 17 103 packPrelogin = 18 104) 105 106// prelogin fields 107// http://msdn.microsoft.com/en-us/library/dd357559.aspx 108const ( 109 preloginVERSION = 0 110 preloginENCRYPTION = 1 111 preloginINSTOPT = 2 112 preloginTHREADID = 3 113 preloginMARS = 4 114 preloginTRACEID = 5 115 preloginTERMINATOR = 0xff 116) 117 118const ( 119 encryptOff = 0 // Encryption is available but off. 120 encryptOn = 1 // Encryption is available and on. 121 encryptNotSup = 2 // Encryption is not available. 122 encryptReq = 3 // Encryption is required. 123) 124 125type tdsSession struct { 126 buf *tdsBuffer 127 loginAck loginAckStruct 128 database string 129 partner string 130 columns []columnStruct 131 tranid uint64 132 logFlags uint64 133 log optionalLogger 134 routedServer string 135 routedPort uint16 136} 137 138const ( 139 logErrors = 1 140 logMessages = 2 141 logRows = 4 142 logSQL = 8 143 logParams = 16 144 logTransaction = 32 145 logDebug = 64 146) 147 148type columnStruct struct { 149 UserType uint32 150 Flags uint16 151 ColName string 152 ti typeInfo 153} 154 155type keySlice []uint8 156 157func (p keySlice) Len() int { return len(p) } 158func (p keySlice) Less(i, j int) bool { return p[i] < p[j] } 159func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 160 161// http://msdn.microsoft.com/en-us/library/dd357559.aspx 162func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { 163 var err error 164 165 w.BeginPacket(packPrelogin, false) 166 offset := uint16(5*len(fields) + 1) 167 keys := make(keySlice, 0, len(fields)) 168 for k, _ := range fields { 169 keys = append(keys, k) 170 } 171 sort.Sort(keys) 172 // writing header 173 for _, k := range keys { 174 err = w.WriteByte(k) 175 if err != nil { 176 return err 177 } 178 err = binary.Write(w, binary.BigEndian, offset) 179 if err != nil { 180 return err 181 } 182 v := fields[k] 183 size := uint16(len(v)) 184 err = binary.Write(w, binary.BigEndian, size) 185 if err != nil { 186 return err 187 } 188 offset += size 189 } 190 err = w.WriteByte(preloginTERMINATOR) 191 if err != nil { 192 return err 193 } 194 // writing values 195 for _, k := range keys { 196 v := fields[k] 197 written, err := w.Write(v) 198 if err != nil { 199 return err 200 } 201 if written != len(v) { 202 return errors.New("Write method didn't write the whole value") 203 } 204 } 205 return w.FinishPacket() 206} 207 208func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { 209 packet_type, err := r.BeginRead() 210 if err != nil { 211 return nil, err 212 } 213 struct_buf, err := ioutil.ReadAll(r) 214 if err != nil { 215 return nil, err 216 } 217 if packet_type != 4 { 218 return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE") 219 } 220 offset := 0 221 results := map[uint8][]byte{} 222 for true { 223 rec_type := struct_buf[offset] 224 if rec_type == preloginTERMINATOR { 225 break 226 } 227 228 rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:]) 229 rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:]) 230 value := struct_buf[rec_offset : rec_offset+rec_len] 231 results[rec_type] = value 232 offset += 5 233 } 234 return results, nil 235} 236 237// OptionFlags2 238// http://msdn.microsoft.com/en-us/library/dd304019.aspx 239const ( 240 fLanguageFatal = 1 241 fODBC = 2 242 fTransBoundary = 4 243 fCacheConnect = 8 244 fIntSecurity = 0x80 245) 246 247// TypeFlags 248const ( 249 // 4 bits for fSQLType 250 // 1 bit for fOLEDB 251 fReadOnlyIntent = 32 252) 253 254type login struct { 255 TDSVersion uint32 256 PacketSize uint32 257 ClientProgVer uint32 258 ClientPID uint32 259 ConnectionID uint32 260 OptionFlags1 uint8 261 OptionFlags2 uint8 262 TypeFlags uint8 263 OptionFlags3 uint8 264 ClientTimeZone int32 265 ClientLCID uint32 266 HostName string 267 UserName string 268 Password string 269 AppName string 270 ServerName string 271 CtlIntName string 272 Language string 273 Database string 274 ClientID [6]byte 275 SSPI []byte 276 AtchDBFile string 277 ChangePassword string 278} 279 280type loginHeader struct { 281 Length uint32 282 TDSVersion uint32 283 PacketSize uint32 284 ClientProgVer uint32 285 ClientPID uint32 286 ConnectionID uint32 287 OptionFlags1 uint8 288 OptionFlags2 uint8 289 TypeFlags uint8 290 OptionFlags3 uint8 291 ClientTimeZone int32 292 ClientLCID uint32 293 HostNameOffset uint16 294 HostNameLength uint16 295 UserNameOffset uint16 296 UserNameLength uint16 297 PasswordOffset uint16 298 PasswordLength uint16 299 AppNameOffset uint16 300 AppNameLength uint16 301 ServerNameOffset uint16 302 ServerNameLength uint16 303 ExtensionOffset uint16 304 ExtensionLenght uint16 305 CtlIntNameOffset uint16 306 CtlIntNameLength uint16 307 LanguageOffset uint16 308 LanguageLength uint16 309 DatabaseOffset uint16 310 DatabaseLength uint16 311 ClientID [6]byte 312 SSPIOffset uint16 313 SSPILength uint16 314 AtchDBFileOffset uint16 315 AtchDBFileLength uint16 316 ChangePasswordOffset uint16 317 ChangePasswordLength uint16 318 SSPILongLength uint32 319} 320 321// convert Go string to UTF-16 encoded []byte (littleEndian) 322// done manually rather than using bytes and binary packages 323// for performance reasons 324func str2ucs2(s string) []byte { 325 res := utf16.Encode([]rune(s)) 326 ucs2 := make([]byte, 2*len(res)) 327 for i := 0; i < len(res); i++ { 328 ucs2[2*i] = byte(res[i]) 329 ucs2[2*i+1] = byte(res[i] >> 8) 330 } 331 return ucs2 332} 333 334func ucs22str(s []byte) (string, error) { 335 if len(s)%2 != 0 { 336 return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s)) 337 } 338 buf := make([]uint16, len(s)/2) 339 for i := 0; i < len(s); i += 2 { 340 buf[i/2] = binary.LittleEndian.Uint16(s[i:]) 341 } 342 return string(utf16.Decode(buf)), nil 343} 344 345func manglePassword(password string) []byte { 346 var ucs2password []byte = str2ucs2(password) 347 for i, ch := range ucs2password { 348 ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5 349 } 350 return ucs2password 351} 352 353// http://msdn.microsoft.com/en-us/library/dd304019.aspx 354func sendLogin(w *tdsBuffer, login login) error { 355 w.BeginPacket(packLogin7, false) 356 hostname := str2ucs2(login.HostName) 357 username := str2ucs2(login.UserName) 358 password := manglePassword(login.Password) 359 appname := str2ucs2(login.AppName) 360 servername := str2ucs2(login.ServerName) 361 ctlintname := str2ucs2(login.CtlIntName) 362 language := str2ucs2(login.Language) 363 database := str2ucs2(login.Database) 364 atchdbfile := str2ucs2(login.AtchDBFile) 365 changepassword := str2ucs2(login.ChangePassword) 366 hdr := loginHeader{ 367 TDSVersion: login.TDSVersion, 368 PacketSize: login.PacketSize, 369 ClientProgVer: login.ClientProgVer, 370 ClientPID: login.ClientPID, 371 ConnectionID: login.ConnectionID, 372 OptionFlags1: login.OptionFlags1, 373 OptionFlags2: login.OptionFlags2, 374 TypeFlags: login.TypeFlags, 375 OptionFlags3: login.OptionFlags3, 376 ClientTimeZone: login.ClientTimeZone, 377 ClientLCID: login.ClientLCID, 378 HostNameLength: uint16(utf8.RuneCountInString(login.HostName)), 379 UserNameLength: uint16(utf8.RuneCountInString(login.UserName)), 380 PasswordLength: uint16(utf8.RuneCountInString(login.Password)), 381 AppNameLength: uint16(utf8.RuneCountInString(login.AppName)), 382 ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)), 383 CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)), 384 LanguageLength: uint16(utf8.RuneCountInString(login.Language)), 385 DatabaseLength: uint16(utf8.RuneCountInString(login.Database)), 386 ClientID: login.ClientID, 387 SSPILength: uint16(len(login.SSPI)), 388 AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)), 389 ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)), 390 } 391 offset := uint16(binary.Size(hdr)) 392 hdr.HostNameOffset = offset 393 offset += uint16(len(hostname)) 394 hdr.UserNameOffset = offset 395 offset += uint16(len(username)) 396 hdr.PasswordOffset = offset 397 offset += uint16(len(password)) 398 hdr.AppNameOffset = offset 399 offset += uint16(len(appname)) 400 hdr.ServerNameOffset = offset 401 offset += uint16(len(servername)) 402 hdr.CtlIntNameOffset = offset 403 offset += uint16(len(ctlintname)) 404 hdr.LanguageOffset = offset 405 offset += uint16(len(language)) 406 hdr.DatabaseOffset = offset 407 offset += uint16(len(database)) 408 hdr.SSPIOffset = offset 409 offset += uint16(len(login.SSPI)) 410 hdr.AtchDBFileOffset = offset 411 offset += uint16(len(atchdbfile)) 412 hdr.ChangePasswordOffset = offset 413 offset += uint16(len(changepassword)) 414 hdr.Length = uint32(offset) 415 var err error 416 err = binary.Write(w, binary.LittleEndian, &hdr) 417 if err != nil { 418 return err 419 } 420 _, err = w.Write(hostname) 421 if err != nil { 422 return err 423 } 424 _, err = w.Write(username) 425 if err != nil { 426 return err 427 } 428 _, err = w.Write(password) 429 if err != nil { 430 return err 431 } 432 _, err = w.Write(appname) 433 if err != nil { 434 return err 435 } 436 _, err = w.Write(servername) 437 if err != nil { 438 return err 439 } 440 _, err = w.Write(ctlintname) 441 if err != nil { 442 return err 443 } 444 _, err = w.Write(language) 445 if err != nil { 446 return err 447 } 448 _, err = w.Write(database) 449 if err != nil { 450 return err 451 } 452 _, err = w.Write(login.SSPI) 453 if err != nil { 454 return err 455 } 456 _, err = w.Write(atchdbfile) 457 if err != nil { 458 return err 459 } 460 _, err = w.Write(changepassword) 461 if err != nil { 462 return err 463 } 464 return w.FinishPacket() 465} 466 467func readUcs2(r io.Reader, numchars int) (res string, err error) { 468 buf := make([]byte, numchars*2) 469 _, err = io.ReadFull(r, buf) 470 if err != nil { 471 return "", err 472 } 473 return ucs22str(buf) 474} 475 476func readUsVarChar(r io.Reader) (res string, err error) { 477 var numchars uint16 478 err = binary.Read(r, binary.LittleEndian, &numchars) 479 if err != nil { 480 return "", err 481 } 482 return readUcs2(r, int(numchars)) 483} 484 485func writeUsVarChar(w io.Writer, s string) (err error) { 486 buf := str2ucs2(s) 487 var numchars int = len(buf) / 2 488 if numchars > 0xffff { 489 panic("invalid size for US_VARCHAR") 490 } 491 err = binary.Write(w, binary.LittleEndian, uint16(numchars)) 492 if err != nil { 493 return 494 } 495 _, err = w.Write(buf) 496 return 497} 498 499func readBVarChar(r io.Reader) (res string, err error) { 500 var numchars uint8 501 err = binary.Read(r, binary.LittleEndian, &numchars) 502 if err != nil { 503 return "", err 504 } 505 506 // A zero length could be returned, return an empty string 507 if numchars == 0 { 508 return "", nil 509 } 510 return readUcs2(r, int(numchars)) 511} 512 513func writeBVarChar(w io.Writer, s string) (err error) { 514 buf := str2ucs2(s) 515 var numchars int = len(buf) / 2 516 if numchars > 0xff { 517 panic("invalid size for B_VARCHAR") 518 } 519 err = binary.Write(w, binary.LittleEndian, uint8(numchars)) 520 if err != nil { 521 return 522 } 523 _, err = w.Write(buf) 524 return 525} 526 527func readBVarByte(r io.Reader) (res []byte, err error) { 528 var length uint8 529 err = binary.Read(r, binary.LittleEndian, &length) 530 if err != nil { 531 return 532 } 533 res = make([]byte, length) 534 _, err = io.ReadFull(r, res) 535 return 536} 537 538func readUshort(r io.Reader) (res uint16, err error) { 539 err = binary.Read(r, binary.LittleEndian, &res) 540 return 541} 542 543func readByte(r io.Reader) (res byte, err error) { 544 var b [1]byte 545 _, err = r.Read(b[:]) 546 res = b[0] 547 return 548} 549 550// Packet Data Stream Headers 551// http://msdn.microsoft.com/en-us/library/dd304953.aspx 552type headerStruct struct { 553 hdrtype uint16 554 data []byte 555} 556 557const ( 558 dataStmHdrQueryNotif = 1 // query notifications 559 dataStmHdrTransDescr = 2 // MARS transaction descriptor (required) 560 dataStmHdrTraceActivity = 3 561) 562 563// Query Notifications Header 564// http://msdn.microsoft.com/en-us/library/dd304949.aspx 565type queryNotifHdr struct { 566 notifyId string 567 ssbDeployment string 568 notifyTimeout uint32 569} 570 571func (hdr queryNotifHdr) pack() (res []byte) { 572 notifyId := str2ucs2(hdr.notifyId) 573 ssbDeployment := str2ucs2(hdr.ssbDeployment) 574 575 res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4) 576 b := res 577 578 binary.LittleEndian.PutUint16(b, uint16(len(notifyId))) 579 b = b[2:] 580 copy(b, notifyId) 581 b = b[len(notifyId):] 582 583 binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment))) 584 b = b[2:] 585 copy(b, ssbDeployment) 586 b = b[len(ssbDeployment):] 587 588 binary.LittleEndian.PutUint32(b, hdr.notifyTimeout) 589 590 return res 591} 592 593// MARS Transaction Descriptor Header 594// http://msdn.microsoft.com/en-us/library/dd340515.aspx 595type transDescrHdr struct { 596 transDescr uint64 // transaction descriptor returned from ENVCHANGE 597 outstandingReqCnt uint32 // outstanding request count 598} 599 600func (hdr transDescrHdr) pack() (res []byte) { 601 res = make([]byte, 8+4) 602 binary.LittleEndian.PutUint64(res, hdr.transDescr) 603 binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt) 604 return res 605} 606 607func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) { 608 // Calculating total length. 609 var totallen uint32 = 4 610 for _, hdr := range headers { 611 totallen += 4 + 2 + uint32(len(hdr.data)) 612 } 613 // writing 614 err = binary.Write(w, binary.LittleEndian, totallen) 615 if err != nil { 616 return err 617 } 618 for _, hdr := range headers { 619 var headerlen uint32 = 4 + 2 + uint32(len(hdr.data)) 620 err = binary.Write(w, binary.LittleEndian, headerlen) 621 if err != nil { 622 return err 623 } 624 err = binary.Write(w, binary.LittleEndian, hdr.hdrtype) 625 if err != nil { 626 return err 627 } 628 _, err = w.Write(hdr.data) 629 if err != nil { 630 return err 631 } 632 } 633 return nil 634} 635 636func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) { 637 buf.BeginPacket(packSQLBatch, resetSession) 638 639 if err = writeAllHeaders(buf, headers); err != nil { 640 return 641 } 642 643 _, err = buf.Write(str2ucs2(sqltext)) 644 if err != nil { 645 return 646 } 647 return buf.FinishPacket() 648} 649 650// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx 651// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx 652func sendAttention(buf *tdsBuffer) error { 653 buf.BeginPacket(packAttention, false) 654 return buf.FinishPacket() 655} 656 657type connectParams struct { 658 logFlags uint64 659 port uint64 660 host string 661 instance string 662 database string 663 user string 664 password string 665 dial_timeout time.Duration 666 conn_timeout time.Duration 667 keepAlive time.Duration 668 encrypt bool 669 disableEncryption bool 670 trustServerCertificate bool 671 certificate string 672 hostInCertificate string 673 serverSPN string 674 workstation string 675 appname string 676 typeFlags uint8 677 failOverPartner string 678 failOverPort uint64 679 packetSize uint16 680} 681 682func splitConnectionString(dsn string) (res map[string]string) { 683 res = map[string]string{} 684 parts := strings.Split(dsn, ";") 685 for _, part := range parts { 686 if len(part) == 0 { 687 continue 688 } 689 lst := strings.SplitN(part, "=", 2) 690 name := strings.TrimSpace(strings.ToLower(lst[0])) 691 if len(name) == 0 { 692 continue 693 } 694 var value string = "" 695 if len(lst) > 1 { 696 value = strings.TrimSpace(lst[1]) 697 } 698 res[name] = value 699 } 700 return res 701} 702 703// Splits a URL in the ODBC format 704func splitConnectionStringOdbc(dsn string) (map[string]string, error) { 705 res := map[string]string{} 706 707 type parserState int 708 const ( 709 // Before the start of a key 710 parserStateBeforeKey parserState = iota 711 712 // Inside a key 713 parserStateKey 714 715 // Beginning of a value. May be bare or braced 716 parserStateBeginValue 717 718 // Inside a bare value 719 parserStateBareValue 720 721 // Inside a braced value 722 parserStateBracedValue 723 724 // A closing brace inside a braced value. 725 // May be the end of the value or an escaped closing brace, depending on the next character 726 parserStateBracedValueClosingBrace 727 728 // After a value. Next character should be a semicolon or whitespace. 729 parserStateEndValue 730 ) 731 732 var state = parserStateBeforeKey 733 734 var key string 735 var value string 736 737 for i, c := range dsn { 738 switch state { 739 case parserStateBeforeKey: 740 switch { 741 case c == '=': 742 return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i) 743 case !unicode.IsSpace(c) && c != ';': 744 state = parserStateKey 745 key += string(c) 746 } 747 748 case parserStateKey: 749 switch c { 750 case '=': 751 key = normalizeOdbcKey(key) 752 if len(key) == 0 { 753 return res, fmt.Errorf("Unexpected end of key at index %d.", i) 754 } 755 756 state = parserStateBeginValue 757 758 case ';': 759 // Key without value 760 key = normalizeOdbcKey(key) 761 if len(key) == 0 { 762 return res, fmt.Errorf("Unexpected end of key at index %d.", i) 763 } 764 765 res[key] = value 766 key = "" 767 value = "" 768 state = parserStateBeforeKey 769 770 default: 771 key += string(c) 772 } 773 774 case parserStateBeginValue: 775 switch { 776 case c == '{': 777 state = parserStateBracedValue 778 case c == ';': 779 // Empty value 780 res[key] = value 781 key = "" 782 state = parserStateBeforeKey 783 case unicode.IsSpace(c): 784 // Ignore whitespace 785 default: 786 state = parserStateBareValue 787 value += string(c) 788 } 789 790 case parserStateBareValue: 791 if c == ';' { 792 res[key] = strings.TrimRightFunc(value, unicode.IsSpace) 793 key = "" 794 value = "" 795 state = parserStateBeforeKey 796 } else { 797 value += string(c) 798 } 799 800 case parserStateBracedValue: 801 if c == '}' { 802 state = parserStateBracedValueClosingBrace 803 } else { 804 value += string(c) 805 } 806 807 case parserStateBracedValueClosingBrace: 808 if c == '}' { 809 // Escaped closing brace 810 value += string(c) 811 state = parserStateBracedValue 812 continue 813 } 814 815 // End of braced value 816 res[key] = value 817 key = "" 818 value = "" 819 820 // This character is the first character past the end, 821 // so it needs to be parsed like the parserStateEndValue state. 822 state = parserStateEndValue 823 switch { 824 case c == ';': 825 state = parserStateBeforeKey 826 case unicode.IsSpace(c): 827 // Ignore whitespace 828 default: 829 return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) 830 } 831 832 case parserStateEndValue: 833 switch { 834 case c == ';': 835 state = parserStateBeforeKey 836 case unicode.IsSpace(c): 837 // Ignore whitespace 838 default: 839 return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) 840 } 841 } 842 } 843 844 switch state { 845 case parserStateBeforeKey: // Okay 846 case parserStateKey: // Unfinished key. Treat as key without value. 847 key = normalizeOdbcKey(key) 848 if len(key) == 0 { 849 return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn)) 850 } 851 res[key] = value 852 case parserStateBeginValue: // Empty value 853 res[key] = value 854 case parserStateBareValue: 855 res[key] = strings.TrimRightFunc(value, unicode.IsSpace) 856 case parserStateBracedValue: 857 return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn)) 858 case parserStateBracedValueClosingBrace: // End of braced value 859 res[key] = value 860 case parserStateEndValue: // Okay 861 } 862 863 return res, nil 864} 865 866// Normalizes the given string as an ODBC-format key 867func normalizeOdbcKey(s string) string { 868 return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace)) 869} 870 871// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value 872func splitConnectionStringURL(dsn string) (map[string]string, error) { 873 res := map[string]string{} 874 875 u, err := url.Parse(dsn) 876 if err != nil { 877 return res, err 878 } 879 880 if u.Scheme != "sqlserver" { 881 return res, fmt.Errorf("scheme %s is not recognized", u.Scheme) 882 } 883 884 if u.User != nil { 885 res["user id"] = u.User.Username() 886 p, exists := u.User.Password() 887 if exists { 888 res["password"] = p 889 } 890 } 891 892 host, port, err := net.SplitHostPort(u.Host) 893 if err != nil { 894 host = u.Host 895 } 896 897 if len(u.Path) > 0 { 898 res["server"] = host + "\\" + u.Path[1:] 899 } else { 900 res["server"] = host 901 } 902 903 if len(port) > 0 { 904 res["port"] = port 905 } 906 907 query := u.Query() 908 for k, v := range query { 909 if len(v) > 1 { 910 return res, fmt.Errorf("key %s provided more than once", k) 911 } 912 res[strings.ToLower(k)] = v[0] 913 } 914 915 return res, nil 916} 917 918func parseConnectParams(dsn string) (connectParams, error) { 919 var p connectParams 920 921 var params map[string]string 922 if strings.HasPrefix(dsn, "odbc:") { 923 parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):]) 924 if err != nil { 925 return p, err 926 } 927 params = parameters 928 } else if strings.HasPrefix(dsn, "sqlserver://") { 929 parameters, err := splitConnectionStringURL(dsn) 930 if err != nil { 931 return p, err 932 } 933 params = parameters 934 } else { 935 params = splitConnectionString(dsn) 936 } 937 938 strlog, ok := params["log"] 939 if ok { 940 var err error 941 p.logFlags, err = strconv.ParseUint(strlog, 10, 64) 942 if err != nil { 943 return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error()) 944 } 945 } 946 server := params["server"] 947 parts := strings.SplitN(server, `\`, 2) 948 p.host = parts[0] 949 if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" { 950 p.host = "localhost" 951 } 952 if len(parts) > 1 { 953 p.instance = parts[1] 954 } 955 p.database = params["database"] 956 p.user = params["user id"] 957 p.password = params["password"] 958 959 p.port = 1433 960 strport, ok := params["port"] 961 if ok { 962 var err error 963 p.port, err = strconv.ParseUint(strport, 10, 16) 964 if err != nil { 965 f := "Invalid tcp port '%v': %v" 966 return p, fmt.Errorf(f, strport, err.Error()) 967 } 968 } 969 970 // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option 971 // Default packet size remains at 4096 bytes 972 p.packetSize = 4096 973 strpsize, ok := params["packet size"] 974 if ok { 975 var err error 976 psize, err := strconv.ParseUint(strpsize, 0, 16) 977 if err != nil { 978 f := "Invalid packet size '%v': %v" 979 return p, fmt.Errorf(f, strpsize, err.Error()) 980 } 981 982 // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes 983 // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request 984 // a higher packet size, the server will respond with an ENVCHANGE request to 985 // alter the packet size to 16383 bytes. 986 p.packetSize = uint16(psize) 987 if p.packetSize < 512 { 988 p.packetSize = 512 989 } else if p.packetSize > 32767 { 990 p.packetSize = 32767 991 } 992 } 993 994 // https://msdn.microsoft.com/en-us/library/dd341108.aspx 995 // 996 // Do not set a connection timeout. Use Context to manage such things. 997 // Default to zero, but still allow it to be set. 998 if strconntimeout, ok := params["connection timeout"]; ok { 999 timeout, err := strconv.ParseUint(strconntimeout, 10, 64) 1000 if err != nil { 1001 f := "Invalid connection timeout '%v': %v" 1002 return p, fmt.Errorf(f, strconntimeout, err.Error()) 1003 } 1004 p.conn_timeout = time.Duration(timeout) * time.Second 1005 } 1006 p.dial_timeout = 15 * time.Second 1007 if strdialtimeout, ok := params["dial timeout"]; ok { 1008 timeout, err := strconv.ParseUint(strdialtimeout, 10, 64) 1009 if err != nil { 1010 f := "Invalid dial timeout '%v': %v" 1011 return p, fmt.Errorf(f, strdialtimeout, err.Error()) 1012 } 1013 p.dial_timeout = time.Duration(timeout) * time.Second 1014 } 1015 1016 // default keep alive should be 30 seconds according to spec: 1017 // https://msdn.microsoft.com/en-us/library/dd341108.aspx 1018 p.keepAlive = 30 * time.Second 1019 if keepAlive, ok := params["keepalive"]; ok { 1020 timeout, err := strconv.ParseUint(keepAlive, 10, 64) 1021 if err != nil { 1022 f := "Invalid keepAlive value '%s': %s" 1023 return p, fmt.Errorf(f, keepAlive, err.Error()) 1024 } 1025 p.keepAlive = time.Duration(timeout) * time.Second 1026 } 1027 encrypt, ok := params["encrypt"] 1028 if ok { 1029 if strings.EqualFold(encrypt, "DISABLE") { 1030 p.disableEncryption = true 1031 } else { 1032 var err error 1033 p.encrypt, err = strconv.ParseBool(encrypt) 1034 if err != nil { 1035 f := "Invalid encrypt '%s': %s" 1036 return p, fmt.Errorf(f, encrypt, err.Error()) 1037 } 1038 } 1039 } else { 1040 p.trustServerCertificate = true 1041 } 1042 trust, ok := params["trustservercertificate"] 1043 if ok { 1044 var err error 1045 p.trustServerCertificate, err = strconv.ParseBool(trust) 1046 if err != nil { 1047 f := "Invalid trust server certificate '%s': %s" 1048 return p, fmt.Errorf(f, trust, err.Error()) 1049 } 1050 } 1051 p.certificate = params["certificate"] 1052 p.hostInCertificate, ok = params["hostnameincertificate"] 1053 if !ok { 1054 p.hostInCertificate = p.host 1055 } 1056 1057 serverSPN, ok := params["serverspn"] 1058 if ok { 1059 p.serverSPN = serverSPN 1060 } else { 1061 p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port) 1062 } 1063 1064 workstation, ok := params["workstation id"] 1065 if ok { 1066 p.workstation = workstation 1067 } else { 1068 workstation, err := os.Hostname() 1069 if err == nil { 1070 p.workstation = workstation 1071 } 1072 } 1073 1074 appname, ok := params["app name"] 1075 if !ok { 1076 appname = "go-mssqldb" 1077 } 1078 p.appname = appname 1079 1080 appintent, ok := params["applicationintent"] 1081 if ok { 1082 if appintent == "ReadOnly" { 1083 p.typeFlags |= fReadOnlyIntent 1084 } 1085 } 1086 1087 failOverPartner, ok := params["failoverpartner"] 1088 if ok { 1089 p.failOverPartner = failOverPartner 1090 } 1091 1092 failOverPort, ok := params["failoverport"] 1093 if ok { 1094 var err error 1095 p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16) 1096 if err != nil { 1097 f := "Invalid tcp port '%v': %v" 1098 return p, fmt.Errorf(f, failOverPort, err.Error()) 1099 } 1100 } 1101 1102 return p, nil 1103} 1104 1105type auth interface { 1106 InitialBytes() ([]byte, error) 1107 NextBytes([]byte) ([]byte, error) 1108 Free() 1109} 1110 1111// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a 1112// list of IP addresses. So if there is more than one, try them all and 1113// use the first one that allows a connection. 1114func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) { 1115 var ips []net.IP 1116 ips, err = net.LookupIP(p.host) 1117 if err != nil { 1118 ip := net.ParseIP(p.host) 1119 if ip == nil { 1120 return nil, err 1121 } 1122 ips = []net.IP{ip} 1123 } 1124 if len(ips) == 1 { 1125 d := c.getDialer(&p) 1126 addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port))) 1127 conn, err = d.DialContext(ctx, "tcp", addr) 1128 1129 } else { 1130 //Try Dials in parallel to avoid waiting for timeouts. 1131 connChan := make(chan net.Conn, len(ips)) 1132 errChan := make(chan error, len(ips)) 1133 portStr := strconv.Itoa(int(p.port)) 1134 for _, ip := range ips { 1135 go func(ip net.IP) { 1136 d := c.getDialer(&p) 1137 addr := net.JoinHostPort(ip.String(), portStr) 1138 conn, err := d.DialContext(ctx, "tcp", addr) 1139 if err == nil { 1140 connChan <- conn 1141 } else { 1142 errChan <- err 1143 } 1144 }(ip) 1145 } 1146 // Wait for either the *first* successful connection, or all the errors 1147 wait_loop: 1148 for i, _ := range ips { 1149 select { 1150 case conn = <-connChan: 1151 // Got a connection to use, close any others 1152 go func(n int) { 1153 for i := 0; i < n; i++ { 1154 select { 1155 case conn := <-connChan: 1156 conn.Close() 1157 case <-errChan: 1158 } 1159 } 1160 }(len(ips) - i - 1) 1161 // Remove any earlier errors we may have collected 1162 err = nil 1163 break wait_loop 1164 case err = <-errChan: 1165 } 1166 } 1167 } 1168 // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection 1169 if conn == nil { 1170 f := "Unable to open tcp connection with host '%v:%v': %v" 1171 return nil, fmt.Errorf(f, p.host, p.port, err.Error()) 1172 } 1173 return conn, err 1174} 1175 1176func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) { 1177 dialCtx := ctx 1178 if p.dial_timeout > 0 { 1179 var cancel func() 1180 dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout) 1181 defer cancel() 1182 } 1183 // if instance is specified use instance resolution service 1184 if p.instance != "" { 1185 p.instance = strings.ToUpper(p.instance) 1186 d := c.getDialer(&p) 1187 instances, err := getInstances(dialCtx, d, p.host) 1188 if err != nil { 1189 f := "Unable to get instances from Sql Server Browser on host %v: %v" 1190 return nil, fmt.Errorf(f, p.host, err.Error()) 1191 } 1192 strport, ok := instances[p.instance]["tcp"] 1193 if !ok { 1194 f := "No instance matching '%v' returned from host '%v'" 1195 return nil, fmt.Errorf(f, p.instance, p.host) 1196 } 1197 p.port, err = strconv.ParseUint(strport, 0, 16) 1198 if err != nil { 1199 f := "Invalid tcp port returned from Sql Server Browser '%v': %v" 1200 return nil, fmt.Errorf(f, strport, err.Error()) 1201 } 1202 } 1203 1204initiate_connection: 1205 conn, err := dialConnection(dialCtx, c, p) 1206 if err != nil { 1207 return nil, err 1208 } 1209 1210 toconn := newTimeoutConn(conn, p.conn_timeout) 1211 1212 outbuf := newTdsBuffer(p.packetSize, toconn) 1213 sess := tdsSession{ 1214 buf: outbuf, 1215 log: log, 1216 logFlags: p.logFlags, 1217 } 1218 1219 instance_buf := []byte(p.instance) 1220 instance_buf = append(instance_buf, 0) // zero terminate instance name 1221 var encrypt byte 1222 if p.disableEncryption { 1223 encrypt = encryptNotSup 1224 } else if p.encrypt { 1225 encrypt = encryptOn 1226 } else { 1227 encrypt = encryptOff 1228 } 1229 fields := map[uint8][]byte{ 1230 preloginVERSION: {0, 0, 0, 0, 0, 0}, 1231 preloginENCRYPTION: {encrypt}, 1232 preloginINSTOPT: instance_buf, 1233 preloginTHREADID: {0, 0, 0, 0}, 1234 preloginMARS: {0}, // MARS disabled 1235 } 1236 1237 err = writePrelogin(outbuf, fields) 1238 if err != nil { 1239 return nil, err 1240 } 1241 1242 fields, err = readPrelogin(outbuf) 1243 if err != nil { 1244 return nil, err 1245 } 1246 1247 encryptBytes, ok := fields[preloginENCRYPTION] 1248 if !ok { 1249 return nil, fmt.Errorf("Encrypt negotiation failed") 1250 } 1251 encrypt = encryptBytes[0] 1252 if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) { 1253 return nil, fmt.Errorf("Server does not support encryption") 1254 } 1255 1256 if encrypt != encryptNotSup { 1257 var config tls.Config 1258 if p.certificate != "" { 1259 pem, err := ioutil.ReadFile(p.certificate) 1260 if err != nil { 1261 return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err) 1262 } 1263 certs := x509.NewCertPool() 1264 certs.AppendCertsFromPEM(pem) 1265 config.RootCAs = certs 1266 } 1267 if p.trustServerCertificate { 1268 config.InsecureSkipVerify = true 1269 } 1270 config.ServerName = p.hostInCertificate 1271 // fix for https://github.com/denisenkom/go-mssqldb/issues/166 1272 // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, 1273 // while SQL Server seems to expect one TCP segment per encrypted TDS package. 1274 // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package 1275 config.DynamicRecordSizingDisabled = true 1276 outbuf.transport = conn 1277 toconn.buf = outbuf 1278 tlsConn := tls.Client(toconn, &config) 1279 err = tlsConn.Handshake() 1280 1281 toconn.buf = nil 1282 outbuf.transport = tlsConn 1283 if err != nil { 1284 return nil, fmt.Errorf("TLS Handshake failed: %v", err) 1285 } 1286 if encrypt == encryptOff { 1287 outbuf.afterFirst = func() { 1288 outbuf.transport = toconn 1289 } 1290 } 1291 } 1292 1293 login := login{ 1294 TDSVersion: verTDS74, 1295 PacketSize: uint32(outbuf.PackageSize()), 1296 Database: p.database, 1297 OptionFlags2: fODBC, // to get unlimited TEXTSIZE 1298 HostName: p.workstation, 1299 ServerName: p.host, 1300 AppName: p.appname, 1301 TypeFlags: p.typeFlags, 1302 } 1303 auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation) 1304 if auth_ok { 1305 login.SSPI, err = auth.InitialBytes() 1306 if err != nil { 1307 return nil, err 1308 } 1309 login.OptionFlags2 |= fIntSecurity 1310 defer auth.Free() 1311 } else { 1312 login.UserName = p.user 1313 login.Password = p.password 1314 } 1315 err = sendLogin(outbuf, login) 1316 if err != nil { 1317 return nil, err 1318 } 1319 1320 // processing login response 1321 var sspi_msg []byte 1322continue_login: 1323 tokchan := make(chan tokenStruct, 5) 1324 go processResponse(context.Background(), &sess, tokchan, nil) 1325 success := false 1326 for tok := range tokchan { 1327 switch token := tok.(type) { 1328 case sspiMsg: 1329 sspi_msg, err = auth.NextBytes(token) 1330 if err != nil { 1331 return nil, err 1332 } 1333 case loginAckStruct: 1334 success = true 1335 sess.loginAck = token 1336 case error: 1337 return nil, fmt.Errorf("Login error: %s", token.Error()) 1338 case doneStruct: 1339 if token.isError() { 1340 return nil, fmt.Errorf("Login error: %s", token.getError()) 1341 } 1342 } 1343 } 1344 if sspi_msg != nil { 1345 outbuf.BeginPacket(packSSPIMessage, false) 1346 _, err = outbuf.Write(sspi_msg) 1347 if err != nil { 1348 return nil, err 1349 } 1350 err = outbuf.FinishPacket() 1351 if err != nil { 1352 return nil, err 1353 } 1354 sspi_msg = nil 1355 goto continue_login 1356 } 1357 if !success { 1358 return nil, fmt.Errorf("Login failed") 1359 } 1360 if sess.routedServer != "" { 1361 toconn.Close() 1362 p.host = sess.routedServer 1363 p.port = uint64(sess.routedPort) 1364 goto initiate_connection 1365 } 1366 return &sess, nil 1367} 1368