1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package windows
6
7import (
8	"sync"
9	"sync/atomic"
10	"syscall"
11	"unsafe"
12)
13
14// We need to use LoadLibrary and GetProcAddress from the Go runtime, because
15// the these symbols are loaded by the system linker and are required to
16// dynamically load additional symbols. Note that in the Go runtime, these
17// return syscall.Handle and syscall.Errno, but these are the same, in fact,
18// as windows.Handle and windows.Errno, and we intend to keep these the same.
19
20//go:linkname syscall_loadlibrary syscall.loadlibrary
21func syscall_loadlibrary(filename *uint16) (handle Handle, err Errno)
22
23//go:linkname syscall_getprocaddress syscall.getprocaddress
24func syscall_getprocaddress(handle Handle, procname *uint8) (proc uintptr, err Errno)
25
26// DLLError describes reasons for DLL load failures.
27type DLLError struct {
28	Err     error
29	ObjName string
30	Msg     string
31}
32
33func (e *DLLError) Error() string { return e.Msg }
34
35// A DLL implements access to a single DLL.
36type DLL struct {
37	Name   string
38	Handle Handle
39}
40
41// LoadDLL loads DLL file into memory.
42//
43// Warning: using LoadDLL without an absolute path name is subject to
44// DLL preloading attacks. To safely load a system DLL, use LazyDLL
45// with System set to true, or use LoadLibraryEx directly.
46func LoadDLL(name string) (dll *DLL, err error) {
47	namep, err := UTF16PtrFromString(name)
48	if err != nil {
49		return nil, err
50	}
51	h, e := syscall_loadlibrary(namep)
52	if e != 0 {
53		return nil, &DLLError{
54			Err:     e,
55			ObjName: name,
56			Msg:     "Failed to load " + name + ": " + e.Error(),
57		}
58	}
59	d := &DLL{
60		Name:   name,
61		Handle: h,
62	}
63	return d, nil
64}
65
66// MustLoadDLL is like LoadDLL but panics if load operation failes.
67func MustLoadDLL(name string) *DLL {
68	d, e := LoadDLL(name)
69	if e != nil {
70		panic(e)
71	}
72	return d
73}
74
75// FindProc searches DLL d for procedure named name and returns *Proc
76// if found. It returns an error if search fails.
77func (d *DLL) FindProc(name string) (proc *Proc, err error) {
78	namep, err := BytePtrFromString(name)
79	if err != nil {
80		return nil, err
81	}
82	a, e := syscall_getprocaddress(d.Handle, namep)
83	if e != 0 {
84		return nil, &DLLError{
85			Err:     e,
86			ObjName: name,
87			Msg:     "Failed to find " + name + " procedure in " + d.Name + ": " + e.Error(),
88		}
89	}
90	p := &Proc{
91		Dll:  d,
92		Name: name,
93		addr: a,
94	}
95	return p, nil
96}
97
98// MustFindProc is like FindProc but panics if search fails.
99func (d *DLL) MustFindProc(name string) *Proc {
100	p, e := d.FindProc(name)
101	if e != nil {
102		panic(e)
103	}
104	return p
105}
106
107// Release unloads DLL d from memory.
108func (d *DLL) Release() (err error) {
109	return FreeLibrary(d.Handle)
110}
111
112// A Proc implements access to a procedure inside a DLL.
113type Proc struct {
114	Dll  *DLL
115	Name string
116	addr uintptr
117}
118
119// Addr returns the address of the procedure represented by p.
120// The return value can be passed to Syscall to run the procedure.
121func (p *Proc) Addr() uintptr {
122	return p.addr
123}
124
125//go:uintptrescapes
126
127// Call executes procedure p with arguments a. It will panic, if more than 15 arguments
128// are supplied.
129//
130// The returned error is always non-nil, constructed from the result of GetLastError.
131// Callers must inspect the primary return value to decide whether an error occurred
132// (according to the semantics of the specific function being called) before consulting
133// the error. The error will be guaranteed to contain windows.Errno.
134func (p *Proc) Call(a ...uintptr) (r1, r2 uintptr, lastErr error) {
135	switch len(a) {
136	case 0:
137		return syscall.Syscall(p.Addr(), uintptr(len(a)), 0, 0, 0)
138	case 1:
139		return syscall.Syscall(p.Addr(), uintptr(len(a)), a[0], 0, 0)
140	case 2:
141		return syscall.Syscall(p.Addr(), uintptr(len(a)), a[0], a[1], 0)
142	case 3:
143		return syscall.Syscall(p.Addr(), uintptr(len(a)), a[0], a[1], a[2])
144	case 4:
145		return syscall.Syscall6(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], 0, 0)
146	case 5:
147		return syscall.Syscall6(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], 0)
148	case 6:
149		return syscall.Syscall6(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5])
150	case 7:
151		return syscall.Syscall9(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], 0, 0)
152	case 8:
153		return syscall.Syscall9(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], 0)
154	case 9:
155		return syscall.Syscall9(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8])
156	case 10:
157		return syscall.Syscall12(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], 0, 0)
158	case 11:
159		return syscall.Syscall12(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], 0)
160	case 12:
161		return syscall.Syscall12(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11])
162	case 13:
163		return syscall.Syscall15(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], 0, 0)
164	case 14:
165		return syscall.Syscall15(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], 0)
166	case 15:
167		return syscall.Syscall15(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14])
168	default:
169		panic("Call " + p.Name + " with too many arguments " + itoa(len(a)) + ".")
170	}
171}
172
173// A LazyDLL implements access to a single DLL.
174// It will delay the load of the DLL until the first
175// call to its Handle method or to one of its
176// LazyProc's Addr method.
177type LazyDLL struct {
178	Name string
179
180	// System determines whether the DLL must be loaded from the
181	// Windows System directory, bypassing the normal DLL search
182	// path.
183	System bool
184
185	mu  sync.Mutex
186	dll *DLL // non nil once DLL is loaded
187}
188
189// Load loads DLL file d.Name into memory. It returns an error if fails.
190// Load will not try to load DLL, if it is already loaded into memory.
191func (d *LazyDLL) Load() error {
192	// Non-racy version of:
193	// if d.dll != nil {
194	if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll))) != nil {
195		return nil
196	}
197	d.mu.Lock()
198	defer d.mu.Unlock()
199	if d.dll != nil {
200		return nil
201	}
202
203	// kernel32.dll is special, since it's where LoadLibraryEx comes from.
204	// The kernel already special-cases its name, so it's always
205	// loaded from system32.
206	var dll *DLL
207	var err error
208	if d.Name == "kernel32.dll" {
209		dll, err = LoadDLL(d.Name)
210	} else {
211		dll, err = loadLibraryEx(d.Name, d.System)
212	}
213	if err != nil {
214		return err
215	}
216
217	// Non-racy version of:
218	// d.dll = dll
219	atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll)), unsafe.Pointer(dll))
220	return nil
221}
222
223// mustLoad is like Load but panics if search fails.
224func (d *LazyDLL) mustLoad() {
225	e := d.Load()
226	if e != nil {
227		panic(e)
228	}
229}
230
231// Handle returns d's module handle.
232func (d *LazyDLL) Handle() uintptr {
233	d.mustLoad()
234	return uintptr(d.dll.Handle)
235}
236
237// NewProc returns a LazyProc for accessing the named procedure in the DLL d.
238func (d *LazyDLL) NewProc(name string) *LazyProc {
239	return &LazyProc{l: d, Name: name}
240}
241
242// NewLazyDLL creates new LazyDLL associated with DLL file.
243func NewLazyDLL(name string) *LazyDLL {
244	return &LazyDLL{Name: name}
245}
246
247// NewLazySystemDLL is like NewLazyDLL, but will only
248// search Windows System directory for the DLL if name is
249// a base name (like "advapi32.dll").
250func NewLazySystemDLL(name string) *LazyDLL {
251	return &LazyDLL{Name: name, System: true}
252}
253
254// A LazyProc implements access to a procedure inside a LazyDLL.
255// It delays the lookup until the Addr method is called.
256type LazyProc struct {
257	Name string
258
259	mu   sync.Mutex
260	l    *LazyDLL
261	proc *Proc
262}
263
264// Find searches DLL for procedure named p.Name. It returns
265// an error if search fails. Find will not search procedure,
266// if it is already found and loaded into memory.
267func (p *LazyProc) Find() error {
268	// Non-racy version of:
269	// if p.proc == nil {
270	if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.proc))) == nil {
271		p.mu.Lock()
272		defer p.mu.Unlock()
273		if p.proc == nil {
274			e := p.l.Load()
275			if e != nil {
276				return e
277			}
278			proc, e := p.l.dll.FindProc(p.Name)
279			if e != nil {
280				return e
281			}
282			// Non-racy version of:
283			// p.proc = proc
284			atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.proc)), unsafe.Pointer(proc))
285		}
286	}
287	return nil
288}
289
290// mustFind is like Find but panics if search fails.
291func (p *LazyProc) mustFind() {
292	e := p.Find()
293	if e != nil {
294		panic(e)
295	}
296}
297
298// Addr returns the address of the procedure represented by p.
299// The return value can be passed to Syscall to run the procedure.
300// It will panic if the procedure cannot be found.
301func (p *LazyProc) Addr() uintptr {
302	p.mustFind()
303	return p.proc.Addr()
304}
305
306//go:uintptrescapes
307
308// Call executes procedure p with arguments a. It will panic, if more than 15 arguments
309// are supplied. It will also panic if the procedure cannot be found.
310//
311// The returned error is always non-nil, constructed from the result of GetLastError.
312// Callers must inspect the primary return value to decide whether an error occurred
313// (according to the semantics of the specific function being called) before consulting
314// the error. The error will be guaranteed to contain windows.Errno.
315func (p *LazyProc) Call(a ...uintptr) (r1, r2 uintptr, lastErr error) {
316	p.mustFind()
317	return p.proc.Call(a...)
318}
319
320var canDoSearchSystem32Once struct {
321	sync.Once
322	v bool
323}
324
325func initCanDoSearchSystem32() {
326	// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
327	// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
328	// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
329	// systems that have KB2533623 installed. To determine whether the
330	// flags are available, use GetProcAddress to get the address of the
331	// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
332	// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
333	// flags can be used with LoadLibraryEx."
334	canDoSearchSystem32Once.v = (modkernel32.NewProc("AddDllDirectory").Find() == nil)
335}
336
337func canDoSearchSystem32() bool {
338	canDoSearchSystem32Once.Do(initCanDoSearchSystem32)
339	return canDoSearchSystem32Once.v
340}
341
342func isBaseName(name string) bool {
343	for _, c := range name {
344		if c == ':' || c == '/' || c == '\\' {
345			return false
346		}
347	}
348	return true
349}
350
351// loadLibraryEx wraps the Windows LoadLibraryEx function.
352//
353// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms684179(v=vs.85).aspx
354//
355// If name is not an absolute path, LoadLibraryEx searches for the DLL
356// in a variety of automatic locations unless constrained by flags.
357// See: https://msdn.microsoft.com/en-us/library/ff919712%28VS.85%29.aspx
358func loadLibraryEx(name string, system bool) (*DLL, error) {
359	loadDLL := name
360	var flags uintptr
361	if system {
362		if canDoSearchSystem32() {
363			const LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
364			flags = LOAD_LIBRARY_SEARCH_SYSTEM32
365		} else if isBaseName(name) {
366			// WindowsXP or unpatched Windows machine
367			// trying to load "foo.dll" out of the system
368			// folder, but LoadLibraryEx doesn't support
369			// that yet on their system, so emulate it.
370			systemdir, err := GetSystemDirectory()
371			if err != nil {
372				return nil, err
373			}
374			loadDLL = systemdir + "\\" + name
375		}
376	}
377	h, err := LoadLibraryEx(loadDLL, 0, flags)
378	if err != nil {
379		return nil, err
380	}
381	return &DLL{Name: name, Handle: h}, nil
382}
383
384type errString string
385
386func (s errString) Error() string { return string(s) }
387