1 #ifndef STAN_MATH_REV_CORE_SCOPED_CHAINABLESTACK_HPP
2 #define STAN_MATH_REV_CORE_SCOPED_CHAINABLESTACK_HPP
3 
4 #include <stan/math/rev/core/chainablestack.hpp>
5 
6 #include <mutex>
7 #include <stdexcept>
8 #include <thread>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * The AD tape of reverse mode AD is by default stored globally within the
15  * process (or thread). With the ScopedChainableStack class one may execute a
16  * nullary functor with reference to an AD tape which is stored with the
17  * instance of ScopedChainableStack. Example:
18  *
19  * ScopedChainableStack scoped_stack;
20  *
21  * double cgrad_a = scoped_stack.execute([] {
22  *   var a = 2.0;
23  *   var b = 4.0;
24  *   var c = a*a + b;
25  *   c.grad();
26  *   return a.adj();
27  * });
28  *
29  * Doing so will not interfere with the process (or thread) AD tape.
30  */
31 class ScopedChainableStack {
32   ChainableStack::AutodiffStackStorage local_stack_;
33   std::mutex local_stack_mutex_;
34 
35   struct activate_scope {
36     ChainableStack::AutodiffStackStorage* parent_stack_;
37     ScopedChainableStack& scoped_stack_;
38 
activate_scopestan::math::ScopedChainableStack::activate_scope39     explicit activate_scope(ScopedChainableStack& scoped_stack)
40         : parent_stack_(ChainableStack::instance_),
41           scoped_stack_(scoped_stack) {
42       if (!scoped_stack_.local_stack_mutex_.try_lock()) {
43         throw std::logic_error{"Cannot recurse same instance scoped stacks."};
44       }
45       ChainableStack::instance_ = &scoped_stack.local_stack_;
46     }
47 
~activate_scopestan::math::ScopedChainableStack::activate_scope48     ~activate_scope() {
49       ChainableStack::instance_ = parent_stack_;
50       scoped_stack_.local_stack_mutex_.unlock();
51     }
52   };
53 
54  public:
55   ScopedChainableStack() = default;
56 
57   /**
58    * Execute in the current thread a function and write the AD
59    * tape to local_stack_ of this instance. The function may return
60    * any type.
61    *
62    * @tparam F functor to evaluate
63    * @param f instance of functor
64    * @param args arguments passed to functor
65    * @return Result of evaluated functor
66    */
67   template <typename F, typename... Args>
execute(F && f,Args &&...args)68   decltype(auto) execute(F&& f, Args&&... args) {
69     const activate_scope active_scope(*this);
70     return std::forward<F>(f)(std::forward<Args>(args)...);
71   }
72 
73   // Prevent undesirable operations
74   ScopedChainableStack(const ScopedChainableStack&) = delete;
75   ScopedChainableStack& operator=(const ScopedChainableStack&) = delete;
76 };
77 
78 }  // namespace math
79 }  // namespace stan
80 #endif
81