1 #include "chainerx/strides.h"
2 
3 #include <numeric>
4 #include <vector>
5 
6 #include <absl/types/span.h>
7 #include <gtest/gtest.h>
8 
9 #include "chainerx/axes.h"
10 
11 namespace chainerx {
12 namespace {
13 
CheckSpanEqual(std::initializer_list<int64_t> expect,absl::Span<const int64_t> actual)14 void CheckSpanEqual(std::initializer_list<int64_t> expect, absl::Span<const int64_t> actual) {
15     EXPECT_EQ(absl::MakeConstSpan(expect.begin(), expect.end()), actual);
16 }
17 
TEST(StridesTest,Ctor)18 TEST(StridesTest, Ctor) {
19     {  // Default ctor
20         const Strides strides{};
21         EXPECT_EQ(0, strides.ndim());
22         EXPECT_EQ(size_t{0}, strides.size());
23     }
24     {  // From std::initializer_list
25         const Strides strides{48, 16, 4};
26         EXPECT_EQ(3, strides.ndim());
27         EXPECT_EQ(size_t{3}, strides.size());
28         CheckSpanEqual({48, 16, 4}, strides.span());
29     }
30     {  // From span
31         const std::array<int64_t, 3> dims{48, 16, 4};
32         const Strides strides{absl::MakeConstSpan(dims)};
33         EXPECT_EQ(3, strides.ndim());
34         CheckSpanEqual({48, 16, 4}, strides.span());
35     }
36     {  // From iterators
37         const std::vector<int64_t> dims{48, 16, 4};
38         const Strides strides{dims.begin(), dims.end()};
39         EXPECT_EQ(3, strides.ndim());
40         CheckSpanEqual({48, 16, 4}, strides.span());
41     }
42     {  // From empty std::initializer_list
43         const Strides strides(std::initializer_list<int64_t>{});
44         EXPECT_EQ(0, strides.ndim());
45         CheckSpanEqual({}, strides.span());
46     }
47     {  // From empty span
48         const std::array<int64_t, 0> dims{};
49         const Strides strides{absl::MakeConstSpan(dims)};
50         EXPECT_EQ(0, strides.ndim());
51         CheckSpanEqual({}, strides.span());
52     }
53     {  // From empty iterators
54         const std::vector<int64_t> dims{};
55         const Strides strides{dims.begin(), dims.end()};
56         EXPECT_EQ(0, strides.ndim());
57         CheckSpanEqual({}, strides.span());
58     }
59     {  // From shape and element size
60         const Strides strides{{2, 3, 4}, 4};
61         EXPECT_EQ(3, strides.ndim());
62         EXPECT_EQ(size_t{3}, strides.size());
63         CheckSpanEqual({48, 16, 4}, strides.span());
64     }
65     {  // From shape and dtype
66         const Strides strides{{2, 3, 4}, Dtype::kInt32};
67         EXPECT_EQ(3, strides.ndim());
68         EXPECT_EQ(size_t{3}, strides.size());
69         CheckSpanEqual({48, 16, 4}, strides.span());
70     }
71     {  // Too long std::initializer_list
72         EXPECT_THROW(Strides({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), DimensionError);
73     }
74     {  // Too long span
75         const std::array<int64_t, kMaxNdim + 1> too_long{1};
76         EXPECT_THROW(Strides{absl::MakeConstSpan(too_long)}, DimensionError);
77     }
78     {  // Too long iterators
79         std::vector<int64_t> dims{};
80         dims.resize(kMaxNdim + 1);
81         std::iota(dims.begin(), dims.end(), int64_t{1});
82         EXPECT_THROW(Strides({dims.begin(), dims.end()}), DimensionError);
83     }
84 }
85 
TEST(StridesTest,Subscript)86 TEST(StridesTest, Subscript) {
87     const Strides strides = {48, 16, 4};
88     EXPECT_EQ(48, strides[0]);
89     EXPECT_EQ(16, strides[1]);
90     EXPECT_EQ(4, strides[2]);
91     EXPECT_THROW(strides[-1], DimensionError);
92     EXPECT_THROW(strides[3], DimensionError);
93 }
94 
TEST(StridesTest,Compare)95 TEST(StridesTest, Compare) {
96     {
97         const Strides strides = {48, 16, 4};
98         const Strides strides2 = {48, 16, 4};
99         EXPECT_TRUE(strides == strides2);
100     }
101     {
102         const Strides strides = {48, 16, 4};
103         const Strides strides2 = {48, 16};
104         EXPECT_TRUE(strides != strides2);
105     }
106     {
107         const Strides strides = {48, 16, 4};
108         const Strides strides2 = {4, 8, 24};
109         EXPECT_TRUE(strides != strides2);
110     }
111 }
112 
TEST(StridesTest,CheckEqual)113 TEST(StridesTest, CheckEqual) {
114     {
115         const Strides strides = {48, 16, 4};
116         const Strides strides2 = {48, 16, 4};
117         CheckEqual(strides, strides2);
118     }
119     {
120         const Strides strides = {48, 16, 4};
121         const Strides strides2 = {};
122         EXPECT_THROW(CheckEqual(strides, strides2), DimensionError);
123     }
124 }
125 
TEST(StridesTest,Iterator)126 TEST(StridesTest, Iterator) {
127     const Strides strides = {48, 16, 4};
128     CheckSpanEqual({48, 16, 4}, absl::MakeConstSpan(std::vector<int64_t>{strides.begin(), strides.end()}));
129     CheckSpanEqual({4, 16, 48}, absl::MakeConstSpan(std::vector<int64_t>{strides.rbegin(), strides.rend()}));
130 }
131 
TEST(StridesTest,ToString)132 TEST(StridesTest, ToString) {
133     {
134         const Strides strides = {};
135         EXPECT_EQ(strides.ToString(), "()");
136     }
137     {
138         const Strides strides = {4};
139         EXPECT_EQ(strides.ToString(), "(4,)");
140     }
141     {
142         const Strides strides = {48, 16, 4};
143         EXPECT_EQ(strides.ToString(), "(48, 16, 4)");
144     }
145 }
146 
TEST(StridesTest,SpanFromStrides)147 TEST(StridesTest, SpanFromStrides) {
148     const Strides strides = {2, 3, 4};
149     CheckSpanEqual({2, 3, 4}, absl::MakeConstSpan(strides));
150 }
151 
TEST(StridesTest,Permute)152 TEST(StridesTest, Permute) {
153     const Strides strides = {2, 3, 4};
154     CheckSpanEqual({3, 4}, strides.Permute(Axes{1, 2}).span());
155     EXPECT_THROW(strides.Permute(Axes{3}), DimensionError);
156     EXPECT_THROW(strides.Permute(Axes{-1}), DimensionError);
157 }
158 
159 struct GetDataRangeTestParams {
160     Shape shape;
161     Strides strides;
162     size_t itemsize;
163     int64_t first;
164     int64_t last;
165 };
166 
167 class GetDataRangeTest : public ::testing::TestWithParam<GetDataRangeTestParams> {};
168 INSTANTIATE_TEST_CASE_P(
169         GetDataRangeTest,
170         GetDataRangeTest,
171         ::testing::Values(
172                 GetDataRangeTestParams{{2, 3, 4}, {96, 32, 8}, 8, 0, 192},
173                 GetDataRangeTestParams{{10, 12}, {160, 8}, 8, 0, 1536},
174                 GetDataRangeTestParams{{}, {}, 8, 0, 8},
175                 GetDataRangeTestParams{{3, 0, 3}, {24, 24, 8}, 8, 0, 0},
176                 GetDataRangeTestParams{{10, 3, 4}, {-96, 32, 8}, 8, -864, 96},
177                 GetDataRangeTestParams{{10, 3, 4}, {-96, -32, -8}, 8, -952, 8},
178                 GetDataRangeTestParams{{3, 4}, {8, 24}, 8, 0, 96},
179                 GetDataRangeTestParams{{100}, {24}, 8, 0, 2384}));
180 
TEST_P(GetDataRangeTest,GetDataRange)181 TEST_P(GetDataRangeTest, GetDataRange) {
182     GetDataRangeTestParams param = GetParam();
183     std::tuple<int64_t, int64_t> actual = GetDataRange(param.shape, param.strides, param.itemsize);
184     EXPECT_EQ(actual, std::make_tuple(param.first, param.last));
185 }
186 
187 }  // namespace
188 }  // namespace chainerx
189