1 #ifndef STAN_MATH_PRIM_FUN_SORT_INDICES_HPP
2 #define STAN_MATH_PRIM_FUN_SORT_INDICES_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/fun/Eigen.hpp>
6 #include <stan/math/prim/fun/to_ref.hpp>
7 #include <algorithm>
8 #include <vector>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * A comparator that works for any container type that has the
15  * brackets operator.
16  *
17  * @tparam ascending true if sorting in ascending order
18  * @tparam C container type
19  */
20 namespace internal {
21 template <bool ascending, typename C>
22 class index_comparator {
23   const C& xs_;
24 
25  public:
26   /**
27    * Construct an index comparator holding a reference
28    * to the specified container.
29    *
30    * @param xs Container
31    */
index_comparator(const C & xs)32   explicit index_comparator(const C& xs) : xs_(xs) {}
33 
34   /**
35    * Return true if the value at the first index is sorted in
36    * front of the value at the second index;  this will depend
37    * on the template parameter <code>ascending</code>.
38    *
39    * @param i Index of first value for comparison
40    * @param j Index of second value for comparison
41    */
operator ()(int i,int j) const42   bool operator()(int i, int j) const {
43     if (ascending) {
44       return xs_[i - 1] < xs_[j - 1];
45     } else {
46       return xs_[i - 1] > xs_[j - 1];
47     }
48   }
49 };
50 
51 }  // namespace internal
52 
53 /**
54  * Return an integer array of indices of the specified container
55  * sorting the values in ascending or descending order based on
56  * the value of the first template parameter.
57  *
58  * @tparam ascending true if sort is in ascending order
59  * @tparam C type of container
60  * @param xs Container to sort
61  * @return sorted version of container
62  */
63 template <bool ascending, typename C>
sort_indices(const C & xs)64 std::vector<int> sort_indices(const C& xs) {
65   using idx_t = index_type_t<C>;
66   idx_t xs_size = xs.size();
67   std::vector<int> idxs;
68   idxs.resize(xs_size);
69   for (idx_t i = 0; i < xs_size; ++i) {
70     idxs[i] = i + 1;
71   }
72   internal::index_comparator<ascending, ref_type_t<C>> comparator(to_ref(xs));
73   std::sort(idxs.begin(), idxs.end(), comparator);
74   return idxs;
75 }
76 
77 }  // namespace math
78 }  // namespace stan
79 
80 #endif
81