1 // Copyright 2016 Mozilla Foundation 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 #[cfg(feature = "dist-client")] 15 pub use self::client::Client; 16 #[cfg(feature = "dist-server")] 17 pub use self::server::Server; 18 #[cfg(feature = "dist-server")] 19 pub use self::server::{ 20 ClientAuthCheck, ClientVisibleMsg, Scheduler, ServerAuthCheck, HEARTBEAT_TIMEOUT, 21 }; 22 23 mod common { 24 #[cfg(feature = "dist-client")] 25 use futures::{Future, Stream}; 26 use hyperx::header; 27 #[cfg(feature = "dist-server")] 28 use std::collections::HashMap; 29 use std::fmt; 30 31 use crate::dist; 32 33 use crate::errors::*; 34 use crate::util::RequestExt; 35 36 // Note that content-length is necessary due to https://github.com/tiny-http/tiny-http/issues/147 37 pub trait ReqwestRequestBuilderExt: Sized { bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self>38 fn bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self>; bytes(self, bytes: Vec<u8>) -> Self39 fn bytes(self, bytes: Vec<u8>) -> Self; bearer_auth(self, token: String) -> Self40 fn bearer_auth(self, token: String) -> Self; 41 } 42 impl ReqwestRequestBuilderExt for reqwest::RequestBuilder { bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self>43 fn bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self> { 44 let bytes = 45 bincode::serialize(bincode).context("Failed to serialize body to bincode")?; 46 Ok(self.bytes(bytes)) 47 } bytes(self, bytes: Vec<u8>) -> Self48 fn bytes(self, bytes: Vec<u8>) -> Self { 49 self.set_header(header::ContentType::octet_stream()) 50 .set_header(header::ContentLength(bytes.len() as u64)) 51 .body(bytes) 52 } bearer_auth(self, token: String) -> Self53 fn bearer_auth(self, token: String) -> Self { 54 self.set_header(header::Authorization(header::Bearer { token })) 55 } 56 } 57 impl ReqwestRequestBuilderExt for reqwest::r#async::RequestBuilder { bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self>58 fn bincode<T: serde::Serialize + ?Sized>(self, bincode: &T) -> Result<Self> { 59 let bytes = 60 bincode::serialize(bincode).context("Failed to serialize body to bincode")?; 61 Ok(self.bytes(bytes)) 62 } bytes(self, bytes: Vec<u8>) -> Self63 fn bytes(self, bytes: Vec<u8>) -> Self { 64 self.set_header(header::ContentType::octet_stream()) 65 .set_header(header::ContentLength(bytes.len() as u64)) 66 .body(bytes) 67 } bearer_auth(self, token: String) -> Self68 fn bearer_auth(self, token: String) -> Self { 69 self.set_header(header::Authorization(header::Bearer { token })) 70 } 71 } 72 bincode_req<T: serde::de::DeserializeOwned + 'static>( req: reqwest::RequestBuilder, ) -> Result<T>73 pub fn bincode_req<T: serde::de::DeserializeOwned + 'static>( 74 req: reqwest::RequestBuilder, 75 ) -> Result<T> { 76 // Work around tiny_http issue #151 by disabling HTTP pipeline with 77 // `Connection: close`. 78 let mut res = req.set_header(header::Connection::close()).send()?; 79 let status = res.status(); 80 let mut body = vec![]; 81 res.copy_to(&mut body) 82 .context("error reading response body")?; 83 if !status.is_success() { 84 Err(anyhow!( 85 "Error {} (Headers={:?}): {}", 86 status.as_u16(), 87 res.headers(), 88 String::from_utf8_lossy(&body) 89 )) 90 } else { 91 bincode::deserialize(&body).map_err(Into::into) 92 } 93 } 94 #[cfg(feature = "dist-client")] bincode_req_fut<T: serde::de::DeserializeOwned + 'static>( req: reqwest::r#async::RequestBuilder, ) -> SFuture<T>95 pub fn bincode_req_fut<T: serde::de::DeserializeOwned + 'static>( 96 req: reqwest::r#async::RequestBuilder, 97 ) -> SFuture<T> { 98 Box::new( 99 // Work around tiny_http issue #151 by disabling HTTP pipeline with 100 // `Connection: close`. 101 req.set_header(header::Connection::close()) 102 .send() 103 .map_err(Into::into) 104 .and_then(|res| { 105 let status = res.status(); 106 res.into_body() 107 .concat2() 108 .map(move |b| (status, b)) 109 .map_err(Into::into) 110 }) 111 .and_then(|(status, body)| { 112 if !status.is_success() { 113 let errmsg = format!( 114 "Error {}: {}", 115 status.as_u16(), 116 String::from_utf8_lossy(&body) 117 ); 118 if status.is_client_error() { 119 return f_err(HttpClientError(errmsg)); 120 } else { 121 return f_err(anyhow!(errmsg)); 122 } 123 } 124 match bincode::deserialize(&body) { 125 Ok(r) => f_ok(r), 126 Err(e) => f_err(e), 127 } 128 }), 129 ) 130 } 131 132 #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] 133 #[serde(deny_unknown_fields)] 134 pub struct JobJwt { 135 pub job_id: dist::JobId, 136 } 137 138 #[derive(Clone, Debug, Serialize, Deserialize)] 139 #[serde(deny_unknown_fields)] 140 pub enum AllocJobHttpResponse { 141 Success { 142 job_alloc: dist::JobAlloc, 143 need_toolchain: bool, 144 cert_digest: Vec<u8>, 145 }, 146 Fail { 147 msg: String, 148 }, 149 } 150 impl AllocJobHttpResponse { 151 #[cfg(feature = "dist-server")] from_alloc_job_result( res: dist::AllocJobResult, certs: &HashMap<dist::ServerId, (Vec<u8>, Vec<u8>)>, ) -> Self152 pub fn from_alloc_job_result( 153 res: dist::AllocJobResult, 154 certs: &HashMap<dist::ServerId, (Vec<u8>, Vec<u8>)>, 155 ) -> Self { 156 match res { 157 dist::AllocJobResult::Success { 158 job_alloc, 159 need_toolchain, 160 } => { 161 if let Some((digest, _)) = certs.get(&job_alloc.server_id) { 162 AllocJobHttpResponse::Success { 163 job_alloc, 164 need_toolchain, 165 cert_digest: digest.to_owned(), 166 } 167 } else { 168 AllocJobHttpResponse::Fail { 169 msg: format!( 170 "missing certificates for server {}", 171 job_alloc.server_id.addr() 172 ), 173 } 174 } 175 } 176 dist::AllocJobResult::Fail { msg } => AllocJobHttpResponse::Fail { msg }, 177 } 178 } 179 } 180 181 #[derive(Clone, Debug, Serialize, Deserialize)] 182 #[serde(deny_unknown_fields)] 183 pub struct ServerCertificateHttpResponse { 184 pub cert_digest: Vec<u8>, 185 pub cert_pem: Vec<u8>, 186 } 187 188 #[derive(Clone, Serialize, Deserialize)] 189 #[serde(deny_unknown_fields)] 190 pub struct HeartbeatServerHttpRequest { 191 pub jwt_key: Vec<u8>, 192 pub num_cpus: usize, 193 pub server_nonce: dist::ServerNonce, 194 pub cert_digest: Vec<u8>, 195 pub cert_pem: Vec<u8>, 196 } 197 // cert_pem is quite long so elide it (you can retrieve it by hitting the server url anyway) 198 impl fmt::Debug for HeartbeatServerHttpRequest { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 200 let HeartbeatServerHttpRequest { 201 jwt_key, 202 num_cpus, 203 server_nonce, 204 cert_digest, 205 cert_pem, 206 } = self; 207 write!(f, "HeartbeatServerHttpRequest {{ jwt_key: {:?}, num_cpus: {:?}, server_nonce: {:?}, cert_digest: {:?}, cert_pem: [...{} bytes...] }}", jwt_key, num_cpus, server_nonce, cert_digest, cert_pem.len()) 208 } 209 } 210 #[derive(Clone, Debug, Serialize, Deserialize)] 211 #[serde(deny_unknown_fields)] 212 pub struct RunJobHttpRequest { 213 pub command: dist::CompileCommand, 214 pub outputs: Vec<String>, 215 } 216 } 217 218 pub mod urls { 219 use crate::dist::{JobId, ServerId}; 220 scheduler_alloc_job(scheduler_url: &reqwest::Url) -> reqwest::Url221 pub fn scheduler_alloc_job(scheduler_url: &reqwest::Url) -> reqwest::Url { 222 scheduler_url 223 .join("/api/v1/scheduler/alloc_job") 224 .expect("failed to create alloc job url") 225 } scheduler_server_certificate( scheduler_url: &reqwest::Url, server_id: ServerId, ) -> reqwest::Url226 pub fn scheduler_server_certificate( 227 scheduler_url: &reqwest::Url, 228 server_id: ServerId, 229 ) -> reqwest::Url { 230 scheduler_url 231 .join(&format!( 232 "/api/v1/scheduler/server_certificate/{}", 233 server_id.addr() 234 )) 235 .expect("failed to create server certificate url") 236 } scheduler_heartbeat_server(scheduler_url: &reqwest::Url) -> reqwest::Url237 pub fn scheduler_heartbeat_server(scheduler_url: &reqwest::Url) -> reqwest::Url { 238 scheduler_url 239 .join("/api/v1/scheduler/heartbeat_server") 240 .expect("failed to create heartbeat url") 241 } scheduler_job_state(scheduler_url: &reqwest::Url, job_id: JobId) -> reqwest::Url242 pub fn scheduler_job_state(scheduler_url: &reqwest::Url, job_id: JobId) -> reqwest::Url { 243 scheduler_url 244 .join(&format!("/api/v1/scheduler/job_state/{}", job_id)) 245 .expect("failed to create job state url") 246 } scheduler_status(scheduler_url: &reqwest::Url) -> reqwest::Url247 pub fn scheduler_status(scheduler_url: &reqwest::Url) -> reqwest::Url { 248 scheduler_url 249 .join("/api/v1/scheduler/status") 250 .expect("failed to create alloc job url") 251 } 252 server_assign_job(server_id: ServerId, job_id: JobId) -> reqwest::Url253 pub fn server_assign_job(server_id: ServerId, job_id: JobId) -> reqwest::Url { 254 let url = format!( 255 "https://{}/api/v1/distserver/assign_job/{}", 256 server_id.addr(), 257 job_id 258 ); 259 reqwest::Url::parse(&url).expect("failed to create assign job url") 260 } server_submit_toolchain(server_id: ServerId, job_id: JobId) -> reqwest::Url261 pub fn server_submit_toolchain(server_id: ServerId, job_id: JobId) -> reqwest::Url { 262 let url = format!( 263 "https://{}/api/v1/distserver/submit_toolchain/{}", 264 server_id.addr(), 265 job_id 266 ); 267 reqwest::Url::parse(&url).expect("failed to create submit toolchain url") 268 } server_run_job(server_id: ServerId, job_id: JobId) -> reqwest::Url269 pub fn server_run_job(server_id: ServerId, job_id: JobId) -> reqwest::Url { 270 let url = format!( 271 "https://{}/api/v1/distserver/run_job/{}", 272 server_id.addr(), 273 job_id 274 ); 275 reqwest::Url::parse(&url).expect("failed to create run job url") 276 } 277 } 278 279 #[cfg(feature = "dist-server")] 280 mod server { 281 use crate::jwt; 282 use byteorder::{BigEndian, ReadBytesExt}; 283 use flate2::read::ZlibDecoder as ZlibReadDecoder; 284 use rand::{rngs::OsRng, RngCore}; 285 use rouille::accept; 286 use std::collections::HashMap; 287 use std::io::Read; 288 use std::net::SocketAddr; 289 use std::result::Result as StdResult; 290 use std::sync::atomic; 291 use std::sync::Mutex; 292 use std::thread; 293 use std::time::Duration; 294 use void::Void; 295 296 use super::common::{ 297 bincode_req, AllocJobHttpResponse, HeartbeatServerHttpRequest, JobJwt, 298 ReqwestRequestBuilderExt, RunJobHttpRequest, ServerCertificateHttpResponse, 299 }; 300 use super::urls; 301 use crate::dist::{ 302 self, AllocJobResult, AssignJobResult, HeartbeatServerResult, InputsReader, JobAuthorizer, 303 JobId, JobState, RunJobResult, SchedulerStatusResult, ServerId, ServerNonce, 304 SubmitToolchainResult, Toolchain, ToolchainReader, UpdateJobStateResult, 305 }; 306 use crate::errors::*; 307 308 const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); 309 const HEARTBEAT_ERROR_INTERVAL: Duration = Duration::from_secs(10); 310 pub const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(90); 311 create_https_cert_and_privkey(addr: SocketAddr) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)>312 fn create_https_cert_and_privkey(addr: SocketAddr) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> { 313 let rsa_key = openssl::rsa::Rsa::<openssl::pkey::Private>::generate(2048) 314 .context("failed to generate rsa privkey")?; 315 let privkey_pem = rsa_key 316 .private_key_to_pem() 317 .context("failed to create pem from rsa privkey")?; 318 let privkey: openssl::pkey::PKey<openssl::pkey::Private> = 319 openssl::pkey::PKey::from_rsa(rsa_key) 320 .context("failed to create openssl pkey from rsa privkey")?; 321 let mut builder = 322 openssl::x509::X509::builder().context("failed to create x509 builder")?; 323 324 // Populate the certificate with the necessary parts, mostly from mkcert in openssl 325 builder 326 .set_version(2) 327 .context("failed to set x509 version")?; 328 let serial_number = openssl::bn::BigNum::from_u32(0) 329 .and_then(|bn| bn.to_asn1_integer()) 330 .context("failed to create openssl asn1 0")?; 331 builder 332 .set_serial_number(serial_number.as_ref()) 333 .context("failed to set x509 serial number")?; 334 let not_before = openssl::asn1::Asn1Time::days_from_now(0) 335 .context("failed to create openssl not before asn1")?; 336 builder 337 .set_not_before(not_before.as_ref()) 338 .context("failed to set not before on x509")?; 339 let not_after = openssl::asn1::Asn1Time::days_from_now(365) 340 .context("failed to create openssl not after asn1")?; 341 builder 342 .set_not_after(not_after.as_ref()) 343 .context("failed to set not after on x509")?; 344 builder 345 .set_pubkey(privkey.as_ref()) 346 .context("failed to set pubkey for x509")?; 347 348 let mut name = openssl::x509::X509Name::builder()?; 349 name.append_entry_by_nid(openssl::nid::Nid::COMMONNAME, &addr.to_string())?; 350 let name = name.build(); 351 352 builder 353 .set_subject_name(&name) 354 .context("failed to set subject name")?; 355 builder 356 .set_issuer_name(&name) 357 .context("failed to set issuer name")?; 358 359 // Add the SubjectAlternativeName 360 let extension = openssl::x509::extension::SubjectAlternativeName::new() 361 .ip(&addr.ip().to_string()) 362 .build(&builder.x509v3_context(None, None)) 363 .context("failed to build SAN extension for x509")?; 364 builder 365 .append_extension(extension) 366 .context("failed to append SAN extension for x509")?; 367 368 // Add ExtendedKeyUsage 369 let ext_key_usage = openssl::x509::extension::ExtendedKeyUsage::new() 370 .server_auth() 371 .build() 372 .context("failed to build EKU extension for x509")?; 373 builder 374 .append_extension(ext_key_usage) 375 .context("failes to append EKU extension for x509")?; 376 377 // Finish the certificate 378 builder 379 .sign(&privkey, openssl::hash::MessageDigest::sha1()) 380 .context("failed to sign x509 with sha1")?; 381 let cert: openssl::x509::X509 = builder.build(); 382 let cert_pem = cert.to_pem().context("failed to create pem from x509")?; 383 let cert_digest = cert 384 .digest(openssl::hash::MessageDigest::sha256()) 385 .context("failed to create digest of x509 certificate")? 386 .as_ref() 387 .to_owned(); 388 389 Ok((cert_digest, cert_pem, privkey_pem)) 390 } 391 392 // Messages that are non-sensitive and can be sent to the client 393 #[derive(Debug)] 394 pub struct ClientVisibleMsg(String); 395 impl ClientVisibleMsg { from_nonsensitive(s: String) -> Self396 pub fn from_nonsensitive(s: String) -> Self { 397 ClientVisibleMsg(s) 398 } 399 } 400 401 pub trait ClientAuthCheck: Send + Sync { check(&self, token: &str) -> StdResult<(), ClientVisibleMsg>402 fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg>; 403 } 404 pub type ServerAuthCheck = Box<dyn Fn(&str) -> Option<ServerId> + Send + Sync>; 405 406 const JWT_KEY_LENGTH: usize = 256 / 8; 407 lazy_static! { 408 static ref JWT_HEADER: jwt::Header = jwt::Header::new(jwt::Algorithm::HS256); 409 static ref JWT_VALIDATION: jwt::Validation = jwt::Validation { 410 leeway: 0, 411 validate_exp: false, 412 validate_nbf: false, 413 aud: None, 414 iss: None, 415 sub: None, 416 algorithms: vec![jwt::Algorithm::HS256], 417 }; 418 } 419 420 // Based on rouille::input::json::json_input 421 #[derive(Debug)] 422 pub enum RouilleBincodeError { 423 BodyAlreadyExtracted, 424 WrongContentType, 425 ParseError(bincode::Error), 426 } 427 impl From<bincode::Error> for RouilleBincodeError { from(err: bincode::Error) -> RouilleBincodeError428 fn from(err: bincode::Error) -> RouilleBincodeError { 429 RouilleBincodeError::ParseError(err) 430 } 431 } 432 impl std::error::Error for RouilleBincodeError { source(&self) -> Option<&(dyn std::error::Error + 'static)>433 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 434 match *self { 435 RouilleBincodeError::ParseError(ref e) => Some(e), 436 _ => None, 437 } 438 } 439 } 440 impl std::fmt::Display for RouilleBincodeError { fmt( &self, fmt: &mut std::fmt::Formatter<'_>, ) -> std::result::Result<(), std::fmt::Error>441 fn fmt( 442 &self, 443 fmt: &mut std::fmt::Formatter<'_>, 444 ) -> std::result::Result<(), std::fmt::Error> { 445 write!( 446 fmt, 447 "{}", 448 match *self { 449 RouilleBincodeError::BodyAlreadyExtracted => { 450 "the body of the request was already extracted" 451 } 452 RouilleBincodeError::WrongContentType => { 453 "the request didn't have a binary content type" 454 } 455 RouilleBincodeError::ParseError(_) => "error while parsing the bincode body", 456 } 457 ) 458 } 459 } bincode_input<O>(request: &rouille::Request) -> std::result::Result<O, RouilleBincodeError> where O: serde::de::DeserializeOwned,460 fn bincode_input<O>(request: &rouille::Request) -> std::result::Result<O, RouilleBincodeError> 461 where 462 O: serde::de::DeserializeOwned, 463 { 464 if let Some(header) = request.header("Content-Type") { 465 if !header.starts_with("application/octet-stream") { 466 return Err(RouilleBincodeError::WrongContentType); 467 } 468 } else { 469 return Err(RouilleBincodeError::WrongContentType); 470 } 471 472 if let Some(mut b) = request.data() { 473 bincode::deserialize_from::<_, O>(&mut b).map_err(From::from) 474 } else { 475 Err(RouilleBincodeError::BodyAlreadyExtracted) 476 } 477 } 478 479 // Based on try_or_400 in rouille, but with logging 480 #[derive(Serialize)] 481 pub struct ErrJson { 482 description: String, 483 cause: Option<Box<ErrJson>>, 484 } 485 486 impl ErrJson { from_err<E: ?Sized + std::error::Error>(err: &E) -> ErrJson487 fn from_err<E: ?Sized + std::error::Error>(err: &E) -> ErrJson { 488 let cause = err.source().map(ErrJson::from_err).map(Box::new); 489 ErrJson { 490 description: err.to_string(), 491 cause, 492 } 493 } 494 into_data(self) -> String495 fn into_data(self) -> String { 496 serde_json::to_string(&self).expect("infallible serialization for ErrJson failed") 497 } 498 } 499 macro_rules! try_or_err_and_log { 500 ($reqid:expr, $code:expr, $result:expr) => { 501 match $result { 502 Ok(r) => r, 503 Err(err) => { 504 // TODO: would ideally just use error_chain 505 #[allow(unused_imports)] 506 use std::error::Error; 507 let mut err_msg = err.to_string(); 508 let mut maybe_cause = err.source(); 509 while let Some(cause) = maybe_cause { 510 err_msg.push_str(", caused by: "); 511 err_msg.push_str(&cause.to_string()); 512 maybe_cause = cause.source(); 513 } 514 515 warn!("Res {} error: {}", $reqid, err_msg); 516 let err: Box<dyn std::error::Error + 'static> = err.into(); 517 let json = ErrJson::from_err(&*err); 518 return rouille::Response::json(&json).with_status_code($code); 519 } 520 } 521 }; 522 } 523 macro_rules! try_or_400_log { 524 ($reqid:expr, $result:expr) => { 525 try_or_err_and_log!($reqid, 400, $result) 526 }; 527 } 528 macro_rules! try_or_500_log { 529 ($reqid:expr, $result:expr) => { 530 try_or_err_and_log!($reqid, 500, $result) 531 }; 532 } make_401_with_body(short_err: &str, body: ClientVisibleMsg) -> rouille::Response533 fn make_401_with_body(short_err: &str, body: ClientVisibleMsg) -> rouille::Response { 534 rouille::Response { 535 status_code: 401, 536 headers: vec![( 537 "WWW-Authenticate".into(), 538 format!("Bearer error=\"{}\"", short_err).into(), 539 )], 540 data: rouille::ResponseBody::from_data(body.0), 541 upgrade: None, 542 } 543 } make_401(short_err: &str) -> rouille::Response544 fn make_401(short_err: &str) -> rouille::Response { 545 make_401_with_body(short_err, ClientVisibleMsg(String::new())) 546 } bearer_http_auth(request: &rouille::Request) -> Option<&str>547 fn bearer_http_auth(request: &rouille::Request) -> Option<&str> { 548 let header = request.header("Authorization")?; 549 550 let mut split = header.splitn(2, |c| c == ' '); 551 552 let authtype = split.next()?; 553 if authtype != "Bearer" { 554 return None; 555 } 556 557 split.next() 558 } 559 560 /// Return `content` as a bincode-encoded `Response`. bincode_response<T>(content: &T) -> rouille::Response where T: serde::Serialize,561 pub fn bincode_response<T>(content: &T) -> rouille::Response 562 where 563 T: serde::Serialize, 564 { 565 let data = bincode::serialize(content).context("Failed to serialize response body"); 566 let data = try_or_500_log!("bincode body serialization", data); 567 568 rouille::Response { 569 status_code: 200, 570 headers: vec![("Content-Type".into(), "application/octet-stream".into())], 571 data: rouille::ResponseBody::from_data(data), 572 upgrade: None, 573 } 574 } 575 576 /// Return `content` as either a bincode or json encoded `Response` 577 /// depending on the Accept header in `request`. prepare_response<T>(request: &rouille::Request, content: &T) -> rouille::Response where T: serde::Serialize,578 pub fn prepare_response<T>(request: &rouille::Request, content: &T) -> rouille::Response 579 where 580 T: serde::Serialize, 581 { 582 accept!(request, 583 "application/octet-stream" => bincode_response(content), 584 "application/json" => rouille::Response::json(content), 585 ) 586 } 587 588 // Verification of job auth in a request 589 macro_rules! job_auth_or_401 { 590 ($request:ident, $job_authorizer:expr, $job_id:expr) => {{ 591 let verify_result = match bearer_http_auth($request) { 592 Some(token) => $job_authorizer.verify_token($job_id, token), 593 None => Err(anyhow!("no Authorization header")), 594 }; 595 match verify_result { 596 Ok(()) => (), 597 Err(err) => { 598 let err: Box<dyn std::error::Error> = err.into(); 599 let json = ErrJson::from_err(&*err); 600 return make_401_with_body("invalid_jwt", ClientVisibleMsg(json.into_data())); 601 } 602 } 603 }}; 604 } 605 // Generation and verification of job auth 606 struct JWTJobAuthorizer { 607 server_key: Vec<u8>, 608 } 609 impl JWTJobAuthorizer { new(server_key: Vec<u8>) -> Box<Self>610 fn new(server_key: Vec<u8>) -> Box<Self> { 611 Box::new(Self { server_key }) 612 } 613 } 614 impl dist::JobAuthorizer for JWTJobAuthorizer { generate_token(&self, job_id: JobId) -> Result<String>615 fn generate_token(&self, job_id: JobId) -> Result<String> { 616 let claims = JobJwt { job_id }; 617 let key = jwt::EncodingKey::from_secret(&self.server_key); 618 jwt::encode(&JWT_HEADER, &claims, &key) 619 .map_err(|e| anyhow!("Failed to create JWT for job: {}", e)) 620 } verify_token(&self, job_id: JobId, token: &str) -> Result<()>621 fn verify_token(&self, job_id: JobId, token: &str) -> Result<()> { 622 let valid_claims = JobJwt { job_id }; 623 let key = jwt::DecodingKey::from_secret(&self.server_key); 624 jwt::decode(&token, &key, &JWT_VALIDATION) 625 .map_err(|e| anyhow!("JWT decode failed: {}", e)) 626 .and_then(|res| { 627 fn identical_t<T>(_: &T, _: &T) {} 628 identical_t(&res.claims, &valid_claims); 629 if res.claims == valid_claims { 630 Ok(()) 631 } else { 632 Err(anyhow!("mismatched claims")) 633 } 634 }) 635 } 636 } 637 638 #[test] test_job_token_verification()639 fn test_job_token_verification() { 640 let ja = JWTJobAuthorizer::new(vec![1, 2, 2]); 641 642 let job_id = JobId(55); 643 let token = ja.generate_token(job_id).unwrap(); 644 645 let job_id2 = JobId(56); 646 let token2 = ja.generate_token(job_id2).unwrap(); 647 648 let ja2 = JWTJobAuthorizer::new(vec![1, 2, 3]); 649 650 // Check tokens are deterministic 651 assert_eq!(token, ja.generate_token(job_id).unwrap()); 652 // Check token verification works 653 assert!(ja.verify_token(job_id, &token).is_ok()); 654 assert!(ja.verify_token(job_id, &token2).is_err()); 655 assert!(ja.verify_token(job_id2, &token).is_err()); 656 assert!(ja.verify_token(job_id2, &token2).is_ok()); 657 // Check token verification with a different key fails 658 assert!(ja2.verify_token(job_id, &token).is_err()); 659 assert!(ja2.verify_token(job_id2, &token2).is_err()); 660 } 661 662 pub struct Scheduler<S> { 663 public_addr: SocketAddr, 664 handler: S, 665 // Is this client permitted to use the scheduler? 666 check_client_auth: Box<dyn ClientAuthCheck>, 667 // Do we believe the server is who they appear to be? 668 check_server_auth: ServerAuthCheck, 669 } 670 671 impl<S: dist::SchedulerIncoming + 'static> Scheduler<S> { new( public_addr: SocketAddr, handler: S, check_client_auth: Box<dyn ClientAuthCheck>, check_server_auth: ServerAuthCheck, ) -> Self672 pub fn new( 673 public_addr: SocketAddr, 674 handler: S, 675 check_client_auth: Box<dyn ClientAuthCheck>, 676 check_server_auth: ServerAuthCheck, 677 ) -> Self { 678 Self { 679 public_addr, 680 handler, 681 check_client_auth, 682 check_server_auth, 683 } 684 } 685 start(self) -> Result<Void>686 pub fn start(self) -> Result<Void> { 687 let Self { 688 public_addr, 689 handler, 690 check_client_auth, 691 check_server_auth, 692 } = self; 693 let requester = SchedulerRequester { 694 client: Mutex::new(reqwest::Client::new()), 695 }; 696 697 macro_rules! check_server_auth_or_err { 698 ($request:ident) => {{ 699 match bearer_http_auth($request).and_then(&*check_server_auth) { 700 Some(server_id) => { 701 let origin_ip = if let Some(header_val) = $request.header("X-Real-IP") { 702 trace!("X-Real-IP: {:?}", header_val); 703 match header_val.parse() { 704 Ok(ip) => ip, 705 Err(err) => { 706 warn!( 707 "X-Real-IP value {:?} could not be parsed: {:?}", 708 header_val, err 709 ); 710 return rouille::Response::empty_400(); 711 } 712 } 713 } else { 714 $request.remote_addr().ip() 715 }; 716 if server_id.addr().ip() != origin_ip { 717 trace!("server ip: {:?}", server_id.addr().ip()); 718 trace!("request ip: {:?}", $request.remote_addr().ip()); 719 return make_401("invalid_bearer_token_mismatched_address"); 720 } else { 721 server_id 722 } 723 } 724 None => return make_401("invalid_bearer_token"), 725 } 726 }}; 727 } 728 729 fn maybe_update_certs( 730 client: &mut reqwest::Client, 731 certs: &mut HashMap<ServerId, (Vec<u8>, Vec<u8>)>, 732 server_id: ServerId, 733 cert_digest: Vec<u8>, 734 cert_pem: Vec<u8>, 735 ) -> Result<()> { 736 if let Some((saved_cert_digest, _)) = certs.get(&server_id) { 737 if saved_cert_digest == &cert_digest { 738 return Ok(()); 739 } 740 } 741 info!( 742 "Adding new certificate for {} to scheduler", 743 server_id.addr() 744 ); 745 let mut client_builder = reqwest::ClientBuilder::new(); 746 // Add all the certificates we know about 747 client_builder = client_builder.add_root_certificate( 748 reqwest::Certificate::from_pem(&cert_pem) 749 .context("failed to interpret pem as certificate")?, 750 ); 751 for (_, cert_pem) in certs.values() { 752 client_builder = client_builder.add_root_certificate( 753 reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert"), 754 ); 755 } 756 // Finish the clients 757 let new_client = client_builder 758 .build() 759 .context("failed to create a HTTP client")?; 760 // Use the updated certificates 761 *client = new_client; 762 certs.insert(server_id, (cert_digest, cert_pem)); 763 Ok(()) 764 } 765 766 info!("Scheduler listening for clients on {}", public_addr); 767 let request_count = atomic::AtomicUsize::new(0); 768 // From server_id -> cert_digest, cert_pem 769 let server_certificates: Mutex<HashMap<ServerId, (Vec<u8>, Vec<u8>)>> = 770 Default::default(); 771 772 let server = rouille::Server::new(public_addr, move |request| { 773 let req_id = request_count.fetch_add(1, atomic::Ordering::SeqCst); 774 trace!("Req {} ({}): {:?}", req_id, request.remote_addr(), request); 775 let response = (|| router!(request, 776 (POST) (/api/v1/scheduler/alloc_job) => { 777 let bearer_auth = match bearer_http_auth(request) { 778 Some(s) => s, 779 None => return make_401("no_bearer_auth"), 780 }; 781 match check_client_auth.check(bearer_auth) { 782 Ok(()) => (), 783 Err(client_msg) => { 784 warn!("Bearer auth failed: {:?}", client_msg); 785 return make_401_with_body("bearer_auth_failed", client_msg) 786 }, 787 } 788 let toolchain = try_or_400_log!(req_id, bincode_input(request)); 789 trace!("Req {}: alloc_job: {:?}", req_id, toolchain); 790 791 let alloc_job_res: AllocJobResult = try_or_500_log!(req_id, handler.handle_alloc_job(&requester, toolchain)); 792 let certs = server_certificates.lock().unwrap(); 793 let res = AllocJobHttpResponse::from_alloc_job_result(alloc_job_res, &certs); 794 prepare_response(&request, &res) 795 }, 796 (GET) (/api/v1/scheduler/server_certificate/{server_id: ServerId}) => { 797 let certs = server_certificates.lock().unwrap(); 798 let (cert_digest, cert_pem) = try_or_500_log!(req_id, certs.get(&server_id) 799 .context("server cert not available")); 800 let res = ServerCertificateHttpResponse { 801 cert_digest: cert_digest.clone(), 802 cert_pem: cert_pem.clone(), 803 }; 804 prepare_response(&request, &res) 805 }, 806 (POST) (/api/v1/scheduler/heartbeat_server) => { 807 let server_id = check_server_auth_or_err!(request); 808 let heartbeat_server = try_or_400_log!(req_id, bincode_input(request)); 809 trace!("Req {}: heartbeat_server: {:?}", req_id, heartbeat_server); 810 811 let HeartbeatServerHttpRequest { num_cpus, jwt_key, server_nonce, cert_digest, cert_pem } = heartbeat_server; 812 try_or_500_log!(req_id, maybe_update_certs( 813 &mut requester.client.lock().unwrap(), 814 &mut server_certificates.lock().unwrap(), 815 server_id, cert_digest, cert_pem 816 )); 817 let job_authorizer = JWTJobAuthorizer::new(jwt_key); 818 let res: HeartbeatServerResult = try_or_500_log!(req_id, handler.handle_heartbeat_server( 819 server_id, server_nonce, 820 num_cpus, 821 job_authorizer 822 )); 823 prepare_response(&request, &res) 824 }, 825 (POST) (/api/v1/scheduler/job_state/{job_id: JobId}) => { 826 let server_id = check_server_auth_or_err!(request); 827 let job_state = try_or_400_log!(req_id, bincode_input(request)); 828 trace!("Req {}: job state: {:?}", req_id, job_state); 829 830 let res: UpdateJobStateResult = try_or_500_log!(req_id, handler.handle_update_job_state( 831 job_id, server_id, job_state 832 )); 833 prepare_response(&request, &res) 834 }, 835 (GET) (/api/v1/scheduler/status) => { 836 let res: SchedulerStatusResult = try_or_500_log!(req_id, handler.handle_status()); 837 prepare_response(&request, &res) 838 }, 839 _ => { 840 warn!("Unknown request {:?}", request); 841 rouille::Response::empty_404() 842 }, 843 ))(); 844 trace!("Res {}: {:?}", req_id, response); 845 response 846 }).map_err(|e| anyhow!(format!("Failed to start http server for sccache scheduler: {}", e)))?; 847 848 // This limit is rouille's default for `start_server_with_pool`, which 849 // we would use, except that interface doesn't permit any sort of 850 // error handling to be done. 851 let server = server.pool_size(num_cpus::get() * 8); 852 server.run(); 853 854 panic!("Rouille server terminated") 855 } 856 } 857 858 struct SchedulerRequester { 859 client: Mutex<reqwest::Client>, 860 } 861 862 impl dist::SchedulerOutgoing for SchedulerRequester { do_assign_job( &self, server_id: ServerId, job_id: JobId, tc: Toolchain, auth: String, ) -> Result<AssignJobResult>863 fn do_assign_job( 864 &self, 865 server_id: ServerId, 866 job_id: JobId, 867 tc: Toolchain, 868 auth: String, 869 ) -> Result<AssignJobResult> { 870 let url = urls::server_assign_job(server_id, job_id); 871 let req = self.client.lock().unwrap().post(url); 872 bincode_req(req.bearer_auth(auth).bincode(&tc)?) 873 .context("POST to scheduler assign_job failed") 874 } 875 } 876 877 pub struct Server<S> { 878 public_addr: SocketAddr, 879 scheduler_url: reqwest::Url, 880 scheduler_auth: String, 881 // HTTPS pieces all the builders will use for connection encryption 882 cert_digest: Vec<u8>, 883 cert_pem: Vec<u8>, 884 privkey_pem: Vec<u8>, 885 // Key used to sign any requests relating to jobs 886 jwt_key: Vec<u8>, 887 // Randomly generated nonce to allow the scheduler to detect server restarts 888 server_nonce: ServerNonce, 889 handler: S, 890 } 891 892 impl<S: dist::ServerIncoming + 'static> Server<S> { new( public_addr: SocketAddr, scheduler_url: reqwest::Url, scheduler_auth: String, handler: S, ) -> Result<Self>893 pub fn new( 894 public_addr: SocketAddr, 895 scheduler_url: reqwest::Url, 896 scheduler_auth: String, 897 handler: S, 898 ) -> Result<Self> { 899 let (cert_digest, cert_pem, privkey_pem) = 900 create_https_cert_and_privkey(public_addr) 901 .context("failed to create HTTPS certificate for server")?; 902 let mut jwt_key = vec![0; JWT_KEY_LENGTH]; 903 OsRng.fill_bytes(&mut jwt_key); 904 let server_nonce = ServerNonce::new(); 905 906 Ok(Self { 907 public_addr, 908 scheduler_url, 909 scheduler_auth, 910 cert_digest, 911 cert_pem, 912 privkey_pem, 913 jwt_key, 914 server_nonce, 915 handler, 916 }) 917 } 918 start(self) -> Result<Void>919 pub fn start(self) -> Result<Void> { 920 let Self { 921 public_addr, 922 scheduler_url, 923 scheduler_auth, 924 cert_digest, 925 cert_pem, 926 privkey_pem, 927 jwt_key, 928 server_nonce, 929 handler, 930 } = self; 931 let heartbeat_req = HeartbeatServerHttpRequest { 932 num_cpus: num_cpus::get(), 933 jwt_key: jwt_key.clone(), 934 server_nonce, 935 cert_digest, 936 cert_pem: cert_pem.clone(), 937 }; 938 let job_authorizer = JWTJobAuthorizer::new(jwt_key); 939 let heartbeat_url = urls::scheduler_heartbeat_server(&scheduler_url); 940 let requester = ServerRequester { 941 client: reqwest::Client::new(), 942 scheduler_url, 943 scheduler_auth: scheduler_auth.clone(), 944 }; 945 946 // TODO: detect if this panics 947 thread::spawn(move || { 948 let client = reqwest::Client::new(); 949 loop { 950 trace!("Performing heartbeat"); 951 match bincode_req( 952 client 953 .post(heartbeat_url.clone()) 954 .bearer_auth(scheduler_auth.clone()) 955 .bincode(&heartbeat_req) 956 .expect("failed to serialize heartbeat"), 957 ) { 958 Ok(HeartbeatServerResult { is_new }) => { 959 trace!("Heartbeat success is_new={}", is_new); 960 // TODO: if is_new, terminate all running jobs 961 thread::sleep(HEARTBEAT_INTERVAL) 962 } 963 Err(e) => { 964 error!("Failed to send heartbeat to server: {}", e); 965 thread::sleep(HEARTBEAT_ERROR_INTERVAL) 966 } 967 } 968 } 969 }); 970 971 info!("Server listening for clients on {}", public_addr); 972 let request_count = atomic::AtomicUsize::new(0); 973 974 let server = rouille::Server::new_ssl(public_addr, move |request| { 975 let req_id = request_count.fetch_add(1, atomic::Ordering::SeqCst); 976 trace!("Req {} ({}): {:?}", req_id, request.remote_addr(), request); 977 let response = (|| router!(request, 978 (POST) (/api/v1/distserver/assign_job/{job_id: JobId}) => { 979 job_auth_or_401!(request, &job_authorizer, job_id); 980 let toolchain = try_or_400_log!(req_id, bincode_input(request)); 981 trace!("Req {}: assign_job({}): {:?}", req_id, job_id, toolchain); 982 983 let res: AssignJobResult = try_or_500_log!(req_id, handler.handle_assign_job(job_id, toolchain)); 984 prepare_response(&request, &res) 985 }, 986 (POST) (/api/v1/distserver/submit_toolchain/{job_id: JobId}) => { 987 job_auth_or_401!(request, &job_authorizer, job_id); 988 trace!("Req {}: submit_toolchain({})", req_id, job_id); 989 990 let body = request.data().expect("body was already read in submit_toolchain"); 991 let toolchain_rdr = ToolchainReader(Box::new(body)); 992 let res: SubmitToolchainResult = try_or_500_log!(req_id, handler.handle_submit_toolchain(&requester, job_id, toolchain_rdr)); 993 prepare_response(&request, &res) 994 }, 995 (POST) (/api/v1/distserver/run_job/{job_id: JobId}) => { 996 job_auth_or_401!(request, &job_authorizer, job_id); 997 998 let mut body = request.data().expect("body was already read in run_job"); 999 let bincode_length = try_or_500_log!(req_id, body.read_u32::<BigEndian>() 1000 .context("failed to read run job input length")) as u64; 1001 1002 let mut bincode_reader = body.take(bincode_length); 1003 let runjob = try_or_500_log!(req_id, bincode::deserialize_from(&mut bincode_reader) 1004 .context("failed to deserialize run job request")); 1005 trace!("Req {}: run_job({}): {:?}", req_id, job_id, runjob); 1006 let RunJobHttpRequest { command, outputs } = runjob; 1007 let body = bincode_reader.into_inner(); 1008 let inputs_rdr = InputsReader(Box::new(ZlibReadDecoder::new(body))); 1009 let outputs = outputs.into_iter().collect(); 1010 1011 let res: RunJobResult = try_or_500_log!(req_id, handler.handle_run_job(&requester, job_id, command, outputs, inputs_rdr)); 1012 prepare_response(&request, &res) 1013 }, 1014 _ => { 1015 warn!("Unknown request {:?}", request); 1016 rouille::Response::empty_404() 1017 }, 1018 ))(); 1019 trace!("Res {}: {:?}", req_id, response); 1020 response 1021 }, cert_pem, privkey_pem).map_err(|e| anyhow!(format!("Failed to start http server for sccache server: {}", e)))?; 1022 1023 // This limit is rouille's default for `start_server_with_pool`, which 1024 // we would use, except that interface doesn't permit any sort of 1025 // error handling to be done. 1026 let server = server.pool_size(num_cpus::get() * 8); 1027 server.run(); 1028 1029 panic!("Rouille server terminated") 1030 } 1031 } 1032 1033 struct ServerRequester { 1034 client: reqwest::Client, 1035 scheduler_url: reqwest::Url, 1036 scheduler_auth: String, 1037 } 1038 1039 impl dist::ServerOutgoing for ServerRequester { do_update_job_state( &self, job_id: JobId, state: JobState, ) -> Result<UpdateJobStateResult>1040 fn do_update_job_state( 1041 &self, 1042 job_id: JobId, 1043 state: JobState, 1044 ) -> Result<UpdateJobStateResult> { 1045 let url = urls::scheduler_job_state(&self.scheduler_url, job_id); 1046 bincode_req( 1047 self.client 1048 .post(url) 1049 .bearer_auth(self.scheduler_auth.clone()) 1050 .bincode(&state)?, 1051 ) 1052 .context("POST to scheduler job_state failed") 1053 } 1054 } 1055 } 1056 1057 #[cfg(feature = "dist-client")] 1058 mod client { 1059 use super::super::cache; 1060 use crate::config; 1061 use crate::dist::pkg::{InputsPackager, ToolchainPackager}; 1062 use crate::dist::{ 1063 self, AllocJobResult, CompileCommand, JobAlloc, PathTransformer, RunJobResult, 1064 SchedulerStatusResult, SubmitToolchainResult, Toolchain, 1065 }; 1066 use crate::util::SpawnExt; 1067 use byteorder::{BigEndian, WriteBytesExt}; 1068 use flate2::write::ZlibEncoder as ZlibWriteEncoder; 1069 use flate2::Compression; 1070 use futures::Future; 1071 use futures_03::executor::ThreadPool; 1072 use std::collections::HashMap; 1073 use std::io::Write; 1074 use std::path::{Path, PathBuf}; 1075 use std::sync::{Arc, Mutex}; 1076 use std::time::Duration; 1077 1078 use super::common::{ 1079 bincode_req, bincode_req_fut, AllocJobHttpResponse, ReqwestRequestBuilderExt, 1080 RunJobHttpRequest, ServerCertificateHttpResponse, 1081 }; 1082 use super::urls; 1083 use crate::errors::*; 1084 1085 const REQUEST_TIMEOUT_SECS: u64 = 600; 1086 const CONNECT_TIMEOUT_SECS: u64 = 5; 1087 1088 pub struct Client { 1089 auth_token: String, 1090 scheduler_url: reqwest::Url, 1091 // cert_digest -> cert_pem 1092 server_certs: Arc<Mutex<HashMap<Vec<u8>, Vec<u8>>>>, 1093 // TODO: this should really only use the async client, but reqwest async bodies are extremely limited 1094 // and only support owned bytes, which means the whole toolchain would end up in memory 1095 client: Arc<Mutex<reqwest::Client>>, 1096 client_async: Arc<Mutex<reqwest::r#async::Client>>, 1097 pool: ThreadPool, 1098 tc_cache: Arc<cache::ClientToolchains>, 1099 rewrite_includes_only: bool, 1100 } 1101 1102 impl Client { new( pool: &ThreadPool, scheduler_url: reqwest::Url, cache_dir: &Path, cache_size: u64, toolchain_configs: &[config::DistToolchainConfig], auth_token: String, rewrite_includes_only: bool, ) -> Result<Self>1103 pub fn new( 1104 pool: &ThreadPool, 1105 scheduler_url: reqwest::Url, 1106 cache_dir: &Path, 1107 cache_size: u64, 1108 toolchain_configs: &[config::DistToolchainConfig], 1109 auth_token: String, 1110 rewrite_includes_only: bool, 1111 ) -> Result<Self> { 1112 let timeout = Duration::new(REQUEST_TIMEOUT_SECS, 0); 1113 let connect_timeout = Duration::new(CONNECT_TIMEOUT_SECS, 0); 1114 let client = reqwest::ClientBuilder::new() 1115 .timeout(timeout) 1116 .connect_timeout(connect_timeout) 1117 .build() 1118 .context("failed to create a HTTP client")?; 1119 let client_async = reqwest::r#async::ClientBuilder::new() 1120 .timeout(timeout) 1121 .connect_timeout(connect_timeout) 1122 .build() 1123 .context("failed to create an async HTTP client")?; 1124 let client_toolchains = 1125 cache::ClientToolchains::new(cache_dir, cache_size, toolchain_configs) 1126 .context("failed to initialise client toolchains")?; 1127 Ok(Self { 1128 auth_token, 1129 scheduler_url, 1130 server_certs: Default::default(), 1131 client: Arc::new(Mutex::new(client)), 1132 client_async: Arc::new(Mutex::new(client_async)), 1133 pool: pool.clone(), 1134 tc_cache: Arc::new(client_toolchains), 1135 rewrite_includes_only, 1136 }) 1137 } 1138 update_certs( client: &mut reqwest::Client, client_async: &mut reqwest::r#async::Client, certs: &mut HashMap<Vec<u8>, Vec<u8>>, cert_digest: Vec<u8>, cert_pem: Vec<u8>, ) -> Result<()>1139 fn update_certs( 1140 client: &mut reqwest::Client, 1141 client_async: &mut reqwest::r#async::Client, 1142 certs: &mut HashMap<Vec<u8>, Vec<u8>>, 1143 cert_digest: Vec<u8>, 1144 cert_pem: Vec<u8>, 1145 ) -> Result<()> { 1146 let mut client_builder = reqwest::ClientBuilder::new(); 1147 let mut client_async_builder = reqwest::r#async::ClientBuilder::new(); 1148 // Add all the certificates we know about 1149 client_builder = client_builder.add_root_certificate( 1150 reqwest::Certificate::from_pem(&cert_pem) 1151 .context("failed to interpret pem as certificate")?, 1152 ); 1153 client_async_builder = client_async_builder.add_root_certificate( 1154 reqwest::Certificate::from_pem(&cert_pem) 1155 .context("failed to interpret pem as certificate")?, 1156 ); 1157 for cert_pem in certs.values() { 1158 client_builder = client_builder.add_root_certificate( 1159 reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert"), 1160 ); 1161 client_async_builder = client_async_builder.add_root_certificate( 1162 reqwest::Certificate::from_pem(cert_pem).expect("previously valid cert"), 1163 ); 1164 } 1165 // Finish the clients 1166 let timeout = Duration::new(REQUEST_TIMEOUT_SECS, 0); 1167 let new_client = client_builder 1168 .timeout(timeout) 1169 .build() 1170 .context("failed to create a HTTP client")?; 1171 let new_client_async = client_async_builder 1172 .timeout(timeout) 1173 .build() 1174 .context("failed to create an async HTTP client")?; 1175 // Use the updated certificates 1176 *client = new_client; 1177 *client_async = new_client_async; 1178 certs.insert(cert_digest, cert_pem); 1179 Ok(()) 1180 } 1181 } 1182 1183 impl dist::Client for Client { do_alloc_job(&self, tc: Toolchain) -> SFuture<AllocJobResult>1184 fn do_alloc_job(&self, tc: Toolchain) -> SFuture<AllocJobResult> { 1185 let scheduler_url = self.scheduler_url.clone(); 1186 let url = urls::scheduler_alloc_job(&scheduler_url); 1187 let mut req = self.client_async.lock().unwrap().post(url); 1188 req = ftry!(req.bearer_auth(self.auth_token.clone()).bincode(&tc)); 1189 1190 let client = self.client.clone(); 1191 let client_async = self.client_async.clone(); 1192 let server_certs = self.server_certs.clone(); 1193 Box::new(bincode_req_fut(req).and_then(move |res| match res { 1194 AllocJobHttpResponse::Success { 1195 job_alloc, 1196 need_toolchain, 1197 cert_digest, 1198 } => { 1199 let server_id = job_alloc.server_id; 1200 let alloc_job_res = f_ok(AllocJobResult::Success { 1201 job_alloc, 1202 need_toolchain, 1203 }); 1204 if server_certs.lock().unwrap().contains_key(&cert_digest) { 1205 return alloc_job_res; 1206 } 1207 info!( 1208 "Need to request new certificate for server {}", 1209 server_id.addr() 1210 ); 1211 let url = urls::scheduler_server_certificate(&scheduler_url, server_id); 1212 let req = client_async.lock().unwrap().get(url); 1213 Box::new( 1214 bincode_req_fut(req) 1215 .map_err(|e| e.context("GET to scheduler server_certificate failed")) 1216 .and_then(move |res: ServerCertificateHttpResponse| { 1217 ftry!(Self::update_certs( 1218 &mut client.lock().unwrap(), 1219 &mut client_async.lock().unwrap(), 1220 &mut server_certs.lock().unwrap(), 1221 res.cert_digest, 1222 res.cert_pem, 1223 )); 1224 alloc_job_res 1225 }), 1226 ) 1227 } 1228 AllocJobHttpResponse::Fail { msg } => f_ok(AllocJobResult::Fail { msg }), 1229 })) 1230 } do_get_status(&self) -> SFuture<SchedulerStatusResult>1231 fn do_get_status(&self) -> SFuture<SchedulerStatusResult> { 1232 let scheduler_url = self.scheduler_url.clone(); 1233 let url = urls::scheduler_status(&scheduler_url); 1234 let req = self.client.lock().unwrap().get(url); 1235 Box::new(self.pool.spawn_fn(move || bincode_req(req))) 1236 } do_submit_toolchain( &self, job_alloc: JobAlloc, tc: Toolchain, ) -> SFuture<SubmitToolchainResult>1237 fn do_submit_toolchain( 1238 &self, 1239 job_alloc: JobAlloc, 1240 tc: Toolchain, 1241 ) -> SFuture<SubmitToolchainResult> { 1242 match self.tc_cache.get_toolchain(&tc) { 1243 Ok(Some(toolchain_file)) => { 1244 let url = urls::server_submit_toolchain(job_alloc.server_id, job_alloc.job_id); 1245 let req = self.client.lock().unwrap().post(url); 1246 1247 Box::new(self.pool.spawn_fn(move || { 1248 let req = req.bearer_auth(job_alloc.auth.clone()).body(toolchain_file); 1249 bincode_req(req) 1250 })) 1251 } 1252 Ok(None) => f_err(anyhow!("couldn't find toolchain locally")), 1253 Err(e) => f_err(e), 1254 } 1255 } do_run_job( &self, job_alloc: JobAlloc, command: CompileCommand, outputs: Vec<String>, inputs_packager: Box<dyn InputsPackager>, ) -> SFuture<(RunJobResult, PathTransformer)>1256 fn do_run_job( 1257 &self, 1258 job_alloc: JobAlloc, 1259 command: CompileCommand, 1260 outputs: Vec<String>, 1261 inputs_packager: Box<dyn InputsPackager>, 1262 ) -> SFuture<(RunJobResult, PathTransformer)> { 1263 let url = urls::server_run_job(job_alloc.server_id, job_alloc.job_id); 1264 let mut req = self.client.lock().unwrap().post(url); 1265 1266 Box::new(self.pool.spawn_fn(move || { 1267 let bincode = bincode::serialize(&RunJobHttpRequest { command, outputs }) 1268 .context("failed to serialize run job request")?; 1269 let bincode_length = bincode.len(); 1270 1271 let mut body = vec![]; 1272 body.write_u32::<BigEndian>(bincode_length as u32) 1273 .expect("Infallible write of bincode length to vec failed"); 1274 body.write_all(&bincode) 1275 .expect("Infallible write of bincode body to vec failed"); 1276 let path_transformer; 1277 { 1278 let mut compressor = ZlibWriteEncoder::new(&mut body, Compression::fast()); 1279 path_transformer = inputs_packager 1280 .write_inputs(&mut compressor) 1281 .context("Could not write inputs for compilation")?; 1282 compressor.flush().context("failed to flush compressor")?; 1283 trace!( 1284 "Compressed inputs from {} -> {}", 1285 compressor.total_in(), 1286 compressor.total_out() 1287 ); 1288 compressor.finish().context("failed to finish compressor")?; 1289 } 1290 1291 req = req.bearer_auth(job_alloc.auth.clone()).bytes(body); 1292 bincode_req(req).map(|res| (res, path_transformer)) 1293 })) 1294 } 1295 put_toolchain( &self, compiler_path: &Path, weak_key: &str, toolchain_packager: Box<dyn ToolchainPackager>, ) -> SFuture<(Toolchain, Option<(String, PathBuf)>)>1296 fn put_toolchain( 1297 &self, 1298 compiler_path: &Path, 1299 weak_key: &str, 1300 toolchain_packager: Box<dyn ToolchainPackager>, 1301 ) -> SFuture<(Toolchain, Option<(String, PathBuf)>)> { 1302 let compiler_path = compiler_path.to_owned(); 1303 let weak_key = weak_key.to_owned(); 1304 let tc_cache = self.tc_cache.clone(); 1305 Box::new(self.pool.spawn_fn(move || { 1306 tc_cache.put_toolchain(&compiler_path, &weak_key, toolchain_packager) 1307 })) 1308 } 1309 rewrite_includes_only(&self) -> bool1310 fn rewrite_includes_only(&self) -> bool { 1311 self.rewrite_includes_only 1312 } get_custom_toolchain(&self, exe: &PathBuf) -> Option<PathBuf>1313 fn get_custom_toolchain(&self, exe: &PathBuf) -> Option<PathBuf> { 1314 match self.tc_cache.get_custom_toolchain(exe) { 1315 Some(Ok((_, _, path))) => Some(path), 1316 _ => None, 1317 } 1318 } 1319 } 1320 } 1321