1 /*
2  * Copyright (c) 2011-2021, The DART development contributors
3  * All rights reserved.
4  *
5  * The list of contributors can be found at:
6  *   https://github.com/dartsim/dart/blob/master/LICENSE
7  *
8  * This file is provided under the following "BSD-style" License:
9  *   Redistribution and use in source and binary forms, with or
10  *   without modification, are permitted provided that the following
11  *   conditions are met:
12  *   * Redistributions of source code must retain the above copyright
13  *     notice, this list of conditions and the following disclaimer.
14  *   * Redistributions in binary form must reproduce the above
15  *     copyright notice, this list of conditions and the following
16  *     disclaimer in the documentation and/or other materials provided
17  *     with the distribution.
18  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19  *   CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
20  *   INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
21  *   MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22  *   DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
23  *   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
26  *   USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
27  *   AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  *   LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
29  *   ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30  *   POSSIBILITY OF SUCH DAMAGE.
31  */
32 
33 #include "dart/math/Random.hpp"
34 
35 //==============================================================================
36 // This workaround is necessary for old Eigen (< 3.3). See the details here:
37 // http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1286
38 #if !EIGEN_VERSION_AT_LEAST(3, 3, 0)
39 
40 namespace dart {
41 namespace math {
42 namespace detail {
43 
44 template <typename Derived>
45 struct UniformScalarFromMatrixFunctor
46 {
47   using S = typename Derived::Scalar;
48 
UniformScalarFromMatrixFunctordart::math::detail::UniformScalarFromMatrixFunctor49   UniformScalarFromMatrixFunctor(
50       const Eigen::MatrixBase<Derived>& min,
51       const Eigen::MatrixBase<Derived>& max)
52     : mMin(min), mMax(max)
53   {
54     // Do nothing
55   }
56 
operator ()dart::math::detail::UniformScalarFromMatrixFunctor57   S operator()(int i, int j) const
58   {
59     return Random::uniform<S>(mMin(i, j), mMax(i, j));
60   }
61 
62   const Eigen::MatrixBase<Derived>& mMin;
63   const Eigen::MatrixBase<Derived>& mMax;
64 };
65 
66 template <typename Derived>
67 struct UniformScalarFromVectorFunctor
68 {
69   using S = typename Derived::Scalar;
70 
UniformScalarFromVectorFunctordart::math::detail::UniformScalarFromVectorFunctor71   UniformScalarFromVectorFunctor(
72       const Eigen::MatrixBase<Derived>& min,
73       const Eigen::MatrixBase<Derived>& max)
74     : mMin(min), mMax(max)
75   {
76     // Do nothing
77   }
78 
operator ()dart::math::detail::UniformScalarFromVectorFunctor79   S operator()(int i) const
80   {
81     return Random::uniform<S>(mMin[i], mMax[i]);
82   }
83 
84   const Eigen::MatrixBase<Derived>& mMin;
85   const Eigen::MatrixBase<Derived>& mMax;
86 };
87 
88 } // namespace detail
89 } // namespace math
90 } // namespace dart
91 
92 namespace Eigen {
93 namespace internal {
94 
95 template <typename T>
96 struct functor_has_linear_access<
97     dart::math::detail::UniformScalarFromMatrixFunctor<T>>
98 {
99   enum
100   {
101     ret = false
102   };
103 };
104 
105 template <typename T>
106 struct functor_has_linear_access<
107     dart::math::detail::UniformScalarFromVectorFunctor<T>>
108 {
109   enum
110   {
111     ret = true
112   };
113 };
114 
115 } // namespace internal
116 } // namespace Eigen
117 
118 #endif // !EIGEN_VERSION_AT_LEAST(3,3,0)
119 
120 namespace dart {
121 namespace math {
122 
123 namespace detail {
124 
125 //==============================================================================
126 template <template <typename...> class C, typename... Ts>
127 std::true_type is_base_of_template_impl(const C<Ts...>*);
128 
129 template <template <typename...> class C>
130 std::false_type is_base_of_template_impl(...);
131 
132 template <template <typename...> class C, typename T>
133 using is_base_of_template
134     = decltype(is_base_of_template_impl<C>(std::declval<T*>()));
135 
136 template <typename T>
137 using is_base_of_matrix = is_base_of_template<Eigen::MatrixBase, T>;
138 
139 //==============================================================================
140 /// Check whether \c T can be used for std::uniform_int_distribution<T>
141 /// Reference:
142 /// https://en.cppreference.com/w/cpp/numeric/random/uniform_int_distribution
143 template <typename T, typename Enable = void>
144 struct is_compatible_to_uniform_int_distribution : std::false_type
145 {
146   // Define nothing
147 };
148 
149 // clang-format off
150 
151 template <typename T>
152 struct is_compatible_to_uniform_int_distribution<
153     T, typename std::enable_if<
154         std::is_same<typename std::remove_cv<T>::type, short>::value
155         || std::is_same<typename std::remove_cv<T>::type, int>::value
156         || std::is_same<typename std::remove_cv<T>::type, long>::value
157         || std::is_same<typename std::remove_cv<T>::type, long long>::value
158         || std::is_same<typename std::remove_cv<T>::type, unsigned short>::value
159         || std::is_same<typename std::remove_cv<T>::type, unsigned int>::value
160         || std::is_same<typename std::remove_cv<T>::type, unsigned long>::value
161         || std::is_same<typename std::remove_cv<T>::type, unsigned long long>::value
162         >::type
163     > : std::true_type
164 {
165   // Define nothing
166 };
167 
168 // clang-format on
169 
170 //==============================================================================
171 template <typename S, typename Enable = void>
172 struct UniformScalarImpl
173 {
174   // Define nothing
175 };
176 
177 //==============================================================================
178 // Floating-point case
179 template <typename S>
180 struct UniformScalarImpl<
181     S,
182     typename std::enable_if<std::is_floating_point<S>::value>::type>
183 {
rundart::math::detail::UniformScalarImpl184   static S run(S min, S max)
185   {
186     // Distribution objects are lightweight so we simply construct a new
187     // distribution for each random number generation.
188     Random::UniformRealDist<S> d(min, max);
189     return d(Random::getGenerator());
190   }
191 };
192 
193 //==============================================================================
194 // Floating-point case
195 template <typename S>
196 struct UniformScalarImpl<
197     S,
198     typename std::enable_if<
199         is_compatible_to_uniform_int_distribution<S>::value>::type>
200 {
rundart::math::detail::UniformScalarImpl201   static S run(S min, S max)
202   {
203     // Distribution objects are lightweight so we simply construct a new
204     // distribution for each random number generation.
205     Random::UniformIntDist<S> d(min, max);
206     return d(Random::getGenerator());
207   }
208 };
209 
210 //==============================================================================
211 template <typename Derived, typename Enable = void>
212 struct UniformMatrixImpl
213 {
214   // Define nothing
215 };
216 
217 //==============================================================================
218 // Dynamic matrix case
219 template <typename Derived>
220 struct UniformMatrixImpl<
221     Derived,
222     typename std::enable_if<
223         !Derived::IsVectorAtCompileTime
224         && Derived::SizeAtCompileTime == Eigen::Dynamic>::type>
225 {
rundart::math::detail::UniformMatrixImpl226   static typename Derived::PlainObject run(
227       const Eigen::MatrixBase<Derived>& min,
228       const Eigen::MatrixBase<Derived>& max)
229   {
230 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
231     const auto uniformFunc = [&](int i, int j) {
232       return Random::uniform<typename Derived::Scalar>(min(i, j), max(i, j));
233     };
234     return Derived::PlainObject::NullaryExpr(
235         min.rows(), min.cols(), uniformFunc);
236 #else
237     return Derived::PlainObject::NullaryExpr(
238         min.rows(),
239         min.cols(),
240         detail::UniformScalarFromMatrixFunctor<Derived>(min, max));
241 #endif
242   }
243 };
244 
245 //==============================================================================
246 // Dynamic vector case
247 template <typename Derived>
248 struct UniformMatrixImpl<
249     Derived,
250     typename std::enable_if<
251         Derived::IsVectorAtCompileTime
252         && Derived::SizeAtCompileTime == Eigen::Dynamic>::type>
253 {
rundart::math::detail::UniformMatrixImpl254   static typename Derived::PlainObject run(
255       const Eigen::MatrixBase<Derived>& min,
256       const Eigen::MatrixBase<Derived>& max)
257   {
258 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
259     const auto uniformFunc = [&](int i) {
260       return Random::uniform<typename Derived::Scalar>(min[i], max[i]);
261     };
262     return Derived::PlainObject::NullaryExpr(min.size(), uniformFunc);
263 #else
264     return Derived::PlainObject::NullaryExpr(
265         min.size(), detail::UniformScalarFromVectorFunctor<Derived>(min, max));
266 #endif
267   }
268 };
269 
270 //==============================================================================
271 // Fixed matrix case
272 template <typename Derived>
273 struct UniformMatrixImpl<
274     Derived,
275     typename std::enable_if<
276         !Derived::IsVectorAtCompileTime
277         && Derived::SizeAtCompileTime != Eigen::Dynamic>::type>
278 {
rundart::math::detail::UniformMatrixImpl279   static typename Derived::PlainObject run(
280       const Eigen::MatrixBase<Derived>& min,
281       const Eigen::MatrixBase<Derived>& max)
282   {
283 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
284     const auto uniformFunc = [&](int i, int j) {
285       return Random::uniform<typename Derived::Scalar>(min(i, j), max(i, j));
286     };
287     return Derived::PlainObject::NullaryExpr(uniformFunc);
288 #else
289     return Derived::PlainObject::NullaryExpr(
290         detail::UniformScalarFromMatrixFunctor<Derived>(min, max));
291 #endif
292   }
293 };
294 
295 //==============================================================================
296 // Fixed vector case
297 template <typename Derived>
298 struct UniformMatrixImpl<
299     Derived,
300     typename std::enable_if<
301         Derived::IsVectorAtCompileTime
302         && Derived::SizeAtCompileTime != Eigen::Dynamic>::type>
303 {
rundart::math::detail::UniformMatrixImpl304   static typename Derived::PlainObject run(
305       const Eigen::MatrixBase<Derived>& min,
306       const Eigen::MatrixBase<Derived>& max)
307   {
308 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
309     const auto uniformFunc = [&](int i) {
310       return Random::uniform<typename Derived::Scalar>(min[i], max[i]);
311     };
312     return Derived::PlainObject::NullaryExpr(uniformFunc);
313 #else
314     return Derived::PlainObject::NullaryExpr(
315         detail::UniformScalarFromVectorFunctor<Derived>(min, max));
316 #endif
317   }
318 };
319 
320 //==============================================================================
321 template <typename T, typename Enable = void>
322 struct UniformImpl
323 {
324   // Define nothing
325 };
326 
327 //==============================================================================
328 template <typename T>
329 struct UniformImpl<
330     T,
331     typename std::enable_if<std::is_arithmetic<T>::value>::type>
332 {
rundart::math::detail::UniformImpl333   static T run(T min, T max)
334   {
335     return UniformScalarImpl<T>::run(min, max);
336   }
337 };
338 
339 //==============================================================================
340 template <typename T>
341 struct UniformImpl<
342     T,
343     typename std::enable_if<is_base_of_matrix<T>::value>::type>
344 {
rundart::math::detail::UniformImpl345   static T run(const Eigen::MatrixBase<T>& min, const Eigen::MatrixBase<T>& max)
346   {
347     return UniformMatrixImpl<T>::run(min, max);
348   }
349 };
350 
351 //==============================================================================
352 template <typename S, typename Enable = void>
353 struct NormalScalarImpl
354 {
355   // Define nothing
356 };
357 
358 //==============================================================================
359 // Floating-point case
360 template <typename S>
361 struct NormalScalarImpl<
362     S,
363     typename std::enable_if<std::is_floating_point<S>::value>::type>
364 {
rundart::math::detail::NormalScalarImpl365   static S run(S mean, S sigma)
366   {
367     Random::NormalRealDist<S> d(mean, sigma);
368     return d(Random::getGenerator());
369   }
370 };
371 
372 //==============================================================================
373 // Floating-point case
374 template <typename S>
375 struct NormalScalarImpl<
376     S,
377     typename std::enable_if<
378         is_compatible_to_uniform_int_distribution<S>::value>::type>
379 {
rundart::math::detail::NormalScalarImpl380   static S run(S mean, S sigma)
381   {
382     using DefaultFloatType = float;
383     const DefaultFloatType realNormal = Random::normal(
384         static_cast<DefaultFloatType>(mean),
385         static_cast<DefaultFloatType>(sigma));
386     return static_cast<S>(std::round(realNormal));
387   }
388 };
389 
390 //==============================================================================
391 template <typename T, typename Enable = void>
392 struct NormalImpl
393 {
394   // Define nothing
395 };
396 
397 //==============================================================================
398 template <typename T>
399 struct NormalImpl<
400     T,
401     typename std::enable_if<std::is_arithmetic<T>::value>::type>
402 {
rundart::math::detail::NormalImpl403   static T run(T min, T max)
404   {
405     return NormalScalarImpl<T>::run(min, max);
406   }
407 };
408 
409 } // namespace detail
410 
411 //==============================================================================
412 template <typename S>
uniform(S min,S max)413 S Random::uniform(S min, S max)
414 {
415   return detail::UniformImpl<S>::run(min, max);
416 }
417 
418 //==============================================================================
419 template <typename FixedSizeT>
uniform(typename FixedSizeT::Scalar min,typename FixedSizeT::Scalar max)420 FixedSizeT Random::uniform(
421     typename FixedSizeT::Scalar min, typename FixedSizeT::Scalar max)
422 {
423   return uniform<FixedSizeT>(
424       FixedSizeT::Constant(min), FixedSizeT::Constant(max));
425 }
426 
427 //==============================================================================
428 template <typename DynamicSizeVectorT>
uniform(int size,typename DynamicSizeVectorT::Scalar min,typename DynamicSizeVectorT::Scalar max)429 DynamicSizeVectorT Random::uniform(
430     int size,
431     typename DynamicSizeVectorT::Scalar min,
432     typename DynamicSizeVectorT::Scalar max)
433 {
434   return uniform<DynamicSizeVectorT>(
435       DynamicSizeVectorT::Constant(size, min),
436       DynamicSizeVectorT::Constant(size, max));
437 }
438 
439 //==============================================================================
440 template <typename DynamicSizeMatrixT>
uniform(int rows,int cols,typename DynamicSizeMatrixT::Scalar min,typename DynamicSizeMatrixT::Scalar max)441 DynamicSizeMatrixT Random::uniform(
442     int rows,
443     int cols,
444     typename DynamicSizeMatrixT::Scalar min,
445     typename DynamicSizeMatrixT::Scalar max)
446 {
447   return uniform<DynamicSizeMatrixT>(
448       DynamicSizeMatrixT::Constant(rows, cols, min),
449       DynamicSizeMatrixT::Constant(rows, cols, max));
450 }
451 
452 //==============================================================================
453 template <typename S>
normal(S min,S max)454 S Random::normal(S min, S max)
455 {
456   return detail::NormalImpl<S>::run(min, max);
457 }
458 
459 } // namespace math
460 } // namespace dart
461