1 #ifndef STAN_MATH_PRIM_FUN_SCALAR_SEQ_VIEW_HPP 2 #define STAN_MATH_PRIM_FUN_SCALAR_SEQ_VIEW_HPP 3 4 #include <stan/math/prim/meta.hpp> 5 #include <stan/math/prim/fun/value_of.hpp> 6 #include <type_traits> 7 #include <utility> 8 9 namespace stan { 10 /** 11 * scalar_seq_view provides a uniform sequence-like wrapper around either a 12 * scalar or a sequence of scalars. 13 * 14 * @tparam C the container type; will be the scalar type if wrapping a scalar 15 * @tparam T the scalar type 16 */ 17 template <typename C, typename = void> 18 class scalar_seq_view; 19 20 template <typename C> 21 class scalar_seq_view<C, require_eigen_vector_t<C>> { 22 public: 23 template <typename T, 24 typename = require_same_t<plain_type_t<T>, plain_type_t<C>>> scalar_seq_view(T && c)25 explicit scalar_seq_view(T&& c) : c_(std::forward<T>(c)) {} 26 27 /** 28 * Segfaults if out of bounds. 29 * @param i index 30 * @return the element at the specified position in the container 31 */ operator [](size_t i) const32 inline auto operator[](size_t i) const { return c_.coeffRef(i); } 33 size() const34 inline auto size() const noexcept { return c_.size(); } 35 data() const36 inline const value_type_t<C>* data() const noexcept { return c_.data(); } data()37 inline value_type_t<C>* data() noexcept { return c_.data(); } 38 39 template <typename T = C, require_st_arithmetic<T>* = nullptr> val(size_t i) const40 inline decltype(auto) val(size_t i) const { 41 return c_.coeffRef(i); 42 } 43 44 template <typename T = C, require_st_autodiff<T>* = nullptr> val(size_t i) const45 inline decltype(auto) val(size_t i) const { 46 return c_.coeffRef(i).val(); 47 } 48 49 private: 50 ref_type_t<C> c_; 51 }; 52 53 template <typename C> 54 class scalar_seq_view<C, require_var_matrix_t<C>> { 55 public: 56 template <typename T, 57 typename = require_same_t<plain_type_t<T>, plain_type_t<C>>> scalar_seq_view(T && c)58 explicit scalar_seq_view(T&& c) : c_(std::forward<T>(c)) {} 59 60 /** 61 * Segfaults if out of bounds. 62 * @param i index 63 * @return the element at the specified position in the container 64 */ operator [](size_t i) const65 inline auto operator[](size_t i) const { return c_.coeff(i); } data() const66 inline const auto* data() const noexcept { return c_.vi_; } data()67 inline auto* data() noexcept { return c_.vi_; } 68 size() const69 inline auto size() const noexcept { return c_.size(); } 70 71 template <typename T = C, require_st_autodiff<T>* = nullptr> val(size_t i) const72 inline auto val(size_t i) const { 73 return c_.val().coeff(i); 74 } 75 76 template <typename T = C, require_st_autodiff<T>* = nullptr> val(size_t i)77 inline auto& val(size_t i) { 78 return c_.val().coeffRef(i); 79 } 80 81 private: 82 std::decay_t<C> c_; 83 }; 84 85 template <typename C> 86 class scalar_seq_view<C, require_std_vector_t<C>> { 87 public: 88 template <typename T, 89 typename = require_same_t<plain_type_t<T>, plain_type_t<C>>> scalar_seq_view(T && c)90 explicit scalar_seq_view(T&& c) : c_(std::forward<T>(c)) {} 91 92 /** 93 * Segfaults if out of bounds. 94 * @param i index 95 * @return the element at the specified position in the container 96 */ operator [](size_t i) const97 inline auto operator[](size_t i) const { return c_[i]; } size() const98 inline auto size() const noexcept { return c_.size(); } data() const99 inline const auto* data() const noexcept { return c_.data(); } 100 101 template <typename T = C, require_st_arithmetic<T>* = nullptr> val(size_t i) const102 inline decltype(auto) val(size_t i) const { 103 return c_[i]; 104 } 105 106 template <typename T = C, require_st_autodiff<T>* = nullptr> val(size_t i) const107 inline decltype(auto) val(size_t i) const { 108 return c_[i].val(); 109 } 110 111 private: 112 const C& c_; 113 }; 114 115 template <typename C> 116 class scalar_seq_view<C, require_t<std::is_pointer<C>>> { 117 public: 118 template <typename T, 119 typename = require_same_t<plain_type_t<T>, plain_type_t<C>>> scalar_seq_view(const T & c)120 explicit scalar_seq_view(const T& c) : c_(c) {} 121 122 /** 123 * Segfaults if out of bounds. 124 * @param i index 125 * @return the element at the specified position in the container 126 */ operator [](size_t i) const127 inline auto operator[](size_t i) const { return c_[i]; } size() const128 inline auto size() const noexcept { 129 static_assert(1, "Cannot Return Size of scalar_seq_view with pointer type"); 130 } data() const131 inline const auto* data() const noexcept { return &c_[0]; } 132 133 template <typename T = C, require_st_arithmetic<T>* = nullptr> val(size_t i) const134 inline decltype(auto) val(size_t i) const { 135 return c_[i]; 136 } 137 138 template <typename T = C, require_st_autodiff<T>* = nullptr> val(size_t i) const139 inline decltype(auto) val(size_t i) const { 140 return c_[i].val(); 141 } 142 143 private: 144 const C& c_; 145 }; 146 147 /** 148 * This specialization handles wrapping a scalar as if it were a sequence. 149 * 150 * @tparam T the scalar type 151 */ 152 template <typename C> 153 class scalar_seq_view<C, require_stan_scalar_t<C>> { 154 public: scalar_seq_view(const C & t)155 explicit scalar_seq_view(const C& t) noexcept : t_(t) {} 156 operator [](int) const157 inline decltype(auto) operator[](int /* i */) const noexcept { return t_; } 158 159 template <typename T = C, require_st_arithmetic<T>* = nullptr> val(int) const160 inline decltype(auto) val(int /* i */) const noexcept { 161 return t_; 162 } 163 164 template <typename T = C, require_st_autodiff<T>* = nullptr> val(int) const165 inline decltype(auto) val(int /* i */) const noexcept { 166 return t_.val(); 167 } 168 size()169 static constexpr auto size() { return 1; } data() const170 inline const auto* data() const noexcept { return &t_; } data()171 inline auto* data() noexcept { return &t_; } 172 173 private: 174 std::decay_t<C> t_; 175 }; 176 } // namespace stan 177 #endif 178