1 extern crate winapi;
2 use self::winapi::shared::minwindef::{WORD, DWORD, HMODULE, FARPROC};
3 use self::winapi::shared::ntdef::WCHAR;
4 use self::winapi::shared::winerror;
5 use self::winapi::um::{errhandlingapi, libloaderapi};
6 
7 use util::{ensure_compatible_types, cstr_cow_from_bytes};
8 
9 use std::ffi::{OsStr, OsString};
10 use std::{fmt, io, marker, mem, ptr};
11 use std::os::windows::ffi::{OsStrExt, OsStringExt};
12 use std::sync::atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering};
13 
14 
15 /// A platform-specific equivalent of the cross-platform `Library`.
16 pub struct Library(HMODULE);
17 
18 unsafe impl Send for Library {}
19 // Now, this is sort-of-tricky. MSDN documentation does not really make any claims as to safety of
20 // the Win32 APIs. Sadly, whomever I asked, even current and former Microsoft employees, couldn’t
21 // say for sure, whether the Win32 APIs used to implement `Library` are thread-safe or not.
22 //
23 // My investigation ended up with a question about thread-safety properties of the API involved
24 // being sent to an internal (to MS) general question mailing-list. The conclusion of the mail is
25 // as such:
26 //
27 // * Nobody inside MS (at least out of all the people who have seen the question) knows for
28 //   sure either;
29 // * However, the general consensus between MS developers is that one can rely on the API being
30 //   thread-safe. In case it is not thread-safe it should be considered a bug on the Windows
31 //   part. (NB: bugs filled at https://connect.microsoft.com/ against Windows Server)
32 unsafe impl Sync for Library {}
33 
34 impl Library {
35     /// Find and load a shared library (module).
36     ///
37     /// Corresponds to `LoadLibraryW(filename)`.
38     #[inline]
new<P: AsRef<OsStr>>(filename: P) -> ::Result<Library>39     pub fn new<P: AsRef<OsStr>>(filename: P) -> ::Result<Library> {
40         let wide_filename: Vec<u16> = filename.as_ref().encode_wide().chain(Some(0)).collect();
41         let _guard = ErrorModeGuard::new();
42 
43         let ret = with_get_last_error(|| {
44             // Make sure no winapi calls as a result of drop happen inside this closure, because
45             // otherwise that might change the return value of the GetLastError.
46             let handle = unsafe { libloaderapi::LoadLibraryW(wide_filename.as_ptr()) };
47             if handle.is_null()  {
48                 None
49             } else {
50                 Some(Library(handle))
51             }
52         }).map_err(|e| e.unwrap_or_else(||
53             panic!("LoadLibraryW failed but GetLastError did not report the error")
54         ));
55 
56         drop(wide_filename); // Drop wide_filename here to ensure it doesn’t get moved and dropped
57                              // inside the closure by mistake. See comment inside the closure.
58         ret
59     }
60 
61     /// Get a pointer to function or static variable by symbol name.
62     ///
63     /// The `symbol` may not contain any null bytes, with an exception of last byte. A null
64     /// terminated `symbol` may avoid a string allocation in some cases.
65     ///
66     /// Symbol is interpreted as-is; no mangling is done. This means that symbols like `x::y` are
67     /// most likely invalid.
68     ///
69     /// ## Unsafety
70     ///
71     /// Pointer to a value of arbitrary type is returned. Using a value with wrong type is
72     /// undefined.
get<T>(&self, symbol: &[u8]) -> ::Result<Symbol<T>>73     pub unsafe fn get<T>(&self, symbol: &[u8]) -> ::Result<Symbol<T>> {
74         ensure_compatible_types::<T, FARPROC>();
75         let symbol = try!(cstr_cow_from_bytes(symbol));
76         with_get_last_error(|| {
77             let symbol = libloaderapi::GetProcAddress(self.0, symbol.as_ptr());
78             if symbol.is_null() {
79                 None
80             } else {
81                 Some(Symbol {
82                     pointer: symbol,
83                     pd: marker::PhantomData
84                 })
85             }
86         }).map_err(|e| e.unwrap_or_else(||
87             panic!("GetProcAddress failed but GetLastError did not report the error")
88         ))
89     }
90 
91     /// Get a pointer to function or static variable by ordinal number.
92     ///
93     /// ## Unsafety
94     ///
95     /// Pointer to a value of arbitrary type is returned. Using a value with wrong type is
96     /// undefined.
get_ordinal<T>(&self, ordinal: WORD) -> ::Result<Symbol<T>>97     pub unsafe fn get_ordinal<T>(&self, ordinal: WORD) -> ::Result<Symbol<T>> {
98         ensure_compatible_types::<T, FARPROC>();
99         with_get_last_error(|| {
100             let ordinal = ordinal as usize as *mut _;
101             let symbol = libloaderapi::GetProcAddress(self.0, ordinal);
102             if symbol.is_null() {
103                 None
104             } else {
105                 Some(Symbol {
106                     pointer: symbol,
107                     pd: marker::PhantomData
108                 })
109             }
110         }).map_err(|e| e.unwrap_or_else(||
111             panic!("GetProcAddress failed but GetLastError did not report the error")
112         ))
113     }
114 
115     /// Convert the `Library` to a raw handle.
into_raw(self) -> HMODULE116     pub fn into_raw(self) -> HMODULE {
117         let handle = self.0;
118         mem::forget(self);
119         handle
120     }
121 
122     /// Convert a raw handle to a `Library`.
123     ///
124     /// ## Unsafety
125     ///
126     /// The handle shall be a result of a successful call of `LoadLibraryW` or a
127     /// handle previously returned by the `Library::into_raw` call.
from_raw(handle: HMODULE) -> Library128     pub unsafe fn from_raw(handle: HMODULE) -> Library {
129         Library(handle)
130     }
131 }
132 
133 impl Drop for Library {
drop(&mut self)134     fn drop(&mut self) {
135         with_get_last_error(|| {
136             if unsafe { libloaderapi::FreeLibrary(self.0) == 0 } {
137                 None
138             } else {
139                 Some(())
140             }
141         }).unwrap()
142     }
143 }
144 
145 impl fmt::Debug for Library {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result146     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
147         unsafe {
148             let mut buf: [WCHAR; 1024] = mem::uninitialized();
149             let len = libloaderapi::GetModuleFileNameW(self.0,
150                                                    (&mut buf[..]).as_mut_ptr(), 1024) as usize;
151             if len == 0 {
152                 f.write_str(&format!("Library@{:p}", self.0))
153             } else {
154                 let string: OsString = OsString::from_wide(&buf[..len]);
155                 f.write_str(&format!("Library@{:p} from {:?}", self.0, string))
156             }
157         }
158     }
159 }
160 
161 /// Symbol from a library.
162 ///
163 /// A major difference compared to the cross-platform `Symbol` is that this does not ensure the
164 /// `Symbol` does not outlive `Library` it comes from.
165 pub struct Symbol<T> {
166     pointer: FARPROC,
167     pd: marker::PhantomData<T>
168 }
169 
170 impl<T> Symbol<T> {
171     /// Convert the loaded Symbol into a handle.
into_raw(self) -> FARPROC172     pub fn into_raw(self) -> FARPROC {
173         let pointer = self.pointer;
174         mem::forget(self);
175         pointer
176     }
177 }
178 
179 impl<T> Symbol<Option<T>> {
180     /// Lift Option out of the symbol.
lift_option(self) -> Option<Symbol<T>>181     pub fn lift_option(self) -> Option<Symbol<T>> {
182         if self.pointer.is_null() {
183             None
184         } else {
185             Some(Symbol {
186                 pointer: self.pointer,
187                 pd: marker::PhantomData,
188             })
189         }
190     }
191 }
192 
193 unsafe impl<T: Send> Send for Symbol<T> {}
194 unsafe impl<T: Sync> Sync for Symbol<T> {}
195 
196 impl<T> Clone for Symbol<T> {
clone(&self) -> Symbol<T>197     fn clone(&self) -> Symbol<T> {
198         Symbol { ..*self }
199     }
200 }
201 
202 impl<T> ::std::ops::Deref for Symbol<T> {
203     type Target = T;
deref(&self) -> &T204     fn deref(&self) -> &T {
205         unsafe {
206             // Additional reference level for a dereference on `deref` return value.
207             mem::transmute(&self.pointer)
208         }
209     }
210 }
211 
212 impl<T> fmt::Debug for Symbol<T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result213     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
214         f.write_str(&format!("Symbol@{:p}", self.pointer))
215     }
216 }
217 
218 
219 static USE_ERRORMODE: AtomicBool = ATOMIC_BOOL_INIT;
220 struct ErrorModeGuard(DWORD);
221 
222 impl ErrorModeGuard {
new() -> Option<ErrorModeGuard>223     fn new() -> Option<ErrorModeGuard> {
224         const SEM_FAILCE: DWORD = 1;
225         unsafe {
226             if !USE_ERRORMODE.load(Ordering::Acquire) {
227                 let mut previous_mode = 0;
228                 let success = errhandlingapi::SetThreadErrorMode(SEM_FAILCE, &mut previous_mode) != 0;
229                 if !success && errhandlingapi::GetLastError() == winerror::ERROR_CALL_NOT_IMPLEMENTED {
230                     USE_ERRORMODE.store(true, Ordering::Release);
231                 } else if !success {
232                     // SetThreadErrorMode failed with some other error? How in the world is it
233                     // possible for what is essentially a simple variable swap to fail?
234                     // For now we just ignore the error -- the worst that can happen here is
235                     // the previous mode staying on and user seeing a dialog error on older Windows
236                     // machines.
237                     return None;
238                 } else if previous_mode == SEM_FAILCE {
239                     return None;
240                 } else {
241                     return Some(ErrorModeGuard(previous_mode));
242                 }
243             }
244             match errhandlingapi::SetErrorMode(SEM_FAILCE) {
245                 SEM_FAILCE => {
246                     // This is important to reduce racy-ness when this library is used on multiple
247                     // threads. In particular this helps with following race condition:
248                     //
249                     // T1: SetErrorMode(SEM_FAILCE)
250                     // T2: SetErrorMode(SEM_FAILCE)
251                     // T1: SetErrorMode(old_mode) # not SEM_FAILCE
252                     // T2: SetErrorMode(SEM_FAILCE) # restores to SEM_FAILCE on drop
253                     //
254                     // This is still somewhat racy in a sense that T1 might restore the error
255                     // mode before T2 finishes loading the library, but that is less of a
256                     // concern – it will only end up in end user seeing a dialog.
257                     //
258                     // Also, SetErrorMode itself is probably not an atomic operation.
259                     None
260                 }
261                 a => Some(ErrorModeGuard(a))
262             }
263         }
264     }
265 }
266 
267 impl Drop for ErrorModeGuard {
drop(&mut self)268     fn drop(&mut self) {
269         unsafe {
270             if !USE_ERRORMODE.load(Ordering::Relaxed) {
271                 errhandlingapi::SetThreadErrorMode(self.0, ptr::null_mut());
272             } else {
273                 errhandlingapi::SetErrorMode(self.0);
274             }
275         }
276     }
277 }
278 
with_get_last_error<T, F>(closure: F) -> Result<T, Option<io::Error>> where F: FnOnce() -> Option<T>279 fn with_get_last_error<T, F>(closure: F) -> Result<T, Option<io::Error>>
280 where F: FnOnce() -> Option<T> {
281     closure().ok_or_else(|| {
282         let error = unsafe { errhandlingapi::GetLastError() };
283         if error == 0 {
284             None
285         } else {
286             Some(io::Error::from_raw_os_error(error as i32))
287         }
288     })
289 }
290 
291 #[test]
works_getlasterror()292 fn works_getlasterror() {
293     let lib = Library::new("kernel32.dll").unwrap();
294     let gle: Symbol<unsafe extern "system" fn() -> DWORD> = unsafe {
295         lib.get(b"GetLastError").unwrap()
296     };
297     unsafe {
298         errhandlingapi::SetLastError(42);
299         assert_eq!(errhandlingapi::GetLastError(), gle())
300     }
301 }
302 
303 #[test]
works_getlasterror0()304 fn works_getlasterror0() {
305     let lib = Library::new("kernel32.dll").unwrap();
306     let gle: Symbol<unsafe extern "system" fn() -> DWORD> = unsafe {
307         lib.get(b"GetLastError\0").unwrap()
308     };
309     unsafe {
310         errhandlingapi::SetLastError(42);
311         assert_eq!(errhandlingapi::GetLastError(), gle())
312     }
313 }
314