1 #include "chainerx/axes.h"
2
3 #include <algorithm>
4 #include <cstddef>
5 #include <cstdint>
6 #include <numeric>
7 #include <ostream>
8 #include <set>
9 #include <sstream>
10 #include <string>
11 #include <vector>
12
13 #include "chainerx/macro.h"
14
15 namespace chainerx {
16
ToString() const17 std::string Axes::ToString() const {
18 std::ostringstream os;
19 os << *this;
20 return os.str();
21 }
22
operator <<(std::ostream & os,const Axes & axes)23 std::ostream& operator<<(std::ostream& os, const Axes& axes) {
24 os << "(";
25 for (auto iter = axes.begin(); iter != axes.end(); ++iter) {
26 if (iter != axes.begin()) {
27 os << ", ";
28 }
29 os << static_cast<int>(*iter);
30 }
31 // same as Python tuples with trailing comma in case of length 1
32 return os << (axes.ndim() == 1 ? ",)" : ")");
33 }
34
35 namespace internal {
36
IsAxesPermutation(const Axes & axes,int8_t ndim)37 bool IsAxesPermutation(const Axes& axes, int8_t ndim) {
38 CHAINERX_ASSERT(ndim >= 0);
39 if (axes.size() != static_cast<size_t>(ndim)) {
40 return false;
41 }
42
43 Axes sorted_axes = axes;
44 std::sort(sorted_axes.begin(), sorted_axes.end());
45 for (int8_t i = 0; i < ndim; ++i) {
46 if (sorted_axes[i] != i) {
47 return false;
48 }
49 }
50 return true;
51 }
52
NormalizeAxis(int8_t axis,int8_t ndim)53 int8_t NormalizeAxis(int8_t axis, int8_t ndim) {
54 if (axis < -ndim || ndim <= axis) {
55 throw DimensionError{"Axis ", axis, " is out of bounds for array of dimension ", ndim};
56 }
57 if (axis < 0) {
58 return axis + ndim;
59 }
60 return axis;
61 }
62
GetNormalizedAxes(const Axes & axis,int8_t ndim)63 Axes GetNormalizedAxes(const Axes& axis, int8_t ndim) {
64 Axes normalized_axis = axis;
65
66 for (auto& a : normalized_axis) {
67 a = NormalizeAxis(a, ndim);
68 }
69 if (std::unique(normalized_axis.begin(), normalized_axis.end()) != normalized_axis.end()) {
70 throw DimensionError{"Duplicate axis values."};
71 }
72
73 // normalized_axis is unique, and within bounds [0, ndim).
74 CHAINERX_ASSERT(std::set<int8_t>(normalized_axis.begin(), normalized_axis.end()).size() == normalized_axis.size());
75 CHAINERX_ASSERT(std::all_of(normalized_axis.begin(), normalized_axis.end(), [ndim](int8_t x) -> bool { return 0 <= x && x < ndim; }));
76 return normalized_axis;
77 }
78
GetSortedAxes(const Axes & axis,int8_t ndim)79 Axes GetSortedAxes(const Axes& axis, int8_t ndim) {
80 Axes sorted_axis = GetNormalizedAxes(axis, ndim);
81 std::sort(sorted_axis.begin(), sorted_axis.end());
82
83 // sorted_axis is sorted, unique, and within bounds [0, ndim).
84 CHAINERX_ASSERT(std::is_sorted(sorted_axis.begin(), sorted_axis.end()));
85 CHAINERX_ASSERT(std::set<int8_t>(sorted_axis.begin(), sorted_axis.end()).size() == sorted_axis.size());
86 CHAINERX_ASSERT(std::all_of(sorted_axis.begin(), sorted_axis.end(), [ndim](int8_t x) -> bool { return 0 <= x && x < ndim; }));
87 return sorted_axis;
88 }
89
GetSortedAxesOrAll(const OptionalAxes & axis,int8_t ndim)90 Axes GetSortedAxesOrAll(const OptionalAxes& axis, int8_t ndim) {
91 if (axis.has_value()) {
92 return GetSortedAxes(*axis, ndim);
93 }
94 // Fill with all axes
95 Axes sorted_axis{};
96 sorted_axis.resize(ndim);
97 std::iota(sorted_axis.begin(), sorted_axis.end(), int8_t{0});
98 return sorted_axis;
99 }
100
101 } // namespace internal
102 } // namespace chainerx
103