1 #ifndef STAN_MATH_REV_CORE_SCOPED_CHAINABLESTACK_HPP 2 #define STAN_MATH_REV_CORE_SCOPED_CHAINABLESTACK_HPP 3 4 #include <stan/math/rev/core/chainablestack.hpp> 5 6 #include <mutex> 7 #include <stdexcept> 8 #include <thread> 9 10 namespace stan { 11 namespace math { 12 13 /** 14 * The AD tape of reverse mode AD is by default stored globally within the 15 * process (or thread). With the ScopedChainableStack class one may execute a 16 * nullary functor with reference to an AD tape which is stored with the 17 * instance of ScopedChainableStack. Example: 18 * 19 * ScopedChainableStack scoped_stack; 20 * 21 * double cgrad_a = scoped_stack.execute([] { 22 * var a = 2.0; 23 * var b = 4.0; 24 * var c = a*a + b; 25 * c.grad(); 26 * return a.adj(); 27 * }); 28 * 29 * Doing so will not interfere with the process (or thread) AD tape. 30 */ 31 class ScopedChainableStack { 32 ChainableStack::AutodiffStackStorage local_stack_; 33 std::mutex local_stack_mutex_; 34 35 struct activate_scope { 36 ChainableStack::AutodiffStackStorage* parent_stack_; 37 ScopedChainableStack& scoped_stack_; 38 activate_scopestan::math::ScopedChainableStack::activate_scope39 explicit activate_scope(ScopedChainableStack& scoped_stack) 40 : parent_stack_(ChainableStack::instance_), 41 scoped_stack_(scoped_stack) { 42 if (!scoped_stack_.local_stack_mutex_.try_lock()) { 43 throw std::logic_error{"Cannot recurse same instance scoped stacks."}; 44 } 45 ChainableStack::instance_ = &scoped_stack.local_stack_; 46 } 47 ~activate_scopestan::math::ScopedChainableStack::activate_scope48 ~activate_scope() { 49 ChainableStack::instance_ = parent_stack_; 50 scoped_stack_.local_stack_mutex_.unlock(); 51 } 52 }; 53 54 public: 55 ScopedChainableStack() = default; 56 57 /** 58 * Execute in the current thread a function and write the AD 59 * tape to local_stack_ of this instance. The function may return 60 * any type. 61 * 62 * @tparam F functor to evaluate 63 * @param f instance of functor 64 * @param args arguments passed to functor 65 * @return Result of evaluated functor 66 */ 67 template <typename F, typename... Args> execute(F && f,Args &&...args)68 decltype(auto) execute(F&& f, Args&&... args) { 69 const activate_scope active_scope(*this); 70 return std::forward<F>(f)(std::forward<Args>(args)...); 71 } 72 73 // Prevent undesirable operations 74 ScopedChainableStack(const ScopedChainableStack&) = delete; 75 ScopedChainableStack& operator=(const ScopedChainableStack&) = delete; 76 }; 77 78 } // namespace math 79 } // namespace stan 80 #endif 81