1 // Copyright 2016 Mozilla Foundation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 #[cfg(feature = "dist-client")]
15 pub use self::client::Client;
16 #[cfg(feature = "dist-server")]
17 pub use self::server::Server;
18 #[cfg(feature = "dist-server")]
19 pub use self::server::{
20     ClientAuthCheck, ClientVisibleMsg, Scheduler, ServerAuthCheck, HEARTBEAT_TIMEOUT,
21 };
22 
23 mod common {
24     #[cfg(feature = "dist-client")]
25     use futures::{Future, Stream};
26     use hyperx::header;
27     #[cfg(feature = "dist-server")]
28     use std::collections::HashMap;
29     use std::fmt;
30 
31     use crate::dist;
32 
33     use crate::errors::*;
34     use crate::util::RequestExt;
35 
36     // Note that content-length is necessary due to https://github.com/tiny-http/tiny-http/issues/147
37     pub trait ReqwestRequestBuilderExt: Sized {
bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self>38         fn bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self>;
bytes(self, bytes: Vec<u8>) -> Self39         fn bytes(self, bytes: Vec<u8>) -> Self;
bearer_auth(self, token: String) -> Self40         fn bearer_auth(self, token: String) -> Self;
41     }
42     impl ReqwestRequestBuilderExt for reqwest::RequestBuilder {
bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self>43         fn bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self> {
44             let bytes =
45                 bincode::serialize(bincode).context("Failed to serialize body to bincode")?;
46             Ok(self.bytes(bytes))
47         }
bytes(self, bytes: Vec<u8>) -> Self48         fn bytes(self, bytes: Vec<u8>) -> Self {
49             self.set_header(header::ContentType::octet_stream())
50                 .set_header(header::ContentLength(bytes.len() as u64))
51                 .body(bytes)
52         }
bearer_auth(self, token: String) -> Self53         fn bearer_auth(self, token: String) -> Self {
54             self.set_header(header::Authorization(header::Bearer { token }))
55         }
56     }
57     impl ReqwestRequestBuilderExt for reqwest::r#async::RequestBuilder {
bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self>58         fn bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self> {
59             let bytes =
60                 bincode::serialize(bincode).context("Failed to serialize body to bincode")?;
61             Ok(self.bytes(bytes))
62         }
bytes(self, bytes: Vec<u8>) -> Self63         fn bytes(self, bytes: Vec<u8>) -> Self {
64             self.set_header(header::ContentType::octet_stream())
65                 .set_header(header::ContentLength(bytes.len() as u64))
66                 .body(bytes)
67         }
bearer_auth(self, token: String) -> Self68         fn bearer_auth(self, token: String) -> Self {
69             self.set_header(header::Authorization(header::Bearer { token }))
70         }
71     }
72 
bincode_req<T: serde::de::DeserializeOwned + 'static>( req: reqwest::RequestBuilder, ) -> Result<T>73     pub fn bincode_req<T: serde::de::DeserializeOwned + 'static>(
74         req: reqwest::RequestBuilder,
75     ) -> Result<T> {
76         // Work around tiny_http issue #151 by disabling HTTP pipeline with
77         // `Connection: close`.
78         let mut res = req.set_header(header::Connection::close()).send()?;
79         let status = res.status();
80         let mut body = vec![];
81         res.copy_to(&mut body)
82             .context("error reading response body")?;
83         if !status.is_success() {
84             Err(anyhow!(
85                 "Error {} (Headers={:?}): {}",
86                 status.as_u16(),
87                 res.headers(),
88                 String::from_utf8_lossy(&body)
89             ))
90         } else {
91             bincode::deserialize(&body).map_err(Into::into)
92         }
93     }
94     #[cfg(feature = "dist-client")]
bincode_req_fut<T: serde::de::DeserializeOwned + 'static>( req: reqwest::r#async::RequestBuilder, ) -> SFuture<T>95     pub fn bincode_req_fut<T: serde::de::DeserializeOwned + 'static>(
96         req: reqwest::r#async::RequestBuilder,
97     ) -> SFuture<T> {
98         Box::new(
99             // Work around tiny_http issue #151 by disabling HTTP pipeline with
100             // `Connection: close`.
101             req.set_header(header::Connection::close())
102                 .send()
103                 .map_err(Into::into)
104                 .and_then(|res| {
105                     let status = res.status();
106                     res.into_body()
107                         .concat2()
108                         .map(move |b| (status, b))
109                         .map_err(Into::into)
110                 })
111                 .and_then(|(status, body)| {
112                     if !status.is_success() {
113                         let errmsg = format!(
114                             "Error {}: {}",
115                             status.as_u16(),
116                             String::from_utf8_lossy(&body)
117                         );
118                         if status.is_client_error() {
119                             return f_err(HttpClientError(errmsg));
120                         } else {
121                             return f_err(anyhow!(errmsg));
122                         }
123                     }
124                     match bincode::deserialize(&body) {
125                         Ok(r) => f_ok(r),
126                         Err(e) => f_err(e),
127                     }
128                 }),
129         )
130     }
131 
132     #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
133     #[serde(deny_unknown_fields)]
134     pub struct JobJwt {
135         pub job_id: dist::JobId,
136     }
137 
138     #[derive(Clone, Debug, Serialize, Deserialize)]
139     #[serde(deny_unknown_fields)]
140     pub enum AllocJobHttpResponse {
141         Success {
142             job_alloc: dist::JobAlloc,
143             need_toolchain: bool,
144             cert_digest: Vec<u8>,
145         },
146         Fail {
147             msg: String,
148         },
149     }
150     impl AllocJobHttpResponse {
151         #[cfg(feature = "dist-server")]
from_alloc_job_result( res: dist::AllocJobResult, certs: &HashMap<dist::ServerId, (Vec<u8>, Vec<u8>)>, ) -> Self152         pub fn from_alloc_job_result(
153             res: dist::AllocJobResult,
154             certs: &HashMap<dist::ServerId, (Vec<u8>, Vec<u8>)>,
155         ) -> Self {
156             match res {
157                 dist::AllocJobResult::Success {
158                     job_alloc,
159                     need_toolchain,
160                 } => {
161                     if let Some((digest, _)) = certs.get(&job_alloc.server_id) {
162                         AllocJobHttpResponse::Success {
163                             job_alloc,
164                             need_toolchain,
165                             cert_digest: digest.to_owned(),
166                         }
167                     } else {
168                         AllocJobHttpResponse::Fail {
169                             msg: format!(
170                                 "missing certificates for server {}",
171                                 job_alloc.server_id.addr()
172                             ),
173                         }
174                     }
175                 }
176                 dist::AllocJobResult::Fail { msg } => AllocJobHttpResponse::Fail { msg },
177             }
178         }
179     }
180 
181     #[derive(Clone, Debug, Serialize, Deserialize)]
182     #[serde(deny_unknown_fields)]
183     pub struct ServerCertificateHttpResponse {
184         pub cert_digest: Vec<u8>,
185         pub cert_pem: Vec<u8>,
186     }
187 
188     #[derive(Clone, Serialize, Deserialize)]
189     #[serde(deny_unknown_fields)]
190     pub struct HeartbeatServerHttpRequest {
191         pub jwt_key: Vec<u8>,
192         pub num_cpus: usize,
193         pub server_nonce: dist::ServerNonce,
194         pub cert_digest: Vec<u8>,
195         pub cert_pem: Vec<u8>,
196     }
197     // cert_pem is quite long so elide it (you can retrieve it by hitting the server url anyway)
198     impl fmt::Debug for HeartbeatServerHttpRequest {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result199         fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200             let HeartbeatServerHttpRequest {
201                 jwt_key,
202                 num_cpus,
203                 server_nonce,
204                 cert_digest,
205                 cert_pem,
206             } = self;
207             write!(f, "HeartbeatServerHttpRequest {{ jwt_key: {:?}, num_cpus: {:?}, server_nonce: {:?}, cert_digest: {:?}, cert_pem: [...{} bytes...] }}", jwt_key, num_cpus, server_nonce, cert_digest, cert_pem.len())
208         }
209     }
210     #[derive(Clone, Debug, Serialize, Deserialize)]
211     #[serde(deny_unknown_fields)]
212     pub struct RunJobHttpRequest {
213         pub command: dist::CompileCommand,
214         pub outputs: Vec<String>,
215     }
216 }
217 
218 pub mod urls {
219     use crate::dist::{JobId, ServerId};
220 
scheduler_alloc_job(scheduler_url: &reqwest::Url) -> reqwest::Url221     pub fn scheduler_alloc_job(scheduler_url: &reqwest::Url) -> reqwest::Url {
222         scheduler_url
223             .join("/api/v1/scheduler/alloc_job")
224             .expect("failed to create alloc job url")
225     }
scheduler_server_certificate( scheduler_url: &reqwest::Url, server_id: ServerId, ) -> reqwest::Url226     pub fn scheduler_server_certificate(
227         scheduler_url: &reqwest::Url,
228         server_id: ServerId,
229     ) -> reqwest::Url {
230         scheduler_url
231             .join(&format!(
232                 "/api/v1/scheduler/server_certificate/{}",
233                 server_id.addr()
234             ))
235             .expect("failed to create server certificate url")
236     }
scheduler_heartbeat_server(scheduler_url: &reqwest::Url) -> reqwest::Url237     pub fn scheduler_heartbeat_server(scheduler_url: &reqwest::Url) -> reqwest::Url {
238         scheduler_url
239             .join("/api/v1/scheduler/heartbeat_server")
240             .expect("failed to create heartbeat url")
241     }
scheduler_job_state(scheduler_url: &reqwest::Url, job_id: JobId) -> reqwest::Url242     pub fn scheduler_job_state(scheduler_url: &reqwest::Url, job_id: JobId) -> reqwest::Url {
243         scheduler_url
244             .join(&format!("/api/v1/scheduler/job_state/{}", job_id))
245             .expect("failed to create job state url")
246     }
scheduler_status(scheduler_url: &reqwest::Url) -> reqwest::Url247     pub fn scheduler_status(scheduler_url: &reqwest::Url) -> reqwest::Url {
248         scheduler_url
249             .join("/api/v1/scheduler/status")
250             .expect("failed to create alloc job url")
251     }
252 
server_assign_job(server_id: ServerId, job_id: JobId) -> reqwest::Url253     pub fn server_assign_job(server_id: ServerId, job_id: JobId) -> reqwest::Url {
254         let url = format!(
255             "https://{}/api/v1/distserver/assign_job/{}",
256             server_id.addr(),
257             job_id
258         );
259         reqwest::Url::parse(&url).expect("failed to create assign job url")
260     }
server_submit_toolchain(server_id: ServerId, job_id: JobId) -> reqwest::Url261     pub fn server_submit_toolchain(server_id: ServerId, job_id: JobId) -> reqwest::Url {
262         let url = format!(
263             "https://{}/api/v1/distserver/submit_toolchain/{}",
264             server_id.addr(),
265             job_id
266         );
267         reqwest::Url::parse(&url).expect("failed to create submit toolchain url")
268     }
server_run_job(server_id: ServerId, job_id: JobId) -> reqwest::Url269     pub fn server_run_job(server_id: ServerId, job_id: JobId) -> reqwest::Url {
270         let url = format!(
271             "https://{}/api/v1/distserver/run_job/{}",
272             server_id.addr(),
273             job_id
274         );
275         reqwest::Url::parse(&url).expect("failed to create run job url")
276     }
277 }
278 
279 #[cfg(feature = "dist-server")]
280 mod server {
281     use crate::jwt;
282     use byteorder::{BigEndian, ReadBytesExt};
283     use flate2::read::ZlibDecoder as ZlibReadDecoder;
284     use rand::{rngs::OsRng, RngCore};
285     use rouille::accept;
286     use std::collections::HashMap;
287     use std::io::Read;
288     use std::net::SocketAddr;
289     use std::result::Result as StdResult;
290     use std::sync::atomic;
291     use std::sync::Mutex;
292     use std::thread;
293     use std::time::Duration;
294     use void::Void;
295 
296     use super::common::{
297         bincode_req, AllocJobHttpResponse, HeartbeatServerHttpRequest, JobJwt,
298         ReqwestRequestBuilderExt, RunJobHttpRequest, ServerCertificateHttpResponse,
299     };
300     use super::urls;
301     use crate::dist::{
302         self, AllocJobResult, AssignJobResult, HeartbeatServerResult, InputsReader, JobAuthorizer,
303         JobId, JobState, RunJobResult, SchedulerStatusResult, ServerId, ServerNonce,
304         SubmitToolchainResult, Toolchain, ToolchainReader, UpdateJobStateResult,
305     };
306     use crate::errors::*;
307 
308     const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
309     const HEARTBEAT_ERROR_INTERVAL: Duration = Duration::from_secs(10);
310     pub const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(90);
311 
create_https_cert_and_privkey(addr: SocketAddr) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)>312     fn create_https_cert_and_privkey(addr: SocketAddr) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
313         let rsa_key = openssl::rsa::Rsa::<openssl::pkey::Private>::generate(2048)
314             .context("failed to generate rsa privkey")?;
315         let privkey_pem = rsa_key
316             .private_key_to_pem()
317             .context("failed to create pem from rsa privkey")?;
318         let privkey: openssl::pkey::PKey<openssl::pkey::Private> =
319             openssl::pkey::PKey::from_rsa(rsa_key)
320                 .context("failed to create openssl pkey from rsa privkey")?;
321         let mut builder =
322             openssl::x509::X509::builder().context("failed to create x509 builder")?;
323 
324         // Populate the certificate with the necessary parts, mostly from mkcert in openssl
325         builder
326             .set_version(2)
327             .context("failed to set x509 version")?;
328         let serial_number = openssl::bn::BigNum::from_u32(0)
329             .and_then(|bn| bn.to_asn1_integer())
330             .context("failed to create openssl asn1 0")?;
331         builder
332             .set_serial_number(serial_number.as_ref())
333             .context("failed to set x509 serial number")?;
334         let not_before = openssl::asn1::Asn1Time::days_from_now(0)
335             .context("failed to create openssl not before asn1")?;
336         builder
337             .set_not_before(not_before.as_ref())
338             .context("failed to set not before on x509")?;
339         let not_after = openssl::asn1::Asn1Time::days_from_now(365)
340             .context("failed to create openssl not after asn1")?;
341         builder
342             .set_not_after(not_after.as_ref())
343             .context("failed to set not after on x509")?;
344         builder
345             .set_pubkey(privkey.as_ref())
346             .context("failed to set pubkey for x509")?;
347 
348         let mut name = openssl::x509::X509Name::builder()?;
349         name.append_entry_by_nid(openssl::nid::Nid::COMMONNAME, &addr.to_string())?;
350         let name = name.build();
351 
352         builder
353             .set_subject_name(&name)
354             .context("failed to set subject name")?;
355         builder
356             .set_issuer_name(&name)
357             .context("failed to set issuer name")?;
358 
359         // Add the SubjectAlternativeName
360         let extension = openssl::x509::extension::SubjectAlternativeName::new()
361             .ip(&addr.ip().to_string())
362             .build(&builder.x509v3_context(None, None))
363             .context("failed to build SAN extension for x509")?;
364         builder
365             .append_extension(extension)
366             .context("failed to append SAN extension for x509")?;
367 
368         // Add ExtendedKeyUsage
369         let ext_key_usage = openssl::x509::extension::ExtendedKeyUsage::new()
370             .server_auth()
371             .build()
372             .context("failed to build EKU extension for x509")?;
373         builder
374             .append_extension(ext_key_usage)
375             .context("failes to append EKU extension for x509")?;
376 
377         // Finish the certificate
378         builder
379             .sign(&privkey, openssl::hash::MessageDigest::sha1())
380             .context("failed to sign x509 with sha1")?;
381         let cert: openssl::x509::X509 = builder.build();
382         let cert_pem = cert.to_pem().context("failed to create pem from x509")?;
383         let cert_digest = cert
384             .digest(openssl::hash::MessageDigest::sha256())
385             .context("failed to create digest of x509 certificate")?
386             .as_ref()
387             .to_owned();
388 
389         Ok((cert_digest, cert_pem, privkey_pem))
390     }
391 
392     // Messages that are non-sensitive and can be sent to the client
393     #[derive(Debug)]
394     pub struct ClientVisibleMsg(String);
395     impl ClientVisibleMsg {
from_nonsensitive(s: String) -> Self396         pub fn from_nonsensitive(s: String) -> Self {
397             ClientVisibleMsg(s)
398         }
399     }
400 
401     pub trait ClientAuthCheck: Send + Sync {
check(&self, token: &str) -> StdResult<(), ClientVisibleMsg>402         fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg>;
403     }
404     pub type ServerAuthCheck = Box<dyn Fn(&str) -> Option<ServerId> + Send + Sync>;
405 
406     const JWT_KEY_LENGTH: usize = 256 / 8;
407     lazy_static! {
408         static ref JWT_HEADER: jwt::Header = jwt::Header::new(jwt::Algorithm::HS256);
409         static ref JWT_VALIDATION: jwt::Validation = jwt::Validation {
410             leeway: 0,
411             validate_exp: false,
412             validate_nbf: false,
413             aud: None,
414             iss: None,
415             sub: None,
416             algorithms: vec![jwt::Algorithm::HS256],
417         };
418     }
419 
420     // Based on rouille::input::json::json_input
421     #[derive(Debug)]
422     pub enum RouilleBincodeError {
423         BodyAlreadyExtracted,
424         WrongContentType,
425         ParseError(bincode::Error),
426     }
427     impl From<bincode::Error> for RouilleBincodeError {
from(err: bincode::Error) -> RouilleBincodeError428         fn from(err: bincode::Error) -> RouilleBincodeError {
429             RouilleBincodeError::ParseError(err)
430         }
431     }
432     impl std::error::Error for RouilleBincodeError {
source(&self) -> Option<&(dyn std::error::Error + 'static)>433         fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
434             match *self {
435                 RouilleBincodeError::ParseError(ref e) => Some(e),
436                 _ => None,
437             }
438         }
439     }
440     impl std::fmt::Display for RouilleBincodeError {
fmt( &self, fmt: &mut std::fmt::Formatter<'_>, ) -> std::result::Result<(), std::fmt::Error>441         fn fmt(
442             &self,
443             fmt: &mut std::fmt::Formatter<'_>,
444         ) -> std::result::Result<(), std::fmt::Error> {
445             write!(
446                 fmt,
447                 "{}",
448                 match *self {
449                     RouilleBincodeError::BodyAlreadyExtracted => {
450                         "the body of the request was already extracted"
451                     }
452                     RouilleBincodeError::WrongContentType => {
453                         "the request didn't have a binary content type"
454                     }
455                     RouilleBincodeError::ParseError(_) => "error while parsing the bincode body",
456                 }
457             )
458         }
459     }
bincode_input<O>(request: &rouille::Request) -> std::result::Result<O, RouilleBincodeError> where O: serde::de::DeserializeOwned,460     fn bincode_input<O>(request: &rouille::Request) -> std::result::Result<O, RouilleBincodeError>
461     where
462         O: serde::de::DeserializeOwned,
463     {
464         if let Some(header) = request.header("Content-Type") {
465             if !header.starts_with("application/octet-stream") {
466                 return Err(RouilleBincodeError::WrongContentType);
467             }
468         } else {
469             return Err(RouilleBincodeError::WrongContentType);
470         }
471 
472         if let Some(mut b) = request.data() {
473             bincode::deserialize_from::<_, O>(&mut b).map_err(From::from)
474         } else {
475             Err(RouilleBincodeError::BodyAlreadyExtracted)
476         }
477     }
478 
479     // Based on try_or_400 in rouille, but with logging
480     #[derive(Serialize)]
481     pub struct ErrJson {
482         description: String,
483         cause: Option<Box<ErrJson>>,
484     }
485 
486     impl ErrJson {
from_err<E: ?Sized + std::error::Error>(err: &E) -> ErrJson487         fn from_err<E: ?Sized + std::error::Error>(err: &E) -> ErrJson {
488             let cause = err.source().map(ErrJson::from_err).map(Box::new);
489             ErrJson {
490                 description: err.to_string(),
491                 cause,
492             }
493         }
494 
into_data(self) -> String495         fn into_data(self) -> String {
496             serde_json::to_string(&self).expect("infallible serialization for ErrJson failed")
497         }
498     }
499     macro_rules! try_or_err_and_log {
500         ($reqid:expr, $code:expr, $result:expr) => {
501             match $result {
502                 Ok(r) => r,
503                 Err(err) => {
504                     // TODO: would ideally just use error_chain
505                     #[allow(unused_imports)]
506                     use std::error::Error;
507                     let mut err_msg = err.to_string();
508                     let mut maybe_cause = err.source();
509                     while let Some(cause) = maybe_cause {
510                         err_msg.push_str(", caused by: ");
511                         err_msg.push_str(&cause.to_string());
512                         maybe_cause = cause.source();
513                     }
514 
515                     warn!("Res {} error: {}", $reqid, err_msg);
516                     let err: Box<dyn std::error::Error + 'static> = err.into();
517                     let json = ErrJson::from_err(&*err);
518                     return rouille::Response::json(&json).with_status_code($code);
519                 }
520             }
521         };
522     }
523     macro_rules! try_or_400_log {
524         ($reqid:expr, $result:expr) => {
525             try_or_err_and_log!($reqid, 400, $result)
526         };
527     }
528     macro_rules! try_or_500_log {
529         ($reqid:expr, $result:expr) => {
530             try_or_err_and_log!($reqid, 500, $result)
531         };
532     }
make_401_with_body(short_err: &str, body: ClientVisibleMsg) -> rouille::Response533     fn make_401_with_body(short_err: &str, body: ClientVisibleMsg) -> rouille::Response {
534         rouille::Response {
535             status_code: 401,
536             headers: vec![(
537                 "WWW-Authenticate".into(),
538                 format!("Bearer error=\"{}\"", short_err).into(),
539             )],
540             data: rouille::ResponseBody::from_data(body.0),
541             upgrade: None,
542         }
543     }
make_401(short_err: &str) -> rouille::Response544     fn make_401(short_err: &str) -> rouille::Response {
545         make_401_with_body(short_err, ClientVisibleMsg(String::new()))
546     }
bearer_http_auth(request: &rouille::Request) -> Option<&str>547     fn bearer_http_auth(request: &rouille::Request) -> Option<&str> {
548         let header = request.header("Authorization")?;
549 
550         let mut split = header.splitn(2, |c| c == ' ');
551 
552         let authtype = split.next()?;
553         if authtype != "Bearer" {
554             return None;
555         }
556 
557         split.next()
558     }
559 
560     /// Return `content` as a bincode-encoded `Response`.
bincode_response<T>(content: &T) -> rouille::Response where T: serde::Serialize,561     pub fn bincode_response<T>(content: &T) -> rouille::Response
562     where
563         T: serde::Serialize,
564     {
565         let data = bincode::serialize(content).context("Failed to serialize response body");
566         let data = try_or_500_log!("bincode body serialization", data);
567 
568         rouille::Response {
569             status_code: 200,
570             headers: vec![("Content-Type".into(), "application/octet-stream".into())],
571             data: rouille::ResponseBody::from_data(data),
572             upgrade: None,
573         }
574     }
575 
576     /// Return `content` as either a bincode or json encoded `Response`
577     /// depending on the Accept header in `request`.
prepare_response<T>(request: &rouille::Request, content: &T) -> rouille::Response where T: serde::Serialize,578     pub fn prepare_response<T>(request: &rouille::Request, content: &T) -> rouille::Response
579     where
580         T: serde::Serialize,
581     {
582         accept!(request,
583         "application/octet-stream" => bincode_response(content),
584         "application/json" => rouille::Response::json(content),
585         )
586     }
587 
588     // Verification of job auth in a request
589     macro_rules! job_auth_or_401 {
590         ($request:ident, $job_authorizer:expr, $job_id:expr) => {{
591             let verify_result = match bearer_http_auth($request) {
592                 Some(token) => $job_authorizer.verify_token($job_id, token),
593                 None => Err(anyhow!("no Authorization header")),
594             };
595             match verify_result {
596                 Ok(()) => (),
597                 Err(err) => {
598                     let err: Box<dyn std::error::Error> = err.into();
599                     let json = ErrJson::from_err(&*err);
600                     return make_401_with_body("invalid_jwt", ClientVisibleMsg(json.into_data()));
601                 }
602             }
603         }};
604     }
605     // Generation and verification of job auth
606     struct JWTJobAuthorizer {
607         server_key: Vec<u8>,
608     }
609     impl JWTJobAuthorizer {
new(server_key: Vec<u8>) -> Box<Self>610         fn new(server_key: Vec<u8>) -> Box<Self> {
611             Box::new(Self { server_key })
612         }
613     }
614     impl dist::JobAuthorizer for JWTJobAuthorizer {
generate_token(&self, job_id: JobId) -> Result<String>615         fn generate_token(&self, job_id: JobId) -> Result<String> {
616             let claims = JobJwt { job_id };
617             let key = jwt::EncodingKey::from_secret(&self.server_key);
618             jwt::encode(&JWT_HEADER, &claims, &key)
619                 .map_err(|e| anyhow!("Failed to create JWT for job: {}", e))
620         }
verify_token(&self, job_id: JobId, token: &str) -> Result<()>621         fn verify_token(&self, job_id: JobId, token: &str) -> Result<()> {
622             let valid_claims = JobJwt { job_id };
623             let key = jwt::DecodingKey::from_secret(&self.server_key);
624             jwt::decode(&token, &key, &JWT_VALIDATION)
625                 .map_err(|e| anyhow!("JWT decode failed: {}", e))
626                 .and_then(|res| {
627                     fn identical_t<T>(_: &T, _: &T) {}
628                     identical_t(&res.claims, &valid_claims);
629                     if res.claims == valid_claims {
630                         Ok(())
631                     } else {
632                         Err(anyhow!("mismatched claims"))
633                     }
634                 })
635         }
636     }
637 
638     #[test]
test_job_token_verification()639     fn test_job_token_verification() {
640         let ja = JWTJobAuthorizer::new(vec![1, 2, 2]);
641 
642         let job_id = JobId(55);
643         let token = ja.generate_token(job_id).unwrap();
644 
645         let job_id2 = JobId(56);
646         let token2 = ja.generate_token(job_id2).unwrap();
647 
648         let ja2 = JWTJobAuthorizer::new(vec![1, 2, 3]);
649 
650         // Check tokens are deterministic
651         assert_eq!(token, ja.generate_token(job_id).unwrap());
652         // Check token verification works
653         assert!(ja.verify_token(job_id, &token).is_ok());
654         assert!(ja.verify_token(job_id, &token2).is_err());
655         assert!(ja.verify_token(job_id2, &token).is_err());
656         assert!(ja.verify_token(job_id2, &token2).is_ok());
657         // Check token verification with a different key fails
658         assert!(ja2.verify_token(job_id, &token).is_err());
659         assert!(ja2.verify_token(job_id2, &token2).is_err());
660     }
661 
662     pub struct Scheduler<S> {
663         public_addr: SocketAddr,
664         handler: S,
665         // Is this client permitted to use the scheduler?
666         check_client_auth: Box<dyn ClientAuthCheck>,
667         // Do we believe the server is who they appear to be?
668         check_server_auth: ServerAuthCheck,
669     }
670 
671     impl<S: dist::SchedulerIncoming + 'static> Scheduler<S> {
new( public_addr: SocketAddr, handler: S, check_client_auth: Box<dyn ClientAuthCheck>, check_server_auth: ServerAuthCheck, ) -> Self672         pub fn new(
673             public_addr: SocketAddr,
674             handler: S,
675             check_client_auth: Box<dyn ClientAuthCheck>,
676             check_server_auth: ServerAuthCheck,
677         ) -> Self {
678             Self {
679                 public_addr,
680                 handler,
681                 check_client_auth,
682                 check_server_auth,
683             }
684         }
685 
start(self) -> Result<Void>686         pub fn start(self) -> Result<Void> {
687             let Self {
688                 public_addr,
689                 handler,
690                 check_client_auth,
691                 check_server_auth,
692             } = self;
693             let requester = SchedulerRequester {
694                 client: Mutex::new(reqwest::Client::new()),
695             };
696 
697             macro_rules! check_server_auth_or_err {
698                 ($request:ident) => {{
699                     match bearer_http_auth($request).and_then(&*check_server_auth) {
700                         Some(server_id) => {
701                             let origin_ip = if let Some(header_val) = $request.header("X-Real-IP") {
702                                 trace!("X-Real-IP: {:?}", header_val);
703                                 match header_val.parse() {
704                                     Ok(ip) => ip,
705                                     Err(err) => {
706                                         warn!(
707                                             "X-Real-IP value {:?} could not be parsed: {:?}",
708                                             header_val, err
709                                         );
710                                         return rouille::Response::empty_400();
711                                     }
712                                 }
713                             } else {
714                                 $request.remote_addr().ip()
715                             };
716                             if server_id.addr().ip() != origin_ip {
717                                 trace!("server ip: {:?}", server_id.addr().ip());
718                                 trace!("request ip: {:?}", $request.remote_addr().ip());
719                                 return make_401("invalid_bearer_token_mismatched_address");
720                             } else {
721                                 server_id
722                             }
723                         }
724                         None => return make_401("invalid_bearer_token"),
725                     }
726                 }};
727             }
728 
729             fn maybe_update_certs(
730                 client: &mut reqwest::Client,
731                 certs: &mut HashMap<ServerId, (Vec<u8>, Vec<u8>)>,
732                 server_id: ServerId,
733                 cert_digest: Vec<u8>,
734                 cert_pem: Vec<u8>,
735             ) -> Result<()> {
736                 if let Some((saved_cert_digest, _)) = certs.get(&server_id) {
737                     if saved_cert_digest == &cert_digest {
738                         return Ok(());
739                     }
740                 }
741                 info!(
742                     "Adding new certificate for {} to scheduler",
743                     server_id.addr()
744                 );
745                 let mut client_builder = reqwest::ClientBuilder::new();
746                 // Add all the certificates we know about
747                 client_builder = client_builder.add_root_certificate(
748                     reqwest::Certificate::from_pem(&cert_pem)
749                         .context("failed to interpret pem as certificate")?,
750                 );
751                 for (_, cert_pem) in certs.values() {
752                     client_builder = client_builder.add_root_certificate(
753                         reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert"),
754                     );
755                 }
756                 // Finish the clients
757                 let new_client = client_builder
758                     .build()
759                     .context("failed to create a HTTP client")?;
760                 // Use the updated certificates
761                 *client = new_client;
762                 certs.insert(server_id, (cert_digest, cert_pem));
763                 Ok(())
764             }
765 
766             info!("Scheduler listening for clients on {}", public_addr);
767             let request_count = atomic::AtomicUsize::new(0);
768             // From server_id -> cert_digest, cert_pem
769             let server_certificates: Mutex<HashMap<ServerId, (Vec<u8>, Vec<u8>)>> =
770                 Default::default();
771 
772             let server = rouille::Server::new(public_addr, move |request| {
773                 let req_id = request_count.fetch_add(1, atomic::Ordering::SeqCst);
774                 trace!("Req {} ({}): {:?}", req_id, request.remote_addr(), request);
775                 let response = (|| router!(request,
776                     (POST) (/api/v1/scheduler/alloc_job) => {
777                         let bearer_auth = match bearer_http_auth(request) {
778                             Some(s) => s,
779                             None => return make_401("no_bearer_auth"),
780                         };
781                         match check_client_auth.check(bearer_auth) {
782                             Ok(()) => (),
783                             Err(client_msg) => {
784                                 warn!("Bearer auth failed: {:?}", client_msg);
785                                 return make_401_with_body("bearer_auth_failed", client_msg)
786                             },
787                         }
788                         let toolchain = try_or_400_log!(req_id, bincode_input(request));
789                         trace!("Req {}: alloc_job: {:?}", req_id, toolchain);
790 
791                         let alloc_job_res: AllocJobResult = try_or_500_log!(req_id, handler.handle_alloc_job(&requester, toolchain));
792                         let certs = server_certificates.lock().unwrap();
793                         let res = AllocJobHttpResponse::from_alloc_job_result(alloc_job_res, &certs);
794                         prepare_response(&request, &res)
795                     },
796                     (GET) (/api/v1/scheduler/server_certificate/{server_id: ServerId}) => {
797                         let certs = server_certificates.lock().unwrap();
798                         let (cert_digest, cert_pem) = try_or_500_log!(req_id, certs.get(&server_id)
799                             .context("server cert not available"));
800                         let res = ServerCertificateHttpResponse {
801                             cert_digest: cert_digest.clone(),
802                             cert_pem: cert_pem.clone(),
803                         };
804                         prepare_response(&request, &res)
805                     },
806                     (POST) (/api/v1/scheduler/heartbeat_server) => {
807                         let server_id = check_server_auth_or_err!(request);
808                         let heartbeat_server = try_or_400_log!(req_id, bincode_input(request));
809                         trace!("Req {}: heartbeat_server: {:?}", req_id, heartbeat_server);
810 
811                         let HeartbeatServerHttpRequest { num_cpus, jwt_key, server_nonce, cert_digest, cert_pem } = heartbeat_server;
812                         try_or_500_log!(req_id, maybe_update_certs(
813                             &mut requester.client.lock().unwrap(),
814                             &mut server_certificates.lock().unwrap(),
815                             server_id, cert_digest, cert_pem
816                         ));
817                         let job_authorizer = JWTJobAuthorizer::new(jwt_key);
818                         let res: HeartbeatServerResult = try_or_500_log!(req_id, handler.handle_heartbeat_server(
819                             server_id, server_nonce,
820                             num_cpus,
821                             job_authorizer
822                         ));
823                         prepare_response(&request, &res)
824                     },
825                     (POST) (/api/v1/scheduler/job_state/{job_id: JobId}) => {
826                         let server_id = check_server_auth_or_err!(request);
827                         let job_state = try_or_400_log!(req_id, bincode_input(request));
828                         trace!("Req {}: job state: {:?}", req_id, job_state);
829 
830                         let res: UpdateJobStateResult = try_or_500_log!(req_id, handler.handle_update_job_state(
831                             job_id, server_id, job_state
832                         ));
833                         prepare_response(&request, &res)
834                     },
835                     (GET) (/api/v1/scheduler/status) => {
836                         let res: SchedulerStatusResult = try_or_500_log!(req_id, handler.handle_status());
837                         prepare_response(&request, &res)
838                     },
839                     _ => {
840                         warn!("Unknown request {:?}", request);
841                         rouille::Response::empty_404()
842                     },
843                 ))();
844                 trace!("Res {}: {:?}", req_id, response);
845                 response
846             }).map_err(|e| anyhow!(format!("Failed to start http server for sccache scheduler: {}", e)))?;
847 
848             // This limit is rouille's default for `start_server_with_pool`, which
849             // we would use, except that interface doesn't permit any sort of
850             // error handling to be done.
851             let server = server.pool_size(num_cpus::get() * 8);
852             server.run();
853 
854             panic!("Rouille server terminated")
855         }
856     }
857 
858     struct SchedulerRequester {
859         client: Mutex<reqwest::Client>,
860     }
861 
862     impl dist::SchedulerOutgoing for SchedulerRequester {
do_assign_job( &self, server_id: ServerId, job_id: JobId, tc: Toolchain, auth: String, ) -> Result<AssignJobResult>863         fn do_assign_job(
864             &self,
865             server_id: ServerId,
866             job_id: JobId,
867             tc: Toolchain,
868             auth: String,
869         ) -> Result<AssignJobResult> {
870             let url = urls::server_assign_job(server_id, job_id);
871             let req = self.client.lock().unwrap().post(url);
872             bincode_req(req.bearer_auth(auth).bincode(&tc)?)
873                 .context("POST to scheduler assign_job failed")
874         }
875     }
876 
877     pub struct Server<S> {
878         public_addr: SocketAddr,
879         scheduler_url: reqwest::Url,
880         scheduler_auth: String,
881         // HTTPS pieces all the builders will use for connection encryption
882         cert_digest: Vec<u8>,
883         cert_pem: Vec<u8>,
884         privkey_pem: Vec<u8>,
885         // Key used to sign any requests relating to jobs
886         jwt_key: Vec<u8>,
887         // Randomly generated nonce to allow the scheduler to detect server restarts
888         server_nonce: ServerNonce,
889         handler: S,
890     }
891 
892     impl<S: dist::ServerIncoming + 'static> Server<S> {
new( public_addr: SocketAddr, scheduler_url: reqwest::Url, scheduler_auth: String, handler: S, ) -> Result<Self>893         pub fn new(
894             public_addr: SocketAddr,
895             scheduler_url: reqwest::Url,
896             scheduler_auth: String,
897             handler: S,
898         ) -> Result<Self> {
899             let (cert_digest, cert_pem, privkey_pem) =
900                 create_https_cert_and_privkey(public_addr)
901                     .context("failed to create HTTPS certificate for server")?;
902             let mut jwt_key = vec![0; JWT_KEY_LENGTH];
903             OsRng.fill_bytes(&mut jwt_key);
904             let server_nonce = ServerNonce::new();
905 
906             Ok(Self {
907                 public_addr,
908                 scheduler_url,
909                 scheduler_auth,
910                 cert_digest,
911                 cert_pem,
912                 privkey_pem,
913                 jwt_key,
914                 server_nonce,
915                 handler,
916             })
917         }
918 
start(self) -> Result<Void>919         pub fn start(self) -> Result<Void> {
920             let Self {
921                 public_addr,
922                 scheduler_url,
923                 scheduler_auth,
924                 cert_digest,
925                 cert_pem,
926                 privkey_pem,
927                 jwt_key,
928                 server_nonce,
929                 handler,
930             } = self;
931             let heartbeat_req = HeartbeatServerHttpRequest {
932                 num_cpus: num_cpus::get(),
933                 jwt_key: jwt_key.clone(),
934                 server_nonce,
935                 cert_digest,
936                 cert_pem: cert_pem.clone(),
937             };
938             let job_authorizer = JWTJobAuthorizer::new(jwt_key);
939             let heartbeat_url = urls::scheduler_heartbeat_server(&scheduler_url);
940             let requester = ServerRequester {
941                 client: reqwest::Client::new(),
942                 scheduler_url,
943                 scheduler_auth: scheduler_auth.clone(),
944             };
945 
946             // TODO: detect if this panics
947             thread::spawn(move || {
948                 let client = reqwest::Client::new();
949                 loop {
950                     trace!("Performing heartbeat");
951                     match bincode_req(
952                         client
953                             .post(heartbeat_url.clone())
954                             .bearer_auth(scheduler_auth.clone())
955                             .bincode(&heartbeat_req)
956                             .expect("failed to serialize heartbeat"),
957                     ) {
958                         Ok(HeartbeatServerResult { is_new }) => {
959                             trace!("Heartbeat success is_new={}", is_new);
960                             // TODO: if is_new, terminate all running jobs
961                             thread::sleep(HEARTBEAT_INTERVAL)
962                         }
963                         Err(e) => {
964                             error!("Failed to send heartbeat to server: {}", e);
965                             thread::sleep(HEARTBEAT_ERROR_INTERVAL)
966                         }
967                     }
968                 }
969             });
970 
971             info!("Server listening for clients on {}", public_addr);
972             let request_count = atomic::AtomicUsize::new(0);
973 
974             let server = rouille::Server::new_ssl(public_addr, move |request| {
975                 let req_id = request_count.fetch_add(1, atomic::Ordering::SeqCst);
976                 trace!("Req {} ({}): {:?}", req_id, request.remote_addr(), request);
977                 let response = (|| router!(request,
978                     (POST) (/api/v1/distserver/assign_job/{job_id: JobId}) => {
979                         job_auth_or_401!(request, &job_authorizer, job_id);
980                         let toolchain = try_or_400_log!(req_id, bincode_input(request));
981                         trace!("Req {}: assign_job({}): {:?}", req_id, job_id, toolchain);
982 
983                         let res: AssignJobResult = try_or_500_log!(req_id, handler.handle_assign_job(job_id, toolchain));
984                         prepare_response(&request, &res)
985                     },
986                     (POST) (/api/v1/distserver/submit_toolchain/{job_id: JobId}) => {
987                         job_auth_or_401!(request, &job_authorizer, job_id);
988                         trace!("Req {}: submit_toolchain({})", req_id, job_id);
989 
990                         let body = request.data().expect("body was already read in submit_toolchain");
991                         let toolchain_rdr = ToolchainReader(Box::new(body));
992                         let res: SubmitToolchainResult = try_or_500_log!(req_id, handler.handle_submit_toolchain(&requester, job_id, toolchain_rdr));
993                         prepare_response(&request, &res)
994                     },
995                     (POST) (/api/v1/distserver/run_job/{job_id: JobId}) => {
996                         job_auth_or_401!(request, &job_authorizer, job_id);
997 
998                         let mut body = request.data().expect("body was already read in run_job");
999                         let bincode_length = try_or_500_log!(req_id, body.read_u32::<BigEndian>()
1000                             .context("failed to read run job input length")) as u64;
1001 
1002                         let mut bincode_reader = body.take(bincode_length);
1003                         let runjob = try_or_500_log!(req_id, bincode::deserialize_from(&mut bincode_reader)
1004                             .context("failed to deserialize run job request"));
1005                         trace!("Req {}: run_job({}): {:?}", req_id, job_id, runjob);
1006                         let RunJobHttpRequest { command, outputs } = runjob;
1007                         let body = bincode_reader.into_inner();
1008                         let inputs_rdr = InputsReader(Box::new(ZlibReadDecoder::new(body)));
1009                         let outputs = outputs.into_iter().collect();
1010 
1011                         let res: RunJobResult = try_or_500_log!(req_id, handler.handle_run_job(&requester, job_id, command, outputs, inputs_rdr));
1012                         prepare_response(&request, &res)
1013                     },
1014                     _ => {
1015                         warn!("Unknown request {:?}", request);
1016                         rouille::Response::empty_404()
1017                     },
1018                 ))();
1019                 trace!("Res {}: {:?}", req_id, response);
1020                 response
1021             }, cert_pem, privkey_pem).map_err(|e| anyhow!(format!("Failed to start http server for sccache server: {}", e)))?;
1022 
1023             // This limit is rouille's default for `start_server_with_pool`, which
1024             // we would use, except that interface doesn't permit any sort of
1025             // error handling to be done.
1026             let server = server.pool_size(num_cpus::get() * 8);
1027             server.run();
1028 
1029             panic!("Rouille server terminated")
1030         }
1031     }
1032 
1033     struct ServerRequester {
1034         client: reqwest::Client,
1035         scheduler_url: reqwest::Url,
1036         scheduler_auth: String,
1037     }
1038 
1039     impl dist::ServerOutgoing for ServerRequester {
do_update_job_state( &self, job_id: JobId, state: JobState, ) -> Result<UpdateJobStateResult>1040         fn do_update_job_state(
1041             &self,
1042             job_id: JobId,
1043             state: JobState,
1044         ) -> Result<UpdateJobStateResult> {
1045             let url = urls::scheduler_job_state(&self.scheduler_url, job_id);
1046             bincode_req(
1047                 self.client
1048                     .post(url)
1049                     .bearer_auth(self.scheduler_auth.clone())
1050                     .bincode(&state)?,
1051             )
1052             .context("POST to scheduler job_state failed")
1053         }
1054     }
1055 }
1056 
1057 #[cfg(feature = "dist-client")]
1058 mod client {
1059     use super::super::cache;
1060     use crate::config;
1061     use crate::dist::pkg::{InputsPackager, ToolchainPackager};
1062     use crate::dist::{
1063         self, AllocJobResult, CompileCommand, JobAlloc, PathTransformer, RunJobResult,
1064         SchedulerStatusResult, SubmitToolchainResult, Toolchain,
1065     };
1066     use crate::util::SpawnExt;
1067     use byteorder::{BigEndian, WriteBytesExt};
1068     use flate2::write::ZlibEncoder as ZlibWriteEncoder;
1069     use flate2::Compression;
1070     use futures::Future;
1071     use futures_03::executor::ThreadPool;
1072     use std::collections::HashMap;
1073     use std::io::Write;
1074     use std::path::{Path, PathBuf};
1075     use std::sync::{Arc, Mutex};
1076     use std::time::Duration;
1077 
1078     use super::common::{
1079         bincode_req, bincode_req_fut, AllocJobHttpResponse, ReqwestRequestBuilderExt,
1080         RunJobHttpRequest, ServerCertificateHttpResponse,
1081     };
1082     use super::urls;
1083     use crate::errors::*;
1084 
1085     const REQUEST_TIMEOUT_SECS: u64 = 600;
1086     const CONNECT_TIMEOUT_SECS: u64 = 5;
1087 
1088     pub struct Client {
1089         auth_token: String,
1090         scheduler_url: reqwest::Url,
1091         // cert_digest -> cert_pem
1092         server_certs: Arc<Mutex<HashMap<Vec<u8>, Vec<u8>>>>,
1093         // TODO: this should really only use the async client, but reqwest async bodies are extremely limited
1094         // and only support owned bytes, which means the whole toolchain would end up in memory
1095         client: Arc<Mutex<reqwest::Client>>,
1096         client_async: Arc<Mutex<reqwest::r#async::Client>>,
1097         pool: ThreadPool,
1098         tc_cache: Arc<cache::ClientToolchains>,
1099         rewrite_includes_only: bool,
1100     }
1101 
1102     impl Client {
new( pool: &ThreadPool, scheduler_url: reqwest::Url, cache_dir: &Path, cache_size: u64, toolchain_configs: &[config::DistToolchainConfig], auth_token: String, rewrite_includes_only: bool, ) -> Result<Self>1103         pub fn new(
1104             pool: &ThreadPool,
1105             scheduler_url: reqwest::Url,
1106             cache_dir: &Path,
1107             cache_size: u64,
1108             toolchain_configs: &[config::DistToolchainConfig],
1109             auth_token: String,
1110             rewrite_includes_only: bool,
1111         ) -> Result<Self> {
1112             let timeout = Duration::new(REQUEST_TIMEOUT_SECS, 0);
1113             let connect_timeout = Duration::new(CONNECT_TIMEOUT_SECS, 0);
1114             let client = reqwest::ClientBuilder::new()
1115                 .timeout(timeout)
1116                 .connect_timeout(connect_timeout)
1117                 .build()
1118                 .context("failed to create a HTTP client")?;
1119             let client_async = reqwest::r#async::ClientBuilder::new()
1120                 .timeout(timeout)
1121                 .connect_timeout(connect_timeout)
1122                 .build()
1123                 .context("failed to create an async HTTP client")?;
1124             let client_toolchains =
1125                 cache::ClientToolchains::new(cache_dir, cache_size, toolchain_configs)
1126                     .context("failed to initialise client toolchains")?;
1127             Ok(Self {
1128                 auth_token,
1129                 scheduler_url,
1130                 server_certs: Default::default(),
1131                 client: Arc::new(Mutex::new(client)),
1132                 client_async: Arc::new(Mutex::new(client_async)),
1133                 pool: pool.clone(),
1134                 tc_cache: Arc::new(client_toolchains),
1135                 rewrite_includes_only,
1136             })
1137         }
1138 
update_certs( client: &mut reqwest::Client, client_async: &mut reqwest::r#async::Client, certs: &mut HashMap<Vec<u8>, Vec<u8>>, cert_digest: Vec<u8>, cert_pem: Vec<u8>, ) -> Result<()>1139         fn update_certs(
1140             client: &mut reqwest::Client,
1141             client_async: &mut reqwest::r#async::Client,
1142             certs: &mut HashMap<Vec<u8>, Vec<u8>>,
1143             cert_digest: Vec<u8>,
1144             cert_pem: Vec<u8>,
1145         ) -> Result<()> {
1146             let mut client_builder = reqwest::ClientBuilder::new();
1147             let mut client_async_builder = reqwest::r#async::ClientBuilder::new();
1148             // Add all the certificates we know about
1149             client_builder = client_builder.add_root_certificate(
1150                 reqwest::Certificate::from_pem(&cert_pem)
1151                     .context("failed to interpret pem as certificate")?,
1152             );
1153             client_async_builder = client_async_builder.add_root_certificate(
1154                 reqwest::Certificate::from_pem(&cert_pem)
1155                     .context("failed to interpret pem as certificate")?,
1156             );
1157             for cert_pem in certs.values() {
1158                 client_builder = client_builder.add_root_certificate(
1159                     reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert"),
1160                 );
1161                 client_async_builder = client_async_builder.add_root_certificate(
1162                     reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert"),
1163                 );
1164             }
1165             // Finish the clients
1166             let timeout = Duration::new(REQUEST_TIMEOUT_SECS, 0);
1167             let new_client = client_builder
1168                 .timeout(timeout)
1169                 .build()
1170                 .context("failed to create a HTTP client")?;
1171             let new_client_async = client_async_builder
1172                 .timeout(timeout)
1173                 .build()
1174                 .context("failed to create an async HTTP client")?;
1175             // Use the updated certificates
1176             *client = new_client;
1177             *client_async = new_client_async;
1178             certs.insert(cert_digest, cert_pem);
1179             Ok(())
1180         }
1181     }
1182 
1183     impl dist::Client for Client {
do_alloc_job(&self, tc: Toolchain) -> SFuture<AllocJobResult>1184         fn do_alloc_job(&self, tc: Toolchain) -> SFuture<AllocJobResult> {
1185             let scheduler_url = self.scheduler_url.clone();
1186             let url = urls::scheduler_alloc_job(&scheduler_url);
1187             let mut req = self.client_async.lock().unwrap().post(url);
1188             req = ftry!(req.bearer_auth(self.auth_token.clone()).bincode(&tc));
1189 
1190             let client = self.client.clone();
1191             let client_async = self.client_async.clone();
1192             let server_certs = self.server_certs.clone();
1193             Box::new(bincode_req_fut(req).and_then(move |res| match res {
1194                 AllocJobHttpResponse::Success {
1195                     job_alloc,
1196                     need_toolchain,
1197                     cert_digest,
1198                 } => {
1199                     let server_id = job_alloc.server_id;
1200                     let alloc_job_res = f_ok(AllocJobResult::Success {
1201                         job_alloc,
1202                         need_toolchain,
1203                     });
1204                     if server_certs.lock().unwrap().contains_key(&cert_digest) {
1205                         return alloc_job_res;
1206                     }
1207                     info!(
1208                         "Need to request new certificate for server {}",
1209                         server_id.addr()
1210                     );
1211                     let url = urls::scheduler_server_certificate(&scheduler_url, server_id);
1212                     let req = client_async.lock().unwrap().get(url);
1213                     Box::new(
1214                         bincode_req_fut(req)
1215                             .map_err(|e| e.context("GET to scheduler server_certificate failed"))
1216                             .and_then(move |res: ServerCertificateHttpResponse| {
1217                                 ftry!(Self::update_certs(
1218                                     &mut client.lock().unwrap(),
1219                                     &mut client_async.lock().unwrap(),
1220                                     &mut server_certs.lock().unwrap(),
1221                                     res.cert_digest,
1222                                     res.cert_pem,
1223                                 ));
1224                                 alloc_job_res
1225                             }),
1226                     )
1227                 }
1228                 AllocJobHttpResponse::Fail { msg } => f_ok(AllocJobResult::Fail { msg }),
1229             }))
1230         }
do_get_status(&self) -> SFuture<SchedulerStatusResult>1231         fn do_get_status(&self) -> SFuture<SchedulerStatusResult> {
1232             let scheduler_url = self.scheduler_url.clone();
1233             let url = urls::scheduler_status(&scheduler_url);
1234             let req = self.client.lock().unwrap().get(url);
1235             Box::new(self.pool.spawn_fn(move || bincode_req(req)))
1236         }
do_submit_toolchain( &self, job_alloc: JobAlloc, tc: Toolchain, ) -> SFuture<SubmitToolchainResult>1237         fn do_submit_toolchain(
1238             &self,
1239             job_alloc: JobAlloc,
1240             tc: Toolchain,
1241         ) -> SFuture<SubmitToolchainResult> {
1242             match self.tc_cache.get_toolchain(&tc) {
1243                 Ok(Some(toolchain_file)) => {
1244                     let url = urls::server_submit_toolchain(job_alloc.server_id, job_alloc.job_id);
1245                     let req = self.client.lock().unwrap().post(url);
1246 
1247                     Box::new(self.pool.spawn_fn(move || {
1248                         let req = req.bearer_auth(job_alloc.auth.clone()).body(toolchain_file);
1249                         bincode_req(req)
1250                     }))
1251                 }
1252                 Ok(None) => f_err(anyhow!("couldn't find toolchain locally")),
1253                 Err(e) => f_err(e),
1254             }
1255         }
do_run_job( &self, job_alloc: JobAlloc, command: CompileCommand, outputs: Vec<String>, inputs_packager: Box<dyn InputsPackager>, ) -> SFuture<(RunJobResult, PathTransformer)>1256         fn do_run_job(
1257             &self,
1258             job_alloc: JobAlloc,
1259             command: CompileCommand,
1260             outputs: Vec<String>,
1261             inputs_packager: Box<dyn InputsPackager>,
1262         ) -> SFuture<(RunJobResult, PathTransformer)> {
1263             let url = urls::server_run_job(job_alloc.server_id, job_alloc.job_id);
1264             let mut req = self.client.lock().unwrap().post(url);
1265 
1266             Box::new(self.pool.spawn_fn(move || {
1267                 let bincode = bincode::serialize(&RunJobHttpRequest { command, outputs })
1268                     .context("failed to serialize run job request")?;
1269                 let bincode_length = bincode.len();
1270 
1271                 let mut body = vec![];
1272                 body.write_u32::<BigEndian>(bincode_length as u32)
1273                     .expect("Infallible write of bincode length to vec failed");
1274                 body.write_all(&bincode)
1275                     .expect("Infallible write of bincode body to vec failed");
1276                 let path_transformer;
1277                 {
1278                     let mut compressor = ZlibWriteEncoder::new(&mut body, Compression::fast());
1279                     path_transformer = inputs_packager
1280                         .write_inputs(&mut compressor)
1281                         .context("Could not write inputs for compilation")?;
1282                     compressor.flush().context("failed to flush compressor")?;
1283                     trace!(
1284                         "Compressed inputs from {} -> {}",
1285                         compressor.total_in(),
1286                         compressor.total_out()
1287                     );
1288                     compressor.finish().context("failed to finish compressor")?;
1289                 }
1290 
1291                 req = req.bearer_auth(job_alloc.auth.clone()).bytes(body);
1292                 bincode_req(req).map(|res| (res, path_transformer))
1293             }))
1294         }
1295 
put_toolchain( &self, compiler_path: &Path, weak_key: &str, toolchain_packager: Box<dyn ToolchainPackager>, ) -> SFuture<(Toolchain, Option<(String, PathBuf)>)>1296         fn put_toolchain(
1297             &self,
1298             compiler_path: &Path,
1299             weak_key: &str,
1300             toolchain_packager: Box<dyn ToolchainPackager>,
1301         ) -> SFuture<(Toolchain, Option<(String, PathBuf)>)> {
1302             let compiler_path = compiler_path.to_owned();
1303             let weak_key = weak_key.to_owned();
1304             let tc_cache = self.tc_cache.clone();
1305             Box::new(self.pool.spawn_fn(move || {
1306                 tc_cache.put_toolchain(&compiler_path, &weak_key, toolchain_packager)
1307             }))
1308         }
1309 
rewrite_includes_only(&self) -> bool1310         fn rewrite_includes_only(&self) -> bool {
1311             self.rewrite_includes_only
1312         }
get_custom_toolchain(&self, exe: &PathBuf) -> Option<PathBuf>1313         fn get_custom_toolchain(&self, exe: &PathBuf) -> Option<PathBuf> {
1314             match self.tc_cache.get_custom_toolchain(exe) {
1315                 Some(Ok((_, _, path))) => Some(path),
1316                 _ => None,
1317             }
1318         }
1319     }
1320 }
1321