1package mssql
2
3import (
4	"context"
5	"crypto/tls"
6	"crypto/x509"
7	"encoding/binary"
8	"errors"
9	"fmt"
10	"io"
11	"io/ioutil"
12	"net"
13	"sort"
14	"strconv"
15	"strings"
16	"unicode/utf16"
17	"unicode/utf8"
18)
19
20func parseInstances(msg []byte) map[string]map[string]string {
21	results := map[string]map[string]string{}
22	if len(msg) > 3 && msg[0] == 5 {
23		out_s := string(msg[3:])
24		tokens := strings.Split(out_s, ";")
25		instdict := map[string]string{}
26		got_name := false
27		var name string
28		for _, token := range tokens {
29			if got_name {
30				instdict[name] = token
31				got_name = false
32			} else {
33				name = token
34				if len(name) == 0 {
35					if len(instdict) == 0 {
36						break
37					}
38					results[strings.ToUpper(instdict["InstanceName"])] = instdict
39					instdict = map[string]string{}
40					continue
41				}
42				got_name = true
43			}
44		}
45	}
46	return results
47}
48
49func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
50	conn, err := d.DialContext(ctx, "udp", address+":1434")
51	if err != nil {
52		return nil, err
53	}
54	defer conn.Close()
55	deadline, _ := ctx.Deadline()
56	conn.SetDeadline(deadline)
57	_, err = conn.Write([]byte{3})
58	if err != nil {
59		return nil, err
60	}
61	var resp = make([]byte, 16*1024-1)
62	read, err := conn.Read(resp)
63	if err != nil {
64		return nil, err
65	}
66	return parseInstances(resp[:read]), nil
67}
68
69// tds versions
70const (
71	verTDS70     = 0x70000000
72	verTDS71     = 0x71000000
73	verTDS71rev1 = 0x71000001
74	verTDS72     = 0x72090002
75	verTDS73A    = 0x730A0003
76	verTDS73     = verTDS73A
77	verTDS73B    = 0x730B0003
78	verTDS74     = 0x74000004
79)
80
81// packet types
82// https://msdn.microsoft.com/en-us/library/dd304214.aspx
83const (
84	packSQLBatch   packetType = 1
85	packRPCRequest            = 3
86	packReply                 = 4
87
88	// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
89	// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
90	packAttention = 6
91
92	packBulkLoadBCP = 7
93	packTransMgrReq = 14
94	packNormal      = 15
95	packLogin7      = 16
96	packSSPIMessage = 17
97	packPrelogin    = 18
98)
99
100// prelogin fields
101// http://msdn.microsoft.com/en-us/library/dd357559.aspx
102const (
103	preloginVERSION         = 0
104	preloginENCRYPTION      = 1
105	preloginINSTOPT         = 2
106	preloginTHREADID        = 3
107	preloginMARS            = 4
108	preloginTRACEID         = 5
109	preloginFEDAUTHREQUIRED = 6
110	preloginNONCEOPT        = 7
111	preloginTERMINATOR      = 0xff
112)
113
114const (
115	encryptOff    = 0 // Encryption is available but off.
116	encryptOn     = 1 // Encryption is available and on.
117	encryptNotSup = 2 // Encryption is not available.
118	encryptReq    = 3 // Encryption is required.
119)
120
121type tdsSession struct {
122	buf          *tdsBuffer
123	loginAck     loginAckStruct
124	database     string
125	partner      string
126	columns      []columnStruct
127	tranid       uint64
128	logFlags     uint64
129	log          optionalLogger
130	routedServer string
131	routedPort   uint16
132}
133
134const (
135	logErrors      = 1
136	logMessages    = 2
137	logRows        = 4
138	logSQL         = 8
139	logParams      = 16
140	logTransaction = 32
141	logDebug       = 64
142)
143
144type columnStruct struct {
145	UserType uint32
146	Flags    uint16
147	ColName  string
148	ti       typeInfo
149}
150
151type keySlice []uint8
152
153func (p keySlice) Len() int           { return len(p) }
154func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
155func (p keySlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
156
157// http://msdn.microsoft.com/en-us/library/dd357559.aspx
158func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
159	var err error
160
161	w.BeginPacket(packPrelogin, false)
162	offset := uint16(5*len(fields) + 1)
163	keys := make(keySlice, 0, len(fields))
164	for k, _ := range fields {
165		keys = append(keys, k)
166	}
167	sort.Sort(keys)
168	// writing header
169	for _, k := range keys {
170		err = w.WriteByte(k)
171		if err != nil {
172			return err
173		}
174		err = binary.Write(w, binary.BigEndian, offset)
175		if err != nil {
176			return err
177		}
178		v := fields[k]
179		size := uint16(len(v))
180		err = binary.Write(w, binary.BigEndian, size)
181		if err != nil {
182			return err
183		}
184		offset += size
185	}
186	err = w.WriteByte(preloginTERMINATOR)
187	if err != nil {
188		return err
189	}
190	// writing values
191	for _, k := range keys {
192		v := fields[k]
193		written, err := w.Write(v)
194		if err != nil {
195			return err
196		}
197		if written != len(v) {
198			return errors.New("Write method didn't write the whole value")
199		}
200	}
201	return w.FinishPacket()
202}
203
204func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
205	packet_type, err := r.BeginRead()
206	if err != nil {
207		return nil, err
208	}
209	struct_buf, err := ioutil.ReadAll(r)
210	if err != nil {
211		return nil, err
212	}
213	if packet_type != 4 {
214		return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
215	}
216	offset := 0
217	results := map[uint8][]byte{}
218	for true {
219		rec_type := struct_buf[offset]
220		if rec_type == preloginTERMINATOR {
221			break
222		}
223
224		rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
225		rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
226		value := struct_buf[rec_offset : rec_offset+rec_len]
227		results[rec_type] = value
228		offset += 5
229	}
230	return results, nil
231}
232
233// OptionFlags2
234// http://msdn.microsoft.com/en-us/library/dd304019.aspx
235const (
236	fLanguageFatal = 1
237	fODBC          = 2
238	fTransBoundary = 4
239	fCacheConnect  = 8
240	fIntSecurity   = 0x80
241)
242
243// TypeFlags
244const (
245	// 4 bits for fSQLType
246	// 1 bit for fOLEDB
247	fReadOnlyIntent = 32
248)
249
250// OptionFlags3
251// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
252const (
253	fExtension = 0x10
254)
255
256type login struct {
257	TDSVersion     uint32
258	PacketSize     uint32
259	ClientProgVer  uint32
260	ClientPID      uint32
261	ConnectionID   uint32
262	OptionFlags1   uint8
263	OptionFlags2   uint8
264	TypeFlags      uint8
265	OptionFlags3   uint8
266	ClientTimeZone int32
267	ClientLCID     uint32
268	HostName       string
269	UserName       string
270	Password       string
271	AppName        string
272	ServerName     string
273	CtlIntName     string
274	Language       string
275	Database       string
276	ClientID       [6]byte
277	SSPI           []byte
278	AtchDBFile     string
279	ChangePassword string
280	FeatureExt     featureExts
281}
282
283type featureExts struct {
284	features map[byte]featureExt
285}
286
287type featureExt interface {
288	featureID() byte
289	toBytes() []byte
290}
291
292func (e *featureExts) Add(f featureExt) error {
293	if f == nil {
294		return nil
295	}
296	id := f.featureID()
297	if _, exists := e.features[id]; exists {
298		f := "Login error: Feature with ID '%v' is already present in FeatureExt block."
299		return fmt.Errorf(f, id)
300	}
301	if e.features == nil {
302		e.features = make(map[byte]featureExt)
303	}
304	e.features[id] = f
305	return nil
306}
307
308func (e featureExts) toBytes() []byte {
309	if len(e.features) == 0 {
310		return nil
311	}
312	var d []byte
313	for featureID, f := range e.features {
314		featureData := f.toBytes()
315
316		hdr := make([]byte, 5)
317		hdr[0] = featureID                                               // FedAuth feature extension BYTE
318		binary.LittleEndian.PutUint32(hdr[1:], uint32(len(featureData))) // FeatureDataLen DWORD
319		d = append(d, hdr...)
320
321		d = append(d, featureData...) // FeatureData *BYTE
322	}
323	if d != nil {
324		d = append(d, 0xff) // Terminator
325	}
326	return d
327}
328
329type featureExtFedAuthSTS struct {
330	FedAuthEcho  bool
331	FedAuthToken string
332	Nonce        []byte
333}
334
335func (e *featureExtFedAuthSTS) featureID() byte {
336	return 0x02
337}
338
339func (e *featureExtFedAuthSTS) toBytes() []byte {
340	if e == nil {
341		return nil
342	}
343
344	options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT
345	if e.FedAuthEcho {
346		options |= 1 // fFedAuthEcho
347	}
348
349	d := make([]byte, 5)
350	d[0] = options
351
352	// looks like string in
353	// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
354	tokenBytes := str2ucs2(e.FedAuthToken)
355	binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
356	d = append(d, tokenBytes...)
357
358	if len(e.Nonce) == 32 {
359		d = append(d, e.Nonce...)
360	}
361
362	return d
363}
364
365type loginHeader struct {
366	Length               uint32
367	TDSVersion           uint32
368	PacketSize           uint32
369	ClientProgVer        uint32
370	ClientPID            uint32
371	ConnectionID         uint32
372	OptionFlags1         uint8
373	OptionFlags2         uint8
374	TypeFlags            uint8
375	OptionFlags3         uint8
376	ClientTimeZone       int32
377	ClientLCID           uint32
378	HostNameOffset       uint16
379	HostNameLength       uint16
380	UserNameOffset       uint16
381	UserNameLength       uint16
382	PasswordOffset       uint16
383	PasswordLength       uint16
384	AppNameOffset        uint16
385	AppNameLength        uint16
386	ServerNameOffset     uint16
387	ServerNameLength     uint16
388	ExtensionOffset      uint16
389	ExtensionLength      uint16
390	CtlIntNameOffset     uint16
391	CtlIntNameLength     uint16
392	LanguageOffset       uint16
393	LanguageLength       uint16
394	DatabaseOffset       uint16
395	DatabaseLength       uint16
396	ClientID             [6]byte
397	SSPIOffset           uint16
398	SSPILength           uint16
399	AtchDBFileOffset     uint16
400	AtchDBFileLength     uint16
401	ChangePasswordOffset uint16
402	ChangePasswordLength uint16
403	SSPILongLength       uint32
404}
405
406// convert Go string to UTF-16 encoded []byte (littleEndian)
407// done manually rather than using bytes and binary packages
408// for performance reasons
409func str2ucs2(s string) []byte {
410	res := utf16.Encode([]rune(s))
411	ucs2 := make([]byte, 2*len(res))
412	for i := 0; i < len(res); i++ {
413		ucs2[2*i] = byte(res[i])
414		ucs2[2*i+1] = byte(res[i] >> 8)
415	}
416	return ucs2
417}
418
419func ucs22str(s []byte) (string, error) {
420	if len(s)%2 != 0 {
421		return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
422	}
423	buf := make([]uint16, len(s)/2)
424	for i := 0; i < len(s); i += 2 {
425		buf[i/2] = binary.LittleEndian.Uint16(s[i:])
426	}
427	return string(utf16.Decode(buf)), nil
428}
429
430func manglePassword(password string) []byte {
431	var ucs2password []byte = str2ucs2(password)
432	for i, ch := range ucs2password {
433		ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
434	}
435	return ucs2password
436}
437
438// http://msdn.microsoft.com/en-us/library/dd304019.aspx
439func sendLogin(w *tdsBuffer, login login) error {
440	w.BeginPacket(packLogin7, false)
441	hostname := str2ucs2(login.HostName)
442	username := str2ucs2(login.UserName)
443	password := manglePassword(login.Password)
444	appname := str2ucs2(login.AppName)
445	servername := str2ucs2(login.ServerName)
446	ctlintname := str2ucs2(login.CtlIntName)
447	language := str2ucs2(login.Language)
448	database := str2ucs2(login.Database)
449	atchdbfile := str2ucs2(login.AtchDBFile)
450	changepassword := str2ucs2(login.ChangePassword)
451	featureExt := login.FeatureExt.toBytes()
452
453	hdr := loginHeader{
454		TDSVersion:           login.TDSVersion,
455		PacketSize:           login.PacketSize,
456		ClientProgVer:        login.ClientProgVer,
457		ClientPID:            login.ClientPID,
458		ConnectionID:         login.ConnectionID,
459		OptionFlags1:         login.OptionFlags1,
460		OptionFlags2:         login.OptionFlags2,
461		TypeFlags:            login.TypeFlags,
462		OptionFlags3:         login.OptionFlags3,
463		ClientTimeZone:       login.ClientTimeZone,
464		ClientLCID:           login.ClientLCID,
465		HostNameLength:       uint16(utf8.RuneCountInString(login.HostName)),
466		UserNameLength:       uint16(utf8.RuneCountInString(login.UserName)),
467		PasswordLength:       uint16(utf8.RuneCountInString(login.Password)),
468		AppNameLength:        uint16(utf8.RuneCountInString(login.AppName)),
469		ServerNameLength:     uint16(utf8.RuneCountInString(login.ServerName)),
470		CtlIntNameLength:     uint16(utf8.RuneCountInString(login.CtlIntName)),
471		LanguageLength:       uint16(utf8.RuneCountInString(login.Language)),
472		DatabaseLength:       uint16(utf8.RuneCountInString(login.Database)),
473		ClientID:             login.ClientID,
474		SSPILength:           uint16(len(login.SSPI)),
475		AtchDBFileLength:     uint16(utf8.RuneCountInString(login.AtchDBFile)),
476		ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
477	}
478	offset := uint16(binary.Size(hdr))
479	hdr.HostNameOffset = offset
480	offset += uint16(len(hostname))
481	hdr.UserNameOffset = offset
482	offset += uint16(len(username))
483	hdr.PasswordOffset = offset
484	offset += uint16(len(password))
485	hdr.AppNameOffset = offset
486	offset += uint16(len(appname))
487	hdr.ServerNameOffset = offset
488	offset += uint16(len(servername))
489	hdr.CtlIntNameOffset = offset
490	offset += uint16(len(ctlintname))
491	hdr.LanguageOffset = offset
492	offset += uint16(len(language))
493	hdr.DatabaseOffset = offset
494	offset += uint16(len(database))
495	hdr.SSPIOffset = offset
496	offset += uint16(len(login.SSPI))
497	hdr.AtchDBFileOffset = offset
498	offset += uint16(len(atchdbfile))
499	hdr.ChangePasswordOffset = offset
500	offset += uint16(len(changepassword))
501
502	featureExtOffset := uint32(0)
503	featureExtLen := len(featureExt)
504	if featureExtLen > 0 {
505		hdr.OptionFlags3 |= fExtension
506		hdr.ExtensionOffset = offset
507		hdr.ExtensionLength = 4
508		offset += hdr.ExtensionLength // DWORD
509		featureExtOffset = uint32(offset)
510	}
511	hdr.Length = uint32(offset) + uint32(featureExtLen)
512
513	var err error
514	err = binary.Write(w, binary.LittleEndian, &hdr)
515	if err != nil {
516		return err
517	}
518	_, err = w.Write(hostname)
519	if err != nil {
520		return err
521	}
522	_, err = w.Write(username)
523	if err != nil {
524		return err
525	}
526	_, err = w.Write(password)
527	if err != nil {
528		return err
529	}
530	_, err = w.Write(appname)
531	if err != nil {
532		return err
533	}
534	_, err = w.Write(servername)
535	if err != nil {
536		return err
537	}
538	_, err = w.Write(ctlintname)
539	if err != nil {
540		return err
541	}
542	_, err = w.Write(language)
543	if err != nil {
544		return err
545	}
546	_, err = w.Write(database)
547	if err != nil {
548		return err
549	}
550	_, err = w.Write(login.SSPI)
551	if err != nil {
552		return err
553	}
554	_, err = w.Write(atchdbfile)
555	if err != nil {
556		return err
557	}
558	_, err = w.Write(changepassword)
559	if err != nil {
560		return err
561	}
562	if featureExtOffset > 0 {
563		err = binary.Write(w, binary.LittleEndian, featureExtOffset)
564		if err != nil {
565			return err
566		}
567		_, err = w.Write(featureExt)
568		if err != nil {
569			return err
570		}
571	}
572	return w.FinishPacket()
573}
574
575func readUcs2(r io.Reader, numchars int) (res string, err error) {
576	buf := make([]byte, numchars*2)
577	_, err = io.ReadFull(r, buf)
578	if err != nil {
579		return "", err
580	}
581	return ucs22str(buf)
582}
583
584func readUsVarChar(r io.Reader) (res string, err error) {
585	numchars, err := readUshort(r)
586	if err != nil {
587		return
588	}
589	return readUcs2(r, int(numchars))
590}
591
592func writeUsVarChar(w io.Writer, s string) (err error) {
593	buf := str2ucs2(s)
594	var numchars int = len(buf) / 2
595	if numchars > 0xffff {
596		panic("invalid size for US_VARCHAR")
597	}
598	err = binary.Write(w, binary.LittleEndian, uint16(numchars))
599	if err != nil {
600		return
601	}
602	_, err = w.Write(buf)
603	return
604}
605
606func readBVarChar(r io.Reader) (res string, err error) {
607	numchars, err := readByte(r)
608	if err != nil {
609		return "", err
610	}
611
612	// A zero length could be returned, return an empty string
613	if numchars == 0 {
614		return "", nil
615	}
616	return readUcs2(r, int(numchars))
617}
618
619func writeBVarChar(w io.Writer, s string) (err error) {
620	buf := str2ucs2(s)
621	var numchars int = len(buf) / 2
622	if numchars > 0xff {
623		panic("invalid size for B_VARCHAR")
624	}
625	err = binary.Write(w, binary.LittleEndian, uint8(numchars))
626	if err != nil {
627		return
628	}
629	_, err = w.Write(buf)
630	return
631}
632
633func readBVarByte(r io.Reader) (res []byte, err error) {
634	length, err := readByte(r)
635	if err != nil {
636		return
637	}
638	res = make([]byte, length)
639	_, err = io.ReadFull(r, res)
640	return
641}
642
643func readUshort(r io.Reader) (res uint16, err error) {
644	err = binary.Read(r, binary.LittleEndian, &res)
645	return
646}
647
648func readByte(r io.Reader) (res byte, err error) {
649	var b [1]byte
650	_, err = r.Read(b[:])
651	res = b[0]
652	return
653}
654
655// Packet Data Stream Headers
656// http://msdn.microsoft.com/en-us/library/dd304953.aspx
657type headerStruct struct {
658	hdrtype uint16
659	data    []byte
660}
661
662const (
663	dataStmHdrQueryNotif    = 1 // query notifications
664	dataStmHdrTransDescr    = 2 // MARS transaction descriptor (required)
665	dataStmHdrTraceActivity = 3
666)
667
668// Query Notifications Header
669// http://msdn.microsoft.com/en-us/library/dd304949.aspx
670type queryNotifHdr struct {
671	notifyId      string
672	ssbDeployment string
673	notifyTimeout uint32
674}
675
676func (hdr queryNotifHdr) pack() (res []byte) {
677	notifyId := str2ucs2(hdr.notifyId)
678	ssbDeployment := str2ucs2(hdr.ssbDeployment)
679
680	res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
681	b := res
682
683	binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
684	b = b[2:]
685	copy(b, notifyId)
686	b = b[len(notifyId):]
687
688	binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
689	b = b[2:]
690	copy(b, ssbDeployment)
691	b = b[len(ssbDeployment):]
692
693	binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
694
695	return res
696}
697
698// MARS Transaction Descriptor Header
699// http://msdn.microsoft.com/en-us/library/dd340515.aspx
700type transDescrHdr struct {
701	transDescr        uint64 // transaction descriptor returned from ENVCHANGE
702	outstandingReqCnt uint32 // outstanding request count
703}
704
705func (hdr transDescrHdr) pack() (res []byte) {
706	res = make([]byte, 8+4)
707	binary.LittleEndian.PutUint64(res, hdr.transDescr)
708	binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
709	return res
710}
711
712func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
713	// Calculating total length.
714	var totallen uint32 = 4
715	for _, hdr := range headers {
716		totallen += 4 + 2 + uint32(len(hdr.data))
717	}
718	// writing
719	err = binary.Write(w, binary.LittleEndian, totallen)
720	if err != nil {
721		return err
722	}
723	for _, hdr := range headers {
724		var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
725		err = binary.Write(w, binary.LittleEndian, headerlen)
726		if err != nil {
727			return err
728		}
729		err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
730		if err != nil {
731			return err
732		}
733		_, err = w.Write(hdr.data)
734		if err != nil {
735			return err
736		}
737	}
738	return nil
739}
740
741func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
742	buf.BeginPacket(packSQLBatch, resetSession)
743
744	if err = writeAllHeaders(buf, headers); err != nil {
745		return
746	}
747
748	_, err = buf.Write(str2ucs2(sqltext))
749	if err != nil {
750		return
751	}
752	return buf.FinishPacket()
753}
754
755// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
756// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
757func sendAttention(buf *tdsBuffer) error {
758	buf.BeginPacket(packAttention, false)
759	return buf.FinishPacket()
760}
761
762type auth interface {
763	InitialBytes() ([]byte, error)
764	NextBytes([]byte) ([]byte, error)
765	Free()
766}
767
768// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
769// list of IP addresses.  So if there is more than one, try them all and
770// use the first one that allows a connection.
771func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
772	var ips []net.IP
773	ips, err = net.LookupIP(p.host)
774	if err != nil {
775		ip := net.ParseIP(p.host)
776		if ip == nil {
777			return nil, err
778		}
779		ips = []net.IP{ip}
780	}
781	if len(ips) == 1 {
782		d := c.getDialer(&p)
783		addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.port))))
784		conn, err = d.DialContext(ctx, "tcp", addr)
785
786	} else {
787		//Try Dials in parallel to avoid waiting for timeouts.
788		connChan := make(chan net.Conn, len(ips))
789		errChan := make(chan error, len(ips))
790		portStr := strconv.Itoa(int(resolveServerPort(p.port)))
791		for _, ip := range ips {
792			go func(ip net.IP) {
793				d := c.getDialer(&p)
794				addr := net.JoinHostPort(ip.String(), portStr)
795				conn, err := d.DialContext(ctx, "tcp", addr)
796				if err == nil {
797					connChan <- conn
798				} else {
799					errChan <- err
800				}
801			}(ip)
802		}
803		// Wait for either the *first* successful connection, or all the errors
804	wait_loop:
805		for i, _ := range ips {
806			select {
807			case conn = <-connChan:
808				// Got a connection to use, close any others
809				go func(n int) {
810					for i := 0; i < n; i++ {
811						select {
812						case conn := <-connChan:
813							conn.Close()
814						case <-errChan:
815						}
816					}
817				}(len(ips) - i - 1)
818				// Remove any earlier errors we may have collected
819				err = nil
820				break wait_loop
821			case err = <-errChan:
822			}
823		}
824	}
825	// Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
826	if conn == nil {
827		f := "Unable to open tcp connection with host '%v:%v': %v"
828		return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error())
829	}
830	return conn, err
831}
832
833func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
834	dialCtx := ctx
835	if p.dial_timeout > 0 {
836		var cancel func()
837		dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
838		defer cancel()
839	}
840	// if instance is specified use instance resolution service
841	if p.instance != "" && p.port == 0 {
842		p.instance = strings.ToUpper(p.instance)
843		d := c.getDialer(&p)
844		instances, err := getInstances(dialCtx, d, p.host)
845		if err != nil {
846			f := "Unable to get instances from Sql Server Browser on host %v: %v"
847			return nil, fmt.Errorf(f, p.host, err.Error())
848		}
849		strport, ok := instances[p.instance]["tcp"]
850		if !ok {
851			f := "No instance matching '%v' returned from host '%v'"
852			return nil, fmt.Errorf(f, p.instance, p.host)
853		}
854		port, err := strconv.ParseUint(strport, 0, 16)
855		if err != nil {
856			f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
857			return nil, fmt.Errorf(f, strport, err.Error())
858		}
859		p.port = port
860	}
861
862initiate_connection:
863	conn, err := dialConnection(dialCtx, c, p)
864	if err != nil {
865		return nil, err
866	}
867
868	toconn := newTimeoutConn(conn, p.conn_timeout)
869
870	outbuf := newTdsBuffer(p.packetSize, toconn)
871	sess := tdsSession{
872		buf:      outbuf,
873		log:      log,
874		logFlags: p.logFlags,
875	}
876
877	instance_buf := []byte(p.instance)
878	instance_buf = append(instance_buf, 0) // zero terminate instance name
879	var encrypt byte
880	if p.disableEncryption {
881		encrypt = encryptNotSup
882	} else if p.encrypt {
883		encrypt = encryptOn
884	} else {
885		encrypt = encryptOff
886	}
887	fields := map[uint8][]byte{
888		preloginVERSION:    {0, 0, 0, 0, 0, 0},
889		preloginENCRYPTION: {encrypt},
890		preloginINSTOPT:    instance_buf,
891		preloginTHREADID:   {0, 0, 0, 0},
892		preloginMARS:       {0}, // MARS disabled
893	}
894
895	err = writePrelogin(outbuf, fields)
896	if err != nil {
897		return nil, err
898	}
899
900	fields, err = readPrelogin(outbuf)
901	if err != nil {
902		return nil, err
903	}
904
905	encryptBytes, ok := fields[preloginENCRYPTION]
906	if !ok {
907		return nil, fmt.Errorf("Encrypt negotiation failed")
908	}
909	encrypt = encryptBytes[0]
910	if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
911		return nil, fmt.Errorf("Server does not support encryption")
912	}
913
914	if encrypt != encryptNotSup {
915		var config tls.Config
916		if p.certificate != "" {
917			pem, err := ioutil.ReadFile(p.certificate)
918			if err != nil {
919				return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
920			}
921			certs := x509.NewCertPool()
922			certs.AppendCertsFromPEM(pem)
923			config.RootCAs = certs
924		}
925		if p.trustServerCertificate {
926			config.InsecureSkipVerify = true
927		}
928		config.ServerName = p.hostInCertificate
929		// fix for https://github.com/denisenkom/go-mssqldb/issues/166
930		// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
931		// while SQL Server seems to expect one TCP segment per encrypted TDS package.
932		// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
933		config.DynamicRecordSizingDisabled = true
934		// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
935		handshakeConn := tlsHandshakeConn{buf: outbuf}
936		passthrough := passthroughConn{c: &handshakeConn}
937		tlsConn := tls.Client(&passthrough, &config)
938		err = tlsConn.Handshake()
939		passthrough.c = toconn
940		outbuf.transport = tlsConn
941		if err != nil {
942			return nil, fmt.Errorf("TLS Handshake failed: %v", err)
943		}
944		if encrypt == encryptOff {
945			outbuf.afterFirst = func() {
946				outbuf.transport = toconn
947			}
948		}
949	}
950
951	login := login{
952		TDSVersion:   verTDS74,
953		PacketSize:   uint32(outbuf.PackageSize()),
954		Database:     p.database,
955		OptionFlags2: fODBC, // to get unlimited TEXTSIZE
956		HostName:     p.workstation,
957		ServerName:   p.host,
958		AppName:      p.appname,
959		TypeFlags:    p.typeFlags,
960	}
961	auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation)
962	switch {
963	case p.fedAuthAccessToken != "": // accesstoken ignores user/password
964		featurext := &featureExtFedAuthSTS{
965			FedAuthEcho:  len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1,
966			FedAuthToken: p.fedAuthAccessToken,
967			Nonce:        fields[preloginNONCEOPT],
968		}
969		login.FeatureExt.Add(featurext)
970	case authOk:
971		login.SSPI, err = auth.InitialBytes()
972		if err != nil {
973			return nil, err
974		}
975		login.OptionFlags2 |= fIntSecurity
976		defer auth.Free()
977	default:
978		login.UserName = p.user
979		login.Password = p.password
980	}
981	err = sendLogin(outbuf, login)
982	if err != nil {
983		return nil, err
984	}
985
986	// processing login response
987	success := false
988	for {
989		tokchan := make(chan tokenStruct, 5)
990		go processResponse(context.Background(), &sess, tokchan, nil)
991		for tok := range tokchan {
992			switch token := tok.(type) {
993			case sspiMsg:
994				sspi_msg, err := auth.NextBytes(token)
995				if err != nil {
996					return nil, err
997				}
998				if sspi_msg != nil && len(sspi_msg) > 0 {
999					outbuf.BeginPacket(packSSPIMessage, false)
1000					_, err = outbuf.Write(sspi_msg)
1001					if err != nil {
1002						return nil, err
1003					}
1004					err = outbuf.FinishPacket()
1005					if err != nil {
1006						return nil, err
1007					}
1008					sspi_msg = nil
1009				}
1010			case loginAckStruct:
1011				success = true
1012				sess.loginAck = token
1013			case error:
1014				return nil, fmt.Errorf("Login error: %s", token.Error())
1015			case doneStruct:
1016				if token.isError() {
1017					return nil, fmt.Errorf("Login error: %s", token.getError())
1018				}
1019				goto loginEnd
1020			}
1021		}
1022	}
1023loginEnd:
1024	if !success {
1025		return nil, fmt.Errorf("Login failed")
1026	}
1027	if sess.routedServer != "" {
1028		toconn.Close()
1029		p.host = sess.routedServer
1030		p.port = uint64(sess.routedPort)
1031		if !p.hostInCertificateProvided {
1032			p.hostInCertificate = sess.routedServer
1033		}
1034		goto initiate_connection
1035	}
1036	return &sess, nil
1037}
1038