1 #ifndef STAN_MATH_PRIM_FUN_SCALAR_SEQ_VIEW_HPP
2 #define STAN_MATH_PRIM_FUN_SCALAR_SEQ_VIEW_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/fun/value_of.hpp>
6 #include <type_traits>
7 #include <utility>
8 
9 namespace stan {
10 /**
11  * scalar_seq_view provides a uniform sequence-like wrapper around either a
12  * scalar or a sequence of scalars.
13  *
14  * @tparam C the container type; will be the scalar type if wrapping a scalar
15  * @tparam T the scalar type
16  */
17 template <typename C, typename = void>
18 class scalar_seq_view;
19 
20 template <typename C>
21 class scalar_seq_view<C, require_eigen_vector_t<C>> {
22  public:
23   template <typename T,
24             typename = require_same_t<plain_type_t<T>, plain_type_t<C>>>
scalar_seq_view(T && c)25   explicit scalar_seq_view(T&& c) : c_(std::forward<T>(c)) {}
26 
27   /**
28    * Segfaults if out of bounds.
29    * @param i index
30    * @return the element at the specified position in the container
31    */
operator [](size_t i) const32   inline auto operator[](size_t i) const { return c_.coeffRef(i); }
33 
size() const34   inline auto size() const noexcept { return c_.size(); }
35 
data() const36   inline const value_type_t<C>* data() const noexcept { return c_.data(); }
data()37   inline value_type_t<C>* data() noexcept { return c_.data(); }
38 
39   template <typename T = C, require_st_arithmetic<T>* = nullptr>
val(size_t i) const40   inline decltype(auto) val(size_t i) const {
41     return c_.coeffRef(i);
42   }
43 
44   template <typename T = C, require_st_autodiff<T>* = nullptr>
val(size_t i) const45   inline decltype(auto) val(size_t i) const {
46     return c_.coeffRef(i).val();
47   }
48 
49  private:
50   ref_type_t<C> c_;
51 };
52 
53 template <typename C>
54 class scalar_seq_view<C, require_var_matrix_t<C>> {
55  public:
56   template <typename T,
57             typename = require_same_t<plain_type_t<T>, plain_type_t<C>>>
scalar_seq_view(T && c)58   explicit scalar_seq_view(T&& c) : c_(std::forward<T>(c)) {}
59 
60   /**
61    * Segfaults if out of bounds.
62    * @param i index
63    * @return the element at the specified position in the container
64    */
operator [](size_t i) const65   inline auto operator[](size_t i) const { return c_.coeff(i); }
data() const66   inline const auto* data() const noexcept { return c_.vi_; }
data()67   inline auto* data() noexcept { return c_.vi_; }
68 
size() const69   inline auto size() const noexcept { return c_.size(); }
70 
71   template <typename T = C, require_st_autodiff<T>* = nullptr>
val(size_t i) const72   inline auto val(size_t i) const {
73     return c_.val().coeff(i);
74   }
75 
76   template <typename T = C, require_st_autodiff<T>* = nullptr>
val(size_t i)77   inline auto& val(size_t i) {
78     return c_.val().coeffRef(i);
79   }
80 
81  private:
82   std::decay_t<C> c_;
83 };
84 
85 template <typename C>
86 class scalar_seq_view<C, require_std_vector_t<C>> {
87  public:
88   template <typename T,
89             typename = require_same_t<plain_type_t<T>, plain_type_t<C>>>
scalar_seq_view(T && c)90   explicit scalar_seq_view(T&& c) : c_(std::forward<T>(c)) {}
91 
92   /**
93    * Segfaults if out of bounds.
94    * @param i index
95    * @return the element at the specified position in the container
96    */
operator [](size_t i) const97   inline auto operator[](size_t i) const { return c_[i]; }
size() const98   inline auto size() const noexcept { return c_.size(); }
data() const99   inline const auto* data() const noexcept { return c_.data(); }
100 
101   template <typename T = C, require_st_arithmetic<T>* = nullptr>
val(size_t i) const102   inline decltype(auto) val(size_t i) const {
103     return c_[i];
104   }
105 
106   template <typename T = C, require_st_autodiff<T>* = nullptr>
val(size_t i) const107   inline decltype(auto) val(size_t i) const {
108     return c_[i].val();
109   }
110 
111  private:
112   const C& c_;
113 };
114 
115 template <typename C>
116 class scalar_seq_view<C, require_t<std::is_pointer<C>>> {
117  public:
118   template <typename T,
119             typename = require_same_t<plain_type_t<T>, plain_type_t<C>>>
scalar_seq_view(const T & c)120   explicit scalar_seq_view(const T& c) : c_(c) {}
121 
122   /**
123    * Segfaults if out of bounds.
124    * @param i index
125    * @return the element at the specified position in the container
126    */
operator [](size_t i) const127   inline auto operator[](size_t i) const { return c_[i]; }
size() const128   inline auto size() const noexcept {
129     static_assert(1, "Cannot Return Size of scalar_seq_view with pointer type");
130   }
data() const131   inline const auto* data() const noexcept { return &c_[0]; }
132 
133   template <typename T = C, require_st_arithmetic<T>* = nullptr>
val(size_t i) const134   inline decltype(auto) val(size_t i) const {
135     return c_[i];
136   }
137 
138   template <typename T = C, require_st_autodiff<T>* = nullptr>
val(size_t i) const139   inline decltype(auto) val(size_t i) const {
140     return c_[i].val();
141   }
142 
143  private:
144   const C& c_;
145 };
146 
147 /**
148  * This specialization handles wrapping a scalar as if it were a sequence.
149  *
150  * @tparam T the scalar type
151  */
152 template <typename C>
153 class scalar_seq_view<C, require_stan_scalar_t<C>> {
154  public:
scalar_seq_view(const C & t)155   explicit scalar_seq_view(const C& t) noexcept : t_(t) {}
156 
operator [](int) const157   inline decltype(auto) operator[](int /* i */) const noexcept { return t_; }
158 
159   template <typename T = C, require_st_arithmetic<T>* = nullptr>
val(int) const160   inline decltype(auto) val(int /* i */) const noexcept {
161     return t_;
162   }
163 
164   template <typename T = C, require_st_autodiff<T>* = nullptr>
val(int) const165   inline decltype(auto) val(int /* i */) const noexcept {
166     return t_.val();
167   }
168 
size()169   static constexpr auto size() { return 1; }
data() const170   inline const auto* data() const noexcept { return &t_; }
data()171   inline auto* data() noexcept { return &t_; }
172 
173  private:
174   std::decay_t<C> t_;
175 };
176 }  // namespace stan
177 #endif
178