1package mssql 2 3import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/binary" 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net" 13 "sort" 14 "strconv" 15 "strings" 16 "unicode/utf16" 17 "unicode/utf8" 18) 19 20func parseInstances(msg []byte) map[string]map[string]string { 21 results := map[string]map[string]string{} 22 if len(msg) > 3 && msg[0] == 5 { 23 out_s := string(msg[3:]) 24 tokens := strings.Split(out_s, ";") 25 instdict := map[string]string{} 26 got_name := false 27 var name string 28 for _, token := range tokens { 29 if got_name { 30 instdict[name] = token 31 got_name = false 32 } else { 33 name = token 34 if len(name) == 0 { 35 if len(instdict) == 0 { 36 break 37 } 38 results[strings.ToUpper(instdict["InstanceName"])] = instdict 39 instdict = map[string]string{} 40 continue 41 } 42 got_name = true 43 } 44 } 45 } 46 return results 47} 48 49func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) { 50 conn, err := d.DialContext(ctx, "udp", address+":1434") 51 if err != nil { 52 return nil, err 53 } 54 defer conn.Close() 55 deadline, _ := ctx.Deadline() 56 conn.SetDeadline(deadline) 57 _, err = conn.Write([]byte{3}) 58 if err != nil { 59 return nil, err 60 } 61 var resp = make([]byte, 16*1024-1) 62 read, err := conn.Read(resp) 63 if err != nil { 64 return nil, err 65 } 66 return parseInstances(resp[:read]), nil 67} 68 69// tds versions 70const ( 71 verTDS70 = 0x70000000 72 verTDS71 = 0x71000000 73 verTDS71rev1 = 0x71000001 74 verTDS72 = 0x72090002 75 verTDS73A = 0x730A0003 76 verTDS73 = verTDS73A 77 verTDS73B = 0x730B0003 78 verTDS74 = 0x74000004 79) 80 81// packet types 82// https://msdn.microsoft.com/en-us/library/dd304214.aspx 83const ( 84 packSQLBatch packetType = 1 85 packRPCRequest = 3 86 packReply = 4 87 88 // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx 89 // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx 90 packAttention = 6 91 92 packBulkLoadBCP = 7 93 packTransMgrReq = 14 94 packNormal = 15 95 packLogin7 = 16 96 packSSPIMessage = 17 97 packPrelogin = 18 98) 99 100// prelogin fields 101// http://msdn.microsoft.com/en-us/library/dd357559.aspx 102const ( 103 preloginVERSION = 0 104 preloginENCRYPTION = 1 105 preloginINSTOPT = 2 106 preloginTHREADID = 3 107 preloginMARS = 4 108 preloginTRACEID = 5 109 preloginFEDAUTHREQUIRED = 6 110 preloginNONCEOPT = 7 111 preloginTERMINATOR = 0xff 112) 113 114const ( 115 encryptOff = 0 // Encryption is available but off. 116 encryptOn = 1 // Encryption is available and on. 117 encryptNotSup = 2 // Encryption is not available. 118 encryptReq = 3 // Encryption is required. 119) 120 121type tdsSession struct { 122 buf *tdsBuffer 123 loginAck loginAckStruct 124 database string 125 partner string 126 columns []columnStruct 127 tranid uint64 128 logFlags uint64 129 log optionalLogger 130 routedServer string 131 routedPort uint16 132} 133 134const ( 135 logErrors = 1 136 logMessages = 2 137 logRows = 4 138 logSQL = 8 139 logParams = 16 140 logTransaction = 32 141 logDebug = 64 142) 143 144type columnStruct struct { 145 UserType uint32 146 Flags uint16 147 ColName string 148 ti typeInfo 149} 150 151type keySlice []uint8 152 153func (p keySlice) Len() int { return len(p) } 154func (p keySlice) Less(i, j int) bool { return p[i] < p[j] } 155func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 156 157// http://msdn.microsoft.com/en-us/library/dd357559.aspx 158func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { 159 var err error 160 161 w.BeginPacket(packPrelogin, false) 162 offset := uint16(5*len(fields) + 1) 163 keys := make(keySlice, 0, len(fields)) 164 for k, _ := range fields { 165 keys = append(keys, k) 166 } 167 sort.Sort(keys) 168 // writing header 169 for _, k := range keys { 170 err = w.WriteByte(k) 171 if err != nil { 172 return err 173 } 174 err = binary.Write(w, binary.BigEndian, offset) 175 if err != nil { 176 return err 177 } 178 v := fields[k] 179 size := uint16(len(v)) 180 err = binary.Write(w, binary.BigEndian, size) 181 if err != nil { 182 return err 183 } 184 offset += size 185 } 186 err = w.WriteByte(preloginTERMINATOR) 187 if err != nil { 188 return err 189 } 190 // writing values 191 for _, k := range keys { 192 v := fields[k] 193 written, err := w.Write(v) 194 if err != nil { 195 return err 196 } 197 if written != len(v) { 198 return errors.New("Write method didn't write the whole value") 199 } 200 } 201 return w.FinishPacket() 202} 203 204func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { 205 packet_type, err := r.BeginRead() 206 if err != nil { 207 return nil, err 208 } 209 struct_buf, err := ioutil.ReadAll(r) 210 if err != nil { 211 return nil, err 212 } 213 if packet_type != 4 { 214 return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE") 215 } 216 offset := 0 217 results := map[uint8][]byte{} 218 for true { 219 rec_type := struct_buf[offset] 220 if rec_type == preloginTERMINATOR { 221 break 222 } 223 224 rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:]) 225 rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:]) 226 value := struct_buf[rec_offset : rec_offset+rec_len] 227 results[rec_type] = value 228 offset += 5 229 } 230 return results, nil 231} 232 233// OptionFlags2 234// http://msdn.microsoft.com/en-us/library/dd304019.aspx 235const ( 236 fLanguageFatal = 1 237 fODBC = 2 238 fTransBoundary = 4 239 fCacheConnect = 8 240 fIntSecurity = 0x80 241) 242 243// TypeFlags 244const ( 245 // 4 bits for fSQLType 246 // 1 bit for fOLEDB 247 fReadOnlyIntent = 32 248) 249 250// OptionFlags3 251// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac 252const ( 253 fExtension = 0x10 254) 255 256type login struct { 257 TDSVersion uint32 258 PacketSize uint32 259 ClientProgVer uint32 260 ClientPID uint32 261 ConnectionID uint32 262 OptionFlags1 uint8 263 OptionFlags2 uint8 264 TypeFlags uint8 265 OptionFlags3 uint8 266 ClientTimeZone int32 267 ClientLCID uint32 268 HostName string 269 UserName string 270 Password string 271 AppName string 272 ServerName string 273 CtlIntName string 274 Language string 275 Database string 276 ClientID [6]byte 277 SSPI []byte 278 AtchDBFile string 279 ChangePassword string 280 FeatureExt featureExts 281} 282 283type featureExts struct { 284 features map[byte]featureExt 285} 286 287type featureExt interface { 288 featureID() byte 289 toBytes() []byte 290} 291 292func (e *featureExts) Add(f featureExt) error { 293 if f == nil { 294 return nil 295 } 296 id := f.featureID() 297 if _, exists := e.features[id]; exists { 298 f := "Login error: Feature with ID '%v' is already present in FeatureExt block." 299 return fmt.Errorf(f, id) 300 } 301 if e.features == nil { 302 e.features = make(map[byte]featureExt) 303 } 304 e.features[id] = f 305 return nil 306} 307 308func (e featureExts) toBytes() []byte { 309 if len(e.features) == 0 { 310 return nil 311 } 312 var d []byte 313 for featureID, f := range e.features { 314 featureData := f.toBytes() 315 316 hdr := make([]byte, 5) 317 hdr[0] = featureID // FedAuth feature extension BYTE 318 binary.LittleEndian.PutUint32(hdr[1:], uint32(len(featureData))) // FeatureDataLen DWORD 319 d = append(d, hdr...) 320 321 d = append(d, featureData...) // FeatureData *BYTE 322 } 323 if d != nil { 324 d = append(d, 0xff) // Terminator 325 } 326 return d 327} 328 329type featureExtFedAuthSTS struct { 330 FedAuthEcho bool 331 FedAuthToken string 332 Nonce []byte 333} 334 335func (e *featureExtFedAuthSTS) featureID() byte { 336 return 0x02 337} 338 339func (e *featureExtFedAuthSTS) toBytes() []byte { 340 if e == nil { 341 return nil 342 } 343 344 options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT 345 if e.FedAuthEcho { 346 options |= 1 // fFedAuthEcho 347 } 348 349 d := make([]byte, 5) 350 d[0] = options 351 352 // looks like string in 353 // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 354 tokenBytes := str2ucs2(e.FedAuthToken) 355 binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work 356 d = append(d, tokenBytes...) 357 358 if len(e.Nonce) == 32 { 359 d = append(d, e.Nonce...) 360 } 361 362 return d 363} 364 365type loginHeader struct { 366 Length uint32 367 TDSVersion uint32 368 PacketSize uint32 369 ClientProgVer uint32 370 ClientPID uint32 371 ConnectionID uint32 372 OptionFlags1 uint8 373 OptionFlags2 uint8 374 TypeFlags uint8 375 OptionFlags3 uint8 376 ClientTimeZone int32 377 ClientLCID uint32 378 HostNameOffset uint16 379 HostNameLength uint16 380 UserNameOffset uint16 381 UserNameLength uint16 382 PasswordOffset uint16 383 PasswordLength uint16 384 AppNameOffset uint16 385 AppNameLength uint16 386 ServerNameOffset uint16 387 ServerNameLength uint16 388 ExtensionOffset uint16 389 ExtensionLength uint16 390 CtlIntNameOffset uint16 391 CtlIntNameLength uint16 392 LanguageOffset uint16 393 LanguageLength uint16 394 DatabaseOffset uint16 395 DatabaseLength uint16 396 ClientID [6]byte 397 SSPIOffset uint16 398 SSPILength uint16 399 AtchDBFileOffset uint16 400 AtchDBFileLength uint16 401 ChangePasswordOffset uint16 402 ChangePasswordLength uint16 403 SSPILongLength uint32 404} 405 406// convert Go string to UTF-16 encoded []byte (littleEndian) 407// done manually rather than using bytes and binary packages 408// for performance reasons 409func str2ucs2(s string) []byte { 410 res := utf16.Encode([]rune(s)) 411 ucs2 := make([]byte, 2*len(res)) 412 for i := 0; i < len(res); i++ { 413 ucs2[2*i] = byte(res[i]) 414 ucs2[2*i+1] = byte(res[i] >> 8) 415 } 416 return ucs2 417} 418 419func ucs22str(s []byte) (string, error) { 420 if len(s)%2 != 0 { 421 return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s)) 422 } 423 buf := make([]uint16, len(s)/2) 424 for i := 0; i < len(s); i += 2 { 425 buf[i/2] = binary.LittleEndian.Uint16(s[i:]) 426 } 427 return string(utf16.Decode(buf)), nil 428} 429 430func manglePassword(password string) []byte { 431 var ucs2password []byte = str2ucs2(password) 432 for i, ch := range ucs2password { 433 ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5 434 } 435 return ucs2password 436} 437 438// http://msdn.microsoft.com/en-us/library/dd304019.aspx 439func sendLogin(w *tdsBuffer, login login) error { 440 w.BeginPacket(packLogin7, false) 441 hostname := str2ucs2(login.HostName) 442 username := str2ucs2(login.UserName) 443 password := manglePassword(login.Password) 444 appname := str2ucs2(login.AppName) 445 servername := str2ucs2(login.ServerName) 446 ctlintname := str2ucs2(login.CtlIntName) 447 language := str2ucs2(login.Language) 448 database := str2ucs2(login.Database) 449 atchdbfile := str2ucs2(login.AtchDBFile) 450 changepassword := str2ucs2(login.ChangePassword) 451 featureExt := login.FeatureExt.toBytes() 452 453 hdr := loginHeader{ 454 TDSVersion: login.TDSVersion, 455 PacketSize: login.PacketSize, 456 ClientProgVer: login.ClientProgVer, 457 ClientPID: login.ClientPID, 458 ConnectionID: login.ConnectionID, 459 OptionFlags1: login.OptionFlags1, 460 OptionFlags2: login.OptionFlags2, 461 TypeFlags: login.TypeFlags, 462 OptionFlags3: login.OptionFlags3, 463 ClientTimeZone: login.ClientTimeZone, 464 ClientLCID: login.ClientLCID, 465 HostNameLength: uint16(utf8.RuneCountInString(login.HostName)), 466 UserNameLength: uint16(utf8.RuneCountInString(login.UserName)), 467 PasswordLength: uint16(utf8.RuneCountInString(login.Password)), 468 AppNameLength: uint16(utf8.RuneCountInString(login.AppName)), 469 ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)), 470 CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)), 471 LanguageLength: uint16(utf8.RuneCountInString(login.Language)), 472 DatabaseLength: uint16(utf8.RuneCountInString(login.Database)), 473 ClientID: login.ClientID, 474 SSPILength: uint16(len(login.SSPI)), 475 AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)), 476 ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)), 477 } 478 offset := uint16(binary.Size(hdr)) 479 hdr.HostNameOffset = offset 480 offset += uint16(len(hostname)) 481 hdr.UserNameOffset = offset 482 offset += uint16(len(username)) 483 hdr.PasswordOffset = offset 484 offset += uint16(len(password)) 485 hdr.AppNameOffset = offset 486 offset += uint16(len(appname)) 487 hdr.ServerNameOffset = offset 488 offset += uint16(len(servername)) 489 hdr.CtlIntNameOffset = offset 490 offset += uint16(len(ctlintname)) 491 hdr.LanguageOffset = offset 492 offset += uint16(len(language)) 493 hdr.DatabaseOffset = offset 494 offset += uint16(len(database)) 495 hdr.SSPIOffset = offset 496 offset += uint16(len(login.SSPI)) 497 hdr.AtchDBFileOffset = offset 498 offset += uint16(len(atchdbfile)) 499 hdr.ChangePasswordOffset = offset 500 offset += uint16(len(changepassword)) 501 502 featureExtOffset := uint32(0) 503 featureExtLen := len(featureExt) 504 if featureExtLen > 0 { 505 hdr.OptionFlags3 |= fExtension 506 hdr.ExtensionOffset = offset 507 hdr.ExtensionLength = 4 508 offset += hdr.ExtensionLength // DWORD 509 featureExtOffset = uint32(offset) 510 } 511 hdr.Length = uint32(offset) + uint32(featureExtLen) 512 513 var err error 514 err = binary.Write(w, binary.LittleEndian, &hdr) 515 if err != nil { 516 return err 517 } 518 _, err = w.Write(hostname) 519 if err != nil { 520 return err 521 } 522 _, err = w.Write(username) 523 if err != nil { 524 return err 525 } 526 _, err = w.Write(password) 527 if err != nil { 528 return err 529 } 530 _, err = w.Write(appname) 531 if err != nil { 532 return err 533 } 534 _, err = w.Write(servername) 535 if err != nil { 536 return err 537 } 538 _, err = w.Write(ctlintname) 539 if err != nil { 540 return err 541 } 542 _, err = w.Write(language) 543 if err != nil { 544 return err 545 } 546 _, err = w.Write(database) 547 if err != nil { 548 return err 549 } 550 _, err = w.Write(login.SSPI) 551 if err != nil { 552 return err 553 } 554 _, err = w.Write(atchdbfile) 555 if err != nil { 556 return err 557 } 558 _, err = w.Write(changepassword) 559 if err != nil { 560 return err 561 } 562 if featureExtOffset > 0 { 563 err = binary.Write(w, binary.LittleEndian, featureExtOffset) 564 if err != nil { 565 return err 566 } 567 _, err = w.Write(featureExt) 568 if err != nil { 569 return err 570 } 571 } 572 return w.FinishPacket() 573} 574 575func readUcs2(r io.Reader, numchars int) (res string, err error) { 576 buf := make([]byte, numchars*2) 577 _, err = io.ReadFull(r, buf) 578 if err != nil { 579 return "", err 580 } 581 return ucs22str(buf) 582} 583 584func readUsVarChar(r io.Reader) (res string, err error) { 585 numchars, err := readUshort(r) 586 if err != nil { 587 return 588 } 589 return readUcs2(r, int(numchars)) 590} 591 592func writeUsVarChar(w io.Writer, s string) (err error) { 593 buf := str2ucs2(s) 594 var numchars int = len(buf) / 2 595 if numchars > 0xffff { 596 panic("invalid size for US_VARCHAR") 597 } 598 err = binary.Write(w, binary.LittleEndian, uint16(numchars)) 599 if err != nil { 600 return 601 } 602 _, err = w.Write(buf) 603 return 604} 605 606func readBVarChar(r io.Reader) (res string, err error) { 607 numchars, err := readByte(r) 608 if err != nil { 609 return "", err 610 } 611 612 // A zero length could be returned, return an empty string 613 if numchars == 0 { 614 return "", nil 615 } 616 return readUcs2(r, int(numchars)) 617} 618 619func writeBVarChar(w io.Writer, s string) (err error) { 620 buf := str2ucs2(s) 621 var numchars int = len(buf) / 2 622 if numchars > 0xff { 623 panic("invalid size for B_VARCHAR") 624 } 625 err = binary.Write(w, binary.LittleEndian, uint8(numchars)) 626 if err != nil { 627 return 628 } 629 _, err = w.Write(buf) 630 return 631} 632 633func readBVarByte(r io.Reader) (res []byte, err error) { 634 length, err := readByte(r) 635 if err != nil { 636 return 637 } 638 res = make([]byte, length) 639 _, err = io.ReadFull(r, res) 640 return 641} 642 643func readUshort(r io.Reader) (res uint16, err error) { 644 err = binary.Read(r, binary.LittleEndian, &res) 645 return 646} 647 648func readByte(r io.Reader) (res byte, err error) { 649 var b [1]byte 650 _, err = r.Read(b[:]) 651 res = b[0] 652 return 653} 654 655// Packet Data Stream Headers 656// http://msdn.microsoft.com/en-us/library/dd304953.aspx 657type headerStruct struct { 658 hdrtype uint16 659 data []byte 660} 661 662const ( 663 dataStmHdrQueryNotif = 1 // query notifications 664 dataStmHdrTransDescr = 2 // MARS transaction descriptor (required) 665 dataStmHdrTraceActivity = 3 666) 667 668// Query Notifications Header 669// http://msdn.microsoft.com/en-us/library/dd304949.aspx 670type queryNotifHdr struct { 671 notifyId string 672 ssbDeployment string 673 notifyTimeout uint32 674} 675 676func (hdr queryNotifHdr) pack() (res []byte) { 677 notifyId := str2ucs2(hdr.notifyId) 678 ssbDeployment := str2ucs2(hdr.ssbDeployment) 679 680 res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4) 681 b := res 682 683 binary.LittleEndian.PutUint16(b, uint16(len(notifyId))) 684 b = b[2:] 685 copy(b, notifyId) 686 b = b[len(notifyId):] 687 688 binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment))) 689 b = b[2:] 690 copy(b, ssbDeployment) 691 b = b[len(ssbDeployment):] 692 693 binary.LittleEndian.PutUint32(b, hdr.notifyTimeout) 694 695 return res 696} 697 698// MARS Transaction Descriptor Header 699// http://msdn.microsoft.com/en-us/library/dd340515.aspx 700type transDescrHdr struct { 701 transDescr uint64 // transaction descriptor returned from ENVCHANGE 702 outstandingReqCnt uint32 // outstanding request count 703} 704 705func (hdr transDescrHdr) pack() (res []byte) { 706 res = make([]byte, 8+4) 707 binary.LittleEndian.PutUint64(res, hdr.transDescr) 708 binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt) 709 return res 710} 711 712func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) { 713 // Calculating total length. 714 var totallen uint32 = 4 715 for _, hdr := range headers { 716 totallen += 4 + 2 + uint32(len(hdr.data)) 717 } 718 // writing 719 err = binary.Write(w, binary.LittleEndian, totallen) 720 if err != nil { 721 return err 722 } 723 for _, hdr := range headers { 724 var headerlen uint32 = 4 + 2 + uint32(len(hdr.data)) 725 err = binary.Write(w, binary.LittleEndian, headerlen) 726 if err != nil { 727 return err 728 } 729 err = binary.Write(w, binary.LittleEndian, hdr.hdrtype) 730 if err != nil { 731 return err 732 } 733 _, err = w.Write(hdr.data) 734 if err != nil { 735 return err 736 } 737 } 738 return nil 739} 740 741func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) { 742 buf.BeginPacket(packSQLBatch, resetSession) 743 744 if err = writeAllHeaders(buf, headers); err != nil { 745 return 746 } 747 748 _, err = buf.Write(str2ucs2(sqltext)) 749 if err != nil { 750 return 751 } 752 return buf.FinishPacket() 753} 754 755// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx 756// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx 757func sendAttention(buf *tdsBuffer) error { 758 buf.BeginPacket(packAttention, false) 759 return buf.FinishPacket() 760} 761 762type auth interface { 763 InitialBytes() ([]byte, error) 764 NextBytes([]byte) ([]byte, error) 765 Free() 766} 767 768// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a 769// list of IP addresses. So if there is more than one, try them all and 770// use the first one that allows a connection. 771func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) { 772 var ips []net.IP 773 ips, err = net.LookupIP(p.host) 774 if err != nil { 775 ip := net.ParseIP(p.host) 776 if ip == nil { 777 return nil, err 778 } 779 ips = []net.IP{ip} 780 } 781 if len(ips) == 1 { 782 d := c.getDialer(&p) 783 addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.port)))) 784 conn, err = d.DialContext(ctx, "tcp", addr) 785 786 } else { 787 //Try Dials in parallel to avoid waiting for timeouts. 788 connChan := make(chan net.Conn, len(ips)) 789 errChan := make(chan error, len(ips)) 790 portStr := strconv.Itoa(int(resolveServerPort(p.port))) 791 for _, ip := range ips { 792 go func(ip net.IP) { 793 d := c.getDialer(&p) 794 addr := net.JoinHostPort(ip.String(), portStr) 795 conn, err := d.DialContext(ctx, "tcp", addr) 796 if err == nil { 797 connChan <- conn 798 } else { 799 errChan <- err 800 } 801 }(ip) 802 } 803 // Wait for either the *first* successful connection, or all the errors 804 wait_loop: 805 for i, _ := range ips { 806 select { 807 case conn = <-connChan: 808 // Got a connection to use, close any others 809 go func(n int) { 810 for i := 0; i < n; i++ { 811 select { 812 case conn := <-connChan: 813 conn.Close() 814 case <-errChan: 815 } 816 } 817 }(len(ips) - i - 1) 818 // Remove any earlier errors we may have collected 819 err = nil 820 break wait_loop 821 case err = <-errChan: 822 } 823 } 824 } 825 // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection 826 if conn == nil { 827 f := "Unable to open tcp connection with host '%v:%v': %v" 828 return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error()) 829 } 830 return conn, err 831} 832 833func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) { 834 dialCtx := ctx 835 if p.dial_timeout > 0 { 836 var cancel func() 837 dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout) 838 defer cancel() 839 } 840 // if instance is specified use instance resolution service 841 if p.instance != "" && p.port == 0 { 842 p.instance = strings.ToUpper(p.instance) 843 d := c.getDialer(&p) 844 instances, err := getInstances(dialCtx, d, p.host) 845 if err != nil { 846 f := "Unable to get instances from Sql Server Browser on host %v: %v" 847 return nil, fmt.Errorf(f, p.host, err.Error()) 848 } 849 strport, ok := instances[p.instance]["tcp"] 850 if !ok { 851 f := "No instance matching '%v' returned from host '%v'" 852 return nil, fmt.Errorf(f, p.instance, p.host) 853 } 854 port, err := strconv.ParseUint(strport, 0, 16) 855 if err != nil { 856 f := "Invalid tcp port returned from Sql Server Browser '%v': %v" 857 return nil, fmt.Errorf(f, strport, err.Error()) 858 } 859 p.port = port 860 } 861 862initiate_connection: 863 conn, err := dialConnection(dialCtx, c, p) 864 if err != nil { 865 return nil, err 866 } 867 868 toconn := newTimeoutConn(conn, p.conn_timeout) 869 870 outbuf := newTdsBuffer(p.packetSize, toconn) 871 sess := tdsSession{ 872 buf: outbuf, 873 log: log, 874 logFlags: p.logFlags, 875 } 876 877 instance_buf := []byte(p.instance) 878 instance_buf = append(instance_buf, 0) // zero terminate instance name 879 var encrypt byte 880 if p.disableEncryption { 881 encrypt = encryptNotSup 882 } else if p.encrypt { 883 encrypt = encryptOn 884 } else { 885 encrypt = encryptOff 886 } 887 fields := map[uint8][]byte{ 888 preloginVERSION: {0, 0, 0, 0, 0, 0}, 889 preloginENCRYPTION: {encrypt}, 890 preloginINSTOPT: instance_buf, 891 preloginTHREADID: {0, 0, 0, 0}, 892 preloginMARS: {0}, // MARS disabled 893 } 894 895 err = writePrelogin(outbuf, fields) 896 if err != nil { 897 return nil, err 898 } 899 900 fields, err = readPrelogin(outbuf) 901 if err != nil { 902 return nil, err 903 } 904 905 encryptBytes, ok := fields[preloginENCRYPTION] 906 if !ok { 907 return nil, fmt.Errorf("Encrypt negotiation failed") 908 } 909 encrypt = encryptBytes[0] 910 if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) { 911 return nil, fmt.Errorf("Server does not support encryption") 912 } 913 914 if encrypt != encryptNotSup { 915 var config tls.Config 916 if p.certificate != "" { 917 pem, err := ioutil.ReadFile(p.certificate) 918 if err != nil { 919 return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err) 920 } 921 certs := x509.NewCertPool() 922 certs.AppendCertsFromPEM(pem) 923 config.RootCAs = certs 924 } 925 if p.trustServerCertificate { 926 config.InsecureSkipVerify = true 927 } 928 config.ServerName = p.hostInCertificate 929 // fix for https://github.com/denisenkom/go-mssqldb/issues/166 930 // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, 931 // while SQL Server seems to expect one TCP segment per encrypted TDS package. 932 // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package 933 config.DynamicRecordSizingDisabled = true 934 // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream 935 handshakeConn := tlsHandshakeConn{buf: outbuf} 936 passthrough := passthroughConn{c: &handshakeConn} 937 tlsConn := tls.Client(&passthrough, &config) 938 err = tlsConn.Handshake() 939 passthrough.c = toconn 940 outbuf.transport = tlsConn 941 if err != nil { 942 return nil, fmt.Errorf("TLS Handshake failed: %v", err) 943 } 944 if encrypt == encryptOff { 945 outbuf.afterFirst = func() { 946 outbuf.transport = toconn 947 } 948 } 949 } 950 951 login := login{ 952 TDSVersion: verTDS74, 953 PacketSize: uint32(outbuf.PackageSize()), 954 Database: p.database, 955 OptionFlags2: fODBC, // to get unlimited TEXTSIZE 956 HostName: p.workstation, 957 ServerName: p.host, 958 AppName: p.appname, 959 TypeFlags: p.typeFlags, 960 } 961 auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation) 962 switch { 963 case p.fedAuthAccessToken != "": // accesstoken ignores user/password 964 featurext := &featureExtFedAuthSTS{ 965 FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1, 966 FedAuthToken: p.fedAuthAccessToken, 967 Nonce: fields[preloginNONCEOPT], 968 } 969 login.FeatureExt.Add(featurext) 970 case authOk: 971 login.SSPI, err = auth.InitialBytes() 972 if err != nil { 973 return nil, err 974 } 975 login.OptionFlags2 |= fIntSecurity 976 defer auth.Free() 977 default: 978 login.UserName = p.user 979 login.Password = p.password 980 } 981 err = sendLogin(outbuf, login) 982 if err != nil { 983 return nil, err 984 } 985 986 // processing login response 987 success := false 988 for { 989 tokchan := make(chan tokenStruct, 5) 990 go processResponse(context.Background(), &sess, tokchan, nil) 991 for tok := range tokchan { 992 switch token := tok.(type) { 993 case sspiMsg: 994 sspi_msg, err := auth.NextBytes(token) 995 if err != nil { 996 return nil, err 997 } 998 if sspi_msg != nil && len(sspi_msg) > 0 { 999 outbuf.BeginPacket(packSSPIMessage, false) 1000 _, err = outbuf.Write(sspi_msg) 1001 if err != nil { 1002 return nil, err 1003 } 1004 err = outbuf.FinishPacket() 1005 if err != nil { 1006 return nil, err 1007 } 1008 sspi_msg = nil 1009 } 1010 case loginAckStruct: 1011 success = true 1012 sess.loginAck = token 1013 case error: 1014 return nil, fmt.Errorf("Login error: %s", token.Error()) 1015 case doneStruct: 1016 if token.isError() { 1017 return nil, fmt.Errorf("Login error: %s", token.getError()) 1018 } 1019 goto loginEnd 1020 } 1021 } 1022 } 1023loginEnd: 1024 if !success { 1025 return nil, fmt.Errorf("Login failed") 1026 } 1027 if sess.routedServer != "" { 1028 toconn.Close() 1029 p.host = sess.routedServer 1030 p.port = uint64(sess.routedPort) 1031 if !p.hostInCertificateProvided { 1032 p.hostInCertificate = sess.routedServer 1033 } 1034 goto initiate_connection 1035 } 1036 return &sess, nil 1037} 1038