1 #include "chainerx/strides.h"
2
3 #include <numeric>
4 #include <vector>
5
6 #include <absl/types/span.h>
7 #include <gtest/gtest.h>
8
9 #include "chainerx/axes.h"
10
11 namespace chainerx {
12 namespace {
13
CheckSpanEqual(std::initializer_list<int64_t> expect,absl::Span<const int64_t> actual)14 void CheckSpanEqual(std::initializer_list<int64_t> expect, absl::Span<const int64_t> actual) {
15 EXPECT_EQ(absl::MakeConstSpan(expect.begin(), expect.end()), actual);
16 }
17
TEST(StridesTest,Ctor)18 TEST(StridesTest, Ctor) {
19 { // Default ctor
20 const Strides strides{};
21 EXPECT_EQ(0, strides.ndim());
22 EXPECT_EQ(size_t{0}, strides.size());
23 }
24 { // From std::initializer_list
25 const Strides strides{48, 16, 4};
26 EXPECT_EQ(3, strides.ndim());
27 EXPECT_EQ(size_t{3}, strides.size());
28 CheckSpanEqual({48, 16, 4}, strides.span());
29 }
30 { // From span
31 const std::array<int64_t, 3> dims{48, 16, 4};
32 const Strides strides{absl::MakeConstSpan(dims)};
33 EXPECT_EQ(3, strides.ndim());
34 CheckSpanEqual({48, 16, 4}, strides.span());
35 }
36 { // From iterators
37 const std::vector<int64_t> dims{48, 16, 4};
38 const Strides strides{dims.begin(), dims.end()};
39 EXPECT_EQ(3, strides.ndim());
40 CheckSpanEqual({48, 16, 4}, strides.span());
41 }
42 { // From empty std::initializer_list
43 const Strides strides(std::initializer_list<int64_t>{});
44 EXPECT_EQ(0, strides.ndim());
45 CheckSpanEqual({}, strides.span());
46 }
47 { // From empty span
48 const std::array<int64_t, 0> dims{};
49 const Strides strides{absl::MakeConstSpan(dims)};
50 EXPECT_EQ(0, strides.ndim());
51 CheckSpanEqual({}, strides.span());
52 }
53 { // From empty iterators
54 const std::vector<int64_t> dims{};
55 const Strides strides{dims.begin(), dims.end()};
56 EXPECT_EQ(0, strides.ndim());
57 CheckSpanEqual({}, strides.span());
58 }
59 { // From shape and element size
60 const Strides strides{{2, 3, 4}, 4};
61 EXPECT_EQ(3, strides.ndim());
62 EXPECT_EQ(size_t{3}, strides.size());
63 CheckSpanEqual({48, 16, 4}, strides.span());
64 }
65 { // From shape and dtype
66 const Strides strides{{2, 3, 4}, Dtype::kInt32};
67 EXPECT_EQ(3, strides.ndim());
68 EXPECT_EQ(size_t{3}, strides.size());
69 CheckSpanEqual({48, 16, 4}, strides.span());
70 }
71 { // Too long std::initializer_list
72 EXPECT_THROW(Strides({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), DimensionError);
73 }
74 { // Too long span
75 const std::array<int64_t, kMaxNdim + 1> too_long{1};
76 EXPECT_THROW(Strides{absl::MakeConstSpan(too_long)}, DimensionError);
77 }
78 { // Too long iterators
79 std::vector<int64_t> dims{};
80 dims.resize(kMaxNdim + 1);
81 std::iota(dims.begin(), dims.end(), int64_t{1});
82 EXPECT_THROW(Strides({dims.begin(), dims.end()}), DimensionError);
83 }
84 }
85
TEST(StridesTest,Subscript)86 TEST(StridesTest, Subscript) {
87 const Strides strides = {48, 16, 4};
88 EXPECT_EQ(48, strides[0]);
89 EXPECT_EQ(16, strides[1]);
90 EXPECT_EQ(4, strides[2]);
91 EXPECT_THROW(strides[-1], DimensionError);
92 EXPECT_THROW(strides[3], DimensionError);
93 }
94
TEST(StridesTest,Compare)95 TEST(StridesTest, Compare) {
96 {
97 const Strides strides = {48, 16, 4};
98 const Strides strides2 = {48, 16, 4};
99 EXPECT_TRUE(strides == strides2);
100 }
101 {
102 const Strides strides = {48, 16, 4};
103 const Strides strides2 = {48, 16};
104 EXPECT_TRUE(strides != strides2);
105 }
106 {
107 const Strides strides = {48, 16, 4};
108 const Strides strides2 = {4, 8, 24};
109 EXPECT_TRUE(strides != strides2);
110 }
111 }
112
TEST(StridesTest,CheckEqual)113 TEST(StridesTest, CheckEqual) {
114 {
115 const Strides strides = {48, 16, 4};
116 const Strides strides2 = {48, 16, 4};
117 CheckEqual(strides, strides2);
118 }
119 {
120 const Strides strides = {48, 16, 4};
121 const Strides strides2 = {};
122 EXPECT_THROW(CheckEqual(strides, strides2), DimensionError);
123 }
124 }
125
TEST(StridesTest,Iterator)126 TEST(StridesTest, Iterator) {
127 const Strides strides = {48, 16, 4};
128 CheckSpanEqual({48, 16, 4}, absl::MakeConstSpan(std::vector<int64_t>{strides.begin(), strides.end()}));
129 CheckSpanEqual({4, 16, 48}, absl::MakeConstSpan(std::vector<int64_t>{strides.rbegin(), strides.rend()}));
130 }
131
TEST(StridesTest,ToString)132 TEST(StridesTest, ToString) {
133 {
134 const Strides strides = {};
135 EXPECT_EQ(strides.ToString(), "()");
136 }
137 {
138 const Strides strides = {4};
139 EXPECT_EQ(strides.ToString(), "(4,)");
140 }
141 {
142 const Strides strides = {48, 16, 4};
143 EXPECT_EQ(strides.ToString(), "(48, 16, 4)");
144 }
145 }
146
TEST(StridesTest,SpanFromStrides)147 TEST(StridesTest, SpanFromStrides) {
148 const Strides strides = {2, 3, 4};
149 CheckSpanEqual({2, 3, 4}, absl::MakeConstSpan(strides));
150 }
151
TEST(StridesTest,Permute)152 TEST(StridesTest, Permute) {
153 const Strides strides = {2, 3, 4};
154 CheckSpanEqual({3, 4}, strides.Permute(Axes{1, 2}).span());
155 EXPECT_THROW(strides.Permute(Axes{3}), DimensionError);
156 EXPECT_THROW(strides.Permute(Axes{-1}), DimensionError);
157 }
158
159 struct GetDataRangeTestParams {
160 Shape shape;
161 Strides strides;
162 size_t itemsize;
163 int64_t first;
164 int64_t last;
165 };
166
167 class GetDataRangeTest : public ::testing::TestWithParam<GetDataRangeTestParams> {};
168 INSTANTIATE_TEST_CASE_P(
169 GetDataRangeTest,
170 GetDataRangeTest,
171 ::testing::Values(
172 GetDataRangeTestParams{{2, 3, 4}, {96, 32, 8}, 8, 0, 192},
173 GetDataRangeTestParams{{10, 12}, {160, 8}, 8, 0, 1536},
174 GetDataRangeTestParams{{}, {}, 8, 0, 8},
175 GetDataRangeTestParams{{3, 0, 3}, {24, 24, 8}, 8, 0, 0},
176 GetDataRangeTestParams{{10, 3, 4}, {-96, 32, 8}, 8, -864, 96},
177 GetDataRangeTestParams{{10, 3, 4}, {-96, -32, -8}, 8, -952, 8},
178 GetDataRangeTestParams{{3, 4}, {8, 24}, 8, 0, 96},
179 GetDataRangeTestParams{{100}, {24}, 8, 0, 2384}));
180
TEST_P(GetDataRangeTest,GetDataRange)181 TEST_P(GetDataRangeTest, GetDataRange) {
182 GetDataRangeTestParams param = GetParam();
183 std::tuple<int64_t, int64_t> actual = GetDataRange(param.shape, param.strides, param.itemsize);
184 EXPECT_EQ(actual, std::make_tuple(param.first, param.last));
185 }
186
187 } // namespace
188 } // namespace chainerx
189