1 #include "chainerx/array_repr.h"
2
3 #include <cassert>
4 #include <cmath>
5 #include <cstdint>
6 #include <iomanip>
7 #include <iostream>
8 #include <memory>
9 #include <sstream>
10 #include <utility>
11 #include <vector>
12
13 #include "chainerx/array.h"
14 #include "chainerx/array_index.h"
15 #include "chainerx/array_node.h"
16 #include "chainerx/backprop_mode.h"
17 #include "chainerx/constant.h"
18 #include "chainerx/device.h"
19 #include "chainerx/dtype.h"
20 #include "chainerx/indexable_array.h"
21 #include "chainerx/indexer.h"
22 #include "chainerx/native/data_type.h"
23 #include "chainerx/numeric.h"
24 #include "chainerx/routines/indexing.h"
25 #include "chainerx/routines/manipulation.h"
26 #include "chainerx/shape.h"
27
28 namespace chainerx {
29
30 namespace {
31
GetNDigits(int64_t value)32 int GetNDigits(int64_t value) {
33 int digits = 0;
34 while (value != 0) {
35 value /= 10;
36 ++digits;
37 }
38 return digits;
39 }
40
PrintNTimes(std::ostream & os,char c,int n)41 void PrintNTimes(std::ostream& os, char c, int n) {
42 while (n-- > 0) {
43 os << c;
44 }
45 }
46
47 class IntFormatter {
48 public:
Scan(int64_t value)49 void Scan(int64_t value) {
50 int digits = 0;
51 if (value < 0) {
52 ++digits;
53 value = -value;
54 }
55 digits += GetNDigits(value);
56 if (max_digits_ < digits) {
57 max_digits_ = digits;
58 }
59 }
60
Print(std::ostream & os,int64_t value) const61 void Print(std::ostream& os, int64_t value) const { os << std::setw(max_digits_) << std::right << value; }
62
63 private:
64 int max_digits_ = 1;
65 };
66
67 class FloatFormatter {
68 public:
Scan(Float16 value)69 void Scan(Float16 value) { Scan(static_cast<double>(value)); }
70
Scan(double value)71 void Scan(double value) {
72 int b_digits = 0;
73 if (value < 0) {
74 has_minus_ = true;
75 ++b_digits;
76 value = -value;
77 }
78 if (IsInf(value) || IsNan(value)) {
79 b_digits += 3;
80 if (digits_before_point_ < b_digits) {
81 digits_before_point_ = b_digits;
82 }
83 return;
84 }
85 if (value >= 100'000'000) {
86 int e_digits = GetNDigits(static_cast<int64_t>(std::log10(value)));
87 if (digits_after_e_ < e_digits) {
88 digits_after_e_ = e_digits;
89 }
90 }
91 if (digits_after_e_ > 0) {
92 return;
93 }
94
95 const auto int_frac_parts = IntFracPartsToPrint(value);
96
97 b_digits += GetNDigits(int_frac_parts.first);
98 if (digits_before_point_ < b_digits) {
99 digits_before_point_ = b_digits;
100 }
101
102 const int a_digits = GetNDigits(int_frac_parts.second) - 1;
103 if (digits_after_point_ < a_digits) {
104 digits_after_point_ = a_digits;
105 }
106 }
107
Print(std::ostream & os,Float16 value)108 void Print(std::ostream& os, Float16 value) { Print(os, static_cast<double>(value)); }
109
Print(std::ostream & os,double value)110 void Print(std::ostream& os, double value) {
111 if (digits_after_e_ > 0) {
112 int width = 12 + (has_minus_ ? 1 : 0) + digits_after_e_;
113 if (has_minus_ && !std::signbit(value)) {
114 os << ' ';
115 --width;
116 }
117 os << std::scientific << std::left << std::setw(width) << std::setprecision(8) << value;
118 } else {
119 if (IsInf(value) || IsNan(value)) {
120 os << std::right << std::setw(digits_before_point_ + digits_after_point_ + 1) << value;
121 return;
122 }
123 const auto int_frac_parts = IntFracPartsToPrint(value);
124 const int a_digits = GetNDigits(int_frac_parts.second) - 1;
125 os << std::fixed << std::right << std::setw(digits_before_point_ + a_digits + 1) << std::setprecision(a_digits)
126 << std::showpoint << value;
127 PrintNTimes(os, ' ', digits_after_point_ - a_digits);
128 }
129 }
130
131 private:
132 // Returns the integral part and fractional part as integers.
133 // Note that the fractional part is prefixed by 1 so that the information of preceding zeros is not missed.
IntFracPartsToPrint(double value)134 static std::pair<int64_t, int64_t> IntFracPartsToPrint(double value) {
135 double int_part;
136 const double frac_part = std::modf(value, &int_part);
137
138 auto shifted_frac_part = static_cast<int64_t>((std::abs(frac_part) + 1) * 100'000'000);
139 while ((shifted_frac_part % 10) == 0) {
140 shifted_frac_part /= 10;
141 }
142
143 return {static_cast<int64_t>(int_part), shifted_frac_part};
144 }
145
146 int digits_before_point_ = 1;
147 int digits_after_point_ = 0;
148 int digits_after_e_ = 0;
149 bool has_minus_ = false;
150 };
151
152 class BoolFormatter {
153 public:
Scan(bool value)154 void Scan(bool value) { (void)value; /* unused */ }
155
Print(std::ostream & os,bool value) const156 void Print(std::ostream& os, bool value) const {
157 os << (value ? " True" : "False"); // NOLINTER
158 }
159 };
160
161 template <typename T>
162 using Formatter = std::
163 conditional_t<std::is_same<T, bool>::value, BoolFormatter, std::conditional_t<IsFloatingPointV<T>, FloatFormatter, IntFormatter>>;
164
165 template <int8_t Ndim>
166 struct ArrayReprImpl {
167 template <typename T>
operator ()chainerx::__anonbeaca1150111::ArrayReprImpl168 void operator()(const Array& array, std::ostream& os) const {
169 Array native_array = array.AsGradStopped().ToNative();
170 Formatter<T> formatter;
171
172 // array should be already synchronized
173 // Let formatter scan all elements to print.
174 Indexer<Ndim> indexer{array.shape()};
175 IndexableArray<const T, Ndim> iarray{native_array};
176 for (auto it = indexer.It(0); it; ++it) {
177 T value = native::StorageToDataType<const T>(iarray[it]);
178 formatter.Scan(value);
179 }
180
181 os << "array(";
182 if (array.GetTotalSize() == 0) {
183 os << "[]";
184 } else {
185 NoBackpropModeScope scope{};
186 bool should_abbreviate = array.GetTotalSize() > kThreshold;
187 ArrayReprRecursive<T>(native_array, formatter, 7, os, should_abbreviate);
188 }
189
190 // Print the footer
191 os << ", shape=" << array.shape();
192 os << ", dtype=" << array.dtype();
193 os << ", device='" << array.device().name() << "'";
194 const std::vector<std::shared_ptr<internal::ArrayNode>>& array_nodes = internal::GetArrayBody(array)->nodes();
195 if (!array_nodes.empty()) {
196 os << ", backprop_ids=[";
197 for (size_t i = 0; i < array_nodes.size(); ++i) {
198 if (i > 0) {
199 os << ", ";
200 }
201 os << '\'' << array_nodes[i]->backprop_id() << '\'';
202 }
203 os << ']';
204 }
205 os << ')';
206 }
207
208 private:
209 static constexpr int kMaxItemNumPerLine = 10;
210 static constexpr int64_t kThreshold = 1000;
211 static constexpr int64_t kEdgeItems = 3;
212
213 // The behavior of this function is recursively defined as follows:
214 // array.ndim() == 0 => Returns a string represenation of the single scalar.
215 // array.ndim() == 1 => Returns a string: space separated scalars.
216 // array.ndim() >= 2 => Returns a string: newline separated arrays with one less dimension.
217 template <typename T>
ArrayReprRecursivechainerx::__anonbeaca1150111::ArrayReprImpl218 void ArrayReprRecursive(const Array& array, Formatter<T>& formatter, size_t indent, std::ostream& os, bool abbreviate = false) const {
219 const uint8_t ndim = array.ndim();
220 if (ndim == 0) {
221 formatter.Print(os, static_cast<T>(AsScalar(array)));
222 return;
223 }
224 auto print_indent = [ndim, indent, &os](int64_t i) {
225 if (i != 0) {
226 os << ",";
227 if (ndim > 1 || i % kMaxItemNumPerLine == 0) {
228 PrintNTimes(os, '\n', ndim - 1);
229 PrintNTimes(os, ' ', indent);
230 } else {
231 os << ' ';
232 }
233 }
234 };
235 os << "[";
236 int64_t size = array.shape().front();
237 if (abbreviate && size > kEdgeItems * 2) {
238 for (int64_t i = 0; i < kEdgeItems; ++i) {
239 print_indent(i);
240 ArrayReprRecursive<T>(internal::At(array, {ArrayIndex{i}}), formatter, indent + 1, os, abbreviate);
241 }
242 print_indent(1);
243 os << "...";
244 print_indent(1);
245 for (int64_t i = 0; i < kEdgeItems; ++i) {
246 print_indent(i);
247 ArrayReprRecursive<T>(internal::At(array, {ArrayIndex{i - kEdgeItems}}), formatter, indent + 1, os, abbreviate);
248 }
249 } else {
250 for (int64_t i = 0; i < size; ++i) {
251 print_indent(i);
252 ArrayReprRecursive<T>(internal::At(array, {ArrayIndex{i}}), formatter, indent + 1, os, abbreviate);
253 }
254 }
255 os << "]";
256 }
257 };
258
259 } // namespace
260
operator <<(std::ostream & os,const Array & array)261 std::ostream& operator<<(std::ostream& os, const Array& array) {
262 // TODO(hvy): We need to determine the output specification of this function, whether or not to align with Python repr specification,
263 // and also whether this functionality should be defined in C++ layer or Python layer.
264 // TODO(hvy): Consider using a static dimensionality.
265 VisitDtype(array.dtype(), [&os, &array](auto pt) { ArrayReprImpl<kDynamicNdim>{}.operator()<typename decltype(pt)::type>(array, os); });
266 return os;
267 }
268
ArrayRepr(const Array & array)269 std::string ArrayRepr(const Array& array) {
270 std::ostringstream os;
271 os << array;
272 return os.str();
273 }
274
275 } // namespace chainerx
276