1 use libc::{c_char, c_int, c_uint, c_void, size_t};
2 use std::ffi::{CStr, CString};
3 use std::mem;
4 use std::ptr;
5 use std::slice;
6 use std::str;
7 
8 use crate::cert::Cert;
9 use crate::util::Binding;
10 use crate::{
11     panic, raw, Cred, CredentialType, Error, IndexerProgress, Oid, PackBuilderStage, Progress,
12 };
13 
14 /// A structure to contain the callbacks which are invoked when a repository is
15 /// being updated or downloaded.
16 ///
17 /// These callbacks are used to manage facilities such as authentication,
18 /// transfer progress, etc.
19 pub struct RemoteCallbacks<'a> {
20     push_progress: Option<Box<PushTransferProgress<'a>>>,
21     progress: Option<Box<IndexerProgress<'a>>>,
22     pack_progress: Option<Box<PackProgress<'a>>>,
23     credentials: Option<Box<Credentials<'a>>>,
24     sideband_progress: Option<Box<TransportMessage<'a>>>,
25     update_tips: Option<Box<UpdateTips<'a>>>,
26     certificate_check: Option<Box<CertificateCheck<'a>>>,
27     push_update_reference: Option<Box<PushUpdateReference<'a>>>,
28 }
29 
30 /// Callback used to acquire credentials for when a remote is fetched.
31 ///
32 /// * `url` - the resource for which the credentials are required.
33 /// * `username_from_url` - the username that was embedded in the url, or `None`
34 ///                         if it was not included.
35 /// * `allowed_types` - a bitmask stating which cred types are ok to return.
36 pub type Credentials<'a> =
37     dyn FnMut(&str, Option<&str>, CredentialType) -> Result<Cred, Error> + 'a;
38 
39 /// Callback for receiving messages delivered by the transport.
40 ///
41 /// The return value indicates whether the network operation should continue.
42 pub type TransportMessage<'a> = dyn FnMut(&[u8]) -> bool + 'a;
43 
44 /// Callback for whenever a reference is updated locally.
45 pub type UpdateTips<'a> = dyn FnMut(&str, Oid, Oid) -> bool + 'a;
46 
47 /// Callback for a custom certificate check.
48 ///
49 /// The first argument is the certificate receved on the connection.
50 /// Certificates are typically either an SSH or X509 certificate.
51 ///
52 /// The second argument is the hostname for the connection is passed as the last
53 /// argument.
54 pub type CertificateCheck<'a> = dyn FnMut(&Cert<'_>, &str) -> bool + 'a;
55 
56 /// Callback for each updated reference on push.
57 ///
58 /// The first argument here is the `refname` of the reference, and the second is
59 /// the status message sent by a server. If the status is `Some` then the update
60 /// was rejected by the remote server with a reason why.
61 pub type PushUpdateReference<'a> = dyn FnMut(&str, Option<&str>) -> Result<(), Error> + 'a;
62 
63 /// Callback for push transfer progress
64 ///
65 /// Parameters:
66 ///     * current
67 ///     * total
68 ///     * bytes
69 pub type PushTransferProgress<'a> = dyn FnMut(usize, usize, usize) + 'a;
70 
71 /// Callback for pack progress
72 ///
73 /// Parameters:
74 ///     * stage
75 ///     * current
76 ///     * total
77 pub type PackProgress<'a> = dyn FnMut(PackBuilderStage, usize, usize) + 'a;
78 
79 impl<'a> Default for RemoteCallbacks<'a> {
default() -> Self80     fn default() -> Self {
81         Self::new()
82     }
83 }
84 
85 impl<'a> RemoteCallbacks<'a> {
86     /// Creates a new set of empty callbacks
new() -> RemoteCallbacks<'a>87     pub fn new() -> RemoteCallbacks<'a> {
88         RemoteCallbacks {
89             credentials: None,
90             progress: None,
91             pack_progress: None,
92             sideband_progress: None,
93             update_tips: None,
94             certificate_check: None,
95             push_update_reference: None,
96             push_progress: None,
97         }
98     }
99 
100     /// The callback through which to fetch credentials if required.
101     ///
102     /// # Example
103     ///
104     /// Prepare a callback to authenticate using the `$HOME/.ssh/id_rsa` SSH key, and
105     /// extracting the username from the URL (i.e. git@github.com:rust-lang/git2-rs.git):
106     ///
107     /// ```no_run
108     /// use git2::{Cred, RemoteCallbacks};
109     /// use std::env;
110     ///
111     /// let mut callbacks = RemoteCallbacks::new();
112     /// callbacks.credentials(|_url, username_from_url, _allowed_types| {
113     ///   Cred::ssh_key(
114     ///     username_from_url.unwrap(),
115     ///     None,
116     ///     std::path::Path::new(&format!("{}/.ssh/id_rsa", env::var("HOME").unwrap())),
117     ///     None,
118     ///   )
119     /// });
120     /// ```
credentials<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a> where F: FnMut(&str, Option<&str>, CredentialType) -> Result<Cred, Error> + 'a,121     pub fn credentials<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a>
122     where
123         F: FnMut(&str, Option<&str>, CredentialType) -> Result<Cred, Error> + 'a,
124     {
125         self.credentials = Some(Box::new(cb) as Box<Credentials<'a>>);
126         self
127     }
128 
129     /// The callback through which progress is monitored.
transfer_progress<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a> where F: FnMut(Progress<'_>) -> bool + 'a,130     pub fn transfer_progress<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a>
131     where
132         F: FnMut(Progress<'_>) -> bool + 'a,
133     {
134         self.progress = Some(Box::new(cb) as Box<IndexerProgress<'a>>);
135         self
136     }
137 
138     /// Textual progress from the remote.
139     ///
140     /// Text sent over the progress side-band will be passed to this function
141     /// (this is the 'counting objects' output).
sideband_progress<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a> where F: FnMut(&[u8]) -> bool + 'a,142     pub fn sideband_progress<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a>
143     where
144         F: FnMut(&[u8]) -> bool + 'a,
145     {
146         self.sideband_progress = Some(Box::new(cb) as Box<TransportMessage<'a>>);
147         self
148     }
149 
150     /// Each time a reference is updated locally, the callback will be called
151     /// with information about it.
update_tips<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a> where F: FnMut(&str, Oid, Oid) -> bool + 'a,152     pub fn update_tips<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a>
153     where
154         F: FnMut(&str, Oid, Oid) -> bool + 'a,
155     {
156         self.update_tips = Some(Box::new(cb) as Box<UpdateTips<'a>>);
157         self
158     }
159 
160     /// If certificate verification fails, then this callback will be invoked to
161     /// let the caller make the final decision of whether to allow the
162     /// connection to proceed.
certificate_check<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a> where F: FnMut(&Cert<'_>, &str) -> bool + 'a,163     pub fn certificate_check<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a>
164     where
165         F: FnMut(&Cert<'_>, &str) -> bool + 'a,
166     {
167         self.certificate_check = Some(Box::new(cb) as Box<CertificateCheck<'a>>);
168         self
169     }
170 
171     /// Set a callback to get invoked for each updated reference on a push.
172     ///
173     /// The first argument to the callback is the name of the reference and the
174     /// second is a status message sent by the server. If the status is `Some`
175     /// then the push was rejected.
push_update_reference<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a> where F: FnMut(&str, Option<&str>) -> Result<(), Error> + 'a,176     pub fn push_update_reference<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a>
177     where
178         F: FnMut(&str, Option<&str>) -> Result<(), Error> + 'a,
179     {
180         self.push_update_reference = Some(Box::new(cb) as Box<PushUpdateReference<'a>>);
181         self
182     }
183 
184     /// The callback through which progress of push transfer is monitored
push_transfer_progress<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a> where F: FnMut(usize, usize, usize) + 'a,185     pub fn push_transfer_progress<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a>
186     where
187         F: FnMut(usize, usize, usize) + 'a,
188     {
189         self.push_progress = Some(Box::new(cb) as Box<PushTransferProgress<'a>>);
190         self
191     }
192 
193     /// Function to call with progress information during pack building.
194     /// Be aware that this is called inline with pack building operations,
195     /// so performance may be affected.
pack_progress<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a> where F: FnMut(PackBuilderStage, usize, usize) + 'a,196     pub fn pack_progress<F>(&mut self, cb: F) -> &mut RemoteCallbacks<'a>
197     where
198         F: FnMut(PackBuilderStage, usize, usize) + 'a,
199     {
200         self.pack_progress = Some(Box::new(cb) as Box<PackProgress<'a>>);
201         self
202     }
203 }
204 
205 impl<'a> Binding for RemoteCallbacks<'a> {
206     type Raw = raw::git_remote_callbacks;
from_raw(_raw: raw::git_remote_callbacks) -> RemoteCallbacks<'a>207     unsafe fn from_raw(_raw: raw::git_remote_callbacks) -> RemoteCallbacks<'a> {
208         panic!("unimplemented");
209     }
210 
raw(&self) -> raw::git_remote_callbacks211     fn raw(&self) -> raw::git_remote_callbacks {
212         unsafe {
213             let mut callbacks: raw::git_remote_callbacks = mem::zeroed();
214             assert_eq!(
215                 raw::git_remote_init_callbacks(&mut callbacks, raw::GIT_REMOTE_CALLBACKS_VERSION),
216                 0
217             );
218             if self.progress.is_some() {
219                 callbacks.transfer_progress = Some(transfer_progress_cb);
220             }
221             if self.credentials.is_some() {
222                 callbacks.credentials = Some(credentials_cb);
223             }
224             if self.sideband_progress.is_some() {
225                 callbacks.sideband_progress = Some(sideband_progress_cb);
226             }
227             if self.certificate_check.is_some() {
228                 callbacks.certificate_check = Some(certificate_check_cb);
229             }
230             if self.push_update_reference.is_some() {
231                 callbacks.push_update_reference = Some(push_update_reference_cb);
232             }
233             if self.push_progress.is_some() {
234                 callbacks.push_transfer_progress = Some(push_transfer_progress_cb);
235             }
236             if self.pack_progress.is_some() {
237                 callbacks.pack_progress = Some(pack_progress_cb);
238             }
239             if self.update_tips.is_some() {
240                 let f: extern "C" fn(
241                     *const c_char,
242                     *const raw::git_oid,
243                     *const raw::git_oid,
244                     *mut c_void,
245                 ) -> c_int = update_tips_cb;
246                 callbacks.update_tips = Some(f);
247             }
248             callbacks.payload = self as *const _ as *mut _;
249             callbacks
250         }
251     }
252 }
253 
credentials_cb( ret: *mut *mut raw::git_cred, url: *const c_char, username_from_url: *const c_char, allowed_types: c_uint, payload: *mut c_void, ) -> c_int254 extern "C" fn credentials_cb(
255     ret: *mut *mut raw::git_cred,
256     url: *const c_char,
257     username_from_url: *const c_char,
258     allowed_types: c_uint,
259     payload: *mut c_void,
260 ) -> c_int {
261     unsafe {
262         let ok = panic::wrap(|| {
263             let payload = &mut *(payload as *mut RemoteCallbacks<'_>);
264             let callback = payload
265                 .credentials
266                 .as_mut()
267                 .ok_or(raw::GIT_PASSTHROUGH as c_int)?;
268             *ret = ptr::null_mut();
269             let url = str::from_utf8(CStr::from_ptr(url).to_bytes())
270                 .map_err(|_| raw::GIT_PASSTHROUGH as c_int)?;
271             let username_from_url = match crate::opt_bytes(&url, username_from_url) {
272                 Some(username) => {
273                     Some(str::from_utf8(username).map_err(|_| raw::GIT_PASSTHROUGH as c_int)?)
274                 }
275                 None => None,
276             };
277 
278             let cred_type = CredentialType::from_bits_truncate(allowed_types as u32);
279 
280             callback(url, username_from_url, cred_type).map_err(|e| {
281                 let s = CString::new(e.to_string()).unwrap();
282                 raw::git_error_set_str(e.raw_code() as c_int, s.as_ptr());
283                 e.raw_code() as c_int
284             })
285         });
286         match ok {
287             Some(Ok(cred)) => {
288                 // Turns out it's a memory safety issue if we pass through any
289                 // and all credentials into libgit2
290                 if allowed_types & (cred.credtype() as c_uint) != 0 {
291                     *ret = cred.unwrap();
292                     0
293                 } else {
294                     raw::GIT_PASSTHROUGH as c_int
295                 }
296             }
297             Some(Err(e)) => e,
298             None => -1,
299         }
300     }
301 }
302 
transfer_progress_cb( stats: *const raw::git_indexer_progress, payload: *mut c_void, ) -> c_int303 extern "C" fn transfer_progress_cb(
304     stats: *const raw::git_indexer_progress,
305     payload: *mut c_void,
306 ) -> c_int {
307     let ok = panic::wrap(|| unsafe {
308         let payload = &mut *(payload as *mut RemoteCallbacks<'_>);
309         let callback = match payload.progress {
310             Some(ref mut c) => c,
311             None => return true,
312         };
313         let progress = Binding::from_raw(stats);
314         callback(progress)
315     });
316     if ok == Some(true) {
317         0
318     } else {
319         -1
320     }
321 }
322 
sideband_progress_cb(str: *const c_char, len: c_int, payload: *mut c_void) -> c_int323 extern "C" fn sideband_progress_cb(str: *const c_char, len: c_int, payload: *mut c_void) -> c_int {
324     let ok = panic::wrap(|| unsafe {
325         let payload = &mut *(payload as *mut RemoteCallbacks<'_>);
326         let callback = match payload.sideband_progress {
327             Some(ref mut c) => c,
328             None => return true,
329         };
330         let buf = slice::from_raw_parts(str as *const u8, len as usize);
331         callback(buf)
332     });
333     if ok == Some(true) {
334         0
335     } else {
336         -1
337     }
338 }
339 
update_tips_cb( refname: *const c_char, a: *const raw::git_oid, b: *const raw::git_oid, data: *mut c_void, ) -> c_int340 extern "C" fn update_tips_cb(
341     refname: *const c_char,
342     a: *const raw::git_oid,
343     b: *const raw::git_oid,
344     data: *mut c_void,
345 ) -> c_int {
346     let ok = panic::wrap(|| unsafe {
347         let payload = &mut *(data as *mut RemoteCallbacks<'_>);
348         let callback = match payload.update_tips {
349             Some(ref mut c) => c,
350             None => return true,
351         };
352         let refname = str::from_utf8(CStr::from_ptr(refname).to_bytes()).unwrap();
353         let a = Binding::from_raw(a);
354         let b = Binding::from_raw(b);
355         callback(refname, a, b)
356     });
357     if ok == Some(true) {
358         0
359     } else {
360         -1
361     }
362 }
363 
certificate_check_cb( cert: *mut raw::git_cert, _valid: c_int, hostname: *const c_char, data: *mut c_void, ) -> c_int364 extern "C" fn certificate_check_cb(
365     cert: *mut raw::git_cert,
366     _valid: c_int,
367     hostname: *const c_char,
368     data: *mut c_void,
369 ) -> c_int {
370     let ok = panic::wrap(|| unsafe {
371         let payload = &mut *(data as *mut RemoteCallbacks<'_>);
372         let callback = match payload.certificate_check {
373             Some(ref mut c) => c,
374             None => return true,
375         };
376         let cert = Binding::from_raw(cert);
377         let hostname = str::from_utf8(CStr::from_ptr(hostname).to_bytes()).unwrap();
378         callback(&cert, hostname)
379     });
380     if ok == Some(true) {
381         0
382     } else {
383         -1
384     }
385 }
386 
push_update_reference_cb( refname: *const c_char, status: *const c_char, data: *mut c_void, ) -> c_int387 extern "C" fn push_update_reference_cb(
388     refname: *const c_char,
389     status: *const c_char,
390     data: *mut c_void,
391 ) -> c_int {
392     panic::wrap(|| unsafe {
393         let payload = &mut *(data as *mut RemoteCallbacks<'_>);
394         let callback = match payload.push_update_reference {
395             Some(ref mut c) => c,
396             None => return 0,
397         };
398         let refname = str::from_utf8(CStr::from_ptr(refname).to_bytes()).unwrap();
399         let status = if status.is_null() {
400             None
401         } else {
402             Some(str::from_utf8(CStr::from_ptr(status).to_bytes()).unwrap())
403         };
404         match callback(refname, status) {
405             Ok(()) => 0,
406             Err(e) => e.raw_code(),
407         }
408     })
409     .unwrap_or(-1)
410 }
411 
push_transfer_progress_cb( progress: c_uint, total: c_uint, bytes: size_t, data: *mut c_void, ) -> c_int412 extern "C" fn push_transfer_progress_cb(
413     progress: c_uint,
414     total: c_uint,
415     bytes: size_t,
416     data: *mut c_void,
417 ) -> c_int {
418     panic::wrap(|| unsafe {
419         let payload = &mut *(data as *mut RemoteCallbacks<'_>);
420         let callback = match payload.push_progress {
421             Some(ref mut c) => c,
422             None => return 0,
423         };
424 
425         callback(progress as usize, total as usize, bytes as usize);
426 
427         0
428     })
429     .unwrap_or(-1)
430 }
431 
pack_progress_cb( stage: raw::git_packbuilder_stage_t, current: c_uint, total: c_uint, data: *mut c_void, ) -> c_int432 extern "C" fn pack_progress_cb(
433     stage: raw::git_packbuilder_stage_t,
434     current: c_uint,
435     total: c_uint,
436     data: *mut c_void,
437 ) -> c_int {
438     panic::wrap(|| unsafe {
439         let payload = &mut *(data as *mut RemoteCallbacks<'_>);
440         let callback = match payload.pack_progress {
441             Some(ref mut c) => c,
442             None => return 0,
443         };
444 
445         let stage = Binding::from_raw(stage);
446 
447         callback(stage, current as usize, total as usize);
448 
449         0
450     })
451     .unwrap_or(-1)
452 }
453