1package mssql
2
3import (
4	"context"
5	"crypto/tls"
6	"crypto/x509"
7	"encoding/binary"
8	"errors"
9	"fmt"
10	"io"
11	"io/ioutil"
12	"net"
13	"net/url"
14	"os"
15	"sort"
16	"strconv"
17	"strings"
18	"time"
19	"unicode"
20	"unicode/utf16"
21	"unicode/utf8"
22)
23
24func parseInstances(msg []byte) map[string]map[string]string {
25	results := map[string]map[string]string{}
26	if len(msg) > 3 && msg[0] == 5 {
27		out_s := string(msg[3:])
28		tokens := strings.Split(out_s, ";")
29		instdict := map[string]string{}
30		got_name := false
31		var name string
32		for _, token := range tokens {
33			if got_name {
34				instdict[name] = token
35				got_name = false
36			} else {
37				name = token
38				if len(name) == 0 {
39					if len(instdict) == 0 {
40						break
41					}
42					results[strings.ToUpper(instdict["InstanceName"])] = instdict
43					instdict = map[string]string{}
44					continue
45				}
46				got_name = true
47			}
48		}
49	}
50	return results
51}
52
53func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
54	maxTime := 5 * time.Second
55	ctx, cancel := context.WithTimeout(ctx, maxTime)
56	defer cancel()
57	conn, err := d.DialContext(ctx, "udp", address+":1434")
58	if err != nil {
59		return nil, err
60	}
61	defer conn.Close()
62	conn.SetDeadline(time.Now().Add(maxTime))
63	_, err = conn.Write([]byte{3})
64	if err != nil {
65		return nil, err
66	}
67	var resp = make([]byte, 16*1024-1)
68	read, err := conn.Read(resp)
69	if err != nil {
70		return nil, err
71	}
72	return parseInstances(resp[:read]), nil
73}
74
75// tds versions
76const (
77	verTDS70     = 0x70000000
78	verTDS71     = 0x71000000
79	verTDS71rev1 = 0x71000001
80	verTDS72     = 0x72090002
81	verTDS73A    = 0x730A0003
82	verTDS73     = verTDS73A
83	verTDS73B    = 0x730B0003
84	verTDS74     = 0x74000004
85)
86
87// packet types
88// https://msdn.microsoft.com/en-us/library/dd304214.aspx
89const (
90	packSQLBatch   packetType = 1
91	packRPCRequest            = 3
92	packReply                 = 4
93
94	// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
95	// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
96	packAttention = 6
97
98	packBulkLoadBCP = 7
99	packTransMgrReq = 14
100	packNormal      = 15
101	packLogin7      = 16
102	packSSPIMessage = 17
103	packPrelogin    = 18
104)
105
106// prelogin fields
107// http://msdn.microsoft.com/en-us/library/dd357559.aspx
108const (
109	preloginVERSION    = 0
110	preloginENCRYPTION = 1
111	preloginINSTOPT    = 2
112	preloginTHREADID   = 3
113	preloginMARS       = 4
114	preloginTRACEID    = 5
115	preloginTERMINATOR = 0xff
116)
117
118const (
119	encryptOff    = 0 // Encryption is available but off.
120	encryptOn     = 1 // Encryption is available and on.
121	encryptNotSup = 2 // Encryption is not available.
122	encryptReq    = 3 // Encryption is required.
123)
124
125type tdsSession struct {
126	buf          *tdsBuffer
127	loginAck     loginAckStruct
128	database     string
129	partner      string
130	columns      []columnStruct
131	tranid       uint64
132	logFlags     uint64
133	log          optionalLogger
134	routedServer string
135	routedPort   uint16
136}
137
138const (
139	logErrors      = 1
140	logMessages    = 2
141	logRows        = 4
142	logSQL         = 8
143	logParams      = 16
144	logTransaction = 32
145	logDebug       = 64
146)
147
148type columnStruct struct {
149	UserType uint32
150	Flags    uint16
151	ColName  string
152	ti       typeInfo
153}
154
155type keySlice []uint8
156
157func (p keySlice) Len() int           { return len(p) }
158func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
159func (p keySlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
160
161// http://msdn.microsoft.com/en-us/library/dd357559.aspx
162func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
163	var err error
164
165	w.BeginPacket(packPrelogin, false)
166	offset := uint16(5*len(fields) + 1)
167	keys := make(keySlice, 0, len(fields))
168	for k, _ := range fields {
169		keys = append(keys, k)
170	}
171	sort.Sort(keys)
172	// writing header
173	for _, k := range keys {
174		err = w.WriteByte(k)
175		if err != nil {
176			return err
177		}
178		err = binary.Write(w, binary.BigEndian, offset)
179		if err != nil {
180			return err
181		}
182		v := fields[k]
183		size := uint16(len(v))
184		err = binary.Write(w, binary.BigEndian, size)
185		if err != nil {
186			return err
187		}
188		offset += size
189	}
190	err = w.WriteByte(preloginTERMINATOR)
191	if err != nil {
192		return err
193	}
194	// writing values
195	for _, k := range keys {
196		v := fields[k]
197		written, err := w.Write(v)
198		if err != nil {
199			return err
200		}
201		if written != len(v) {
202			return errors.New("Write method didn't write the whole value")
203		}
204	}
205	return w.FinishPacket()
206}
207
208func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
209	packet_type, err := r.BeginRead()
210	if err != nil {
211		return nil, err
212	}
213	struct_buf, err := ioutil.ReadAll(r)
214	if err != nil {
215		return nil, err
216	}
217	if packet_type != 4 {
218		return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
219	}
220	offset := 0
221	results := map[uint8][]byte{}
222	for true {
223		rec_type := struct_buf[offset]
224		if rec_type == preloginTERMINATOR {
225			break
226		}
227
228		rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
229		rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
230		value := struct_buf[rec_offset : rec_offset+rec_len]
231		results[rec_type] = value
232		offset += 5
233	}
234	return results, nil
235}
236
237// OptionFlags2
238// http://msdn.microsoft.com/en-us/library/dd304019.aspx
239const (
240	fLanguageFatal = 1
241	fODBC          = 2
242	fTransBoundary = 4
243	fCacheConnect  = 8
244	fIntSecurity   = 0x80
245)
246
247// TypeFlags
248const (
249	// 4 bits for fSQLType
250	// 1 bit for fOLEDB
251	fReadOnlyIntent = 32
252)
253
254type login struct {
255	TDSVersion     uint32
256	PacketSize     uint32
257	ClientProgVer  uint32
258	ClientPID      uint32
259	ConnectionID   uint32
260	OptionFlags1   uint8
261	OptionFlags2   uint8
262	TypeFlags      uint8
263	OptionFlags3   uint8
264	ClientTimeZone int32
265	ClientLCID     uint32
266	HostName       string
267	UserName       string
268	Password       string
269	AppName        string
270	ServerName     string
271	CtlIntName     string
272	Language       string
273	Database       string
274	ClientID       [6]byte
275	SSPI           []byte
276	AtchDBFile     string
277	ChangePassword string
278}
279
280type loginHeader struct {
281	Length               uint32
282	TDSVersion           uint32
283	PacketSize           uint32
284	ClientProgVer        uint32
285	ClientPID            uint32
286	ConnectionID         uint32
287	OptionFlags1         uint8
288	OptionFlags2         uint8
289	TypeFlags            uint8
290	OptionFlags3         uint8
291	ClientTimeZone       int32
292	ClientLCID           uint32
293	HostNameOffset       uint16
294	HostNameLength       uint16
295	UserNameOffset       uint16
296	UserNameLength       uint16
297	PasswordOffset       uint16
298	PasswordLength       uint16
299	AppNameOffset        uint16
300	AppNameLength        uint16
301	ServerNameOffset     uint16
302	ServerNameLength     uint16
303	ExtensionOffset      uint16
304	ExtensionLenght      uint16
305	CtlIntNameOffset     uint16
306	CtlIntNameLength     uint16
307	LanguageOffset       uint16
308	LanguageLength       uint16
309	DatabaseOffset       uint16
310	DatabaseLength       uint16
311	ClientID             [6]byte
312	SSPIOffset           uint16
313	SSPILength           uint16
314	AtchDBFileOffset     uint16
315	AtchDBFileLength     uint16
316	ChangePasswordOffset uint16
317	ChangePasswordLength uint16
318	SSPILongLength       uint32
319}
320
321// convert Go string to UTF-16 encoded []byte (littleEndian)
322// done manually rather than using bytes and binary packages
323// for performance reasons
324func str2ucs2(s string) []byte {
325	res := utf16.Encode([]rune(s))
326	ucs2 := make([]byte, 2*len(res))
327	for i := 0; i < len(res); i++ {
328		ucs2[2*i] = byte(res[i])
329		ucs2[2*i+1] = byte(res[i] >> 8)
330	}
331	return ucs2
332}
333
334func ucs22str(s []byte) (string, error) {
335	if len(s)%2 != 0 {
336		return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
337	}
338	buf := make([]uint16, len(s)/2)
339	for i := 0; i < len(s); i += 2 {
340		buf[i/2] = binary.LittleEndian.Uint16(s[i:])
341	}
342	return string(utf16.Decode(buf)), nil
343}
344
345func manglePassword(password string) []byte {
346	var ucs2password []byte = str2ucs2(password)
347	for i, ch := range ucs2password {
348		ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
349	}
350	return ucs2password
351}
352
353// http://msdn.microsoft.com/en-us/library/dd304019.aspx
354func sendLogin(w *tdsBuffer, login login) error {
355	w.BeginPacket(packLogin7, false)
356	hostname := str2ucs2(login.HostName)
357	username := str2ucs2(login.UserName)
358	password := manglePassword(login.Password)
359	appname := str2ucs2(login.AppName)
360	servername := str2ucs2(login.ServerName)
361	ctlintname := str2ucs2(login.CtlIntName)
362	language := str2ucs2(login.Language)
363	database := str2ucs2(login.Database)
364	atchdbfile := str2ucs2(login.AtchDBFile)
365	changepassword := str2ucs2(login.ChangePassword)
366	hdr := loginHeader{
367		TDSVersion:           login.TDSVersion,
368		PacketSize:           login.PacketSize,
369		ClientProgVer:        login.ClientProgVer,
370		ClientPID:            login.ClientPID,
371		ConnectionID:         login.ConnectionID,
372		OptionFlags1:         login.OptionFlags1,
373		OptionFlags2:         login.OptionFlags2,
374		TypeFlags:            login.TypeFlags,
375		OptionFlags3:         login.OptionFlags3,
376		ClientTimeZone:       login.ClientTimeZone,
377		ClientLCID:           login.ClientLCID,
378		HostNameLength:       uint16(utf8.RuneCountInString(login.HostName)),
379		UserNameLength:       uint16(utf8.RuneCountInString(login.UserName)),
380		PasswordLength:       uint16(utf8.RuneCountInString(login.Password)),
381		AppNameLength:        uint16(utf8.RuneCountInString(login.AppName)),
382		ServerNameLength:     uint16(utf8.RuneCountInString(login.ServerName)),
383		CtlIntNameLength:     uint16(utf8.RuneCountInString(login.CtlIntName)),
384		LanguageLength:       uint16(utf8.RuneCountInString(login.Language)),
385		DatabaseLength:       uint16(utf8.RuneCountInString(login.Database)),
386		ClientID:             login.ClientID,
387		SSPILength:           uint16(len(login.SSPI)),
388		AtchDBFileLength:     uint16(utf8.RuneCountInString(login.AtchDBFile)),
389		ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
390	}
391	offset := uint16(binary.Size(hdr))
392	hdr.HostNameOffset = offset
393	offset += uint16(len(hostname))
394	hdr.UserNameOffset = offset
395	offset += uint16(len(username))
396	hdr.PasswordOffset = offset
397	offset += uint16(len(password))
398	hdr.AppNameOffset = offset
399	offset += uint16(len(appname))
400	hdr.ServerNameOffset = offset
401	offset += uint16(len(servername))
402	hdr.CtlIntNameOffset = offset
403	offset += uint16(len(ctlintname))
404	hdr.LanguageOffset = offset
405	offset += uint16(len(language))
406	hdr.DatabaseOffset = offset
407	offset += uint16(len(database))
408	hdr.SSPIOffset = offset
409	offset += uint16(len(login.SSPI))
410	hdr.AtchDBFileOffset = offset
411	offset += uint16(len(atchdbfile))
412	hdr.ChangePasswordOffset = offset
413	offset += uint16(len(changepassword))
414	hdr.Length = uint32(offset)
415	var err error
416	err = binary.Write(w, binary.LittleEndian, &hdr)
417	if err != nil {
418		return err
419	}
420	_, err = w.Write(hostname)
421	if err != nil {
422		return err
423	}
424	_, err = w.Write(username)
425	if err != nil {
426		return err
427	}
428	_, err = w.Write(password)
429	if err != nil {
430		return err
431	}
432	_, err = w.Write(appname)
433	if err != nil {
434		return err
435	}
436	_, err = w.Write(servername)
437	if err != nil {
438		return err
439	}
440	_, err = w.Write(ctlintname)
441	if err != nil {
442		return err
443	}
444	_, err = w.Write(language)
445	if err != nil {
446		return err
447	}
448	_, err = w.Write(database)
449	if err != nil {
450		return err
451	}
452	_, err = w.Write(login.SSPI)
453	if err != nil {
454		return err
455	}
456	_, err = w.Write(atchdbfile)
457	if err != nil {
458		return err
459	}
460	_, err = w.Write(changepassword)
461	if err != nil {
462		return err
463	}
464	return w.FinishPacket()
465}
466
467func readUcs2(r io.Reader, numchars int) (res string, err error) {
468	buf := make([]byte, numchars*2)
469	_, err = io.ReadFull(r, buf)
470	if err != nil {
471		return "", err
472	}
473	return ucs22str(buf)
474}
475
476func readUsVarChar(r io.Reader) (res string, err error) {
477	var numchars uint16
478	err = binary.Read(r, binary.LittleEndian, &numchars)
479	if err != nil {
480		return "", err
481	}
482	return readUcs2(r, int(numchars))
483}
484
485func writeUsVarChar(w io.Writer, s string) (err error) {
486	buf := str2ucs2(s)
487	var numchars int = len(buf) / 2
488	if numchars > 0xffff {
489		panic("invalid size for US_VARCHAR")
490	}
491	err = binary.Write(w, binary.LittleEndian, uint16(numchars))
492	if err != nil {
493		return
494	}
495	_, err = w.Write(buf)
496	return
497}
498
499func readBVarChar(r io.Reader) (res string, err error) {
500	var numchars uint8
501	err = binary.Read(r, binary.LittleEndian, &numchars)
502	if err != nil {
503		return "", err
504	}
505
506	// A zero length could be returned, return an empty string
507	if numchars == 0 {
508		return "", nil
509	}
510	return readUcs2(r, int(numchars))
511}
512
513func writeBVarChar(w io.Writer, s string) (err error) {
514	buf := str2ucs2(s)
515	var numchars int = len(buf) / 2
516	if numchars > 0xff {
517		panic("invalid size for B_VARCHAR")
518	}
519	err = binary.Write(w, binary.LittleEndian, uint8(numchars))
520	if err != nil {
521		return
522	}
523	_, err = w.Write(buf)
524	return
525}
526
527func readBVarByte(r io.Reader) (res []byte, err error) {
528	var length uint8
529	err = binary.Read(r, binary.LittleEndian, &length)
530	if err != nil {
531		return
532	}
533	res = make([]byte, length)
534	_, err = io.ReadFull(r, res)
535	return
536}
537
538func readUshort(r io.Reader) (res uint16, err error) {
539	err = binary.Read(r, binary.LittleEndian, &res)
540	return
541}
542
543func readByte(r io.Reader) (res byte, err error) {
544	var b [1]byte
545	_, err = r.Read(b[:])
546	res = b[0]
547	return
548}
549
550// Packet Data Stream Headers
551// http://msdn.microsoft.com/en-us/library/dd304953.aspx
552type headerStruct struct {
553	hdrtype uint16
554	data    []byte
555}
556
557const (
558	dataStmHdrQueryNotif    = 1 // query notifications
559	dataStmHdrTransDescr    = 2 // MARS transaction descriptor (required)
560	dataStmHdrTraceActivity = 3
561)
562
563// Query Notifications Header
564// http://msdn.microsoft.com/en-us/library/dd304949.aspx
565type queryNotifHdr struct {
566	notifyId      string
567	ssbDeployment string
568	notifyTimeout uint32
569}
570
571func (hdr queryNotifHdr) pack() (res []byte) {
572	notifyId := str2ucs2(hdr.notifyId)
573	ssbDeployment := str2ucs2(hdr.ssbDeployment)
574
575	res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
576	b := res
577
578	binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
579	b = b[2:]
580	copy(b, notifyId)
581	b = b[len(notifyId):]
582
583	binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
584	b = b[2:]
585	copy(b, ssbDeployment)
586	b = b[len(ssbDeployment):]
587
588	binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
589
590	return res
591}
592
593// MARS Transaction Descriptor Header
594// http://msdn.microsoft.com/en-us/library/dd340515.aspx
595type transDescrHdr struct {
596	transDescr        uint64 // transaction descriptor returned from ENVCHANGE
597	outstandingReqCnt uint32 // outstanding request count
598}
599
600func (hdr transDescrHdr) pack() (res []byte) {
601	res = make([]byte, 8+4)
602	binary.LittleEndian.PutUint64(res, hdr.transDescr)
603	binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
604	return res
605}
606
607func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
608	// Calculating total length.
609	var totallen uint32 = 4
610	for _, hdr := range headers {
611		totallen += 4 + 2 + uint32(len(hdr.data))
612	}
613	// writing
614	err = binary.Write(w, binary.LittleEndian, totallen)
615	if err != nil {
616		return err
617	}
618	for _, hdr := range headers {
619		var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
620		err = binary.Write(w, binary.LittleEndian, headerlen)
621		if err != nil {
622			return err
623		}
624		err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
625		if err != nil {
626			return err
627		}
628		_, err = w.Write(hdr.data)
629		if err != nil {
630			return err
631		}
632	}
633	return nil
634}
635
636func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
637	buf.BeginPacket(packSQLBatch, resetSession)
638
639	if err = writeAllHeaders(buf, headers); err != nil {
640		return
641	}
642
643	_, err = buf.Write(str2ucs2(sqltext))
644	if err != nil {
645		return
646	}
647	return buf.FinishPacket()
648}
649
650// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
651// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
652func sendAttention(buf *tdsBuffer) error {
653	buf.BeginPacket(packAttention, false)
654	return buf.FinishPacket()
655}
656
657type connectParams struct {
658	logFlags               uint64
659	port                   uint64
660	host                   string
661	instance               string
662	database               string
663	user                   string
664	password               string
665	dial_timeout           time.Duration
666	conn_timeout           time.Duration
667	keepAlive              time.Duration
668	encrypt                bool
669	disableEncryption      bool
670	trustServerCertificate bool
671	certificate            string
672	hostInCertificate      string
673	serverSPN              string
674	workstation            string
675	appname                string
676	typeFlags              uint8
677	failOverPartner        string
678	failOverPort           uint64
679	packetSize             uint16
680}
681
682func splitConnectionString(dsn string) (res map[string]string) {
683	res = map[string]string{}
684	parts := strings.Split(dsn, ";")
685	for _, part := range parts {
686		if len(part) == 0 {
687			continue
688		}
689		lst := strings.SplitN(part, "=", 2)
690		name := strings.TrimSpace(strings.ToLower(lst[0]))
691		if len(name) == 0 {
692			continue
693		}
694		var value string = ""
695		if len(lst) > 1 {
696			value = strings.TrimSpace(lst[1])
697		}
698		res[name] = value
699	}
700	return res
701}
702
703// Splits a URL in the ODBC format
704func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
705	res := map[string]string{}
706
707	type parserState int
708	const (
709		// Before the start of a key
710		parserStateBeforeKey parserState = iota
711
712		// Inside a key
713		parserStateKey
714
715		// Beginning of a value. May be bare or braced
716		parserStateBeginValue
717
718		// Inside a bare value
719		parserStateBareValue
720
721		// Inside a braced value
722		parserStateBracedValue
723
724		// A closing brace inside a braced value.
725		// May be the end of the value or an escaped closing brace, depending on the next character
726		parserStateBracedValueClosingBrace
727
728		// After a value. Next character should be a semicolon or whitespace.
729		parserStateEndValue
730	)
731
732	var state = parserStateBeforeKey
733
734	var key string
735	var value string
736
737	for i, c := range dsn {
738		switch state {
739		case parserStateBeforeKey:
740			switch {
741			case c == '=':
742				return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
743			case !unicode.IsSpace(c) && c != ';':
744				state = parserStateKey
745				key += string(c)
746			}
747
748		case parserStateKey:
749			switch c {
750			case '=':
751				key = normalizeOdbcKey(key)
752				if len(key) == 0 {
753					return res, fmt.Errorf("Unexpected end of key at index %d.", i)
754				}
755
756				state = parserStateBeginValue
757
758			case ';':
759				// Key without value
760				key = normalizeOdbcKey(key)
761				if len(key) == 0 {
762					return res, fmt.Errorf("Unexpected end of key at index %d.", i)
763				}
764
765				res[key] = value
766				key = ""
767				value = ""
768				state = parserStateBeforeKey
769
770			default:
771				key += string(c)
772			}
773
774		case parserStateBeginValue:
775			switch {
776			case c == '{':
777				state = parserStateBracedValue
778			case c == ';':
779				// Empty value
780				res[key] = value
781				key = ""
782				state = parserStateBeforeKey
783			case unicode.IsSpace(c):
784				// Ignore whitespace
785			default:
786				state = parserStateBareValue
787				value += string(c)
788			}
789
790		case parserStateBareValue:
791			if c == ';' {
792				res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
793				key = ""
794				value = ""
795				state = parserStateBeforeKey
796			} else {
797				value += string(c)
798			}
799
800		case parserStateBracedValue:
801			if c == '}' {
802				state = parserStateBracedValueClosingBrace
803			} else {
804				value += string(c)
805			}
806
807		case parserStateBracedValueClosingBrace:
808			if c == '}' {
809				// Escaped closing brace
810				value += string(c)
811				state = parserStateBracedValue
812				continue
813			}
814
815			// End of braced value
816			res[key] = value
817			key = ""
818			value = ""
819
820			// This character is the first character past the end,
821			// so it needs to be parsed like the parserStateEndValue state.
822			state = parserStateEndValue
823			switch {
824			case c == ';':
825				state = parserStateBeforeKey
826			case unicode.IsSpace(c):
827				// Ignore whitespace
828			default:
829				return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
830			}
831
832		case parserStateEndValue:
833			switch {
834			case c == ';':
835				state = parserStateBeforeKey
836			case unicode.IsSpace(c):
837				// Ignore whitespace
838			default:
839				return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
840			}
841		}
842	}
843
844	switch state {
845	case parserStateBeforeKey: // Okay
846	case parserStateKey: // Unfinished key. Treat as key without value.
847		key = normalizeOdbcKey(key)
848		if len(key) == 0 {
849			return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
850		}
851		res[key] = value
852	case parserStateBeginValue: // Empty value
853		res[key] = value
854	case parserStateBareValue:
855		res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
856	case parserStateBracedValue:
857		return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
858	case parserStateBracedValueClosingBrace: // End of braced value
859		res[key] = value
860	case parserStateEndValue: // Okay
861	}
862
863	return res, nil
864}
865
866// Normalizes the given string as an ODBC-format key
867func normalizeOdbcKey(s string) string {
868	return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
869}
870
871// Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
872func splitConnectionStringURL(dsn string) (map[string]string, error) {
873	res := map[string]string{}
874
875	u, err := url.Parse(dsn)
876	if err != nil {
877		return res, err
878	}
879
880	if u.Scheme != "sqlserver" {
881		return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
882	}
883
884	if u.User != nil {
885		res["user id"] = u.User.Username()
886		p, exists := u.User.Password()
887		if exists {
888			res["password"] = p
889		}
890	}
891
892	host, port, err := net.SplitHostPort(u.Host)
893	if err != nil {
894		host = u.Host
895	}
896
897	if len(u.Path) > 0 {
898		res["server"] = host + "\\" + u.Path[1:]
899	} else {
900		res["server"] = host
901	}
902
903	if len(port) > 0 {
904		res["port"] = port
905	}
906
907	query := u.Query()
908	for k, v := range query {
909		if len(v) > 1 {
910			return res, fmt.Errorf("key %s provided more than once", k)
911		}
912		res[strings.ToLower(k)] = v[0]
913	}
914
915	return res, nil
916}
917
918func parseConnectParams(dsn string) (connectParams, error) {
919	var p connectParams
920
921	var params map[string]string
922	if strings.HasPrefix(dsn, "odbc:") {
923		parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
924		if err != nil {
925			return p, err
926		}
927		params = parameters
928	} else if strings.HasPrefix(dsn, "sqlserver://") {
929		parameters, err := splitConnectionStringURL(dsn)
930		if err != nil {
931			return p, err
932		}
933		params = parameters
934	} else {
935		params = splitConnectionString(dsn)
936	}
937
938	strlog, ok := params["log"]
939	if ok {
940		var err error
941		p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
942		if err != nil {
943			return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
944		}
945	}
946	server := params["server"]
947	parts := strings.SplitN(server, `\`, 2)
948	p.host = parts[0]
949	if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
950		p.host = "localhost"
951	}
952	if len(parts) > 1 {
953		p.instance = parts[1]
954	}
955	p.database = params["database"]
956	p.user = params["user id"]
957	p.password = params["password"]
958
959	p.port = 1433
960	strport, ok := params["port"]
961	if ok {
962		var err error
963		p.port, err = strconv.ParseUint(strport, 10, 16)
964		if err != nil {
965			f := "Invalid tcp port '%v': %v"
966			return p, fmt.Errorf(f, strport, err.Error())
967		}
968	}
969
970	// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
971	// Default packet size remains at 4096 bytes
972	p.packetSize = 4096
973	strpsize, ok := params["packet size"]
974	if ok {
975		var err error
976		psize, err := strconv.ParseUint(strpsize, 0, 16)
977		if err != nil {
978			f := "Invalid packet size '%v': %v"
979			return p, fmt.Errorf(f, strpsize, err.Error())
980		}
981
982		// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
983		// NOTE: Encrypted connections have a maximum size of 16383 bytes.  If you request
984		// a higher packet size, the server will respond with an ENVCHANGE request to
985		// alter the packet size to 16383 bytes.
986		p.packetSize = uint16(psize)
987		if p.packetSize < 512 {
988			p.packetSize = 512
989		} else if p.packetSize > 32767 {
990			p.packetSize = 32767
991		}
992	}
993
994	// https://msdn.microsoft.com/en-us/library/dd341108.aspx
995	//
996	// Do not set a connection timeout. Use Context to manage such things.
997	// Default to zero, but still allow it to be set.
998	if strconntimeout, ok := params["connection timeout"]; ok {
999		timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
1000		if err != nil {
1001			f := "Invalid connection timeout '%v': %v"
1002			return p, fmt.Errorf(f, strconntimeout, err.Error())
1003		}
1004		p.conn_timeout = time.Duration(timeout) * time.Second
1005	}
1006	p.dial_timeout = 15 * time.Second
1007	if strdialtimeout, ok := params["dial timeout"]; ok {
1008		timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
1009		if err != nil {
1010			f := "Invalid dial timeout '%v': %v"
1011			return p, fmt.Errorf(f, strdialtimeout, err.Error())
1012		}
1013		p.dial_timeout = time.Duration(timeout) * time.Second
1014	}
1015
1016	// default keep alive should be 30 seconds according to spec:
1017	// https://msdn.microsoft.com/en-us/library/dd341108.aspx
1018	p.keepAlive = 30 * time.Second
1019	if keepAlive, ok := params["keepalive"]; ok {
1020		timeout, err := strconv.ParseUint(keepAlive, 10, 64)
1021		if err != nil {
1022			f := "Invalid keepAlive value '%s': %s"
1023			return p, fmt.Errorf(f, keepAlive, err.Error())
1024		}
1025		p.keepAlive = time.Duration(timeout) * time.Second
1026	}
1027	encrypt, ok := params["encrypt"]
1028	if ok {
1029		if strings.EqualFold(encrypt, "DISABLE") {
1030			p.disableEncryption = true
1031		} else {
1032			var err error
1033			p.encrypt, err = strconv.ParseBool(encrypt)
1034			if err != nil {
1035				f := "Invalid encrypt '%s': %s"
1036				return p, fmt.Errorf(f, encrypt, err.Error())
1037			}
1038		}
1039	} else {
1040		p.trustServerCertificate = true
1041	}
1042	trust, ok := params["trustservercertificate"]
1043	if ok {
1044		var err error
1045		p.trustServerCertificate, err = strconv.ParseBool(trust)
1046		if err != nil {
1047			f := "Invalid trust server certificate '%s': %s"
1048			return p, fmt.Errorf(f, trust, err.Error())
1049		}
1050	}
1051	p.certificate = params["certificate"]
1052	p.hostInCertificate, ok = params["hostnameincertificate"]
1053	if !ok {
1054		p.hostInCertificate = p.host
1055	}
1056
1057	serverSPN, ok := params["serverspn"]
1058	if ok {
1059		p.serverSPN = serverSPN
1060	} else {
1061		p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
1062	}
1063
1064	workstation, ok := params["workstation id"]
1065	if ok {
1066		p.workstation = workstation
1067	} else {
1068		workstation, err := os.Hostname()
1069		if err == nil {
1070			p.workstation = workstation
1071		}
1072	}
1073
1074	appname, ok := params["app name"]
1075	if !ok {
1076		appname = "go-mssqldb"
1077	}
1078	p.appname = appname
1079
1080	appintent, ok := params["applicationintent"]
1081	if ok {
1082		if appintent == "ReadOnly" {
1083			p.typeFlags |= fReadOnlyIntent
1084		}
1085	}
1086
1087	failOverPartner, ok := params["failoverpartner"]
1088	if ok {
1089		p.failOverPartner = failOverPartner
1090	}
1091
1092	failOverPort, ok := params["failoverport"]
1093	if ok {
1094		var err error
1095		p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
1096		if err != nil {
1097			f := "Invalid tcp port '%v': %v"
1098			return p, fmt.Errorf(f, failOverPort, err.Error())
1099		}
1100	}
1101
1102	return p, nil
1103}
1104
1105type auth interface {
1106	InitialBytes() ([]byte, error)
1107	NextBytes([]byte) ([]byte, error)
1108	Free()
1109}
1110
1111// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
1112// list of IP addresses.  So if there is more than one, try them all and
1113// use the first one that allows a connection.
1114func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
1115	var ips []net.IP
1116	ips, err = net.LookupIP(p.host)
1117	if err != nil {
1118		ip := net.ParseIP(p.host)
1119		if ip == nil {
1120			return nil, err
1121		}
1122		ips = []net.IP{ip}
1123	}
1124	if len(ips) == 1 {
1125		d := c.getDialer(&p)
1126		addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
1127		conn, err = d.DialContext(ctx, "tcp", addr)
1128
1129	} else {
1130		//Try Dials in parallel to avoid waiting for timeouts.
1131		connChan := make(chan net.Conn, len(ips))
1132		errChan := make(chan error, len(ips))
1133		portStr := strconv.Itoa(int(p.port))
1134		for _, ip := range ips {
1135			go func(ip net.IP) {
1136				d := c.getDialer(&p)
1137				addr := net.JoinHostPort(ip.String(), portStr)
1138				conn, err := d.DialContext(ctx, "tcp", addr)
1139				if err == nil {
1140					connChan <- conn
1141				} else {
1142					errChan <- err
1143				}
1144			}(ip)
1145		}
1146		// Wait for either the *first* successful connection, or all the errors
1147	wait_loop:
1148		for i, _ := range ips {
1149			select {
1150			case conn = <-connChan:
1151				// Got a connection to use, close any others
1152				go func(n int) {
1153					for i := 0; i < n; i++ {
1154						select {
1155						case conn := <-connChan:
1156							conn.Close()
1157						case <-errChan:
1158						}
1159					}
1160				}(len(ips) - i - 1)
1161				// Remove any earlier errors we may have collected
1162				err = nil
1163				break wait_loop
1164			case err = <-errChan:
1165			}
1166		}
1167	}
1168	// Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
1169	if conn == nil {
1170		f := "Unable to open tcp connection with host '%v:%v': %v"
1171		return nil, fmt.Errorf(f, p.host, p.port, err.Error())
1172	}
1173	return conn, err
1174}
1175
1176func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
1177	dialCtx := ctx
1178	if p.dial_timeout > 0 {
1179		var cancel func()
1180		dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
1181		defer cancel()
1182	}
1183	// if instance is specified use instance resolution service
1184	if p.instance != "" {
1185		p.instance = strings.ToUpper(p.instance)
1186		d := c.getDialer(&p)
1187		instances, err := getInstances(dialCtx, d, p.host)
1188		if err != nil {
1189			f := "Unable to get instances from Sql Server Browser on host %v: %v"
1190			return nil, fmt.Errorf(f, p.host, err.Error())
1191		}
1192		strport, ok := instances[p.instance]["tcp"]
1193		if !ok {
1194			f := "No instance matching '%v' returned from host '%v'"
1195			return nil, fmt.Errorf(f, p.instance, p.host)
1196		}
1197		p.port, err = strconv.ParseUint(strport, 0, 16)
1198		if err != nil {
1199			f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
1200			return nil, fmt.Errorf(f, strport, err.Error())
1201		}
1202	}
1203
1204initiate_connection:
1205	conn, err := dialConnection(dialCtx, c, p)
1206	if err != nil {
1207		return nil, err
1208	}
1209
1210	toconn := newTimeoutConn(conn, p.conn_timeout)
1211
1212	outbuf := newTdsBuffer(p.packetSize, toconn)
1213	sess := tdsSession{
1214		buf:      outbuf,
1215		log:      log,
1216		logFlags: p.logFlags,
1217	}
1218
1219	instance_buf := []byte(p.instance)
1220	instance_buf = append(instance_buf, 0) // zero terminate instance name
1221	var encrypt byte
1222	if p.disableEncryption {
1223		encrypt = encryptNotSup
1224	} else if p.encrypt {
1225		encrypt = encryptOn
1226	} else {
1227		encrypt = encryptOff
1228	}
1229	fields := map[uint8][]byte{
1230		preloginVERSION:    {0, 0, 0, 0, 0, 0},
1231		preloginENCRYPTION: {encrypt},
1232		preloginINSTOPT:    instance_buf,
1233		preloginTHREADID:   {0, 0, 0, 0},
1234		preloginMARS:       {0}, // MARS disabled
1235	}
1236
1237	err = writePrelogin(outbuf, fields)
1238	if err != nil {
1239		return nil, err
1240	}
1241
1242	fields, err = readPrelogin(outbuf)
1243	if err != nil {
1244		return nil, err
1245	}
1246
1247	encryptBytes, ok := fields[preloginENCRYPTION]
1248	if !ok {
1249		return nil, fmt.Errorf("Encrypt negotiation failed")
1250	}
1251	encrypt = encryptBytes[0]
1252	if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
1253		return nil, fmt.Errorf("Server does not support encryption")
1254	}
1255
1256	if encrypt != encryptNotSup {
1257		var config tls.Config
1258		if p.certificate != "" {
1259			pem, err := ioutil.ReadFile(p.certificate)
1260			if err != nil {
1261				return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
1262			}
1263			certs := x509.NewCertPool()
1264			certs.AppendCertsFromPEM(pem)
1265			config.RootCAs = certs
1266		}
1267		if p.trustServerCertificate {
1268			config.InsecureSkipVerify = true
1269		}
1270		config.ServerName = p.hostInCertificate
1271		// fix for https://github.com/denisenkom/go-mssqldb/issues/166
1272		// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
1273		// while SQL Server seems to expect one TCP segment per encrypted TDS package.
1274		// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
1275		config.DynamicRecordSizingDisabled = true
1276		outbuf.transport = conn
1277		toconn.buf = outbuf
1278		tlsConn := tls.Client(toconn, &config)
1279		err = tlsConn.Handshake()
1280
1281		toconn.buf = nil
1282		outbuf.transport = tlsConn
1283		if err != nil {
1284			return nil, fmt.Errorf("TLS Handshake failed: %v", err)
1285		}
1286		if encrypt == encryptOff {
1287			outbuf.afterFirst = func() {
1288				outbuf.transport = toconn
1289			}
1290		}
1291	}
1292
1293	login := login{
1294		TDSVersion:   verTDS74,
1295		PacketSize:   uint32(outbuf.PackageSize()),
1296		Database:     p.database,
1297		OptionFlags2: fODBC, // to get unlimited TEXTSIZE
1298		HostName:     p.workstation,
1299		ServerName:   p.host,
1300		AppName:      p.appname,
1301		TypeFlags:    p.typeFlags,
1302	}
1303	auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
1304	if auth_ok {
1305		login.SSPI, err = auth.InitialBytes()
1306		if err != nil {
1307			return nil, err
1308		}
1309		login.OptionFlags2 |= fIntSecurity
1310		defer auth.Free()
1311	} else {
1312		login.UserName = p.user
1313		login.Password = p.password
1314	}
1315	err = sendLogin(outbuf, login)
1316	if err != nil {
1317		return nil, err
1318	}
1319
1320	// processing login response
1321	var sspi_msg []byte
1322continue_login:
1323	tokchan := make(chan tokenStruct, 5)
1324	go processResponse(context.Background(), &sess, tokchan, nil)
1325	success := false
1326	for tok := range tokchan {
1327		switch token := tok.(type) {
1328		case sspiMsg:
1329			sspi_msg, err = auth.NextBytes(token)
1330			if err != nil {
1331				return nil, err
1332			}
1333		case loginAckStruct:
1334			success = true
1335			sess.loginAck = token
1336		case error:
1337			return nil, fmt.Errorf("Login error: %s", token.Error())
1338		case doneStruct:
1339			if token.isError() {
1340				return nil, fmt.Errorf("Login error: %s", token.getError())
1341			}
1342		}
1343	}
1344	if sspi_msg != nil {
1345		outbuf.BeginPacket(packSSPIMessage, false)
1346		_, err = outbuf.Write(sspi_msg)
1347		if err != nil {
1348			return nil, err
1349		}
1350		err = outbuf.FinishPacket()
1351		if err != nil {
1352			return nil, err
1353		}
1354		sspi_msg = nil
1355		goto continue_login
1356	}
1357	if !success {
1358		return nil, fmt.Errorf("Login failed")
1359	}
1360	if sess.routedServer != "" {
1361		toconn.Close()
1362		p.host = sess.routedServer
1363		p.port = uint64(sess.routedPort)
1364		goto initiate_connection
1365	}
1366	return &sess, nil
1367}
1368