1 #ifndef STAN_MATH_PRIM_CORE_COMPLEX_BASE_HPP
2 #define STAN_MATH_PRIM_CORE_COMPLEX_BASE_HPP
3 
4 #include <stan/math/prim/fun/square.hpp>
5 #include <stan/math/prim/meta.hpp>
6 #include <complex>
7 
8 namespace stan {
9 namespace math {
10 
11 /**
12  * Base class for complex numbers.  Extending classes must be of
13  * of the form `complex<ValueType>`.
14  *
15  * @tparam ValueType type of real and imaginary parts
16  */
17 template <typename ValueType>
18 class complex_base {
19  public:
20   /**
21    * Type of real and imaginary parts
22    */
23   using value_type = ValueType;
24 
25   /**
26    * Derived complex type used for function return types
27    */
28   using complex_type = std::complex<value_type>;
29 
30   /**
31    * Construct a complex base with zero real and imaginary parts.
32    */
33   complex_base() = default;
34 
35   /**
36    * Construct a complex base with the specified real part and a zero
37    * imaginary part.
38    *
39    * @tparam U real type (assignable to `value_type`)
40    * @param[in] re real part
41    */
42   template <typename U>  // , typename = require_stan_scalar_t<U>>
complex_base(const U & re)43   complex_base(const U& re) : re_(re) {}  // NOLINT(runtime/explicit)
44 
45   /**
46    * Construct a complex base with the specified real and imaginary
47    * parts.
48    *
49    * @tparam U real type (assignable to `value_type`)
50    * @tparam V imaginary type (assignable to `value_type`)
51    * @param[in] re real part
52    * @param[in] im imaginary part
53    */
54   template <typename U, typename V>
complex_base(const U & re,const V & im)55   complex_base(const U& re, const V& im) : re_(re), im_(im) {}
56 
57   /**
58    * Return the real part.
59    *
60    * @return real part
61    */
real() const62   value_type real() const { return re_; }
63 
64   /**
65    * Set the real part to the specified value.
66    *
67    * @param[in] re real part
68    */
real(const value_type & re)69   void real(const value_type& re) { re_ = re; }
70 
71   /**
72    * Return the imaginary part.
73    *
74    * @return imaginary part
75    */
imag() const76   value_type imag() const { return im_; }
77 
78   /**
79    * Set the imaginary part to the specified value.
80    *
81    * @param[in] im imaginary part
82    */
imag(const value_type & im)83   void imag(const value_type& im) { im_ = im; }
84 
85   /**
86    * Assign the specified value to the real part of this complex number
87    * and set imaginary part to zero.
88    *
89    * @tparam U argument type (assignable to `value_type`)
90    * @param[in] re real part
91    * @return this
92    */
93   template <typename U, typename = require_stan_scalar_t<U>>
operator =(U && re)94   complex_type& operator=(U&& re) {
95     re_ = re;
96     im_ = 0;
97     return derived();
98   }
99 
100   /**
101    * Add specified real value to real part.
102    *
103    * @tparam U argument type (assignable to `value_type`)
104    * @param[in] x real number to add
105    * @return this
106    */
107   template <typename U>
operator +=(const U & x)108   complex_type& operator+=(const U& x) {
109     re_ += x;
110     return derived();
111   }
112 
113   /**
114    * Adds specified complex number to this.
115    *
116    * @tparam U value type of argument (assignable to `value_type`)
117    * @param[in] other complex number to add
118    * @return this
119    */
120   template <typename U>
operator +=(const std::complex<U> & other)121   complex_type& operator+=(const std::complex<U>& other) {
122     re_ += other.real();
123     im_ += other.imag();
124     return derived();
125   }
126 
127   /**
128    * Subtracts specified real number from real part.
129    *
130    * @tparam U argument type (assignable to `value_type`)
131    * @param[in] x real number to subtract
132    * @return this
133    */
134   template <typename U>
operator -=(const U & x)135   complex_type& operator-=(const U& x) {
136     re_ -= x;
137     return derived();
138   }
139 
140   /**
141    * Subtracts specified complex number from this.
142    *
143    * @tparam U value type of argument (assignable to `value_type`)
144    * @param[in] other complex number to subtract
145    * @return this
146    */
147   template <typename U>
operator -=(const std::complex<U> & other)148   complex_type& operator-=(const std::complex<U>& other) {
149     re_ -= other.real();
150     im_ -= other.imag();
151     return derived();
152   }
153 
154   /**
155    * Multiplies this by the specified real number.
156    *
157    * @tparam U type of argument (assignable to `value_type`)
158    * @param[in] x real number to multiply
159    * @return this
160    */
161   template <typename U>
operator *=(const U & x)162   complex_type& operator*=(const U& x) {
163     re_ *= x;
164     im_ *= x;
165     return derived();
166   }
167 
168   /**
169    * Multiplies this by specified complex number.
170    *
171    * @tparam U value type of argument (assignable to `value_type`)
172    * @param[in] other complex number to multiply
173    * @return this
174    */
175   template <typename U>
operator *=(const std::complex<U> & other)176   complex_type& operator*=(const std::complex<U>& other) {
177     value_type re_temp = re_ * other.real() - im_ * other.imag();
178     im_ = re_ * other.imag() + other.real() * im_;
179     re_ = re_temp;
180     return derived();
181   }
182 
183   /**
184    * Divides this by the specified real number.
185    *
186    * @tparam U type of argument (assignable to `value_type`)
187    * @param[in] x real number to divide by
188    * @return this
189    */
190   template <typename U>
operator /=(const U & x)191   complex_type& operator/=(const U& x) {
192     re_ /= x;
193     im_ /= x;
194     return derived();
195   }
196 
197   /**
198    * Divides this by the specified complex number.
199    *
200    * @tparam U value type of argument (assignable to `value_type`)
201    * @param[in] other number to divide by
202    * @return this
203    */
204   template <typename U>
operator /=(const std::complex<U> & other)205   complex_type& operator/=(const std::complex<U>& other) {
206     using stan::math::square;
207     value_type sum_sq_im
208         = (other.real() * other.real()) + (other.imag() * other.imag());
209     value_type re_temp = (re_ * other.real() + im_ * other.imag()) / sum_sq_im;
210     im_ = (im_ * other.real() - re_ * other.imag()) / sum_sq_im;
211     re_ = re_temp;
212     return derived();
213   }
214 
215  protected:
216   /**
217    * Real part
218    */
219   value_type re_{0};
220 
221   /**
222    * Imaginary part
223    */
224   value_type im_{0};
225 
226   /**
227    * Return this complex base cast to the complex type.
228    *
229    * @return this complex base cast to the complex type
230    */
derived()231   complex_type& derived() { return static_cast<complex_type&>(*this); }
232 };
233 
234 }  // namespace math
235 }  // namespace stan
236 
237 #endif
238