1 use crate::report::BenchmarkId as InternalBenchmarkId;
2 use crate::Throughput;
3 use std::cell::RefCell;
4 use std::convert::TryFrom;
5 use std::io::{Read, Write};
6 use std::mem::size_of;
7 use std::net::TcpStream;
8 
9 #[derive(Debug)]
10 pub enum MessageError {
11     SerializationError(serde_cbor::Error),
12     IoError(std::io::Error),
13 }
14 impl From<serde_cbor::Error> for MessageError {
from(other: serde_cbor::Error) -> Self15     fn from(other: serde_cbor::Error) -> Self {
16         MessageError::SerializationError(other)
17     }
18 }
19 impl From<std::io::Error> for MessageError {
from(other: std::io::Error) -> Self20     fn from(other: std::io::Error) -> Self {
21         MessageError::IoError(other)
22     }
23 }
24 impl std::fmt::Display for MessageError {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result25     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26         match self {
27             MessageError::SerializationError(error) => write!(
28                 f,
29                 "Failed to serialize or deserialize message to Criterion.rs benchmark:\n{}",
30                 error
31             ),
32             MessageError::IoError(error) => write!(
33                 f,
34                 "Failed to read or write message to Criterion.rs benchmark:\n{}",
35                 error
36             ),
37         }
38     }
39 }
40 impl std::error::Error for MessageError {
source(&self) -> Option<&(dyn std::error::Error + 'static)>41     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
42         match self {
43             MessageError::SerializationError(err) => Some(err),
44             MessageError::IoError(err) => Some(err),
45         }
46     }
47 }
48 
49 // Use str::len as a const fn once we bump MSRV over 1.39.
50 const RUNNER_MAGIC_NUMBER: &str = "cargo-criterion";
51 const RUNNER_HELLO_SIZE: usize = 15 //RUNNER_MAGIC_NUMBER.len() // magic number
52     + (size_of::<u8>() * 3); // version number
53 
54 const BENCHMARK_MAGIC_NUMBER: &str = "Criterion";
55 const BENCHMARK_HELLO_SIZE: usize = 9 //BENCHMARK_MAGIC_NUMBER.len() // magic number
56     + (size_of::<u8>() * 3) // version number
57     + size_of::<u16>() // protocol version
58     + size_of::<u16>(); // protocol format
59 const PROTOCOL_VERSION: u16 = 1;
60 const PROTOCOL_FORMAT: u16 = 1;
61 
62 #[derive(Debug)]
63 struct InnerConnection {
64     socket: TcpStream,
65     receive_buffer: Vec<u8>,
66     send_buffer: Vec<u8>,
67     runner_version: [u8; 3],
68 }
69 impl InnerConnection {
new(mut socket: TcpStream) -> Result<Self, std::io::Error>70     pub fn new(mut socket: TcpStream) -> Result<Self, std::io::Error> {
71         // read the runner-hello
72         let mut hello_buf = [0u8; RUNNER_HELLO_SIZE];
73         socket.read_exact(&mut hello_buf)?;
74         if &hello_buf[0..RUNNER_MAGIC_NUMBER.len()] != RUNNER_MAGIC_NUMBER.as_bytes() {
75             panic!("Not connected to cargo-criterion.");
76         }
77         let i = RUNNER_MAGIC_NUMBER.len();
78         let runner_version = [hello_buf[i], hello_buf[i + 1], hello_buf[i + 2]];
79 
80         info!("Runner version: {:?}", runner_version);
81 
82         // now send the benchmark-hello
83         let mut hello_buf = [0u8; BENCHMARK_HELLO_SIZE];
84         hello_buf[0..BENCHMARK_MAGIC_NUMBER.len()]
85             .copy_from_slice(BENCHMARK_MAGIC_NUMBER.as_bytes());
86         let mut i = BENCHMARK_MAGIC_NUMBER.len();
87         hello_buf[i] = env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap();
88         hello_buf[i + 1] = env!("CARGO_PKG_VERSION_MINOR").parse().unwrap();
89         hello_buf[i + 2] = env!("CARGO_PKG_VERSION_PATCH").parse().unwrap();
90         i += 3;
91         hello_buf[i..i + 2].clone_from_slice(&PROTOCOL_VERSION.to_be_bytes());
92         i += 2;
93         hello_buf[i..i + 2].clone_from_slice(&PROTOCOL_FORMAT.to_be_bytes());
94 
95         socket.write_all(&hello_buf)?;
96 
97         Ok(InnerConnection {
98             socket,
99             receive_buffer: vec![],
100             send_buffer: vec![],
101             runner_version,
102         })
103     }
104 
105     #[allow(dead_code)]
recv(&mut self) -> Result<IncomingMessage, MessageError>106     pub fn recv(&mut self) -> Result<IncomingMessage, MessageError> {
107         let mut length_buf = [0u8; 4];
108         self.socket.read_exact(&mut length_buf)?;
109         let length = u32::from_be_bytes(length_buf);
110         self.receive_buffer.resize(length as usize, 0u8);
111         self.socket.read_exact(&mut self.receive_buffer)?;
112         let value = serde_cbor::from_slice(&self.receive_buffer)?;
113         Ok(value)
114     }
115 
send(&mut self, message: &OutgoingMessage) -> Result<(), MessageError>116     pub fn send(&mut self, message: &OutgoingMessage) -> Result<(), MessageError> {
117         self.send_buffer.truncate(0);
118         serde_cbor::to_writer(&mut self.send_buffer, message)?;
119         let size = u32::try_from(self.send_buffer.len()).unwrap();
120         let length_buf = size.to_be_bytes();
121         self.socket.write_all(&length_buf)?;
122         self.socket.write_all(&self.send_buffer)?;
123         Ok(())
124     }
125 }
126 
127 /// This is really just a holder to allow us to send messages through a shared reference to the
128 /// connection.
129 #[derive(Debug)]
130 pub struct Connection {
131     inner: RefCell<InnerConnection>,
132 }
133 impl Connection {
new(socket: TcpStream) -> Result<Self, std::io::Error>134     pub fn new(socket: TcpStream) -> Result<Self, std::io::Error> {
135         Ok(Connection {
136             inner: RefCell::new(InnerConnection::new(socket)?),
137         })
138     }
139 
140     #[allow(dead_code)]
recv(&self) -> Result<IncomingMessage, MessageError>141     pub fn recv(&self) -> Result<IncomingMessage, MessageError> {
142         self.inner.borrow_mut().recv()
143     }
144 
send(&self, message: &OutgoingMessage) -> Result<(), MessageError>145     pub fn send(&self, message: &OutgoingMessage) -> Result<(), MessageError> {
146         self.inner.borrow_mut().send(message)
147     }
148 
serve_value_formatter( &self, formatter: &dyn crate::measurement::ValueFormatter, ) -> Result<(), MessageError>149     pub fn serve_value_formatter(
150         &self,
151         formatter: &dyn crate::measurement::ValueFormatter,
152     ) -> Result<(), MessageError> {
153         loop {
154             let response = match self.recv()? {
155                 IncomingMessage::FormatValue { value } => OutgoingMessage::FormattedValue {
156                     value: formatter.format_value(value),
157                 },
158                 IncomingMessage::FormatThroughput { value, throughput } => {
159                     OutgoingMessage::FormattedValue {
160                         value: formatter.format_throughput(&throughput, value),
161                     }
162                 }
163                 IncomingMessage::ScaleValues {
164                     typical_value,
165                     mut values,
166                 } => {
167                     let unit = formatter.scale_values(typical_value, &mut values);
168                     OutgoingMessage::ScaledValues {
169                         unit,
170                         scaled_values: values,
171                     }
172                 }
173                 IncomingMessage::ScaleThroughputs {
174                     typical_value,
175                     throughput,
176                     mut values,
177                 } => {
178                     let unit = formatter.scale_throughputs(typical_value, &throughput, &mut values);
179                     OutgoingMessage::ScaledValues {
180                         unit,
181                         scaled_values: values,
182                     }
183                 }
184                 IncomingMessage::ScaleForMachines { mut values } => {
185                     let unit = formatter.scale_for_machines(&mut values);
186                     OutgoingMessage::ScaledValues {
187                         unit,
188                         scaled_values: values,
189                     }
190                 }
191                 IncomingMessage::Continue => break,
192                 _ => panic!(),
193             };
194             self.send(&response)?;
195         }
196         Ok(())
197     }
198 }
199 
200 /// Enum defining the messages we can receive
201 #[derive(Debug, Deserialize)]
202 pub enum IncomingMessage {
203     // Value formatter requests
204     FormatValue {
205         value: f64,
206     },
207     FormatThroughput {
208         value: f64,
209         throughput: Throughput,
210     },
211     ScaleValues {
212         typical_value: f64,
213         values: Vec<f64>,
214     },
215     ScaleThroughputs {
216         typical_value: f64,
217         values: Vec<f64>,
218         throughput: Throughput,
219     },
220     ScaleForMachines {
221         values: Vec<f64>,
222     },
223     Continue,
224 
225     __Other,
226 }
227 
228 /// Enum defining the messages we can send
229 #[derive(Debug, Serialize)]
230 pub enum OutgoingMessage<'a> {
231     BeginningBenchmarkGroup {
232         group: &'a str,
233     },
234     FinishedBenchmarkGroup {
235         group: &'a str,
236     },
237     BeginningBenchmark {
238         id: RawBenchmarkId,
239     },
240     SkippingBenchmark {
241         id: RawBenchmarkId,
242     },
243     Warmup {
244         id: RawBenchmarkId,
245         nanos: f64,
246     },
247     MeasurementStart {
248         id: RawBenchmarkId,
249         sample_count: u64,
250         estimate_ns: f64,
251         iter_count: u64,
252     },
253     MeasurementComplete {
254         id: RawBenchmarkId,
255         iters: &'a [f64],
256         times: &'a [f64],
257         plot_config: PlotConfiguration,
258         sampling_method: SamplingMethod,
259         benchmark_config: BenchmarkConfig,
260     },
261     // value formatter responses
262     FormattedValue {
263         value: String,
264     },
265     ScaledValues {
266         scaled_values: Vec<f64>,
267         unit: &'a str,
268     },
269 }
270 
271 // Also define serializable variants of certain things, either to avoid leaking
272 // serializability into the public interface or because the serialized form
273 // is a bit different from the regular one.
274 
275 #[derive(Debug, Serialize)]
276 pub struct RawBenchmarkId {
277     group_id: String,
278     function_id: Option<String>,
279     value_str: Option<String>,
280     throughput: Vec<Throughput>,
281 }
282 impl From<&InternalBenchmarkId> for RawBenchmarkId {
from(other: &InternalBenchmarkId) -> RawBenchmarkId283     fn from(other: &InternalBenchmarkId) -> RawBenchmarkId {
284         RawBenchmarkId {
285             group_id: other.group_id.clone(),
286             function_id: other.function_id.clone(),
287             value_str: other.value_str.clone(),
288             throughput: other.throughput.iter().cloned().collect(),
289         }
290     }
291 }
292 
293 #[derive(Debug, Serialize)]
294 pub enum AxisScale {
295     Linear,
296     Logarithmic,
297 }
298 impl From<crate::AxisScale> for AxisScale {
from(other: crate::AxisScale) -> Self299     fn from(other: crate::AxisScale) -> Self {
300         match other {
301             crate::AxisScale::Linear => AxisScale::Linear,
302             crate::AxisScale::Logarithmic => AxisScale::Logarithmic,
303         }
304     }
305 }
306 
307 #[derive(Debug, Serialize)]
308 pub struct PlotConfiguration {
309     summary_scale: AxisScale,
310 }
311 impl From<&crate::PlotConfiguration> for PlotConfiguration {
from(other: &crate::PlotConfiguration) -> Self312     fn from(other: &crate::PlotConfiguration) -> Self {
313         PlotConfiguration {
314             summary_scale: other.summary_scale.into(),
315         }
316     }
317 }
318 
319 #[derive(Debug, Serialize)]
320 struct Duration {
321     secs: u64,
322     nanos: u32,
323 }
324 impl From<std::time::Duration> for Duration {
from(other: std::time::Duration) -> Self325     fn from(other: std::time::Duration) -> Self {
326         Duration {
327             secs: other.as_secs(),
328             nanos: other.subsec_nanos(),
329         }
330     }
331 }
332 
333 #[derive(Debug, Serialize)]
334 pub struct BenchmarkConfig {
335     confidence_level: f64,
336     measurement_time: Duration,
337     noise_threshold: f64,
338     nresamples: usize,
339     sample_size: usize,
340     significance_level: f64,
341     warm_up_time: Duration,
342 }
343 impl From<&crate::benchmark::BenchmarkConfig> for BenchmarkConfig {
from(other: &crate::benchmark::BenchmarkConfig) -> Self344     fn from(other: &crate::benchmark::BenchmarkConfig) -> Self {
345         BenchmarkConfig {
346             confidence_level: other.confidence_level,
347             measurement_time: other.measurement_time.into(),
348             noise_threshold: other.noise_threshold,
349             nresamples: other.nresamples,
350             sample_size: other.sample_size,
351             significance_level: other.significance_level,
352             warm_up_time: other.warm_up_time.into(),
353         }
354     }
355 }
356 
357 /// Currently not used; defined for forwards compatibility with cargo-criterion.
358 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
359 pub enum SamplingMethod {
360     Linear,
361     Flat,
362 }
363 impl From<crate::ActualSamplingMode> for SamplingMethod {
from(other: crate::ActualSamplingMode) -> Self364     fn from(other: crate::ActualSamplingMode) -> Self {
365         match other {
366             crate::ActualSamplingMode::Flat => SamplingMethod::Flat,
367             crate::ActualSamplingMode::Linear => SamplingMethod::Linear,
368         }
369     }
370 }
371