1 // Licensed under the Apache License, Version 2.0
2 // <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
4 // All files in the project carrying such notice may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 //! Wrap several flavors of Windows error into a `Result`.
8 
9 use std::error::Error;
10 use std::fmt;
11 
12 use winapi::shared::minwindef::DWORD;
13 use winapi::shared::winerror::{
14     ERROR_SUCCESS, FACILITY_WIN32, HRESULT, HRESULT_FROM_WIN32, SUCCEEDED, S_OK,
15 };
16 use winapi::um::errhandlingapi::GetLastError;
17 
18 /// An error code, optionally with information about the failing call.
19 #[derive(Clone, Debug, Eq, PartialEq)]
20 pub struct ErrorAndSource<T: ErrorCode> {
21     code: T,
22     function: Option<&'static str>,
23     file_line: Option<FileLine>,
24 }
25 
26 /// A wrapper for an error code.
27 pub trait ErrorCode:
28     Copy + fmt::Debug + Eq + PartialEq + fmt::Display + Send + Sync + 'static
29 {
30     type InnerT: Copy + Eq + PartialEq;
31 
get(&self) -> Self::InnerT32     fn get(&self) -> Self::InnerT;
33 }
34 
35 impl<T> ErrorAndSource<T>
36 where
37     T: ErrorCode,
38 {
39     /// Get the underlying error code.
code(&self) -> T::InnerT40     pub fn code(&self) -> T::InnerT {
41         self.code.get()
42     }
43 
44     /// Add the name of the failing function to the error.
function(self, function: &'static str) -> Self45     pub fn function(self, function: &'static str) -> Self {
46         Self {
47             function: Some(function),
48             ..self
49         }
50     }
51 
52     /// Get the name of the failing function, if known.
get_function(&self) -> Option<&'static str>53     pub fn get_function(&self) -> Option<&'static str> {
54         self.function
55     }
56 
57     /// Add the source file name and line number of the call to the error.
file_line(self, file: &'static str, line: u32) -> Self58     pub fn file_line(self, file: &'static str, line: u32) -> Self {
59         Self {
60             file_line: Some(FileLine(file, line)),
61             ..self
62         }
63     }
64 
65     /// Get the source file name and line number of the failing call.
get_file_line(&self) -> &Option<FileLine>66     pub fn get_file_line(&self) -> &Option<FileLine> {
67         &self.file_line
68     }
69 }
70 
71 impl<T> fmt::Display for ErrorAndSource<T>
72 where
73     T: ErrorCode,
74 {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>75     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
76         if let Some(function) = self.function {
77             if let Some(ref file_line) = self.file_line {
78                 write!(f, "{} ", file_line)?;
79             }
80 
81             write!(f, "{} ", function)?;
82 
83             write!(f, "error: ")?;
84         }
85 
86         write!(f, "{}", self.code)?;
87 
88         Ok(())
89     }
90 }
91 
92 impl<T> Error for ErrorAndSource<T> where T: ErrorCode {}
93 
94 #[derive(Clone, Debug, Eq, PartialEq)]
95 pub struct FileLine(pub &'static str, pub u32);
96 
97 impl fmt::Display for FileLine {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>98     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
99         write!(f, "{}:{}", self.0, self.1)
100     }
101 }
102 
103 /// A [Win32 error code](https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-erref/18d8fbe8-a967-4f1c-ae50-99ca8e491d2d),
104 /// usually from `GetLastError()`.
105 ///
106 /// Includes optional function name, source file name, and line number. See
107 /// [`ErrorAndSource`](struct.ErrorAndSource.html) for additional methods.
108 pub type Win32Error = ErrorAndSource<Win32ErrorInner>;
109 
110 impl Win32Error {
111     /// Create from an error code.
new(code: DWORD) -> Self112     pub fn new(code: DWORD) -> Self {
113         Win32Error {
114             code: Win32ErrorInner(code),
115             function: None,
116             file_line: None,
117         }
118     }
119 
120     /// Create from `GetLastError()`
get_last_error() -> Self121     pub fn get_last_error() -> Self {
122         Win32Error::new(unsafe { GetLastError() })
123     }
124 }
125 
126 #[doc(hidden)]
127 #[repr(transparent)]
128 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129 pub struct Win32ErrorInner(DWORD);
130 
131 impl ErrorCode for Win32ErrorInner {
132     type InnerT = DWORD;
133 
get(&self) -> DWORD134     fn get(&self) -> DWORD {
135         self.0
136     }
137 }
138 
139 impl fmt::Display for Win32ErrorInner {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>140     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
141         write!(f, "{:#010x}", self.0)
142     }
143 }
144 
145 /// An [HRESULT error code](https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-erref/0642cb2f-2075-4469-918c-4441e69c548a).
146 /// These usually come from COM APIs.
147 ///
148 /// Includes optional function name, source file name, and line number. See
149 /// [`ErrorAndSource`](struct.ErrorAndSource.html) for additional methods.
150 pub type HResult = ErrorAndSource<HResultInner>;
151 
152 impl HResult {
153     /// Create from an `HRESULT`.
new(hr: HRESULT) -> Self154     pub fn new(hr: HRESULT) -> Self {
155         HResult {
156             code: HResultInner(hr),
157             function: None,
158             file_line: None,
159         }
160     }
161 
162     /// Get the result code portion of the `HRESULT`
extract_code(&self) -> HRESULT163     pub fn extract_code(&self) -> HRESULT {
164         // from winerror.h HRESULT_CODE macro
165         self.code.0 & 0xFFFF
166     }
167 
168     /// Get the facility portion of the `HRESULT`
extract_facility(&self) -> HRESULT169     pub fn extract_facility(&self) -> HRESULT {
170         // from winerror.h HRESULT_FACILITY macro
171         (self.code.0 >> 16) & 0x1fff
172     }
173 
174     /// If the `HResult` corresponds to a Win32 error, convert.
175     ///
176     /// Returns the original `HResult` as an error on failure.
try_into_win32_err(self) -> Result<Win32Error, Self>177     pub fn try_into_win32_err(self) -> Result<Win32Error, Self> {
178         let code = if self.code() == S_OK {
179             // Special case, facility is not set.
180             ERROR_SUCCESS
181         } else if self.extract_facility() == FACILITY_WIN32 {
182             self.extract_code() as DWORD
183         } else {
184             return Err(self);
185         };
186 
187         Ok(Win32Error {
188             code: Win32ErrorInner(code),
189             function: self.function,
190             file_line: self.file_line,
191         })
192     }
193 }
194 
195 #[doc(hidden)]
196 #[repr(transparent)]
197 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
198 pub struct HResultInner(HRESULT);
199 
200 impl ErrorCode for HResultInner {
201     type InnerT = HRESULT;
202 
get(&self) -> HRESULT203     fn get(&self) -> HRESULT {
204         self.0
205     }
206 }
207 
208 impl fmt::Display for HResultInner {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>209     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
210         write!(f, "HRESULT {:#010x}", self.0)
211     }
212 }
213 
214 /// Extra functions to work with a `Result<T, ErrorAndSource>`.
215 pub trait ResultExt<T, E> {
216     type Code;
217 
218     /// Add the name of the failing function to the error.
function(self, function: &'static str) -> Self219     fn function(self, function: &'static str) -> Self;
220 
221     /// Add the source file name and line number of the call to the error.
file_line(self, file: &'static str, line: u32) -> Self222     fn file_line(self, file: &'static str, line: u32) -> Self;
223 
224     /// Replace `Err(code)` with `Ok(replacement)`.
allow_err(self, code: Self::Code, replacement: T) -> Self225     fn allow_err(self, code: Self::Code, replacement: T) -> Self;
226 
227     /// Replace `Err(code)` with `Ok(replacement())`.
allow_err_with<F>(self, code: Self::Code, replacement: F) -> Self where F: FnOnce() -> T228     fn allow_err_with<F>(self, code: Self::Code, replacement: F) -> Self
229     where
230         F: FnOnce() -> T;
231 }
232 
233 impl<T, EC> ResultExt<T, ErrorAndSource<EC>> for Result<T, ErrorAndSource<EC>>
234 where
235     EC: ErrorCode,
236 {
237     type Code = EC::InnerT;
238 
function(self, function: &'static str) -> Self239     fn function(self, function: &'static str) -> Self {
240         self.map_err(|e| e.function(function))
241     }
242 
file_line(self, file: &'static str, line: u32) -> Self243     fn file_line(self, file: &'static str, line: u32) -> Self {
244         self.map_err(|e| e.file_line(file, line))
245     }
246 
allow_err(self, code: Self::Code, replacement: T) -> Self247     fn allow_err(self, code: Self::Code, replacement: T) -> Self {
248         self.or_else(|e| {
249             if e.code() == code {
250                 Ok(replacement)
251             } else {
252                 Err(e)
253             }
254         })
255     }
256 
allow_err_with<F>(self, code: Self::Code, replacement: F) -> Self where F: FnOnce() -> T,257     fn allow_err_with<F>(self, code: Self::Code, replacement: F) -> Self
258     where
259         F: FnOnce() -> T,
260     {
261         self.or_else(|e| {
262             if e.code() == code {
263                 Ok(replacement())
264             } else {
265                 Err(e)
266             }
267         })
268     }
269 }
270 
271 impl From<Win32Error> for HResult {
from(win32_error: Win32Error) -> Self272     fn from(win32_error: Win32Error) -> Self {
273         HResult {
274             code: HResultInner(HRESULT_FROM_WIN32(win32_error.code())),
275             function: win32_error.function,
276             file_line: win32_error.file_line,
277         }
278     }
279 }
280 
281 /// Convert an `HRESULT` into a `Result`.
succeeded_or_err(hr: HRESULT) -> Result<HRESULT, HResult>282 pub fn succeeded_or_err(hr: HRESULT) -> Result<HRESULT, HResult> {
283     if !SUCCEEDED(hr) {
284         Err(HResult::new(hr))
285     } else {
286         Ok(hr)
287     }
288 }
289 
290 /// Call a function that returns an `HRESULT`, convert to a `Result`.
291 ///
292 /// The error will be augmented with the name of the function and the file and line number of
293 /// the macro usage.
294 ///
295 /// # Example
296 /// ```no_run
297 /// # extern crate winapi;
298 /// # use std::ptr;
299 /// # use winapi::um::combaseapi::CoUninitialize;
300 /// # use winapi::um::objbase::CoInitialize;
301 /// # use comedy::{check_succeeded, HResult};
302 /// #
303 /// fn coinit() -> Result<(), HResult> {
304 ///     unsafe {
305 ///         check_succeeded!(CoInitialize(ptr::null_mut()))?;
306 ///
307 ///         CoUninitialize();
308 ///     }
309 ///     Ok(())
310 /// }
311 /// ```
312 #[macro_export]
313 macro_rules! check_succeeded {
314     ($f:ident ( $($arg:expr),* )) => {
315         {
316             use $crate::error::ResultExt;
317             $crate::error::succeeded_or_err($f($($arg),*))
318                 .function(stringify!($f))
319                 .file_line(file!(), line!())
320         }
321     };
322 
323     // support for trailing comma in argument list
324     ($f:ident ( $($arg:expr),+ , )) => {
325         $crate::check_succeeded!($f($($arg),+))
326     };
327 }
328 
329 /// Convert an integer return value into a `Result`, using `GetLastError()` if zero.
true_or_last_err<T>(rv: T) -> Result<T, Win32Error> where T: Eq, T: From<bool>,330 pub fn true_or_last_err<T>(rv: T) -> Result<T, Win32Error>
331 where
332     T: Eq,
333     T: From<bool>,
334 {
335     if rv == T::from(false) {
336         Err(Win32Error::get_last_error())
337     } else {
338         Ok(rv)
339     }
340 }
341 
342 /// Call a function that returns a integer, convert to a `Result`, using `GetLastError()` if zero.
343 ///
344 /// The error will be augmented with the name of the function and the file and line number of
345 /// the macro usage.
346 ///
347 /// # Example
348 /// ```no_run
349 /// # extern crate winapi;
350 /// # use winapi::shared::minwindef::BOOL;
351 /// # use winapi::um::fileapi::FlushFileBuffers;
352 /// # use winapi::um::winnt::HANDLE;
353 /// # use comedy::{check_true, Win32Error};
354 /// #
355 /// fn flush(file: HANDLE) -> Result<(), Win32Error> {
356 ///     unsafe {
357 ///         check_true!(FlushFileBuffers(file))?;
358 ///     }
359 ///     Ok(())
360 /// }
361 /// ```
362 #[macro_export]
363 macro_rules! check_true {
364     ($f:ident ( $($arg:expr),* )) => {
365         {
366             use $crate::error::ResultExt;
367             $crate::error::true_or_last_err($f($($arg),*))
368                 .function(stringify!($f))
369                 .file_line(file!(), line!())
370         }
371     };
372 
373     // support for trailing comma in argument list
374     ($f:ident ( $($arg:expr),+ , )) => {
375         $crate::check_true!($f($($arg),+))
376     };
377 }
378