1 #include <atomic>
2 #include <functional>
3 #include <memory>
4 #include <vector>
5 
6 #include "callback_registry.h"
7 #include "debug.h"
8 
9 std::atomic<uint64_t> nextCallbackId(1);
10 
11 // ============================================================================
12 // Invoke functions
13 // ============================================================================
14 
15 enum InvokeResult {
16   INVOKE_IN_PROGRESS,
17   INVOKE_INTERRUPTED,
18   INVOKE_ERROR,
19   INVOKE_CPP_ERROR,
20   INVOKE_COMPLETED
21 };
22 
23 // This is set by invoke_c(). I
24 InvokeResult last_invoke_result;
25 std::string last_invoke_message;
26 
27 // A wrapper for calling R_CheckUserInterrupt via R_ToplevelExec.
checkInterruptFn(void *)28 void checkInterruptFn(void*) {
29   R_CheckUserInterrupt();
30 }
31 
32 // The purpose of this function is to provide a plain C function to be called
33 // by R_ToplevelExec. Because it's called as a C function, it must not throw
34 // exceptions. Because this function returns void, the way for it to report
35 // the result to its caller is by setting last_invoke_result.
36 //
37 // This code needs to be able to handle interrupts, R errors, and C++
38 // exceptions. There are many ways these things can happen.
39 //
40 // * If the Callback object is a RcppFunctionCallback, then in the case of an
41 //   interrupt or an R error, it will throw a C++ exception. These exceptions
42 //   are the ones defined by Rcpp, and they will be caught by the try-catch in
43 //   this function.
44 // * It could be a StdFunctionCallback with C or C++ code.
45 //   * If the function invokes an Rcpp::Function and an interrupt or R error
46 //     happens within the Rcpp::Function, it will throw exceptions just like
47 //     the RcppFunctionCallback case, and they will be caught.
48 //   * If some other C++ exception occurs, it will be caught.
49 //   * If an interrupt (Ctrl-C, or Esc in RStudio) is received (outside of an
50 //     Rcpp::Function), this function will continue through to the end (and
51 //     set the state to INVOKE_COMPLETED). Later, when the invoke_wrapper()
52 //     function (which called this one) checks to see if the interrupt
53 //     happened, it will set the state to INVOKE_INTERRUPTED. (Note that it is
54 //     potentially possible for an interrupt and an exception to occur, in
55 //     which case we set the state to INVOKE_ERROR.)
56 //   * If the function calls R code with Rf_eval(), an interrupt or R error
57 //     could occur. If it's an interrupt, then it will be detect as in the
58 //     previous case. If an error occurs, then that error will be detected by
59 //     the invoke_wrapper() function (which called this one) and the state
60 //     will be set to INVOKE_ERROR.
61 //
62 // Note that the last case has one potentially problematic issue. If an error
63 // occurs in R code, then it will longjmp out of of this function, back to its
64 // caller, invoke_wrapped(). This will longjmp out of a try statement, which
65 // is generally not a good idea. We don't know ahead of time whether the
66 // Callback may longjmp or throw an exception -- some Callbacks could
67 // potentially do both.
68 //
69 // The alternative is to move the try-catch out of this function and into
70 // invoke_wrapped(), surrounding the `R_ToplevelExec(invoke_c, ...)`. However,
71 // if we do this, then exceptions would pass through the R_ToplevelExec, which
72 // is dangerous because it is plain C code. The current way of doing it is
73 // imperfect, but less dangerous.
74 //
75 // There does not seem to be a 100% safe way to call functions which could
76 // either longjmp or throw exceptions. If we do figure out a way to do that,
77 // it should be used here.
invoke_c(void * callback_p)78 extern "C" void invoke_c(void* callback_p) {
79   ASSERT_MAIN_THREAD()
80   last_invoke_result = INVOKE_IN_PROGRESS;
81   last_invoke_message = "";
82 
83   Callback* cb_p = (Callback*)callback_p;
84 
85   try {
86     cb_p->invoke();
87   }
88   catch(Rcpp::internal::InterruptedException &e) {
89     // Reaches here if the callback is in Rcpp code and an interrupt occurs.
90     DEBUG_LOG("invoke_c: caught Rcpp::internal::InterruptedException", LOG_INFO);
91     last_invoke_result = INVOKE_INTERRUPTED;
92     return;
93   }
94   catch(Rcpp::eval_error &e) {
95     // Reaches here if an R-level error happens in an Rcpp::Function.
96     DEBUG_LOG("invoke_c: caught Rcpp::eval_error", LOG_INFO);
97     last_invoke_result = INVOKE_ERROR;
98     last_invoke_message = e.what();
99     return;
100   }
101   catch(Rcpp::exception& e) {
102     // Reaches here if an R-level error happens in an Rcpp::Function.
103     DEBUG_LOG("invoke_c: caught Rcpp::exception", LOG_INFO);
104     last_invoke_result = INVOKE_ERROR;
105     last_invoke_message = e.what();
106     return;
107   }
108   catch(std::exception& e) {
109     // Reaches here if some other (non-Rcpp) C++ exception is thrown.
110     DEBUG_LOG(std::string("invoke_c: caught std::exception: ") + typeid(e).name(),
111               LOG_INFO);
112     last_invoke_result = INVOKE_CPP_ERROR;
113     last_invoke_message = e.what();
114     return;
115   }
116   catch( ... ) {
117     // Reaches here if a non-exception C++ object is thrown.
118     DEBUG_LOG(std::string("invoke_c: caught unknown object: ") + typeid(std::current_exception()).name(),
119               LOG_INFO);
120     last_invoke_result = INVOKE_CPP_ERROR;
121     return;
122   }
123 
124   // Reaches here if no exceptions are thrown. It's possible to get here if an
125   // interrupt was received outside of Rcpp code, or if an R error happened
126   // using Rf_eval().
127   DEBUG_LOG("invoke_c: COMPLETED", LOG_DEBUG);
128   last_invoke_result = INVOKE_COMPLETED;
129 }
130 
131 // Wrapper method for invoking a callback. The Callback object has an invoke()
132 // method, but instead of invoking it directly, this method should be used
133 // instead. The purpose of this method is to call invoke(), but wrap it in a
134 // R_ToplevelExec, so that any LONGJMPs (due to errors in R functions) won't
135 // cross that barrier in the call stack. If interrupts, exceptions, or
136 // LONGJMPs do occur, this function throws a C++ exception.
invoke_wrapped() const137 void Callback::invoke_wrapped() const {
138   ASSERT_MAIN_THREAD()
139   Rboolean result = R_ToplevelExec(invoke_c, (void*)this);
140 
141   if (!result) {
142     DEBUG_LOG("invoke_wrapped: R_ToplevelExec return is FALSE; error or interrupt occurred in R code", LOG_INFO);
143     last_invoke_result = INVOKE_ERROR;
144   }
145 
146   if (R_ToplevelExec(checkInterruptFn, NULL) == FALSE) {
147     // Reaches here if the callback is C/C++ code and an interrupt occurs.
148     DEBUG_LOG("invoke_wrapped: interrupt (outside of R code) detected by R_CheckUserInterrupt", LOG_INFO);
149     last_invoke_result = INVOKE_INTERRUPTED;
150   }
151 
152   switch (last_invoke_result) {
153   case INVOKE_INTERRUPTED:
154     DEBUG_LOG("invoke_wrapped: throwing Rcpp::internal::InterruptedException", LOG_INFO);
155     throw Rcpp::internal::InterruptedException();
156   case INVOKE_ERROR:
157     DEBUG_LOG("invoke_wrapped: throwing Rcpp::exception", LOG_INFO);
158     throw Rcpp::exception(last_invoke_message.c_str());
159   case INVOKE_CPP_ERROR:
160     throw std::runtime_error("invoke_wrapped: throwing std::runtime_error");
161   default:
162     return;
163   }
164 }
165 
166 
167 // ============================================================================
168 // StdFunctionCallback
169 // ============================================================================
170 
StdFunctionCallback(Timestamp when,std::function<void (void)> func)171 StdFunctionCallback::StdFunctionCallback(Timestamp when, std::function<void(void)> func) :
172   Callback(when),
173   func(func)
174 {
175   this->callbackId = nextCallbackId++;
176 }
177 
rRepresentation() const178 Rcpp::RObject StdFunctionCallback::rRepresentation() const {
179   using namespace Rcpp;
180   ASSERT_MAIN_THREAD()
181 
182   return List::create(
183     _["id"]       = callbackId,
184     _["when"]     = when.diff_secs(Timestamp()),
185     _["callback"] = Rcpp::CharacterVector::create("C/C++ function")
186   );
187 }
188 
189 
190 // ============================================================================
191 // RcppFunctionCallback
192 // ============================================================================
193 
RcppFunctionCallback(Timestamp when,Rcpp::Function func)194 RcppFunctionCallback::RcppFunctionCallback(Timestamp when, Rcpp::Function func) :
195   Callback(when),
196   func(func)
197 {
198   ASSERT_MAIN_THREAD()
199   this->callbackId = nextCallbackId++;
200 }
201 
rRepresentation() const202 Rcpp::RObject RcppFunctionCallback::rRepresentation() const {
203   using namespace Rcpp;
204   ASSERT_MAIN_THREAD()
205 
206   return List::create(
207     _["id"]       = callbackId,
208     _["when"]     = when.diff_secs(Timestamp()),
209     _["callback"] = func
210   );
211 }
212 
213 
214 // ============================================================================
215 // CallbackRegistry
216 // ============================================================================
217 
218 // [[Rcpp::export]]
testCallbackOrdering()219 void testCallbackOrdering() {
220   std::vector<StdFunctionCallback> callbacks;
221   Timestamp ts;
222   std::function<void(void)> func;
223   for (size_t i = 0; i < 100; i++) {
224     callbacks.push_back(StdFunctionCallback(ts, func));
225   }
226   for (size_t i = 1; i < 100; i++) {
227     if (callbacks[i] < callbacks[i-1]) {
228       ::Rf_error("Callback ordering is broken [1]");
229     }
230     if (!(callbacks[i] > callbacks[i-1])) {
231       ::Rf_error("Callback ordering is broken [2]");
232     }
233     if (callbacks[i-1] > callbacks[i]) {
234       ::Rf_error("Callback ordering is broken [3]");
235     }
236     if (!(callbacks[i-1] < callbacks[i])) {
237       ::Rf_error("Callback ordering is broken [4]");
238     }
239   }
240   for (size_t i = 100; i > 1; i--) {
241     if (callbacks[i-1] < callbacks[i-2]) {
242       ::Rf_error("Callback ordering is broken [2]");
243     }
244   }
245 }
246 
CallbackRegistry(int id,Mutex * mutex,ConditionVariable * condvar)247 CallbackRegistry::CallbackRegistry(int id, Mutex* mutex, ConditionVariable* condvar)
248   : id(id), mutex(mutex), condvar(condvar)
249 {
250   ASSERT_MAIN_THREAD()
251 }
252 
~CallbackRegistry()253 CallbackRegistry::~CallbackRegistry() {
254   ASSERT_MAIN_THREAD()
255 }
256 
getId() const257 int CallbackRegistry::getId() const {
258   return id;
259 }
260 
add(Rcpp::Function func,double secs)261 uint64_t CallbackRegistry::add(Rcpp::Function func, double secs) {
262   // Copies of the Rcpp::Function should only be made on the main thread.
263   ASSERT_MAIN_THREAD()
264   Timestamp when(secs);
265   Callback_sp cb = std::make_shared<RcppFunctionCallback>(when, func);
266   Guard guard(mutex);
267   queue.insert(cb);
268   condvar->signal();
269 
270   return cb->getCallbackId();
271 }
272 
add(void (* func)(void *),void * data,double secs)273 uint64_t CallbackRegistry::add(void (*func)(void*), void* data, double secs) {
274   Timestamp when(secs);
275   Callback_sp cb = std::make_shared<StdFunctionCallback>(when, std::bind(func, data));
276   Guard guard(mutex);
277   queue.insert(cb);
278   condvar->signal();
279 
280   return cb->getCallbackId();
281 }
282 
cancel(uint64_t id)283 bool CallbackRegistry::cancel(uint64_t id) {
284   Guard guard(mutex);
285 
286   cbSet::const_iterator it;
287   for (it = queue.begin(); it != queue.end(); ++it) {
288     if ((*it)->getCallbackId() == id) {
289       queue.erase(it);
290       return true;
291     }
292   }
293 
294   return false;
295 }
296 
297 // The smallest timestamp present in the registry, if any.
298 // Use this to determine the next time we need to pump events.
nextTimestamp(bool recursive) const299 Optional<Timestamp> CallbackRegistry::nextTimestamp(bool recursive) const {
300   Guard guard(mutex);
301 
302   Optional<Timestamp> minTimestamp;
303 
304   if (! this->queue.empty()) {
305     cbSet::const_iterator it = queue.begin();
306     minTimestamp = Optional<Timestamp>((*it)->when);
307   }
308 
309   // Now check children
310   if (recursive) {
311     for (std::vector<std::shared_ptr<CallbackRegistry> >::const_iterator it = children.begin();
312          it != children.end();
313          ++it)
314     {
315       Optional<Timestamp> childNextTimestamp = (*it)->nextTimestamp(recursive);
316 
317       if (childNextTimestamp.has_value()) {
318         if (minTimestamp.has_value()) {
319           if (*childNextTimestamp < *minTimestamp) {
320             minTimestamp = childNextTimestamp;
321           }
322         } else {
323           minTimestamp = childNextTimestamp;
324         }
325       }
326     }
327   }
328 
329   return minTimestamp;
330 }
331 
empty() const332 bool CallbackRegistry::empty() const {
333   Guard guard(mutex);
334   return this->queue.empty();
335 }
336 
337 // Returns true if the smallest timestamp exists and is not in the future.
due(const Timestamp & time,bool recursive) const338 bool CallbackRegistry::due(const Timestamp& time, bool recursive) const {
339   ASSERT_MAIN_THREAD()
340   Guard guard(mutex);
341   cbSet::const_iterator it = queue.begin();
342   if (!this->queue.empty() && !((*it)->when > time)) {
343     return true;
344   }
345 
346   // Now check children
347   if (recursive) {
348     for (std::vector<std::shared_ptr<CallbackRegistry> >::const_iterator it = children.begin();
349          it != children.end();
350          ++it)
351     {
352       if ((*it)->due(time, true)) {
353         return true;
354       }
355     }
356   }
357 
358   return false;
359 }
360 
take(size_t max,const Timestamp & time)361 std::vector<Callback_sp> CallbackRegistry::take(size_t max, const Timestamp& time) {
362   ASSERT_MAIN_THREAD()
363   Guard guard(mutex);
364   std::vector<Callback_sp> results;
365   while (this->due(time, false) && (max <= 0 || results.size() < max)) {
366     cbSet::iterator it = queue.begin();
367     results.push_back(*it);
368     this->queue.erase(it);
369   }
370   return results;
371 }
372 
wait(double timeoutSecs,bool recursive) const373 bool CallbackRegistry::wait(double timeoutSecs, bool recursive) const {
374   ASSERT_MAIN_THREAD()
375   if (timeoutSecs < 0) {
376     // "1000 years ought to be enough for anybody" --Bill Gates
377     timeoutSecs = 3e10;
378   }
379 
380   Timestamp expireTime(timeoutSecs);
381 
382   Guard guard(mutex);
383   while (true) {
384     Timestamp end = expireTime;
385     Optional<Timestamp> next = nextTimestamp(recursive);
386     if (next.has_value() && *next < expireTime) {
387       end = *next;
388     }
389     double waitFor = end.diff_secs(Timestamp());
390     if (waitFor <= 0)
391       break;
392     // Don't wait for more than 2 seconds at a time, in order to keep us
393     // at least somewhat responsive to user interrupts
394     if (waitFor > 2) {
395       waitFor = 2;
396     }
397     condvar->timedwait(waitFor);
398     Rcpp::checkUserInterrupt();
399   }
400 
401   return due();
402 }
403 
404 
list() const405 Rcpp::List CallbackRegistry::list() const {
406   ASSERT_MAIN_THREAD()
407   Guard guard(mutex);
408 
409   Rcpp::List results;
410 
411   cbSet::const_iterator it;
412 
413   for (it = queue.begin(); it != queue.end(); it++) {
414     results.push_back((*it)->rRepresentation());
415   }
416 
417   return results;
418 }
419