1 #include <atomic>
2 #include <functional>
3 #include <memory>
4 #include <vector>
5
6 #include "callback_registry.h"
7 #include "debug.h"
8
9 std::atomic<uint64_t> nextCallbackId(1);
10
11 // ============================================================================
12 // Invoke functions
13 // ============================================================================
14
15 enum InvokeResult {
16 INVOKE_IN_PROGRESS,
17 INVOKE_INTERRUPTED,
18 INVOKE_ERROR,
19 INVOKE_CPP_ERROR,
20 INVOKE_COMPLETED
21 };
22
23 // This is set by invoke_c(). I
24 InvokeResult last_invoke_result;
25 std::string last_invoke_message;
26
27 // A wrapper for calling R_CheckUserInterrupt via R_ToplevelExec.
checkInterruptFn(void *)28 void checkInterruptFn(void*) {
29 R_CheckUserInterrupt();
30 }
31
32 // The purpose of this function is to provide a plain C function to be called
33 // by R_ToplevelExec. Because it's called as a C function, it must not throw
34 // exceptions. Because this function returns void, the way for it to report
35 // the result to its caller is by setting last_invoke_result.
36 //
37 // This code needs to be able to handle interrupts, R errors, and C++
38 // exceptions. There are many ways these things can happen.
39 //
40 // * If the Callback object is a RcppFunctionCallback, then in the case of an
41 // interrupt or an R error, it will throw a C++ exception. These exceptions
42 // are the ones defined by Rcpp, and they will be caught by the try-catch in
43 // this function.
44 // * It could be a StdFunctionCallback with C or C++ code.
45 // * If the function invokes an Rcpp::Function and an interrupt or R error
46 // happens within the Rcpp::Function, it will throw exceptions just like
47 // the RcppFunctionCallback case, and they will be caught.
48 // * If some other C++ exception occurs, it will be caught.
49 // * If an interrupt (Ctrl-C, or Esc in RStudio) is received (outside of an
50 // Rcpp::Function), this function will continue through to the end (and
51 // set the state to INVOKE_COMPLETED). Later, when the invoke_wrapper()
52 // function (which called this one) checks to see if the interrupt
53 // happened, it will set the state to INVOKE_INTERRUPTED. (Note that it is
54 // potentially possible for an interrupt and an exception to occur, in
55 // which case we set the state to INVOKE_ERROR.)
56 // * If the function calls R code with Rf_eval(), an interrupt or R error
57 // could occur. If it's an interrupt, then it will be detect as in the
58 // previous case. If an error occurs, then that error will be detected by
59 // the invoke_wrapper() function (which called this one) and the state
60 // will be set to INVOKE_ERROR.
61 //
62 // Note that the last case has one potentially problematic issue. If an error
63 // occurs in R code, then it will longjmp out of of this function, back to its
64 // caller, invoke_wrapped(). This will longjmp out of a try statement, which
65 // is generally not a good idea. We don't know ahead of time whether the
66 // Callback may longjmp or throw an exception -- some Callbacks could
67 // potentially do both.
68 //
69 // The alternative is to move the try-catch out of this function and into
70 // invoke_wrapped(), surrounding the `R_ToplevelExec(invoke_c, ...)`. However,
71 // if we do this, then exceptions would pass through the R_ToplevelExec, which
72 // is dangerous because it is plain C code. The current way of doing it is
73 // imperfect, but less dangerous.
74 //
75 // There does not seem to be a 100% safe way to call functions which could
76 // either longjmp or throw exceptions. If we do figure out a way to do that,
77 // it should be used here.
invoke_c(void * callback_p)78 extern "C" void invoke_c(void* callback_p) {
79 ASSERT_MAIN_THREAD()
80 last_invoke_result = INVOKE_IN_PROGRESS;
81 last_invoke_message = "";
82
83 Callback* cb_p = (Callback*)callback_p;
84
85 try {
86 cb_p->invoke();
87 }
88 catch(Rcpp::internal::InterruptedException &e) {
89 // Reaches here if the callback is in Rcpp code and an interrupt occurs.
90 DEBUG_LOG("invoke_c: caught Rcpp::internal::InterruptedException", LOG_INFO);
91 last_invoke_result = INVOKE_INTERRUPTED;
92 return;
93 }
94 catch(Rcpp::eval_error &e) {
95 // Reaches here if an R-level error happens in an Rcpp::Function.
96 DEBUG_LOG("invoke_c: caught Rcpp::eval_error", LOG_INFO);
97 last_invoke_result = INVOKE_ERROR;
98 last_invoke_message = e.what();
99 return;
100 }
101 catch(Rcpp::exception& e) {
102 // Reaches here if an R-level error happens in an Rcpp::Function.
103 DEBUG_LOG("invoke_c: caught Rcpp::exception", LOG_INFO);
104 last_invoke_result = INVOKE_ERROR;
105 last_invoke_message = e.what();
106 return;
107 }
108 catch(std::exception& e) {
109 // Reaches here if some other (non-Rcpp) C++ exception is thrown.
110 DEBUG_LOG(std::string("invoke_c: caught std::exception: ") + typeid(e).name(),
111 LOG_INFO);
112 last_invoke_result = INVOKE_CPP_ERROR;
113 last_invoke_message = e.what();
114 return;
115 }
116 catch( ... ) {
117 // Reaches here if a non-exception C++ object is thrown.
118 DEBUG_LOG(std::string("invoke_c: caught unknown object: ") + typeid(std::current_exception()).name(),
119 LOG_INFO);
120 last_invoke_result = INVOKE_CPP_ERROR;
121 return;
122 }
123
124 // Reaches here if no exceptions are thrown. It's possible to get here if an
125 // interrupt was received outside of Rcpp code, or if an R error happened
126 // using Rf_eval().
127 DEBUG_LOG("invoke_c: COMPLETED", LOG_DEBUG);
128 last_invoke_result = INVOKE_COMPLETED;
129 }
130
131 // Wrapper method for invoking a callback. The Callback object has an invoke()
132 // method, but instead of invoking it directly, this method should be used
133 // instead. The purpose of this method is to call invoke(), but wrap it in a
134 // R_ToplevelExec, so that any LONGJMPs (due to errors in R functions) won't
135 // cross that barrier in the call stack. If interrupts, exceptions, or
136 // LONGJMPs do occur, this function throws a C++ exception.
invoke_wrapped() const137 void Callback::invoke_wrapped() const {
138 ASSERT_MAIN_THREAD()
139 Rboolean result = R_ToplevelExec(invoke_c, (void*)this);
140
141 if (!result) {
142 DEBUG_LOG("invoke_wrapped: R_ToplevelExec return is FALSE; error or interrupt occurred in R code", LOG_INFO);
143 last_invoke_result = INVOKE_ERROR;
144 }
145
146 if (R_ToplevelExec(checkInterruptFn, NULL) == FALSE) {
147 // Reaches here if the callback is C/C++ code and an interrupt occurs.
148 DEBUG_LOG("invoke_wrapped: interrupt (outside of R code) detected by R_CheckUserInterrupt", LOG_INFO);
149 last_invoke_result = INVOKE_INTERRUPTED;
150 }
151
152 switch (last_invoke_result) {
153 case INVOKE_INTERRUPTED:
154 DEBUG_LOG("invoke_wrapped: throwing Rcpp::internal::InterruptedException", LOG_INFO);
155 throw Rcpp::internal::InterruptedException();
156 case INVOKE_ERROR:
157 DEBUG_LOG("invoke_wrapped: throwing Rcpp::exception", LOG_INFO);
158 throw Rcpp::exception(last_invoke_message.c_str());
159 case INVOKE_CPP_ERROR:
160 throw std::runtime_error("invoke_wrapped: throwing std::runtime_error");
161 default:
162 return;
163 }
164 }
165
166
167 // ============================================================================
168 // StdFunctionCallback
169 // ============================================================================
170
StdFunctionCallback(Timestamp when,std::function<void (void)> func)171 StdFunctionCallback::StdFunctionCallback(Timestamp when, std::function<void(void)> func) :
172 Callback(when),
173 func(func)
174 {
175 this->callbackId = nextCallbackId++;
176 }
177
rRepresentation() const178 Rcpp::RObject StdFunctionCallback::rRepresentation() const {
179 using namespace Rcpp;
180 ASSERT_MAIN_THREAD()
181
182 return List::create(
183 _["id"] = callbackId,
184 _["when"] = when.diff_secs(Timestamp()),
185 _["callback"] = Rcpp::CharacterVector::create("C/C++ function")
186 );
187 }
188
189
190 // ============================================================================
191 // RcppFunctionCallback
192 // ============================================================================
193
RcppFunctionCallback(Timestamp when,Rcpp::Function func)194 RcppFunctionCallback::RcppFunctionCallback(Timestamp when, Rcpp::Function func) :
195 Callback(when),
196 func(func)
197 {
198 ASSERT_MAIN_THREAD()
199 this->callbackId = nextCallbackId++;
200 }
201
rRepresentation() const202 Rcpp::RObject RcppFunctionCallback::rRepresentation() const {
203 using namespace Rcpp;
204 ASSERT_MAIN_THREAD()
205
206 return List::create(
207 _["id"] = callbackId,
208 _["when"] = when.diff_secs(Timestamp()),
209 _["callback"] = func
210 );
211 }
212
213
214 // ============================================================================
215 // CallbackRegistry
216 // ============================================================================
217
218 // [[Rcpp::export]]
testCallbackOrdering()219 void testCallbackOrdering() {
220 std::vector<StdFunctionCallback> callbacks;
221 Timestamp ts;
222 std::function<void(void)> func;
223 for (size_t i = 0; i < 100; i++) {
224 callbacks.push_back(StdFunctionCallback(ts, func));
225 }
226 for (size_t i = 1; i < 100; i++) {
227 if (callbacks[i] < callbacks[i-1]) {
228 ::Rf_error("Callback ordering is broken [1]");
229 }
230 if (!(callbacks[i] > callbacks[i-1])) {
231 ::Rf_error("Callback ordering is broken [2]");
232 }
233 if (callbacks[i-1] > callbacks[i]) {
234 ::Rf_error("Callback ordering is broken [3]");
235 }
236 if (!(callbacks[i-1] < callbacks[i])) {
237 ::Rf_error("Callback ordering is broken [4]");
238 }
239 }
240 for (size_t i = 100; i > 1; i--) {
241 if (callbacks[i-1] < callbacks[i-2]) {
242 ::Rf_error("Callback ordering is broken [2]");
243 }
244 }
245 }
246
CallbackRegistry(int id,Mutex * mutex,ConditionVariable * condvar)247 CallbackRegistry::CallbackRegistry(int id, Mutex* mutex, ConditionVariable* condvar)
248 : id(id), mutex(mutex), condvar(condvar)
249 {
250 ASSERT_MAIN_THREAD()
251 }
252
~CallbackRegistry()253 CallbackRegistry::~CallbackRegistry() {
254 ASSERT_MAIN_THREAD()
255 }
256
getId() const257 int CallbackRegistry::getId() const {
258 return id;
259 }
260
add(Rcpp::Function func,double secs)261 uint64_t CallbackRegistry::add(Rcpp::Function func, double secs) {
262 // Copies of the Rcpp::Function should only be made on the main thread.
263 ASSERT_MAIN_THREAD()
264 Timestamp when(secs);
265 Callback_sp cb = std::make_shared<RcppFunctionCallback>(when, func);
266 Guard guard(mutex);
267 queue.insert(cb);
268 condvar->signal();
269
270 return cb->getCallbackId();
271 }
272
add(void (* func)(void *),void * data,double secs)273 uint64_t CallbackRegistry::add(void (*func)(void*), void* data, double secs) {
274 Timestamp when(secs);
275 Callback_sp cb = std::make_shared<StdFunctionCallback>(when, std::bind(func, data));
276 Guard guard(mutex);
277 queue.insert(cb);
278 condvar->signal();
279
280 return cb->getCallbackId();
281 }
282
cancel(uint64_t id)283 bool CallbackRegistry::cancel(uint64_t id) {
284 Guard guard(mutex);
285
286 cbSet::const_iterator it;
287 for (it = queue.begin(); it != queue.end(); ++it) {
288 if ((*it)->getCallbackId() == id) {
289 queue.erase(it);
290 return true;
291 }
292 }
293
294 return false;
295 }
296
297 // The smallest timestamp present in the registry, if any.
298 // Use this to determine the next time we need to pump events.
nextTimestamp(bool recursive) const299 Optional<Timestamp> CallbackRegistry::nextTimestamp(bool recursive) const {
300 Guard guard(mutex);
301
302 Optional<Timestamp> minTimestamp;
303
304 if (! this->queue.empty()) {
305 cbSet::const_iterator it = queue.begin();
306 minTimestamp = Optional<Timestamp>((*it)->when);
307 }
308
309 // Now check children
310 if (recursive) {
311 for (std::vector<std::shared_ptr<CallbackRegistry> >::const_iterator it = children.begin();
312 it != children.end();
313 ++it)
314 {
315 Optional<Timestamp> childNextTimestamp = (*it)->nextTimestamp(recursive);
316
317 if (childNextTimestamp.has_value()) {
318 if (minTimestamp.has_value()) {
319 if (*childNextTimestamp < *minTimestamp) {
320 minTimestamp = childNextTimestamp;
321 }
322 } else {
323 minTimestamp = childNextTimestamp;
324 }
325 }
326 }
327 }
328
329 return minTimestamp;
330 }
331
empty() const332 bool CallbackRegistry::empty() const {
333 Guard guard(mutex);
334 return this->queue.empty();
335 }
336
337 // Returns true if the smallest timestamp exists and is not in the future.
due(const Timestamp & time,bool recursive) const338 bool CallbackRegistry::due(const Timestamp& time, bool recursive) const {
339 ASSERT_MAIN_THREAD()
340 Guard guard(mutex);
341 cbSet::const_iterator it = queue.begin();
342 if (!this->queue.empty() && !((*it)->when > time)) {
343 return true;
344 }
345
346 // Now check children
347 if (recursive) {
348 for (std::vector<std::shared_ptr<CallbackRegistry> >::const_iterator it = children.begin();
349 it != children.end();
350 ++it)
351 {
352 if ((*it)->due(time, true)) {
353 return true;
354 }
355 }
356 }
357
358 return false;
359 }
360
take(size_t max,const Timestamp & time)361 std::vector<Callback_sp> CallbackRegistry::take(size_t max, const Timestamp& time) {
362 ASSERT_MAIN_THREAD()
363 Guard guard(mutex);
364 std::vector<Callback_sp> results;
365 while (this->due(time, false) && (max <= 0 || results.size() < max)) {
366 cbSet::iterator it = queue.begin();
367 results.push_back(*it);
368 this->queue.erase(it);
369 }
370 return results;
371 }
372
wait(double timeoutSecs,bool recursive) const373 bool CallbackRegistry::wait(double timeoutSecs, bool recursive) const {
374 ASSERT_MAIN_THREAD()
375 if (timeoutSecs < 0) {
376 // "1000 years ought to be enough for anybody" --Bill Gates
377 timeoutSecs = 3e10;
378 }
379
380 Timestamp expireTime(timeoutSecs);
381
382 Guard guard(mutex);
383 while (true) {
384 Timestamp end = expireTime;
385 Optional<Timestamp> next = nextTimestamp(recursive);
386 if (next.has_value() && *next < expireTime) {
387 end = *next;
388 }
389 double waitFor = end.diff_secs(Timestamp());
390 if (waitFor <= 0)
391 break;
392 // Don't wait for more than 2 seconds at a time, in order to keep us
393 // at least somewhat responsive to user interrupts
394 if (waitFor > 2) {
395 waitFor = 2;
396 }
397 condvar->timedwait(waitFor);
398 Rcpp::checkUserInterrupt();
399 }
400
401 return due();
402 }
403
404
list() const405 Rcpp::List CallbackRegistry::list() const {
406 ASSERT_MAIN_THREAD()
407 Guard guard(mutex);
408
409 Rcpp::List results;
410
411 cbSet::const_iterator it;
412
413 for (it = queue.begin(); it != queue.end(); it++) {
414 results.push_back((*it)->rRepresentation());
415 }
416
417 return results;
418 }
419