1 #pragma once
2 
3 #include <algorithm>
4 #include <array>
5 #include <cstdint>
6 #include <initializer_list>
7 #include <iterator>
8 #include <sstream>
9 #include <string>
10 
11 #include <absl/types/span.h>
12 #include <gsl/gsl>
13 
14 #include "chainerx/constant.h"
15 #include "chainerx/error.h"
16 #include "chainerx/optional_container_arg.h"
17 #include "chainerx/stack_vector.h"
18 
19 namespace chainerx {
20 
21 class Axes : public StackVector<int8_t, kMaxNdim> {
22     using BaseVector = StackVector<int8_t, kMaxNdim>;
23 
24 public:
25     using const_iterator = BaseVector::const_iterator;
26     using const_reverse_iterator = BaseVector::const_reverse_iterator;
27     // TODO(niboshi): Declare other types required for this class to be a container.
28 
29     Axes() = default;
30 
31     ~Axes() = default;
32 
33     // by iterators
34     template <typename InputIt>
Axes(InputIt first,InputIt last)35     Axes(InputIt first, InputIt last) {
36         if (std::distance(first, last) > kMaxNdim) {
37             throw DimensionError{"too many dimensions: ", std::distance(first, last)};
38         }
39         insert(begin(), first, last);
40     }
41 
42     // by span
Axes(absl::Span<const int8_t> axes)43     explicit Axes(absl::Span<const int8_t> axes) : Axes{axes.begin(), axes.end()} {}
44 
45     // by initializer list
Axes(std::initializer_list<int8_t> axes)46     Axes(std::initializer_list<int8_t> axes) : Axes{axes.begin(), axes.end()} {}
47 
48     // copy
49     Axes(const Axes&) = default;
50     Axes& operator=(const Axes&) = default;
51 
52     // move
53     Axes(Axes&&) = default;
54     Axes& operator=(Axes&&) = default;
55 
56     std::string ToString() const;
57 
ndim()58     int8_t ndim() const noexcept { return gsl::narrow_cast<int8_t>(size()); }
59 
60     int8_t& operator[](int8_t index) {
61         if (!(0 <= index && static_cast<size_t>(index) < size())) {
62             throw DimensionError{"index out of bounds"};
63         }
64         return this->StackVector::operator[](index);
65     }
66 
67     const int8_t& operator[](int8_t index) const {
68         if (!(0 <= index && static_cast<size_t>(index) < size())) {
69             throw DimensionError{"index out of bounds"};
70         }
71         return this->StackVector::operator[](index);
72     }
73 
74     // span
span()75     absl::Span<const int8_t> span() const { return {*this}; }
76 };
77 
78 std::ostream& operator<<(std::ostream& os, const Axes& axes);
79 
80 using OptionalAxes = OptionalContainerArg<Axes>;
81 
82 namespace internal {
83 
84 bool IsAxesPermutation(const Axes& axes, int8_t ndim);
85 
86 // Normalizes possibly-negative axis to non-negative axis in [0, ndim).
87 // If `axis` does not fit in [-ndim, ndim), DimensionError is thrown.
88 int8_t NormalizeAxis(int8_t axis, int8_t ndim);
89 
90 // Resolves the axis argument of many operations.
91 // Negative axes are converted to non-negative ones (by wrapping at ndim).
92 Axes GetNormalizedAxes(const Axes& axis, int8_t ndim);
93 
94 // Resolves the axis argument of many operations.
95 // Negative axes are converted to non-negative ones (by wrapping at ndim).
96 // Axes are then sorted.
97 Axes GetSortedAxes(const Axes& axis, int8_t ndim);
98 
99 // Resolves the axis argument of many operations.
100 // Negative axes are converted to non-negative ones (by wrapping at ndim).
101 // Axes are then sorted.
102 // nullopt is converted to a vector of all axes.
103 Axes GetSortedAxesOrAll(const OptionalAxes& axis, int8_t ndim);
104 
105 }  // namespace internal
106 }  // namespace chainerx
107