1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 #![allow(clippy::upper_case_acronyms)]
8 
9 use crate::agentio::as_c_void;
10 use crate::err::{Error, Res};
11 use crate::once::OnceResult;
12 use crate::ssl::{PRFileDesc, SSLTimeFunc};
13 
14 use std::boxed::Box;
15 use std::convert::{TryFrom, TryInto};
16 use std::ops::Deref;
17 use std::os::raw::c_void;
18 use std::pin::Pin;
19 use std::time::{Duration, Instant};
20 
21 include!(concat!(env!("OUT_DIR"), "/nspr_time.rs"));
22 
23 experimental_api!(SSL_SetTimeFunc(
24     fd: *mut PRFileDesc,
25     cb: SSLTimeFunc,
26     arg: *mut c_void,
27 ));
28 
29 /// This struct holds the zero time used for converting between `Instant` and `PRTime`.
30 #[derive(Debug)]
31 struct TimeZero {
32     instant: Instant,
33     prtime: PRTime,
34 }
35 
36 impl TimeZero {
37     /// This function sets a baseline from an instance of `Instant`.
38     /// This allows for the possibility that code that uses these APIs will create
39     /// instances of `Instant` before any of this code is run.  If `Instant`s older than
40     /// `BASE_TIME` are used with these conversion functions, they will fail.
41     /// To avoid that, we make sure that this sets the base time using the first value
42     /// it sees if it is in the past.  If it is not, then use `Instant::now()` instead.
baseline(t: Instant) -> Self43     pub fn baseline(t: Instant) -> Self {
44         let now = Instant::now();
45         let prnow = unsafe { PR_Now() };
46 
47         if now <= t {
48             // `t` is in the future, just use `now`.
49             Self {
50                 instant: now,
51                 prtime: prnow,
52             }
53         } else {
54             let elapsed = Interval::from(now.duration_since(now));
55             // An error from these unwrap functions would require
56             // ridiculously long application running time.
57             let prelapsed: PRTime = elapsed.try_into().unwrap();
58             Self {
59                 instant: t,
60                 prtime: prnow.checked_sub(prelapsed).unwrap(),
61             }
62         }
63     }
64 }
65 
66 static mut BASE_TIME: OnceResult<TimeZero> = OnceResult::new();
67 
get_base() -> &'static TimeZero68 fn get_base() -> &'static TimeZero {
69     let f = || TimeZero {
70         instant: Instant::now(),
71         prtime: unsafe { PR_Now() },
72     };
73     unsafe { BASE_TIME.call_once(f) }
74 }
75 
init()76 pub(crate) fn init() {
77     let _ = get_base();
78 }
79 
80 /// Time wraps Instant and provides conversion functions into `PRTime`.
81 #[derive(Clone, Copy, Debug, PartialEq)]
82 pub struct Time {
83     t: Instant,
84 }
85 
86 impl Deref for Time {
87     type Target = Instant;
deref(&self) -> &Self::Target88     fn deref(&self) -> &Self::Target {
89         &self.t
90     }
91 }
92 
93 impl From<Instant> for Time {
94     /// Convert from an Instant into a Time.
from(t: Instant) -> Self95     fn from(t: Instant) -> Self {
96         // Call `TimeZero::baseline(t)` so that time zero can be set.
97         let f = || TimeZero::baseline(t);
98         let _ = unsafe { BASE_TIME.call_once(f) };
99         Self { t }
100     }
101 }
102 
103 impl TryFrom<PRTime> for Time {
104     type Error = Error;
try_from(prtime: PRTime) -> Res<Self>105     fn try_from(prtime: PRTime) -> Res<Self> {
106         let base = get_base();
107         if let Some(delta) = prtime.checked_sub(base.prtime) {
108             let d = Duration::from_micros(delta.try_into()?);
109             base.instant
110                 .checked_add(d)
111                 .map_or(Err(Error::TimeTravelError), |t| Ok(Self { t }))
112         } else {
113             Err(Error::TimeTravelError)
114         }
115     }
116 }
117 
118 impl TryInto<PRTime> for Time {
119     type Error = Error;
try_into(self) -> Res<PRTime>120     fn try_into(self) -> Res<PRTime> {
121         let base = get_base();
122         // TODO(mt) use checked_duration_since when that is available.
123         let delta = self.t.duration_since(base.instant);
124         if let Ok(d) = PRTime::try_from(delta.as_micros()) {
125             d.checked_add(base.prtime).ok_or(Error::TimeTravelError)
126         } else {
127             Err(Error::TimeTravelError)
128         }
129     }
130 }
131 
132 impl From<Time> for Instant {
133     #[must_use]
from(t: Time) -> Self134     fn from(t: Time) -> Self {
135         t.t
136     }
137 }
138 
139 /// Interval wraps Duration and provides conversion functions into `PRTime`.
140 #[derive(Clone, Copy, Debug, PartialEq)]
141 pub struct Interval {
142     d: Duration,
143 }
144 
145 impl Deref for Interval {
146     type Target = Duration;
deref(&self) -> &Self::Target147     fn deref(&self) -> &Self::Target {
148         &self.d
149     }
150 }
151 
152 impl TryFrom<PRTime> for Interval {
153     type Error = Error;
try_from(prtime: PRTime) -> Res<Self>154     fn try_from(prtime: PRTime) -> Res<Self> {
155         Ok(Self {
156             d: Duration::from_micros(u64::try_from(prtime)?),
157         })
158     }
159 }
160 
161 impl From<Duration> for Interval {
from(d: Duration) -> Self162     fn from(d: Duration) -> Self {
163         Self { d }
164     }
165 }
166 
167 impl TryInto<PRTime> for Interval {
168     type Error = Error;
try_into(self) -> Res<PRTime>169     fn try_into(self) -> Res<PRTime> {
170         Ok(PRTime::try_from(self.d.as_micros())?)
171     }
172 }
173 
174 /// `TimeHolder` maintains a `PRTime` value in a form that is accessible to the TLS stack.
175 #[derive(Debug)]
176 pub struct TimeHolder {
177     t: Pin<Box<PRTime>>,
178 }
179 
180 impl TimeHolder {
time_func(arg: *mut c_void) -> PRTime181     unsafe extern "C" fn time_func(arg: *mut c_void) -> PRTime {
182         let p = arg as *const PRTime;
183         *p.as_ref().unwrap()
184     }
185 
bind(&mut self, fd: *mut PRFileDesc) -> Res<()>186     pub fn bind(&mut self, fd: *mut PRFileDesc) -> Res<()> {
187         unsafe { SSL_SetTimeFunc(fd, Some(Self::time_func), as_c_void(&mut self.t)) }
188     }
189 
set(&mut self, t: Instant) -> Res<()>190     pub fn set(&mut self, t: Instant) -> Res<()> {
191         *self.t = Time::from(t).try_into()?;
192         Ok(())
193     }
194 }
195 
196 impl Default for TimeHolder {
default() -> Self197     fn default() -> Self {
198         TimeHolder { t: Box::pin(0) }
199     }
200 }
201 
202 #[cfg(test)]
203 mod test {
204     use super::{get_base, init, Interval, PRTime, Time};
205     use crate::err::Res;
206     use std::convert::{TryFrom, TryInto};
207     use std::time::{Duration, Instant};
208 
209     #[test]
convert_stable()210     fn convert_stable() {
211         init();
212         let now = Time::from(Instant::now());
213         let pr: PRTime = now.try_into().expect("convert to PRTime with truncation");
214         let t2 = Time::try_from(pr).expect("convert to Instant");
215         let pr2: PRTime = t2.try_into().expect("convert to PRTime again");
216         assert_eq!(pr, pr2);
217         let t3 = Time::try_from(pr2).expect("convert to Instant again");
218         assert_eq!(t2, t3);
219     }
220 
221     #[test]
past_time()222     fn past_time() {
223         init();
224         let base = get_base();
225         assert!(Time::try_from(base.prtime - 1).is_err());
226     }
227 
228     #[test]
negative_time()229     fn negative_time() {
230         init();
231         assert!(Time::try_from(-1).is_err());
232     }
233 
234     #[test]
negative_interval()235     fn negative_interval() {
236         init();
237         assert!(Interval::try_from(-1).is_err());
238     }
239 
240     #[test]
241     // We allow replace_consts here because
242     // std::u64::max_value() isn't available
243     // in all of our targets
overflow_interval()244     fn overflow_interval() {
245         init();
246         let interval = Interval::from(Duration::from_micros(u64::max_value()));
247         let res: Res<PRTime> = interval.try_into();
248         assert!(res.is_err());
249     }
250 }
251