1// +build windows
2
3package winio
4
5import (
6	"bytes"
7	"encoding/binary"
8	"fmt"
9	"runtime"
10	"sync"
11	"syscall"
12	"unicode/utf16"
13
14	"golang.org/x/sys/windows"
15)
16
17//sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges
18//sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf
19//sys revertToSelf() (err error) = advapi32.RevertToSelf
20//sys openThreadToken(thread syscall.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken
21//sys getCurrentThread() (h syscall.Handle) = GetCurrentThread
22//sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW
23//sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW
24//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW
25
26const (
27	SE_PRIVILEGE_ENABLED = 2
28
29	ERROR_NOT_ALL_ASSIGNED syscall.Errno = 1300
30
31	SeBackupPrivilege  = "SeBackupPrivilege"
32	SeRestorePrivilege = "SeRestorePrivilege"
33)
34
35const (
36	securityAnonymous = iota
37	securityIdentification
38	securityImpersonation
39	securityDelegation
40)
41
42var (
43	privNames     = make(map[string]uint64)
44	privNameMutex sync.Mutex
45)
46
47// PrivilegeError represents an error enabling privileges.
48type PrivilegeError struct {
49	privileges []uint64
50}
51
52func (e *PrivilegeError) Error() string {
53	s := ""
54	if len(e.privileges) > 1 {
55		s = "Could not enable privileges "
56	} else {
57		s = "Could not enable privilege "
58	}
59	for i, p := range e.privileges {
60		if i != 0 {
61			s += ", "
62		}
63		s += `"`
64		s += getPrivilegeName(p)
65		s += `"`
66	}
67	return s
68}
69
70// RunWithPrivilege enables a single privilege for a function call.
71func RunWithPrivilege(name string, fn func() error) error {
72	return RunWithPrivileges([]string{name}, fn)
73}
74
75// RunWithPrivileges enables privileges for a function call.
76func RunWithPrivileges(names []string, fn func() error) error {
77	privileges, err := mapPrivileges(names)
78	if err != nil {
79		return err
80	}
81	runtime.LockOSThread()
82	defer runtime.UnlockOSThread()
83	token, err := newThreadToken()
84	if err != nil {
85		return err
86	}
87	defer releaseThreadToken(token)
88	err = adjustPrivileges(token, privileges, SE_PRIVILEGE_ENABLED)
89	if err != nil {
90		return err
91	}
92	return fn()
93}
94
95func mapPrivileges(names []string) ([]uint64, error) {
96	var privileges []uint64
97	privNameMutex.Lock()
98	defer privNameMutex.Unlock()
99	for _, name := range names {
100		p, ok := privNames[name]
101		if !ok {
102			err := lookupPrivilegeValue("", name, &p)
103			if err != nil {
104				return nil, err
105			}
106			privNames[name] = p
107		}
108		privileges = append(privileges, p)
109	}
110	return privileges, nil
111}
112
113// EnableProcessPrivileges enables privileges globally for the process.
114func EnableProcessPrivileges(names []string) error {
115	return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED)
116}
117
118// DisableProcessPrivileges disables privileges globally for the process.
119func DisableProcessPrivileges(names []string) error {
120	return enableDisableProcessPrivilege(names, 0)
121}
122
123func enableDisableProcessPrivilege(names []string, action uint32) error {
124	privileges, err := mapPrivileges(names)
125	if err != nil {
126		return err
127	}
128
129	p, _ := windows.GetCurrentProcess()
130	var token windows.Token
131	err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token)
132	if err != nil {
133		return err
134	}
135
136	defer token.Close()
137	return adjustPrivileges(token, privileges, action)
138}
139
140func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error {
141	var b bytes.Buffer
142	binary.Write(&b, binary.LittleEndian, uint32(len(privileges)))
143	for _, p := range privileges {
144		binary.Write(&b, binary.LittleEndian, p)
145		binary.Write(&b, binary.LittleEndian, action)
146	}
147	prevState := make([]byte, b.Len())
148	reqSize := uint32(0)
149	success, err := adjustTokenPrivileges(token, false, &b.Bytes()[0], uint32(len(prevState)), &prevState[0], &reqSize)
150	if !success {
151		return err
152	}
153	if err == ERROR_NOT_ALL_ASSIGNED {
154		return &PrivilegeError{privileges}
155	}
156	return nil
157}
158
159func getPrivilegeName(luid uint64) string {
160	var nameBuffer [256]uint16
161	bufSize := uint32(len(nameBuffer))
162	err := lookupPrivilegeName("", &luid, &nameBuffer[0], &bufSize)
163	if err != nil {
164		return fmt.Sprintf("<unknown privilege %d>", luid)
165	}
166
167	var displayNameBuffer [256]uint16
168	displayBufSize := uint32(len(displayNameBuffer))
169	var langID uint32
170	err = lookupPrivilegeDisplayName("", &nameBuffer[0], &displayNameBuffer[0], &displayBufSize, &langID)
171	if err != nil {
172		return fmt.Sprintf("<unknown privilege %s>", string(utf16.Decode(nameBuffer[:bufSize])))
173	}
174
175	return string(utf16.Decode(displayNameBuffer[:displayBufSize]))
176}
177
178func newThreadToken() (windows.Token, error) {
179	err := impersonateSelf(securityImpersonation)
180	if err != nil {
181		return 0, err
182	}
183
184	var token windows.Token
185	err = openThreadToken(getCurrentThread(), syscall.TOKEN_ADJUST_PRIVILEGES|syscall.TOKEN_QUERY, false, &token)
186	if err != nil {
187		rerr := revertToSelf()
188		if rerr != nil {
189			panic(rerr)
190		}
191		return 0, err
192	}
193	return token, nil
194}
195
196func releaseThreadToken(h windows.Token) {
197	err := revertToSelf()
198	if err != nil {
199		panic(err)
200	}
201	h.Close()
202}
203