1 #include "chainerx/array_repr.h"
2 
3 #include <cassert>
4 #include <cmath>
5 #include <cstdint>
6 #include <iomanip>
7 #include <iostream>
8 #include <memory>
9 #include <sstream>
10 #include <utility>
11 #include <vector>
12 
13 #include "chainerx/array.h"
14 #include "chainerx/array_index.h"
15 #include "chainerx/array_node.h"
16 #include "chainerx/backprop_mode.h"
17 #include "chainerx/constant.h"
18 #include "chainerx/device.h"
19 #include "chainerx/dtype.h"
20 #include "chainerx/indexable_array.h"
21 #include "chainerx/indexer.h"
22 #include "chainerx/native/data_type.h"
23 #include "chainerx/numeric.h"
24 #include "chainerx/routines/indexing.h"
25 #include "chainerx/routines/manipulation.h"
26 #include "chainerx/shape.h"
27 
28 namespace chainerx {
29 
30 namespace {
31 
GetNDigits(int64_t value)32 int GetNDigits(int64_t value) {
33     int digits = 0;
34     while (value != 0) {
35         value /= 10;
36         ++digits;
37     }
38     return digits;
39 }
40 
PrintNTimes(std::ostream & os,char c,int n)41 void PrintNTimes(std::ostream& os, char c, int n) {
42     while (n-- > 0) {
43         os << c;
44     }
45 }
46 
47 class IntFormatter {
48 public:
Scan(int64_t value)49     void Scan(int64_t value) {
50         int digits = 0;
51         if (value < 0) {
52             ++digits;
53             value = -value;
54         }
55         digits += GetNDigits(value);
56         if (max_digits_ < digits) {
57             max_digits_ = digits;
58         }
59     }
60 
Print(std::ostream & os,int64_t value) const61     void Print(std::ostream& os, int64_t value) const { os << std::setw(max_digits_) << std::right << value; }
62 
63 private:
64     int max_digits_ = 1;
65 };
66 
67 class FloatFormatter {
68 public:
Scan(Float16 value)69     void Scan(Float16 value) { Scan(static_cast<double>(value)); }
70 
Scan(double value)71     void Scan(double value) {
72         int b_digits = 0;
73         if (value < 0) {
74             has_minus_ = true;
75             ++b_digits;
76             value = -value;
77         }
78         if (IsInf(value) || IsNan(value)) {
79             b_digits += 3;
80             if (digits_before_point_ < b_digits) {
81                 digits_before_point_ = b_digits;
82             }
83             return;
84         }
85         if (value >= 100'000'000) {
86             int e_digits = GetNDigits(static_cast<int64_t>(std::log10(value)));
87             if (digits_after_e_ < e_digits) {
88                 digits_after_e_ = e_digits;
89             }
90         }
91         if (digits_after_e_ > 0) {
92             return;
93         }
94 
95         const auto int_frac_parts = IntFracPartsToPrint(value);
96 
97         b_digits += GetNDigits(int_frac_parts.first);
98         if (digits_before_point_ < b_digits) {
99             digits_before_point_ = b_digits;
100         }
101 
102         const int a_digits = GetNDigits(int_frac_parts.second) - 1;
103         if (digits_after_point_ < a_digits) {
104             digits_after_point_ = a_digits;
105         }
106     }
107 
Print(std::ostream & os,Float16 value)108     void Print(std::ostream& os, Float16 value) { Print(os, static_cast<double>(value)); }
109 
Print(std::ostream & os,double value)110     void Print(std::ostream& os, double value) {
111         if (digits_after_e_ > 0) {
112             int width = 12 + (has_minus_ ? 1 : 0) + digits_after_e_;
113             if (has_minus_ && !std::signbit(value)) {
114                 os << ' ';
115                 --width;
116             }
117             os << std::scientific << std::left << std::setw(width) << std::setprecision(8) << value;
118         } else {
119             if (IsInf(value) || IsNan(value)) {
120                 os << std::right << std::setw(digits_before_point_ + digits_after_point_ + 1) << value;
121                 return;
122             }
123             const auto int_frac_parts = IntFracPartsToPrint(value);
124             const int a_digits = GetNDigits(int_frac_parts.second) - 1;
125             os << std::fixed << std::right << std::setw(digits_before_point_ + a_digits + 1) << std::setprecision(a_digits)
126                << std::showpoint << value;
127             PrintNTimes(os, ' ', digits_after_point_ - a_digits);
128         }
129     }
130 
131 private:
132     // Returns the integral part and fractional part as integers.
133     // Note that the fractional part is prefixed by 1 so that the information of preceding zeros is not missed.
IntFracPartsToPrint(double value)134     static std::pair<int64_t, int64_t> IntFracPartsToPrint(double value) {
135         double int_part;
136         const double frac_part = std::modf(value, &int_part);
137 
138         auto shifted_frac_part = static_cast<int64_t>((std::abs(frac_part) + 1) * 100'000'000);
139         while ((shifted_frac_part % 10) == 0) {
140             shifted_frac_part /= 10;
141         }
142 
143         return {static_cast<int64_t>(int_part), shifted_frac_part};
144     }
145 
146     int digits_before_point_ = 1;
147     int digits_after_point_ = 0;
148     int digits_after_e_ = 0;
149     bool has_minus_ = false;
150 };
151 
152 class BoolFormatter {
153 public:
Scan(bool value)154     void Scan(bool value) { (void)value; /* unused */ }
155 
Print(std::ostream & os,bool value) const156     void Print(std::ostream& os, bool value) const {
157         os << (value ? " True" : "False");  // NOLINTER
158     }
159 };
160 
161 template <typename T>
162 using Formatter = std::
163         conditional_t<std::is_same<T, bool>::value, BoolFormatter, std::conditional_t<IsFloatingPointV<T>, FloatFormatter, IntFormatter>>;
164 
165 template <int8_t Ndim>
166 struct ArrayReprImpl {
167     template <typename T>
operator ()chainerx::__anonbeaca1150111::ArrayReprImpl168     void operator()(const Array& array, std::ostream& os) const {
169         Array native_array = array.AsGradStopped().ToNative();
170         Formatter<T> formatter;
171 
172         // array should be already synchronized
173         // Let formatter scan all elements to print.
174         Indexer<Ndim> indexer{array.shape()};
175         IndexableArray<const T, Ndim> iarray{native_array};
176         for (auto it = indexer.It(0); it; ++it) {
177             T value = native::StorageToDataType<const T>(iarray[it]);
178             formatter.Scan(value);
179         }
180 
181         os << "array(";
182         if (array.GetTotalSize() == 0) {
183             os << "[]";
184         } else {
185             NoBackpropModeScope scope{};
186             bool should_abbreviate = array.GetTotalSize() > kThreshold;
187             ArrayReprRecursive<T>(native_array, formatter, 7, os, should_abbreviate);
188         }
189 
190         // Print the footer
191         os << ", shape=" << array.shape();
192         os << ", dtype=" << array.dtype();
193         os << ", device='" << array.device().name() << "'";
194         const std::vector<std::shared_ptr<internal::ArrayNode>>& array_nodes = internal::GetArrayBody(array)->nodes();
195         if (!array_nodes.empty()) {
196             os << ", backprop_ids=[";
197             for (size_t i = 0; i < array_nodes.size(); ++i) {
198                 if (i > 0) {
199                     os << ", ";
200                 }
201                 os << '\'' << array_nodes[i]->backprop_id() << '\'';
202             }
203             os << ']';
204         }
205         os << ')';
206     }
207 
208 private:
209     static constexpr int kMaxItemNumPerLine = 10;
210     static constexpr int64_t kThreshold = 1000;
211     static constexpr int64_t kEdgeItems = 3;
212 
213     // The behavior of this function is recursively defined as follows:
214     // array.ndim() == 0 => Returns a string represenation of the single scalar.
215     // array.ndim() == 1 => Returns a string: space separated scalars.
216     // array.ndim() >= 2 => Returns a string: newline separated arrays with one less dimension.
217     template <typename T>
ArrayReprRecursivechainerx::__anonbeaca1150111::ArrayReprImpl218     void ArrayReprRecursive(const Array& array, Formatter<T>& formatter, size_t indent, std::ostream& os, bool abbreviate = false) const {
219         const uint8_t ndim = array.ndim();
220         if (ndim == 0) {
221             formatter.Print(os, static_cast<T>(AsScalar(array)));
222             return;
223         }
224         auto print_indent = [ndim, indent, &os](int64_t i) {
225             if (i != 0) {
226                 os << ",";
227                 if (ndim > 1 || i % kMaxItemNumPerLine == 0) {
228                     PrintNTimes(os, '\n', ndim - 1);
229                     PrintNTimes(os, ' ', indent);
230                 } else {
231                     os << ' ';
232                 }
233             }
234         };
235         os << "[";
236         int64_t size = array.shape().front();
237         if (abbreviate && size > kEdgeItems * 2) {
238             for (int64_t i = 0; i < kEdgeItems; ++i) {
239                 print_indent(i);
240                 ArrayReprRecursive<T>(internal::At(array, {ArrayIndex{i}}), formatter, indent + 1, os, abbreviate);
241             }
242             print_indent(1);
243             os << "...";
244             print_indent(1);
245             for (int64_t i = 0; i < kEdgeItems; ++i) {
246                 print_indent(i);
247                 ArrayReprRecursive<T>(internal::At(array, {ArrayIndex{i - kEdgeItems}}), formatter, indent + 1, os, abbreviate);
248             }
249         } else {
250             for (int64_t i = 0; i < size; ++i) {
251                 print_indent(i);
252                 ArrayReprRecursive<T>(internal::At(array, {ArrayIndex{i}}), formatter, indent + 1, os, abbreviate);
253             }
254         }
255         os << "]";
256     }
257 };
258 
259 }  // namespace
260 
operator <<(std::ostream & os,const Array & array)261 std::ostream& operator<<(std::ostream& os, const Array& array) {
262     // TODO(hvy): We need to determine the output specification of this function, whether or not to align with Python repr specification,
263     // and also whether this functionality should be defined in C++ layer or Python layer.
264     // TODO(hvy): Consider using a static dimensionality.
265     VisitDtype(array.dtype(), [&os, &array](auto pt) { ArrayReprImpl<kDynamicNdim>{}.operator()<typename decltype(pt)::type>(array, os); });
266     return os;
267 }
268 
ArrayRepr(const Array & array)269 std::string ArrayRepr(const Array& array) {
270     std::ostringstream os;
271     os << array;
272     return os.str();
273 }
274 
275 }  // namespace chainerx
276