1 #include "chainerx/scalar.h"
2
3 #include <gtest/gtest.h>
4
5 namespace chainerx {
6 namespace {
7
TEST(ScalarTest,Type)8 TEST(ScalarTest, Type) {
9 EXPECT_EQ(Scalar(true).kind(), DtypeKind::kBool);
10 EXPECT_EQ(Scalar(false).kind(), DtypeKind::kBool);
11 EXPECT_EQ(Scalar(int8_t(1)).kind(), DtypeKind::kInt);
12 EXPECT_EQ(Scalar(int16_t(2)).kind(), DtypeKind::kInt);
13 EXPECT_EQ(Scalar(int32_t(3)).kind(), DtypeKind::kInt);
14 EXPECT_EQ(Scalar(int64_t(4)).kind(), DtypeKind::kInt);
15 EXPECT_EQ(Scalar(uint8_t(5)).kind(), DtypeKind::kInt);
16 EXPECT_EQ(Scalar(6.7f).kind(), DtypeKind::kFloat);
17 EXPECT_EQ(Scalar(8.9).kind(), DtypeKind::kFloat);
18 }
19
20 template <typename T1, typename T2>
ExpectScalarEqual(T1 value1,T2 value2)21 void ExpectScalarEqual(T1 value1, T2 value2) {
22 EXPECT_EQ(Scalar(value1), Scalar(value2));
23 EXPECT_EQ(Scalar(value2), Scalar(value1));
24 }
25
TEST(ScalarTest,Equality)26 TEST(ScalarTest, Equality) {
27 // Same primitive type
28 ExpectScalarEqual(int8_t{0}, int8_t{0});
29 ExpectScalarEqual(int16_t{0}, int16_t{0});
30 ExpectScalarEqual(int32_t{0}, int32_t{0});
31 ExpectScalarEqual(int64_t{0}, int64_t{0});
32 ExpectScalarEqual(uint8_t{0}, uint8_t{0});
33 ExpectScalarEqual(uint16_t{0}, uint16_t{0});
34 ExpectScalarEqual(uint32_t{0}, uint32_t{0});
35 ExpectScalarEqual(int8_t{1}, int8_t{1});
36 ExpectScalarEqual(int16_t{1}, int16_t{1});
37 ExpectScalarEqual(int32_t{1}, int32_t{1});
38 ExpectScalarEqual(int64_t{1}, int64_t{1});
39 ExpectScalarEqual(uint8_t{1}, uint8_t{1});
40 ExpectScalarEqual(uint16_t{1}, uint16_t{1});
41 ExpectScalarEqual(uint32_t{1}, uint32_t{1});
42 ExpectScalarEqual(1.5, 1.5);
43 ExpectScalarEqual(1.5f, 1.5f);
44 ExpectScalarEqual(-1.5, -1.5);
45 ExpectScalarEqual(-1.5f, -1.5f);
46 ExpectScalarEqual(true, true);
47 ExpectScalarEqual(false, false);
48
49 // Different primitive types and same kind
50 ExpectScalarEqual(uint8_t{1}, int64_t{1});
51 ExpectScalarEqual(uint8_t{1}, uint32_t{1});
52 ExpectScalarEqual(int8_t{1}, int32_t{1});
53 ExpectScalarEqual(1.5f, 1.5);
54
55 // Different primitive types and different kinds
56 ExpectScalarEqual(int32_t{1}, 1.0f);
57 ExpectScalarEqual(true, int16_t{1});
58 ExpectScalarEqual(false, int16_t{0});
59 ExpectScalarEqual(false, 0.0f);
60 ExpectScalarEqual(true, 1.0f);
61 }
62
63 template <typename T1, typename T2>
ExpectScalarNotEqual(T1 value1,T2 value2)64 void ExpectScalarNotEqual(T1 value1, T2 value2) {
65 EXPECT_NE(Scalar(value1), Scalar(value2));
66 EXPECT_NE(Scalar(value2), Scalar(value1));
67 }
68
TEST(ScalarTest,Inequality)69 TEST(ScalarTest, Inequality) {
70 // Same primitive type
71 ExpectScalarNotEqual(0, 1);
72 ExpectScalarNotEqual(-1, 1);
73 ExpectScalarNotEqual(-1.0001, -1.0);
74 ExpectScalarNotEqual(-1.0001, -1);
75 ExpectScalarNotEqual(true, false);
76 ExpectScalarNotEqual(true, 1.1);
77 ExpectScalarNotEqual(1.0001, 1.0002);
78 ExpectScalarNotEqual(std::nan(""), std::nan(""));
79
80 // Different primitive types and same kind
81 ExpectScalarNotEqual(int32_t{1}, int16_t{2});
82 ExpectScalarNotEqual(uint8_t{1}, int8_t{2});
83 ExpectScalarNotEqual(uint16_t{1}, uint8_t{2});
84 ExpectScalarNotEqual(uint8_t{0xff}, int8_t{-1});
85 ExpectScalarNotEqual(1.0f, 2.0);
86
87 // Different primitive types and different kinds
88 ExpectScalarNotEqual(int32_t{2}, 1.0);
89 ExpectScalarNotEqual(true, int16_t{2});
90 ExpectScalarNotEqual(true, int16_t{-1});
91 ExpectScalarNotEqual(false, int16_t{2});
92 ExpectScalarNotEqual(false, int16_t{-1});
93 ExpectScalarNotEqual(false, 0.1);
94 ExpectScalarNotEqual(true, 0.9);
95 ExpectScalarNotEqual(true, -1);
96 ExpectScalarNotEqual(true, std::nan(""));
97 ExpectScalarNotEqual(false, std::nan(""));
98 }
99
TEST(ScalarTest,Cast)100 TEST(ScalarTest, Cast) {
101 EXPECT_TRUE(static_cast<bool>(Scalar(true)));
102 EXPECT_TRUE(static_cast<bool>(Scalar(1)));
103 EXPECT_TRUE(static_cast<bool>(Scalar(-3.2)));
104
105 EXPECT_FALSE(static_cast<bool>(Scalar(false)));
106 EXPECT_FALSE(static_cast<bool>(Scalar(0)));
107 EXPECT_FALSE(static_cast<bool>(Scalar(0.0f)));
108
109 EXPECT_EQ(static_cast<int8_t>(Scalar(1)), 1);
110 EXPECT_EQ(static_cast<int8_t>(Scalar(-1.1f)), -1);
111 EXPECT_EQ(static_cast<int8_t>(Scalar(1.1)), 1);
112
113 EXPECT_EQ(static_cast<int16_t>(Scalar(-2)), -2);
114 EXPECT_EQ(static_cast<int16_t>(Scalar(2.2f)), 2);
115 EXPECT_EQ(static_cast<int16_t>(Scalar(2.2)), 2);
116
117 EXPECT_EQ(static_cast<int32_t>(Scalar(3)), 3);
118 EXPECT_EQ(static_cast<int32_t>(Scalar(3.3f)), 3);
119 EXPECT_EQ(static_cast<int32_t>(Scalar(-3.3)), -3);
120
121 EXPECT_EQ(static_cast<int64_t>(Scalar(4)), 4);
122 EXPECT_EQ(static_cast<int64_t>(Scalar(4.4f)), 4);
123 EXPECT_EQ(static_cast<int64_t>(Scalar(-4.4)), -4);
124
125 EXPECT_EQ(static_cast<uint8_t>(Scalar(5)), 5);
126 EXPECT_EQ(static_cast<uint8_t>(Scalar(5.5f)), 5);
127 EXPECT_EQ(static_cast<uint8_t>(Scalar(5.0)), 5);
128
129 EXPECT_FLOAT_EQ(static_cast<float>(Scalar(-6)), -6.0f);
130 EXPECT_FLOAT_EQ(static_cast<float>(Scalar(6.7f)), 6.7f);
131 EXPECT_FLOAT_EQ(static_cast<float>(Scalar(6.7)), 6.7f);
132
133 EXPECT_DOUBLE_EQ(static_cast<double>(Scalar(8)), 8.0);
134 EXPECT_DOUBLE_EQ(static_cast<double>(Scalar(-8.9f)), double{-8.9f});
135 EXPECT_DOUBLE_EQ(static_cast<double>(Scalar(8.9)), 8.9);
136 }
137
TEST(DtypeTest,UnaryOps)138 TEST(DtypeTest, UnaryOps) {
139 EXPECT_THROW(-Scalar(true), DtypeError);
140 EXPECT_THROW(-Scalar(false), DtypeError);
141 EXPECT_EQ(static_cast<int8_t>(-Scalar(1)), -1);
142 EXPECT_EQ(static_cast<int16_t>(-Scalar(2)), -2);
143 EXPECT_EQ(static_cast<int32_t>(-Scalar(3)), -3);
144 EXPECT_EQ(static_cast<int64_t>(-Scalar(4)), -4);
145 EXPECT_EQ(static_cast<uint8_t>(-Scalar(5)), static_cast<uint8_t>(-5));
146 EXPECT_FLOAT_EQ(static_cast<float>(-Scalar(6)), -6.0f);
147 EXPECT_FLOAT_EQ(static_cast<float>(-Scalar(6.7)), -6.7f);
148 EXPECT_DOUBLE_EQ(static_cast<double>(-Scalar(8)), -8.0);
149 EXPECT_DOUBLE_EQ(static_cast<double>(-Scalar(8.9)), -8.9);
150
151 EXPECT_EQ(static_cast<int8_t>(+Scalar(1)), 1);
152 EXPECT_EQ(static_cast<int16_t>(+Scalar(2)), 2);
153 EXPECT_EQ(static_cast<int32_t>(+Scalar(3)), 3);
154 EXPECT_EQ(static_cast<int64_t>(+Scalar(4)), 4);
155 EXPECT_EQ(static_cast<uint8_t>(+Scalar(5)), 5);
156 EXPECT_EQ(static_cast<float>(+Scalar(5)), 5);
157 EXPECT_FLOAT_EQ(static_cast<float>(+Scalar(6)), 6.0f);
158 EXPECT_FLOAT_EQ(static_cast<float>(+Scalar(6.7)), 6.7f);
159 EXPECT_DOUBLE_EQ(static_cast<double>(+Scalar(8)), 8.0);
160 EXPECT_DOUBLE_EQ(static_cast<double>(+Scalar(8.9)), 8.9);
161 }
162
TEST(ScalarTest,ToString)163 TEST(ScalarTest, ToString) {
164 EXPECT_EQ(Scalar(true).ToString(), "True");
165 EXPECT_EQ(Scalar(false).ToString(), "False");
166 EXPECT_EQ(Scalar(int8_t{1}).ToString(), std::to_string(int8_t{1}));
167 EXPECT_EQ(Scalar(int16_t{2}).ToString(), std::to_string(int16_t{2}));
168 EXPECT_EQ(Scalar(int32_t{3}).ToString(), std::to_string(int32_t{3}));
169 EXPECT_EQ(Scalar(int64_t{4}).ToString(), std::to_string(int64_t{4}));
170 EXPECT_EQ(Scalar(uint8_t{5}).ToString(), std::to_string(uint8_t{5}));
171 }
172
173 } // namespace
174 } // namespace chainerx
175