1 extern crate winapi;
2 use self::winapi::shared::minwindef::{WORD, DWORD, HMODULE, FARPROC};
3 use self::winapi::shared::ntdef::WCHAR;
4 use self::winapi::shared::winerror;
5 use self::winapi::um::{errhandlingapi, libloaderapi};
6 
7 use util::{ensure_compatible_types, cstr_cow_from_bytes};
8 
9 use std::ffi::{OsStr, OsString};
10 use std::{fmt, io, marker, mem, ptr};
11 use std::os::windows::ffi::{OsStrExt, OsStringExt};
12 use std::sync::atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering};
13 
14 
15 /// A platform-specific equivalent of the cross-platform `Library`.
16 pub struct Library(HMODULE);
17 
18 unsafe impl Send for Library {}
19 // Now, this is sort-of-tricky. MSDN documentation does not really make any claims as to safety of
20 // the Win32 APIs. Sadly, whomever I asked, even current and former Microsoft employees, couldn’t
21 // say for sure, whether the Win32 APIs used to implement `Library` are thread-safe or not.
22 //
23 // My investigation ended up with a question about thread-safety properties of the API involved
24 // being sent to an internal (to MS) general question mailing-list. The conclusion of the mail is
25 // as such:
26 //
27 // * Nobody inside MS (at least out of all the people who have seen the question) knows for
28 //   sure either;
29 // * However, the general consensus between MS developers is that one can rely on the API being
30 //   thread-safe. In case it is not thread-safe it should be considered a bug on the Windows
31 //   part. (NB: bugs filled at https://connect.microsoft.com/ against Windows Server)
32 unsafe impl Sync for Library {}
33 
34 impl Library {
35     /// Find and load a shared library (module).
36     ///
37     /// Corresponds to `LoadLibraryW(filename)`.
38     #[inline]
new<P: AsRef<OsStr>>(filename: P) -> ::Result<Library>39     pub fn new<P: AsRef<OsStr>>(filename: P) -> ::Result<Library> {
40         let wide_filename: Vec<u16> = filename.as_ref().encode_wide().chain(Some(0)).collect();
41         let _guard = ErrorModeGuard::new();
42 
43         let ret = with_get_last_error(|| {
44             // Make sure no winapi calls as a result of drop happen inside this closure, because
45             // otherwise that might change the return value of the GetLastError.
46             let handle = unsafe { libloaderapi::LoadLibraryW(wide_filename.as_ptr()) };
47             if handle.is_null()  {
48                 None
49             } else {
50                 Some(Library(handle))
51             }
52         }).map_err(|e| e.unwrap_or_else(||
53             panic!("LoadLibraryW failed but GetLastError did not report the error")
54         ));
55 
56         drop(wide_filename); // Drop wide_filename here to ensure it doesn’t get moved and dropped
57                              // inside the closure by mistake. See comment inside the closure.
58         ret
59     }
60 
61     /// Get a pointer to function or static variable by symbol name.
62     ///
63     /// The `symbol` may not contain any null bytes, with an exception of last byte. A null
64     /// terminated `symbol` may avoid a string allocation in some cases.
65     ///
66     /// Symbol is interpreted as-is; no mangling is done. This means that symbols like `x::y` are
67     /// most likely invalid.
68     ///
69     /// ## Unsafety
70     ///
71     /// Pointer to a value of arbitrary type is returned. Using a value with wrong type is
72     /// undefined.
get<T>(&self, symbol: &[u8]) -> ::Result<Symbol<T>>73     pub unsafe fn get<T>(&self, symbol: &[u8]) -> ::Result<Symbol<T>> {
74         ensure_compatible_types::<T, FARPROC>();
75         let symbol = try!(cstr_cow_from_bytes(symbol));
76         with_get_last_error(|| {
77             let symbol = libloaderapi::GetProcAddress(self.0, symbol.as_ptr());
78             if symbol.is_null() {
79                 None
80             } else {
81                 Some(Symbol {
82                     pointer: symbol,
83                     pd: marker::PhantomData
84                 })
85             }
86         }).map_err(|e| e.unwrap_or_else(||
87             panic!("GetProcAddress failed but GetLastError did not report the error")
88         ))
89     }
90 
91     /// Get a pointer to function or static variable by ordinal number.
92     ///
93     /// ## Unsafety
94     ///
95     /// Pointer to a value of arbitrary type is returned. Using a value with wrong type is
96     /// undefined.
get_ordinal<T>(&self, ordinal: WORD) -> ::Result<Symbol<T>>97     pub unsafe fn get_ordinal<T>(&self, ordinal: WORD) -> ::Result<Symbol<T>> {
98         ensure_compatible_types::<T, FARPROC>();
99         with_get_last_error(|| {
100             let ordinal = ordinal as usize as *mut _;
101             let symbol = libloaderapi::GetProcAddress(self.0, ordinal);
102             if symbol.is_null() {
103                 None
104             } else {
105                 Some(Symbol {
106                     pointer: symbol,
107                     pd: marker::PhantomData
108                 })
109             }
110         }).map_err(|e| e.unwrap_or_else(||
111             panic!("GetProcAddress failed but GetLastError did not report the error")
112         ))
113     }
114 }
115 
116 impl Drop for Library {
drop(&mut self)117     fn drop(&mut self) {
118         with_get_last_error(|| {
119             if unsafe { libloaderapi::FreeLibrary(self.0) == 0 } {
120                 None
121             } else {
122                 Some(())
123             }
124         }).unwrap()
125     }
126 }
127 
128 impl fmt::Debug for Library {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result129     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
130         unsafe {
131             let mut buf: [WCHAR; 1024] = mem::uninitialized();
132             let len = libloaderapi::GetModuleFileNameW(self.0,
133                                                    (&mut buf[..]).as_mut_ptr(), 1024) as usize;
134             if len == 0 {
135                 f.write_str(&format!("Library@{:p}", self.0))
136             } else {
137                 let string: OsString = OsString::from_wide(&buf[..len]);
138                 f.write_str(&format!("Library@{:p} from {:?}", self.0, string))
139             }
140         }
141     }
142 }
143 
144 /// Symbol from a library.
145 ///
146 /// A major difference compared to the cross-platform `Symbol` is that this does not ensure the
147 /// `Symbol` does not outlive `Library` it comes from.
148 pub struct Symbol<T> {
149     pointer: FARPROC,
150     pd: marker::PhantomData<T>
151 }
152 
153 impl<T> Symbol<Option<T>> {
154     /// Lift Option out of the symbol.
lift_option(self) -> Option<Symbol<T>>155     pub fn lift_option(self) -> Option<Symbol<T>> {
156         if self.pointer.is_null() {
157             None
158         } else {
159             Some(Symbol {
160                 pointer: self.pointer,
161                 pd: marker::PhantomData,
162             })
163         }
164     }
165 }
166 
167 unsafe impl<T: Send> Send for Symbol<T> {}
168 unsafe impl<T: Sync> Sync for Symbol<T> {}
169 
170 impl<T> Clone for Symbol<T> {
clone(&self) -> Symbol<T>171     fn clone(&self) -> Symbol<T> {
172         Symbol { ..*self }
173     }
174 }
175 
176 impl<T> ::std::ops::Deref for Symbol<T> {
177     type Target = T;
deref(&self) -> &T178     fn deref(&self) -> &T {
179         unsafe {
180             // Additional reference level for a dereference on `deref` return value.
181             mem::transmute(&self.pointer)
182         }
183     }
184 }
185 
186 impl<T> fmt::Debug for Symbol<T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result187     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
188         f.write_str(&format!("Symbol@{:p}", self.pointer))
189     }
190 }
191 
192 
193 static USE_ERRORMODE: AtomicBool = ATOMIC_BOOL_INIT;
194 struct ErrorModeGuard(DWORD);
195 
196 impl ErrorModeGuard {
new() -> Option<ErrorModeGuard>197     fn new() -> Option<ErrorModeGuard> {
198         const SEM_FAILCE: DWORD = 1;
199         unsafe {
200             if !USE_ERRORMODE.load(Ordering::Acquire) {
201                 let mut previous_mode = 0;
202                 let success = errhandlingapi::SetThreadErrorMode(SEM_FAILCE, &mut previous_mode) != 0;
203                 if !success && errhandlingapi::GetLastError() == winerror::ERROR_CALL_NOT_IMPLEMENTED {
204                     USE_ERRORMODE.store(true, Ordering::Release);
205                 } else if !success {
206                     // SetThreadErrorMode failed with some other error? How in the world is it
207                     // possible for what is essentially a simple variable swap to fail?
208                     // For now we just ignore the error -- the worst that can happen here is
209                     // the previous mode staying on and user seeing a dialog error on older Windows
210                     // machines.
211                     return None;
212                 } else if previous_mode == SEM_FAILCE {
213                     return None;
214                 } else {
215                     return Some(ErrorModeGuard(previous_mode));
216                 }
217             }
218             match errhandlingapi::SetErrorMode(SEM_FAILCE) {
219                 SEM_FAILCE => {
220                     // This is important to reduce racy-ness when this library is used on multiple
221                     // threads. In particular this helps with following race condition:
222                     //
223                     // T1: SetErrorMode(SEM_FAILCE)
224                     // T2: SetErrorMode(SEM_FAILCE)
225                     // T1: SetErrorMode(old_mode) # not SEM_FAILCE
226                     // T2: SetErrorMode(SEM_FAILCE) # restores to SEM_FAILCE on drop
227                     //
228                     // This is still somewhat racy in a sense that T1 might resture the error
229                     // mode before T2 finishes loading the library, but that is less of a
230                     // concern – it will only end up in end user seeing a dialog.
231                     //
232                     // Also, SetErrorMode itself is probably not an atomic operation.
233                     None
234                 }
235                 a => Some(ErrorModeGuard(a))
236             }
237         }
238     }
239 }
240 
241 impl Drop for ErrorModeGuard {
drop(&mut self)242     fn drop(&mut self) {
243         unsafe {
244             if !USE_ERRORMODE.load(Ordering::Relaxed) {
245                 errhandlingapi::SetThreadErrorMode(self.0, ptr::null_mut());
246             } else {
247                 errhandlingapi::SetErrorMode(self.0);
248             }
249         }
250     }
251 }
252 
with_get_last_error<T, F>(closure: F) -> Result<T, Option<io::Error>> where F: FnOnce() -> Option<T>253 fn with_get_last_error<T, F>(closure: F) -> Result<T, Option<io::Error>>
254 where F: FnOnce() -> Option<T> {
255     closure().ok_or_else(|| {
256         let error = unsafe { errhandlingapi::GetLastError() };
257         if error == 0 {
258             None
259         } else {
260             Some(io::Error::from_raw_os_error(error as i32))
261         }
262     })
263 }
264 
265 #[test]
works_getlasterror()266 fn works_getlasterror() {
267     let lib = Library::new("kernel32.dll").unwrap();
268     let gle: Symbol<unsafe extern "system" fn() -> DWORD> = unsafe {
269         lib.get(b"GetLastError").unwrap()
270     };
271     unsafe {
272         errhandlingapi::SetLastError(42);
273         assert_eq!(errhandlingapi::GetLastError(), gle())
274     }
275 }
276 
277 #[test]
works_getlasterror0()278 fn works_getlasterror0() {
279     let lib = Library::new("kernel32.dll").unwrap();
280     let gle: Symbol<unsafe extern "system" fn() -> DWORD> = unsafe {
281         lib.get(b"GetLastError\0").unwrap()
282     };
283     unsafe {
284         errhandlingapi::SetLastError(42);
285         assert_eq!(errhandlingapi::GetLastError(), gle())
286     }
287 }
288