1 #include "chainerx/scalar.h"
2 
3 #include <gtest/gtest.h>
4 
5 namespace chainerx {
6 namespace {
7 
TEST(ScalarTest,Type)8 TEST(ScalarTest, Type) {
9     EXPECT_EQ(Scalar(true).kind(), DtypeKind::kBool);
10     EXPECT_EQ(Scalar(false).kind(), DtypeKind::kBool);
11     EXPECT_EQ(Scalar(int8_t(1)).kind(), DtypeKind::kInt);
12     EXPECT_EQ(Scalar(int16_t(2)).kind(), DtypeKind::kInt);
13     EXPECT_EQ(Scalar(int32_t(3)).kind(), DtypeKind::kInt);
14     EXPECT_EQ(Scalar(int64_t(4)).kind(), DtypeKind::kInt);
15     EXPECT_EQ(Scalar(uint8_t(5)).kind(), DtypeKind::kInt);
16     EXPECT_EQ(Scalar(6.7f).kind(), DtypeKind::kFloat);
17     EXPECT_EQ(Scalar(8.9).kind(), DtypeKind::kFloat);
18 }
19 
20 template <typename T1, typename T2>
ExpectScalarEqual(T1 value1,T2 value2)21 void ExpectScalarEqual(T1 value1, T2 value2) {
22     EXPECT_EQ(Scalar(value1), Scalar(value2));
23     EXPECT_EQ(Scalar(value2), Scalar(value1));
24 }
25 
TEST(ScalarTest,Equality)26 TEST(ScalarTest, Equality) {
27     // Same primitive type
28     ExpectScalarEqual(int8_t{0}, int8_t{0});
29     ExpectScalarEqual(int16_t{0}, int16_t{0});
30     ExpectScalarEqual(int32_t{0}, int32_t{0});
31     ExpectScalarEqual(int64_t{0}, int64_t{0});
32     ExpectScalarEqual(uint8_t{0}, uint8_t{0});
33     ExpectScalarEqual(uint16_t{0}, uint16_t{0});
34     ExpectScalarEqual(uint32_t{0}, uint32_t{0});
35     ExpectScalarEqual(int8_t{1}, int8_t{1});
36     ExpectScalarEqual(int16_t{1}, int16_t{1});
37     ExpectScalarEqual(int32_t{1}, int32_t{1});
38     ExpectScalarEqual(int64_t{1}, int64_t{1});
39     ExpectScalarEqual(uint8_t{1}, uint8_t{1});
40     ExpectScalarEqual(uint16_t{1}, uint16_t{1});
41     ExpectScalarEqual(uint32_t{1}, uint32_t{1});
42     ExpectScalarEqual(1.5, 1.5);
43     ExpectScalarEqual(1.5f, 1.5f);
44     ExpectScalarEqual(-1.5, -1.5);
45     ExpectScalarEqual(-1.5f, -1.5f);
46     ExpectScalarEqual(true, true);
47     ExpectScalarEqual(false, false);
48 
49     // Different primitive types and same kind
50     ExpectScalarEqual(uint8_t{1}, int64_t{1});
51     ExpectScalarEqual(uint8_t{1}, uint32_t{1});
52     ExpectScalarEqual(int8_t{1}, int32_t{1});
53     ExpectScalarEqual(1.5f, 1.5);
54 
55     // Different primitive types and different kinds
56     ExpectScalarEqual(int32_t{1}, 1.0f);
57     ExpectScalarEqual(true, int16_t{1});
58     ExpectScalarEqual(false, int16_t{0});
59     ExpectScalarEqual(false, 0.0f);
60     ExpectScalarEqual(true, 1.0f);
61 }
62 
63 template <typename T1, typename T2>
ExpectScalarNotEqual(T1 value1,T2 value2)64 void ExpectScalarNotEqual(T1 value1, T2 value2) {
65     EXPECT_NE(Scalar(value1), Scalar(value2));
66     EXPECT_NE(Scalar(value2), Scalar(value1));
67 }
68 
TEST(ScalarTest,Inequality)69 TEST(ScalarTest, Inequality) {
70     // Same primitive type
71     ExpectScalarNotEqual(0, 1);
72     ExpectScalarNotEqual(-1, 1);
73     ExpectScalarNotEqual(-1.0001, -1.0);
74     ExpectScalarNotEqual(-1.0001, -1);
75     ExpectScalarNotEqual(true, false);
76     ExpectScalarNotEqual(true, 1.1);
77     ExpectScalarNotEqual(1.0001, 1.0002);
78     ExpectScalarNotEqual(std::nan(""), std::nan(""));
79 
80     // Different primitive types and same kind
81     ExpectScalarNotEqual(int32_t{1}, int16_t{2});
82     ExpectScalarNotEqual(uint8_t{1}, int8_t{2});
83     ExpectScalarNotEqual(uint16_t{1}, uint8_t{2});
84     ExpectScalarNotEqual(uint8_t{0xff}, int8_t{-1});
85     ExpectScalarNotEqual(1.0f, 2.0);
86 
87     // Different primitive types and different kinds
88     ExpectScalarNotEqual(int32_t{2}, 1.0);
89     ExpectScalarNotEqual(true, int16_t{2});
90     ExpectScalarNotEqual(true, int16_t{-1});
91     ExpectScalarNotEqual(false, int16_t{2});
92     ExpectScalarNotEqual(false, int16_t{-1});
93     ExpectScalarNotEqual(false, 0.1);
94     ExpectScalarNotEqual(true, 0.9);
95     ExpectScalarNotEqual(true, -1);
96     ExpectScalarNotEqual(true, std::nan(""));
97     ExpectScalarNotEqual(false, std::nan(""));
98 }
99 
TEST(ScalarTest,Cast)100 TEST(ScalarTest, Cast) {
101     EXPECT_TRUE(static_cast<bool>(Scalar(true)));
102     EXPECT_TRUE(static_cast<bool>(Scalar(1)));
103     EXPECT_TRUE(static_cast<bool>(Scalar(-3.2)));
104 
105     EXPECT_FALSE(static_cast<bool>(Scalar(false)));
106     EXPECT_FALSE(static_cast<bool>(Scalar(0)));
107     EXPECT_FALSE(static_cast<bool>(Scalar(0.0f)));
108 
109     EXPECT_EQ(static_cast<int8_t>(Scalar(1)), 1);
110     EXPECT_EQ(static_cast<int8_t>(Scalar(-1.1f)), -1);
111     EXPECT_EQ(static_cast<int8_t>(Scalar(1.1)), 1);
112 
113     EXPECT_EQ(static_cast<int16_t>(Scalar(-2)), -2);
114     EXPECT_EQ(static_cast<int16_t>(Scalar(2.2f)), 2);
115     EXPECT_EQ(static_cast<int16_t>(Scalar(2.2)), 2);
116 
117     EXPECT_EQ(static_cast<int32_t>(Scalar(3)), 3);
118     EXPECT_EQ(static_cast<int32_t>(Scalar(3.3f)), 3);
119     EXPECT_EQ(static_cast<int32_t>(Scalar(-3.3)), -3);
120 
121     EXPECT_EQ(static_cast<int64_t>(Scalar(4)), 4);
122     EXPECT_EQ(static_cast<int64_t>(Scalar(4.4f)), 4);
123     EXPECT_EQ(static_cast<int64_t>(Scalar(-4.4)), -4);
124 
125     EXPECT_EQ(static_cast<uint8_t>(Scalar(5)), 5);
126     EXPECT_EQ(static_cast<uint8_t>(Scalar(5.5f)), 5);
127     EXPECT_EQ(static_cast<uint8_t>(Scalar(5.0)), 5);
128 
129     EXPECT_FLOAT_EQ(static_cast<float>(Scalar(-6)), -6.0f);
130     EXPECT_FLOAT_EQ(static_cast<float>(Scalar(6.7f)), 6.7f);
131     EXPECT_FLOAT_EQ(static_cast<float>(Scalar(6.7)), 6.7f);
132 
133     EXPECT_DOUBLE_EQ(static_cast<double>(Scalar(8)), 8.0);
134     EXPECT_DOUBLE_EQ(static_cast<double>(Scalar(-8.9f)), double{-8.9f});
135     EXPECT_DOUBLE_EQ(static_cast<double>(Scalar(8.9)), 8.9);
136 }
137 
TEST(DtypeTest,UnaryOps)138 TEST(DtypeTest, UnaryOps) {
139     EXPECT_THROW(-Scalar(true), DtypeError);
140     EXPECT_THROW(-Scalar(false), DtypeError);
141     EXPECT_EQ(static_cast<int8_t>(-Scalar(1)), -1);
142     EXPECT_EQ(static_cast<int16_t>(-Scalar(2)), -2);
143     EXPECT_EQ(static_cast<int32_t>(-Scalar(3)), -3);
144     EXPECT_EQ(static_cast<int64_t>(-Scalar(4)), -4);
145     EXPECT_EQ(static_cast<uint8_t>(-Scalar(5)), static_cast<uint8_t>(-5));
146     EXPECT_FLOAT_EQ(static_cast<float>(-Scalar(6)), -6.0f);
147     EXPECT_FLOAT_EQ(static_cast<float>(-Scalar(6.7)), -6.7f);
148     EXPECT_DOUBLE_EQ(static_cast<double>(-Scalar(8)), -8.0);
149     EXPECT_DOUBLE_EQ(static_cast<double>(-Scalar(8.9)), -8.9);
150 
151     EXPECT_EQ(static_cast<int8_t>(+Scalar(1)), 1);
152     EXPECT_EQ(static_cast<int16_t>(+Scalar(2)), 2);
153     EXPECT_EQ(static_cast<int32_t>(+Scalar(3)), 3);
154     EXPECT_EQ(static_cast<int64_t>(+Scalar(4)), 4);
155     EXPECT_EQ(static_cast<uint8_t>(+Scalar(5)), 5);
156     EXPECT_EQ(static_cast<float>(+Scalar(5)), 5);
157     EXPECT_FLOAT_EQ(static_cast<float>(+Scalar(6)), 6.0f);
158     EXPECT_FLOAT_EQ(static_cast<float>(+Scalar(6.7)), 6.7f);
159     EXPECT_DOUBLE_EQ(static_cast<double>(+Scalar(8)), 8.0);
160     EXPECT_DOUBLE_EQ(static_cast<double>(+Scalar(8.9)), 8.9);
161 }
162 
TEST(ScalarTest,ToString)163 TEST(ScalarTest, ToString) {
164     EXPECT_EQ(Scalar(true).ToString(), "True");
165     EXPECT_EQ(Scalar(false).ToString(), "False");
166     EXPECT_EQ(Scalar(int8_t{1}).ToString(), std::to_string(int8_t{1}));
167     EXPECT_EQ(Scalar(int16_t{2}).ToString(), std::to_string(int16_t{2}));
168     EXPECT_EQ(Scalar(int32_t{3}).ToString(), std::to_string(int32_t{3}));
169     EXPECT_EQ(Scalar(int64_t{4}).ToString(), std::to_string(int64_t{4}));
170     EXPECT_EQ(Scalar(uint8_t{5}).ToString(), std::to_string(uint8_t{5}));
171 }
172 
173 }  // namespace
174 }  // namespace chainerx
175