1 #pragma once 2 3 #include <algorithm> 4 #include <array> 5 #include <cstdint> 6 #include <initializer_list> 7 #include <iterator> 8 #include <sstream> 9 #include <string> 10 11 #include <absl/types/span.h> 12 #include <gsl/gsl> 13 14 #include "chainerx/constant.h" 15 #include "chainerx/error.h" 16 #include "chainerx/optional_container_arg.h" 17 #include "chainerx/stack_vector.h" 18 19 namespace chainerx { 20 21 class Axes : public StackVector<int8_t, kMaxNdim> { 22 using BaseVector = StackVector<int8_t, kMaxNdim>; 23 24 public: 25 using const_iterator = BaseVector::const_iterator; 26 using const_reverse_iterator = BaseVector::const_reverse_iterator; 27 // TODO(niboshi): Declare other types required for this class to be a container. 28 29 Axes() = default; 30 31 ~Axes() = default; 32 33 // by iterators 34 template <typename InputIt> Axes(InputIt first,InputIt last)35 Axes(InputIt first, InputIt last) { 36 if (std::distance(first, last) > kMaxNdim) { 37 throw DimensionError{"too many dimensions: ", std::distance(first, last)}; 38 } 39 insert(begin(), first, last); 40 } 41 42 // by span Axes(absl::Span<const int8_t> axes)43 explicit Axes(absl::Span<const int8_t> axes) : Axes{axes.begin(), axes.end()} {} 44 45 // by initializer list Axes(std::initializer_list<int8_t> axes)46 Axes(std::initializer_list<int8_t> axes) : Axes{axes.begin(), axes.end()} {} 47 48 // copy 49 Axes(const Axes&) = default; 50 Axes& operator=(const Axes&) = default; 51 52 // move 53 Axes(Axes&&) = default; 54 Axes& operator=(Axes&&) = default; 55 56 std::string ToString() const; 57 ndim()58 int8_t ndim() const noexcept { return gsl::narrow_cast<int8_t>(size()); } 59 60 int8_t& operator[](int8_t index) { 61 if (!(0 <= index && static_cast<size_t>(index) < size())) { 62 throw DimensionError{"index out of bounds"}; 63 } 64 return this->StackVector::operator[](index); 65 } 66 67 const int8_t& operator[](int8_t index) const { 68 if (!(0 <= index && static_cast<size_t>(index) < size())) { 69 throw DimensionError{"index out of bounds"}; 70 } 71 return this->StackVector::operator[](index); 72 } 73 74 // span span()75 absl::Span<const int8_t> span() const { return {*this}; } 76 }; 77 78 std::ostream& operator<<(std::ostream& os, const Axes& axes); 79 80 using OptionalAxes = OptionalContainerArg<Axes>; 81 82 namespace internal { 83 84 bool IsAxesPermutation(const Axes& axes, int8_t ndim); 85 86 // Normalizes possibly-negative axis to non-negative axis in [0, ndim). 87 // If `axis` does not fit in [-ndim, ndim), DimensionError is thrown. 88 int8_t NormalizeAxis(int8_t axis, int8_t ndim); 89 90 // Resolves the axis argument of many operations. 91 // Negative axes are converted to non-negative ones (by wrapping at ndim). 92 Axes GetNormalizedAxes(const Axes& axis, int8_t ndim); 93 94 // Resolves the axis argument of many operations. 95 // Negative axes are converted to non-negative ones (by wrapping at ndim). 96 // Axes are then sorted. 97 Axes GetSortedAxes(const Axes& axis, int8_t ndim); 98 99 // Resolves the axis argument of many operations. 100 // Negative axes are converted to non-negative ones (by wrapping at ndim). 101 // Axes are then sorted. 102 // nullopt is converted to a vector of all axes. 103 Axes GetSortedAxesOrAll(const OptionalAxes& axis, int8_t ndim); 104 105 } // namespace internal 106 } // namespace chainerx 107