1package mssql
2
3import (
4	"fmt"
5	"strings"
6	"syscall"
7	"unsafe"
8)
9
10var (
11	secur32_dll           = syscall.NewLazyDLL("secur32.dll")
12	initSecurityInterface = secur32_dll.NewProc("InitSecurityInterfaceW")
13	sec_fn                *SecurityFunctionTable
14)
15
16func init() {
17	ptr, _, _ := initSecurityInterface.Call()
18	sec_fn = (*SecurityFunctionTable)(unsafe.Pointer(ptr))
19}
20
21const (
22	SEC_E_OK                        = 0
23	SECPKG_CRED_OUTBOUND            = 2
24	SEC_WINNT_AUTH_IDENTITY_UNICODE = 2
25	ISC_REQ_DELEGATE                = 0x00000001
26	ISC_REQ_REPLAY_DETECT           = 0x00000004
27	ISC_REQ_SEQUENCE_DETECT         = 0x00000008
28	ISC_REQ_CONFIDENTIALITY         = 0x00000010
29	ISC_REQ_CONNECTION              = 0x00000800
30	SECURITY_NETWORK_DREP           = 0
31	SEC_I_CONTINUE_NEEDED           = 0x00090312
32	SEC_I_COMPLETE_NEEDED           = 0x00090313
33	SEC_I_COMPLETE_AND_CONTINUE     = 0x00090314
34	SECBUFFER_VERSION               = 0
35	SECBUFFER_TOKEN                 = 2
36	NTLMBUF_LEN                     = 12000
37)
38
39const ISC_REQ = ISC_REQ_CONFIDENTIALITY |
40	ISC_REQ_REPLAY_DETECT |
41	ISC_REQ_SEQUENCE_DETECT |
42	ISC_REQ_CONNECTION |
43	ISC_REQ_DELEGATE
44
45type SecurityFunctionTable struct {
46	dwVersion                  uint32
47	EnumerateSecurityPackages  uintptr
48	QueryCredentialsAttributes uintptr
49	AcquireCredentialsHandle   uintptr
50	FreeCredentialsHandle      uintptr
51	Reserved2                  uintptr
52	InitializeSecurityContext  uintptr
53	AcceptSecurityContext      uintptr
54	CompleteAuthToken          uintptr
55	DeleteSecurityContext      uintptr
56	ApplyControlToken          uintptr
57	QueryContextAttributes     uintptr
58	ImpersonateSecurityContext uintptr
59	RevertSecurityContext      uintptr
60	MakeSignature              uintptr
61	VerifySignature            uintptr
62	FreeContextBuffer          uintptr
63	QuerySecurityPackageInfo   uintptr
64	Reserved3                  uintptr
65	Reserved4                  uintptr
66	Reserved5                  uintptr
67	Reserved6                  uintptr
68	Reserved7                  uintptr
69	Reserved8                  uintptr
70	QuerySecurityContextToken  uintptr
71	EncryptMessage             uintptr
72	DecryptMessage             uintptr
73}
74
75type SEC_WINNT_AUTH_IDENTITY struct {
76	User           *uint16
77	UserLength     uint32
78	Domain         *uint16
79	DomainLength   uint32
80	Password       *uint16
81	PasswordLength uint32
82	Flags          uint32
83}
84
85type TimeStamp struct {
86	LowPart  uint32
87	HighPart int32
88}
89
90type SecHandle struct {
91	dwLower uintptr
92	dwUpper uintptr
93}
94
95type SecBuffer struct {
96	cbBuffer   uint32
97	BufferType uint32
98	pvBuffer   *byte
99}
100
101type SecBufferDesc struct {
102	ulVersion uint32
103	cBuffers  uint32
104	pBuffers  *SecBuffer
105}
106
107type SSPIAuth struct {
108	Domain   string
109	UserName string
110	Password string
111	Service  string
112	cred     SecHandle
113	ctxt     SecHandle
114}
115
116func getAuth(user, password, service, workstation string) (auth, bool) {
117	if user == "" {
118		return &SSPIAuth{Service: service}, true
119	}
120	if !strings.ContainsRune(user, '\\') {
121		return nil, false
122	}
123	domain_user := strings.SplitN(user, "\\", 2)
124	return &SSPIAuth{
125		Domain:   domain_user[0],
126		UserName: domain_user[1],
127		Password: password,
128		Service:  service,
129	}, true
130}
131
132func (auth *SSPIAuth) InitialBytes() ([]byte, error) {
133	var identity *SEC_WINNT_AUTH_IDENTITY
134	if auth.UserName != "" {
135		identity = &SEC_WINNT_AUTH_IDENTITY{
136			Flags:          SEC_WINNT_AUTH_IDENTITY_UNICODE,
137			Password:       syscall.StringToUTF16Ptr(auth.Password),
138			PasswordLength: uint32(len(auth.Password)),
139			Domain:         syscall.StringToUTF16Ptr(auth.Domain),
140			DomainLength:   uint32(len(auth.Domain)),
141			User:           syscall.StringToUTF16Ptr(auth.UserName),
142			UserLength:     uint32(len(auth.UserName)),
143		}
144	}
145	var ts TimeStamp
146	sec_ok, _, _ := syscall.Syscall9(sec_fn.AcquireCredentialsHandle,
147		9,
148		0,
149		uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr("Negotiate"))),
150		SECPKG_CRED_OUTBOUND,
151		0,
152		uintptr(unsafe.Pointer(identity)),
153		0,
154		0,
155		uintptr(unsafe.Pointer(&auth.cred)),
156		uintptr(unsafe.Pointer(&ts)))
157	if sec_ok != SEC_E_OK {
158		return nil, fmt.Errorf("AcquireCredentialsHandle failed %x", sec_ok)
159	}
160
161	var buf SecBuffer
162	var desc SecBufferDesc
163	desc.ulVersion = SECBUFFER_VERSION
164	desc.cBuffers = 1
165	desc.pBuffers = &buf
166
167	outbuf := make([]byte, NTLMBUF_LEN)
168	buf.cbBuffer = NTLMBUF_LEN
169	buf.BufferType = SECBUFFER_TOKEN
170	buf.pvBuffer = &outbuf[0]
171
172	var attrs uint32
173	sec_ok, _, _ = syscall.Syscall12(sec_fn.InitializeSecurityContext,
174		12,
175		uintptr(unsafe.Pointer(&auth.cred)),
176		0,
177		uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(auth.Service))),
178		ISC_REQ,
179		0,
180		SECURITY_NETWORK_DREP,
181		0,
182		0,
183		uintptr(unsafe.Pointer(&auth.ctxt)),
184		uintptr(unsafe.Pointer(&desc)),
185		uintptr(unsafe.Pointer(&attrs)),
186		uintptr(unsafe.Pointer(&ts)))
187	if sec_ok == SEC_I_COMPLETE_AND_CONTINUE ||
188		sec_ok == SEC_I_COMPLETE_NEEDED {
189		syscall.Syscall6(sec_fn.CompleteAuthToken,
190			2,
191			uintptr(unsafe.Pointer(&auth.ctxt)),
192			uintptr(unsafe.Pointer(&desc)),
193			0, 0, 0, 0)
194	} else if sec_ok != SEC_E_OK &&
195		sec_ok != SEC_I_CONTINUE_NEEDED {
196		syscall.Syscall6(sec_fn.FreeCredentialsHandle,
197			1,
198			uintptr(unsafe.Pointer(&auth.cred)),
199			0, 0, 0, 0, 0)
200		return nil, fmt.Errorf("InitialBytes InitializeSecurityContext failed %x", sec_ok)
201	}
202	return outbuf[:buf.cbBuffer], nil
203}
204
205func (auth *SSPIAuth) NextBytes(bytes []byte) ([]byte, error) {
206	var in_buf, out_buf SecBuffer
207	var in_desc, out_desc SecBufferDesc
208
209	in_desc.ulVersion = SECBUFFER_VERSION
210	in_desc.cBuffers = 1
211	in_desc.pBuffers = &in_buf
212
213	out_desc.ulVersion = SECBUFFER_VERSION
214	out_desc.cBuffers = 1
215	out_desc.pBuffers = &out_buf
216
217	in_buf.BufferType = SECBUFFER_TOKEN
218	in_buf.pvBuffer = &bytes[0]
219	in_buf.cbBuffer = uint32(len(bytes))
220
221	outbuf := make([]byte, NTLMBUF_LEN)
222	out_buf.BufferType = SECBUFFER_TOKEN
223	out_buf.pvBuffer = &outbuf[0]
224	out_buf.cbBuffer = NTLMBUF_LEN
225
226	var attrs uint32
227	var ts TimeStamp
228	sec_ok, _, _ := syscall.Syscall12(sec_fn.InitializeSecurityContext,
229		12,
230		uintptr(unsafe.Pointer(&auth.cred)),
231		uintptr(unsafe.Pointer(&auth.ctxt)),
232		uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(auth.Service))),
233		ISC_REQ,
234		0,
235		SECURITY_NETWORK_DREP,
236		uintptr(unsafe.Pointer(&in_desc)),
237		0,
238		uintptr(unsafe.Pointer(&auth.ctxt)),
239		uintptr(unsafe.Pointer(&out_desc)),
240		uintptr(unsafe.Pointer(&attrs)),
241		uintptr(unsafe.Pointer(&ts)))
242	if sec_ok == SEC_I_COMPLETE_AND_CONTINUE ||
243		sec_ok == SEC_I_COMPLETE_NEEDED {
244		syscall.Syscall6(sec_fn.CompleteAuthToken,
245			2,
246			uintptr(unsafe.Pointer(&auth.ctxt)),
247			uintptr(unsafe.Pointer(&out_desc)),
248			0, 0, 0, 0)
249	} else if sec_ok != SEC_E_OK &&
250		sec_ok != SEC_I_CONTINUE_NEEDED {
251		return nil, fmt.Errorf("NextBytes InitializeSecurityContext failed %x", sec_ok)
252	}
253
254	return outbuf[:out_buf.cbBuffer], nil
255}
256
257func (auth *SSPIAuth) Free() {
258	syscall.Syscall6(sec_fn.DeleteSecurityContext,
259		1,
260		uintptr(unsafe.Pointer(&auth.ctxt)),
261		0, 0, 0, 0, 0)
262	syscall.Syscall6(sec_fn.FreeCredentialsHandle,
263		1,
264		uintptr(unsafe.Pointer(&auth.cred)),
265		0, 0, 0, 0, 0)
266}
267