1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build windows
6
7package sspi
8
9import (
10	"fmt"
11	"syscall"
12	"time"
13	"unsafe"
14)
15
16// TODO: add documentation
17
18type PackageInfo struct {
19	Capabilities uint32
20	Version      uint16
21	RPCID        uint16
22	MaxToken     uint32
23	Name         string
24	Comment      string
25}
26
27func QueryPackageInfo(pkgname string) (*PackageInfo, error) {
28	name, err := syscall.UTF16PtrFromString(pkgname)
29	if err != nil {
30		return nil, err
31	}
32	var pi *SecPkgInfo
33	ret := QuerySecurityPackageInfo(name, &pi)
34	if ret != SEC_E_OK {
35		return nil, ret
36	}
37	defer FreeContextBuffer((*byte)(unsafe.Pointer(pi)))
38
39	return &PackageInfo{
40		Capabilities: pi.Capabilities,
41		Version:      pi.Version,
42		RPCID:        pi.RPCID,
43		MaxToken:     pi.MaxToken,
44		Name:         syscall.UTF16ToString((*[2 << 12]uint16)(unsafe.Pointer(pi.Name))[:]),
45		Comment:      syscall.UTF16ToString((*[2 << 12]uint16)(unsafe.Pointer(pi.Comment))[:]),
46	}, nil
47}
48
49type Credentials struct {
50	Handle CredHandle
51	expiry syscall.Filetime
52}
53
54// AcquireCredentials calls the windows AcquireCredentialsHandle function and
55// returns Credentials containing a security handle that can be used for
56// InitializeSecurityContext or AcceptSecurityContext operations.
57// As a special case, passing an empty string as the principal parameter will
58// pass a null string to the underlying function.
59func AcquireCredentials(principal string, pkgname string, creduse uint32, authdata *byte) (*Credentials, error) {
60	var principalName *uint16
61	if principal != "" {
62		var err error
63		principalName, err = syscall.UTF16PtrFromString(principal)
64		if err != nil {
65			return nil, err
66		}
67	}
68	name, err := syscall.UTF16PtrFromString(pkgname)
69	if err != nil {
70		return nil, err
71	}
72	var c Credentials
73	ret := AcquireCredentialsHandle(principalName, name, creduse, nil, authdata, 0, 0, &c.Handle, &c.expiry)
74	if ret != SEC_E_OK {
75		return nil, ret
76	}
77	return &c, nil
78}
79
80func (c *Credentials) Release() error {
81	if c == nil {
82		return nil
83	}
84	ret := FreeCredentialsHandle(&c.Handle)
85	if ret != SEC_E_OK {
86		return ret
87	}
88	return nil
89}
90
91func (c *Credentials) Expiry() time.Time {
92	return time.Unix(0, c.expiry.Nanoseconds())
93}
94
95// TODO: add functions to display and manage RequestedFlags and EstablishedFlags fields.
96// TODO: maybe get rid of RequestedFlags and EstablishedFlags fields, and replace them with input parameter for New...Context and return value of Update (instead of current bool parameter).
97
98type updateFunc func(c *Context, targname *uint16, h, newh *CtxtHandle, out, in *SecBufferDesc) syscall.Errno
99
100type Context struct {
101	Cred             *Credentials
102	Handle           *CtxtHandle
103	handle           CtxtHandle
104	updFn            updateFunc
105	expiry           syscall.Filetime
106	RequestedFlags   uint32
107	EstablishedFlags uint32
108}
109
110func NewClientContext(cred *Credentials, flags uint32) *Context {
111	return &Context{
112		Cred:           cred,
113		updFn:          initialize,
114		RequestedFlags: flags,
115	}
116}
117
118func NewServerContext(cred *Credentials, flags uint32) *Context {
119	return &Context{
120		Cred:           cred,
121		updFn:          accept,
122		RequestedFlags: flags,
123	}
124}
125
126func initialize(c *Context, targname *uint16, h, newh *CtxtHandle, out, in *SecBufferDesc) syscall.Errno {
127	return InitializeSecurityContext(&c.Cred.Handle, h, targname, c.RequestedFlags,
128		0, SECURITY_NATIVE_DREP, in, 0, newh, out, &c.EstablishedFlags, &c.expiry)
129}
130
131func accept(c *Context, targname *uint16, h, newh *CtxtHandle, out, in *SecBufferDesc) syscall.Errno {
132	return AcceptSecurityContext(&c.Cred.Handle, h, in, c.RequestedFlags,
133		SECURITY_NATIVE_DREP, newh, out, &c.EstablishedFlags, &c.expiry)
134}
135
136func (c *Context) Update(targname *uint16, out, in *SecBufferDesc) syscall.Errno {
137	h := c.Handle
138	if c.Handle == nil {
139		c.Handle = &c.handle
140	}
141	return c.updFn(c, targname, h, c.Handle, out, in)
142}
143
144func (c *Context) Release() error {
145	if c == nil {
146		return nil
147	}
148	ret := DeleteSecurityContext(c.Handle)
149	if ret != SEC_E_OK {
150		return ret
151	}
152	return nil
153}
154
155func (c *Context) Expiry() time.Time {
156	return time.Unix(0, c.expiry.Nanoseconds())
157}
158
159// TODO: add comment to function doco that this "impersonation" is applied to current OS thread.
160func (c *Context) ImpersonateUser() error {
161	ret := ImpersonateSecurityContext(c.Handle)
162	if ret != SEC_E_OK {
163		return ret
164	}
165	return nil
166}
167
168func (c *Context) RevertToSelf() error {
169	ret := RevertSecurityContext(c.Handle)
170	if ret != SEC_E_OK {
171		return ret
172	}
173	return nil
174}
175
176// Sizes queries the context for the sizes used in per-message functions.
177// It returns the maximum token size used in authentication exchanges, the
178// maximum signature size, the preferred integral size of messages, the
179// size of any security trailer, and any error.
180func (c *Context) Sizes() (uint32, uint32, uint32, uint32, error) {
181	var s _SecPkgContext_Sizes
182	ret := QueryContextAttributes(c.Handle, _SECPKG_ATTR_SIZES, (*byte)(unsafe.Pointer(&s)))
183	if ret != SEC_E_OK {
184		return 0, 0, 0, 0, ret
185	}
186	return s.MaxToken, s.MaxSignature, s.BlockSize, s.SecurityTrailer, nil
187}
188
189// VerifyFlags determines if all flags used to construct the context
190// were honored (see NewClientContext).  It should be called after c.Update.
191func (c *Context) VerifyFlags() error {
192	return c.VerifySelectiveFlags(c.RequestedFlags)
193}
194
195// VerifySelectiveFlags determines if the given flags were honored (see NewClientContext).
196// It should be called after c.Update.
197func (c *Context) VerifySelectiveFlags(flags uint32) error {
198	if valid, missing, extra := verifySelectiveFlags(flags, c.RequestedFlags); !valid {
199		return fmt.Errorf("sspi: invalid flags check: desired=%b requested=%b missing=%b extra=%b", flags, c.RequestedFlags, missing, extra)
200	}
201	if valid, missing, extra := verifySelectiveFlags(flags, c.EstablishedFlags); !valid {
202		return fmt.Errorf("sspi: invalid flags: desired=%b established=%b missing=%b extra=%b", flags, c.EstablishedFlags, missing, extra)
203	}
204	return nil
205}
206
207// verifySelectiveFlags determines if all bits requested in flags are set in establishedFlags.
208// missing represents the bits set in flags that are not set in establishedFlags.
209// extra represents the bits set in establishedFlags that are not set in flags.
210// valid is true and missing is zero when establishedFlags has all of the requested flags.
211func verifySelectiveFlags(flags, establishedFlags uint32) (valid bool, missing, extra uint32) {
212	missing = flags&establishedFlags ^ flags
213	extra = flags | establishedFlags ^ flags
214	valid = missing == 0
215	return valid, missing, extra
216}
217
218// NewSecBufferDesc returns an initialized SecBufferDesc describing the
219// provided SecBuffer.
220func NewSecBufferDesc(b []SecBuffer) *SecBufferDesc {
221	return &SecBufferDesc{
222		Version:      SECBUFFER_VERSION,
223		BuffersCount: uint32(len(b)),
224		Buffers:      &b[0],
225	}
226}
227