1 #include "chainerx/axes.h"
2 
3 #include <algorithm>
4 #include <cstddef>
5 #include <cstdint>
6 #include <numeric>
7 #include <ostream>
8 #include <set>
9 #include <sstream>
10 #include <string>
11 #include <vector>
12 
13 #include "chainerx/macro.h"
14 
15 namespace chainerx {
16 
ToString() const17 std::string Axes::ToString() const {
18     std::ostringstream os;
19     os << *this;
20     return os.str();
21 }
22 
operator <<(std::ostream & os,const Axes & axes)23 std::ostream& operator<<(std::ostream& os, const Axes& axes) {
24     os << "(";
25     for (auto iter = axes.begin(); iter != axes.end(); ++iter) {
26         if (iter != axes.begin()) {
27             os << ", ";
28         }
29         os << static_cast<int>(*iter);
30     }
31     // same as Python tuples with trailing comma in case of length 1
32     return os << (axes.ndim() == 1 ? ",)" : ")");
33 }
34 
35 namespace internal {
36 
IsAxesPermutation(const Axes & axes,int8_t ndim)37 bool IsAxesPermutation(const Axes& axes, int8_t ndim) {
38     CHAINERX_ASSERT(ndim >= 0);
39     if (axes.size() != static_cast<size_t>(ndim)) {
40         return false;
41     }
42 
43     Axes sorted_axes = axes;
44     std::sort(sorted_axes.begin(), sorted_axes.end());
45     for (int8_t i = 0; i < ndim; ++i) {
46         if (sorted_axes[i] != i) {
47             return false;
48         }
49     }
50     return true;
51 }
52 
NormalizeAxis(int8_t axis,int8_t ndim)53 int8_t NormalizeAxis(int8_t axis, int8_t ndim) {
54     if (axis < -ndim || ndim <= axis) {
55         throw DimensionError{"Axis ", axis, " is out of bounds for array of dimension ", ndim};
56     }
57     if (axis < 0) {
58         return axis + ndim;
59     }
60     return axis;
61 }
62 
GetNormalizedAxes(const Axes & axis,int8_t ndim)63 Axes GetNormalizedAxes(const Axes& axis, int8_t ndim) {
64     Axes normalized_axis = axis;
65 
66     for (auto& a : normalized_axis) {
67         a = NormalizeAxis(a, ndim);
68     }
69     if (std::unique(normalized_axis.begin(), normalized_axis.end()) != normalized_axis.end()) {
70         throw DimensionError{"Duplicate axis values."};
71     }
72 
73     // normalized_axis is unique, and within bounds [0, ndim).
74     CHAINERX_ASSERT(std::set<int8_t>(normalized_axis.begin(), normalized_axis.end()).size() == normalized_axis.size());
75     CHAINERX_ASSERT(std::all_of(normalized_axis.begin(), normalized_axis.end(), [ndim](int8_t x) -> bool { return 0 <= x && x < ndim; }));
76     return normalized_axis;
77 }
78 
GetSortedAxes(const Axes & axis,int8_t ndim)79 Axes GetSortedAxes(const Axes& axis, int8_t ndim) {
80     Axes sorted_axis = GetNormalizedAxes(axis, ndim);
81     std::sort(sorted_axis.begin(), sorted_axis.end());
82 
83     // sorted_axis is sorted, unique, and within bounds [0, ndim).
84     CHAINERX_ASSERT(std::is_sorted(sorted_axis.begin(), sorted_axis.end()));
85     CHAINERX_ASSERT(std::set<int8_t>(sorted_axis.begin(), sorted_axis.end()).size() == sorted_axis.size());
86     CHAINERX_ASSERT(std::all_of(sorted_axis.begin(), sorted_axis.end(), [ndim](int8_t x) -> bool { return 0 <= x && x < ndim; }));
87     return sorted_axis;
88 }
89 
GetSortedAxesOrAll(const OptionalAxes & axis,int8_t ndim)90 Axes GetSortedAxesOrAll(const OptionalAxes& axis, int8_t ndim) {
91     if (axis.has_value()) {
92         return GetSortedAxes(*axis, ndim);
93     }
94     // Fill with all axes
95     Axes sorted_axis{};
96     sorted_axis.resize(ndim);
97     std::iota(sorted_axis.begin(), sorted_axis.end(), int8_t{0});
98     return sorted_axis;
99 }
100 
101 }  // namespace internal
102 }  // namespace chainerx
103