1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 #![allow(
8     unknown_lints,
9     renamed_and_removed_lints,
10     clippy::unknown_clippy_lints,
11     clippy::upper_case_acronyms
12 )] // Until we require rust 1.51.
13 
14 use crate::agentio::as_c_void;
15 use crate::err::{Error, Res};
16 use crate::once::OnceResult;
17 use crate::ssl::{PRFileDesc, SSLTimeFunc};
18 
19 use std::boxed::Box;
20 use std::convert::{TryFrom, TryInto};
21 use std::ops::Deref;
22 use std::os::raw::c_void;
23 use std::pin::Pin;
24 use std::time::{Duration, Instant};
25 
26 include!(concat!(env!("OUT_DIR"), "/nspr_time.rs"));
27 
28 experimental_api!(SSL_SetTimeFunc(
29     fd: *mut PRFileDesc,
30     cb: SSLTimeFunc,
31     arg: *mut c_void,
32 ));
33 
34 /// This struct holds the zero time used for converting between `Instant` and `PRTime`.
35 #[derive(Debug)]
36 struct TimeZero {
37     instant: Instant,
38     prtime: PRTime,
39 }
40 
41 impl TimeZero {
42     /// This function sets a baseline from an instance of `Instant`.
43     /// This allows for the possibility that code that uses these APIs will create
44     /// instances of `Instant` before any of this code is run.  If `Instant`s older than
45     /// `BASE_TIME` are used with these conversion functions, they will fail.
46     /// To avoid that, we make sure that this sets the base time using the first value
47     /// it sees if it is in the past.  If it is not, then use `Instant::now()` instead.
baseline(t: Instant) -> Self48     pub fn baseline(t: Instant) -> Self {
49         let now = Instant::now();
50         let prnow = unsafe { PR_Now() };
51 
52         if now <= t {
53             // `t` is in the future, just use `now`.
54             Self {
55                 instant: now,
56                 prtime: prnow,
57             }
58         } else {
59             let elapsed = Interval::from(now.duration_since(now));
60             // An error from these unwrap functions would require
61             // ridiculously long application running time.
62             let prelapsed: PRTime = elapsed.try_into().unwrap();
63             Self {
64                 instant: t,
65                 prtime: prnow.checked_sub(prelapsed).unwrap(),
66             }
67         }
68     }
69 }
70 
71 static mut BASE_TIME: OnceResult<TimeZero> = OnceResult::new();
72 
get_base() -> &'static TimeZero73 fn get_base() -> &'static TimeZero {
74     let f = || TimeZero {
75         instant: Instant::now(),
76         prtime: unsafe { PR_Now() },
77     };
78     unsafe { BASE_TIME.call_once(f) }
79 }
80 
init()81 pub(crate) fn init() {
82     let _ = get_base();
83 }
84 
85 /// Time wraps Instant and provides conversion functions into `PRTime`.
86 #[derive(Clone, Copy, Debug, PartialEq)]
87 pub struct Time {
88     t: Instant,
89 }
90 
91 impl Deref for Time {
92     type Target = Instant;
deref(&self) -> &Self::Target93     fn deref(&self) -> &Self::Target {
94         &self.t
95     }
96 }
97 
98 impl From<Instant> for Time {
99     /// Convert from an Instant into a Time.
from(t: Instant) -> Self100     fn from(t: Instant) -> Self {
101         // Call `TimeZero::baseline(t)` so that time zero can be set.
102         let f = || TimeZero::baseline(t);
103         let _ = unsafe { BASE_TIME.call_once(f) };
104         Self { t }
105     }
106 }
107 
108 impl TryFrom<PRTime> for Time {
109     type Error = Error;
try_from(prtime: PRTime) -> Res<Self>110     fn try_from(prtime: PRTime) -> Res<Self> {
111         let base = get_base();
112         if let Some(delta) = prtime.checked_sub(base.prtime) {
113             let d = Duration::from_micros(delta.try_into()?);
114             base.instant
115                 .checked_add(d)
116                 .map_or(Err(Error::TimeTravelError), |t| Ok(Self { t }))
117         } else {
118             Err(Error::TimeTravelError)
119         }
120     }
121 }
122 
123 impl TryInto<PRTime> for Time {
124     type Error = Error;
try_into(self) -> Res<PRTime>125     fn try_into(self) -> Res<PRTime> {
126         let base = get_base();
127         // TODO(mt) use checked_duration_since when that is available.
128         let delta = self.t.duration_since(base.instant);
129         if let Ok(d) = PRTime::try_from(delta.as_micros()) {
130             d.checked_add(base.prtime).ok_or(Error::TimeTravelError)
131         } else {
132             Err(Error::TimeTravelError)
133         }
134     }
135 }
136 
137 impl From<Time> for Instant {
138     #[must_use]
from(t: Time) -> Self139     fn from(t: Time) -> Self {
140         t.t
141     }
142 }
143 
144 /// Interval wraps Duration and provides conversion functions into `PRTime`.
145 #[derive(Clone, Copy, Debug, PartialEq)]
146 pub struct Interval {
147     d: Duration,
148 }
149 
150 impl Deref for Interval {
151     type Target = Duration;
deref(&self) -> &Self::Target152     fn deref(&self) -> &Self::Target {
153         &self.d
154     }
155 }
156 
157 impl TryFrom<PRTime> for Interval {
158     type Error = Error;
try_from(prtime: PRTime) -> Res<Self>159     fn try_from(prtime: PRTime) -> Res<Self> {
160         Ok(Self {
161             d: Duration::from_micros(u64::try_from(prtime)?),
162         })
163     }
164 }
165 
166 impl From<Duration> for Interval {
from(d: Duration) -> Self167     fn from(d: Duration) -> Self {
168         Self { d }
169     }
170 }
171 
172 impl TryInto<PRTime> for Interval {
173     type Error = Error;
try_into(self) -> Res<PRTime>174     fn try_into(self) -> Res<PRTime> {
175         Ok(PRTime::try_from(self.d.as_micros())?)
176     }
177 }
178 
179 /// `TimeHolder` maintains a `PRTime` value in a form that is accessible to the TLS stack.
180 #[derive(Debug)]
181 pub struct TimeHolder {
182     t: Pin<Box<PRTime>>,
183 }
184 
185 impl TimeHolder {
time_func(arg: *mut c_void) -> PRTime186     unsafe extern "C" fn time_func(arg: *mut c_void) -> PRTime {
187         let p = arg as *const PRTime;
188         *p.as_ref().unwrap()
189     }
190 
bind(&mut self, fd: *mut PRFileDesc) -> Res<()>191     pub fn bind(&mut self, fd: *mut PRFileDesc) -> Res<()> {
192         unsafe { SSL_SetTimeFunc(fd, Some(Self::time_func), as_c_void(&mut self.t)) }
193     }
194 
set(&mut self, t: Instant) -> Res<()>195     pub fn set(&mut self, t: Instant) -> Res<()> {
196         *self.t = Time::from(t).try_into()?;
197         Ok(())
198     }
199 }
200 
201 impl Default for TimeHolder {
default() -> Self202     fn default() -> Self {
203         TimeHolder { t: Box::pin(0) }
204     }
205 }
206 
207 #[cfg(test)]
208 mod test {
209     use super::{get_base, init, Interval, PRTime, Time};
210     use crate::err::Res;
211     use std::convert::{TryFrom, TryInto};
212     use std::time::{Duration, Instant};
213 
214     #[test]
convert_stable()215     fn convert_stable() {
216         init();
217         let now = Time::from(Instant::now());
218         let pr: PRTime = now.try_into().expect("convert to PRTime with truncation");
219         let t2 = Time::try_from(pr).expect("convert to Instant");
220         let pr2: PRTime = t2.try_into().expect("convert to PRTime again");
221         assert_eq!(pr, pr2);
222         let t3 = Time::try_from(pr2).expect("convert to Instant again");
223         assert_eq!(t2, t3);
224     }
225 
226     #[test]
past_time()227     fn past_time() {
228         init();
229         let base = get_base();
230         assert!(Time::try_from(base.prtime - 1).is_err());
231     }
232 
233     #[test]
negative_time()234     fn negative_time() {
235         init();
236         assert!(Time::try_from(-1).is_err());
237     }
238 
239     #[test]
negative_interval()240     fn negative_interval() {
241         init();
242         assert!(Interval::try_from(-1).is_err());
243     }
244 
245     #[test]
246     // We allow replace_consts here because
247     // std::u64::max_value() isn't available
248     // in all of our targets
overflow_interval()249     fn overflow_interval() {
250         init();
251         let interval = Interval::from(Duration::from_micros(u64::max_value()));
252         let res: Res<PRTime> = interval.try_into();
253         assert!(res.is_err());
254     }
255 }
256