1 #include "chainerx/routines/creation.h"
2
3 #include <algorithm>
4 #include <cmath>
5 #include <cstddef>
6 #include <cstdint>
7 #include <memory>
8 #include <type_traits>
9 #include <utility>
10 #include <vector>
11
12 #include "chainerx/array.h"
13 #include "chainerx/backend.h"
14 #include "chainerx/backprop_mode.h"
15 #include "chainerx/backward_builder.h"
16 #include "chainerx/backward_context.h"
17 #include "chainerx/constant.h"
18 #include "chainerx/device.h"
19 #include "chainerx/dtype.h"
20 #include "chainerx/graph.h"
21 #include "chainerx/kernels/creation.h"
22 #include "chainerx/kernels/misc.h"
23 #include "chainerx/macro.h"
24 #include "chainerx/routines/indexing.h"
25 #include "chainerx/routines/type_util.h"
26 #include "chainerx/scalar.h"
27 #include "chainerx/shape.h"
28 #include "chainerx/strides.h"
29
30 namespace chainerx {
31 namespace internal {
32
GetRequiredBytes(const Shape & shape,const Strides & strides,size_t item_size)33 size_t GetRequiredBytes(const Shape& shape, const Strides& strides, size_t item_size) {
34 CHAINERX_ASSERT(shape.ndim() == strides.ndim());
35
36 if (shape.GetTotalSize() == 0) {
37 return 0;
38 }
39
40 // Calculate the distance between the first and the last element, plus single element size.
41 size_t n_bytes = item_size;
42 for (int8_t i = 0; i < shape.ndim(); ++i) {
43 n_bytes += (shape[i] - 1) * std::abs(strides[i]);
44 }
45 return n_bytes;
46 }
47
FromHostData(const Shape & shape,Dtype dtype,const std::shared_ptr<void> & data,const Strides & strides,int64_t offset,Device & device)48 Array FromHostData(
49 const Shape& shape, Dtype dtype, const std::shared_ptr<void>& data, const Strides& strides, int64_t offset, Device& device) {
50 auto range = GetDataRange(shape, strides, GetItemSize(dtype));
51 // TODO(niboshi): Copy only required region. Currently the whole preceding (offset) region is copied.
52 std::shared_ptr<void> device_data = device.FromHostMemory(data, offset + std::get<1>(range));
53 return internal::MakeArray(shape, strides, dtype, device, std::move(device_data), offset);
54 }
55
Empty(const Shape & shape,Dtype dtype,const Strides & strides,Device & device)56 Array Empty(const Shape& shape, Dtype dtype, const Strides& strides, Device& device) {
57 auto bytesize = GetRequiredBytes(shape, strides, GetItemSize(dtype));
58 std::shared_ptr<void> data = device.Allocate(bytesize);
59 return MakeArray(shape, strides, dtype, device, std::move(data));
60 }
61
EmptyReduced(const Shape & shape,Dtype dtype,const Axes & axes,bool keepdims,Device & device)62 Array EmptyReduced(const Shape& shape, Dtype dtype, const Axes& axes, bool keepdims, Device& device) {
63 Shape out_shape = ReduceShape(shape, axes, keepdims);
64 if (!keepdims) {
65 return Empty(out_shape, dtype, device);
66 }
67 // Set reduced strides of the output array to 0
68 Strides out_strides{out_shape, dtype};
69 for (int8_t axis : axes) {
70 out_strides[axis] = 0;
71 }
72 return Empty(out_shape, dtype, out_strides, device);
73 }
74
75 } // namespace internal
76
FromContiguousHostData(const Shape & shape,Dtype dtype,const std::shared_ptr<void> & data,Device & device)77 Array FromContiguousHostData(const Shape& shape, Dtype dtype, const std::shared_ptr<void>& data, Device& device) {
78 return internal::FromHostData(shape, dtype, data, {shape, dtype}, 0, device);
79 }
80
FromData(const Shape & shape,Dtype dtype,const std::shared_ptr<void> & data,const absl::optional<Strides> & strides,int64_t offset,Device & device)81 Array FromData(
82 const Shape& shape,
83 Dtype dtype,
84 const std::shared_ptr<void>& data,
85 const absl::optional<Strides>& strides,
86 int64_t offset,
87 Device& device) {
88 return internal::MakeArray(
89 shape, strides.value_or(Strides{shape, dtype}), dtype, device, device.MakeDataFromForeignPointer(data), offset);
90 }
91
Empty(const Shape & shape,Dtype dtype,Device & device)92 Array Empty(const Shape& shape, Dtype dtype, Device& device) {
93 auto bytesize = static_cast<size_t>(shape.GetTotalSize() * GetItemSize(dtype));
94 std::shared_ptr<void> data = device.Allocate(bytesize);
95 return internal::MakeArray(shape, Strides{shape, dtype}, dtype, device, std::move(data));
96 }
97
Full(const Shape & shape,Scalar fill_value,Dtype dtype,Device & device)98 Array Full(const Shape& shape, Scalar fill_value, Dtype dtype, Device& device) {
99 Array array = Empty(shape, dtype, device);
100 array.Fill(fill_value);
101 return array;
102 }
103
Full(const Shape & shape,Scalar fill_value,Device & device)104 Array Full(const Shape& shape, Scalar fill_value, Device& device) {
105 return Full(shape, fill_value, internal::GetDefaultDtype(fill_value.kind()), device);
106 }
107
Zeros(const Shape & shape,Dtype dtype,Device & device)108 Array Zeros(const Shape& shape, Dtype dtype, Device& device) { return Full(shape, 0, dtype, device); }
109
Ones(const Shape & shape,Dtype dtype,Device & device)110 Array Ones(const Shape& shape, Dtype dtype, Device& device) { return Full(shape, 1, dtype, device); }
111
Arange(Scalar start,Scalar stop,Scalar step,Dtype dtype,Device & device)112 Array Arange(Scalar start, Scalar stop, Scalar step, Dtype dtype, Device& device) {
113 // TODO(hvy): Simplify comparison if Scalar::operator== supports dtype conversion.
114 if (static_cast<double>(step) == 0.0) {
115 throw ChainerxError("Cannot create an arange array with 0 step size.");
116 }
117
118 // Compute the size of the output.
119 auto start_value = static_cast<double>(start);
120 auto stop_value = static_cast<double>(stop);
121 auto step_value = static_cast<double>(step);
122 if (step_value < 0) {
123 std::swap(start_value, stop_value);
124 step_value *= -1;
125 }
126 auto size = std::max(int64_t{0}, static_cast<int64_t>(std::ceil((stop_value - start_value) / step_value)));
127 if (size > 2 && dtype == Dtype::kBool) {
128 throw DtypeError{"Cannot create an arange array of booleans with size larger than 2."};
129 }
130
131 Array out = Empty({size}, dtype, device);
132 device.backend().CallKernel<ArangeKernel>(start, step, out);
133 return out;
134 }
135
Arange(Scalar start,Scalar stop,Scalar step,Device & device)136 Array Arange(Scalar start, Scalar stop, Scalar step, Device& device) {
137 // TODO(hvy): Type promote instead of using the dtype of step.
138 return Arange(start, stop, step, internal::GetDefaultDtype(step.kind()), device);
139 }
140
Arange(Scalar start,Scalar stop,Dtype dtype,Device & device)141 Array Arange(Scalar start, Scalar stop, Dtype dtype, Device& device) { return Arange(start, stop, 1, dtype, device); }
142
Arange(Scalar start,Scalar stop,Device & device)143 Array Arange(Scalar start, Scalar stop, Device& device) {
144 // TODO(hvy): Type promote dtype instead of using the dtype of stop.
145 return Arange(start, stop, 1, internal::GetDefaultDtype(stop.kind()), device);
146 }
147
Arange(Scalar stop,Dtype dtype,Device & device)148 Array Arange(Scalar stop, Dtype dtype, Device& device) { return Arange(0, stop, 1, dtype, device); }
149
Arange(Scalar stop,Device & device)150 Array Arange(Scalar stop, Device& device) { return Arange(0, stop, 1, internal::GetDefaultDtype(stop.kind()), device); }
151
EmptyLike(const Array & a,Device & device)152 Array EmptyLike(const Array& a, Device& device) { return Empty(a.shape(), a.dtype(), device); }
153
FullLike(const Array & a,Scalar fill_value,Device & device)154 Array FullLike(const Array& a, Scalar fill_value, Device& device) { return Full(a.shape(), fill_value, a.dtype(), device); }
155
ZerosLike(const Array & a,Device & device)156 Array ZerosLike(const Array& a, Device& device) { return Zeros(a.shape(), a.dtype(), device); }
157
OnesLike(const Array & a,Device & device)158 Array OnesLike(const Array& a, Device& device) { return Ones(a.shape(), a.dtype(), device); }
159
Copy(const Array & a)160 Array Copy(const Array& a) {
161 Array out = EmptyLike(a, a.device());
162 {
163 NoBackpropModeScope scope{};
164 a.device().backend().CallKernel<CopyKernel>(a, out);
165 }
166
167 BackwardBuilder bb{"copy", a, out};
168 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
169 bt.Define([](BackwardContext& bctx) { bctx.input_grad() = *bctx.output_grad(); });
170 }
171 bb.Finalize();
172
173 CHAINERX_ASSERT(out.IsContiguous());
174 return out;
175 }
176
177 // Creates the identity array.
Identity(int64_t n,Dtype dtype,Device & device)178 Array Identity(int64_t n, Dtype dtype, Device& device) {
179 if (n < 0) {
180 throw DimensionError{"Negative dimensions are not allowed"};
181 }
182
183 Array out = Empty(Shape{n, n}, dtype, device);
184 {
185 NoBackpropModeScope scope{};
186 device.backend().CallKernel<IdentityKernel>(out);
187 }
188 return out;
189 }
190
Eye(int64_t n,absl::optional<int64_t> m,absl::optional<int64_t> k,absl::optional<Dtype> dtype,Device & device)191 Array Eye(int64_t n, absl::optional<int64_t> m, absl::optional<int64_t> k, absl::optional<Dtype> dtype, Device& device) {
192 if (!m.has_value()) {
193 m = n;
194 }
195 if (!k.has_value()) {
196 k = 0;
197 }
198 if (!dtype.has_value()) {
199 dtype = Dtype::kFloat64;
200 }
201 if (n < 0 || m < 0) {
202 throw DimensionError{"Negative dimensions are not allowed"};
203 }
204
205 Array out = Empty({n, m.value()}, dtype.value(), device);
206 {
207 NoBackpropModeScope scope{};
208 device.backend().CallKernel<EyeKernel>(k.value(), out);
209 }
210 return out;
211 }
212
AsContiguous(const Array & a,Dtype dtype)213 Array AsContiguous(const Array& a, Dtype dtype) {
214 if (a.IsContiguous() && a.dtype() == dtype) {
215 return a;
216 }
217
218 Array out = Empty(a.shape(), dtype, a.device());
219 {
220 NoBackpropModeScope scope{};
221 // Note: In CopyKernel, Input Array Elements are casted to the type of Output Array.
222 a.device().backend().CallKernel<CopyKernel>(a.AsGradStopped(), out);
223 }
224
225 if (GetKind(dtype) == DtypeKind::kFloat) {
226 BackwardBuilder bb{"ascontiguousarray", a, out};
227 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
228 bt.Define([src_dtype = a.dtype()](BackwardContext& bctx) {
229 const Array& gout = *bctx.output_grad();
230 bctx.input_grad() = gout.AsType(src_dtype, false);
231 });
232 }
233 bb.Finalize();
234 }
235
236 CHAINERX_ASSERT(out.IsContiguous());
237 CHAINERX_ASSERT(out.shape() == a.shape());
238 CHAINERX_ASSERT(out.dtype() == dtype);
239 return out;
240 }
241
AsContiguousArray(const Array & a,absl::optional<Dtype> dtype)242 Array AsContiguousArray(const Array& a, absl::optional<Dtype> dtype) {
243 Dtype src_dt = a.dtype();
244 Dtype dt = dtype.value_or(src_dt);
245
246 if (a.IsContiguous() && src_dt == dt) {
247 if (a.ndim() == 0) {
248 return a.Reshape(Shape{1});
249 }
250 return a;
251 }
252
253 Array out = AsContiguous(a, dt);
254 if (a.ndim() == 0) {
255 out = out.Reshape({1});
256 }
257 return out;
258 }
259
Diag(const Array & v,int64_t k)260 Array Diag(const Array& v, int64_t k) {
261 Array out{};
262 Device& device = v.device();
263
264 int8_t ndim = v.ndim();
265 if (ndim == 1) {
266 // Return a square matrix with filled diagonal.
267 int64_t n = v.shape()[0] + std::abs(k);
268 out = Empty(Shape{n, n}, v.dtype(), device);
269 {
270 NoBackpropModeScope scope{};
271 device.backend().CallKernel<DiagflatKernel>(v, k, out);
272 }
273 } else if (ndim == 2) {
274 // Return the diagonal as a 1D array.
275 int64_t rows = v.shape()[0];
276 int64_t cols = v.shape()[1];
277 int64_t n = std::min(rows, cols);
278 int64_t offset{};
279 if (k >= 0) {
280 offset = k * v.strides()[1];
281 if (cols <= k + n - 1) {
282 n = std::max(int64_t{0}, cols - k);
283 }
284 } else {
285 offset = -k * v.strides()[0];
286 if (rows <= -k + n - 1) {
287 n = std::max(int64_t{0}, rows + k);
288 }
289 }
290 out = internal::MakeArray(Shape{n}, Strides{v.strides()[0] + v.strides()[1]}, v.dtype(), device, v.data(), v.offset() + offset);
291 } else {
292 throw DimensionError{"Input must be 1D or 2D."};
293 }
294
295 BackwardBuilder bb{"diag", v, out};
296 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
297 bt.Define([k](BackwardContext& bctx) {
298 const Array& gout = *bctx.output_grad();
299 bctx.input_grad() = Diag(gout, k);
300 });
301 }
302 bb.Finalize();
303
304 return out;
305 }
306
Diagflat(const Array & v,int64_t k)307 Array Diagflat(const Array& v, int64_t k) {
308 // TODO(hvy): Use Ravel or Flatten when implemented instead of Reshape.
309 return Diag(v.Reshape({v.GetTotalSize()}), k);
310 }
311
312 // Creates a 1-d array with evenly spaced numbers.
Linspace(Scalar start,Scalar stop,absl::optional<int64_t> num,bool endpoint,absl::optional<Dtype> dtype,Device & device)313 Array Linspace(Scalar start, Scalar stop, absl::optional<int64_t> num, bool endpoint, absl::optional<Dtype> dtype, Device& device) {
314 static const int64_t kDefaultNum = 50;
315
316 // Always default to float type.
317 // Similar behavior to numpy
318 // Ref: https://github.com/numpy/numpy/issues/8597
319 Dtype dtype_a = dtype.value_or(internal::GetDefaultDtype(chainerx::DtypeKind::kFloat));
320 int64_t num_a = num.value_or(kDefaultNum);
321
322 if (num_a < 0) {
323 throw ChainerxError{"Number of samples, ", num_a, ", must be non-negative"};
324 }
325
326 Array out = Empty(Shape{num_a}, dtype_a, device);
327 if (num_a > 0) {
328 auto start_value = static_cast<double>(start);
329 auto stop_value = static_cast<double>(stop);
330 if (!endpoint) {
331 stop_value = start_value + (stop_value - start_value) * (num_a - 1) / num_a;
332 }
333 {
334 NoBackpropModeScope scope{};
335 device.backend().CallKernel<LinspaceKernel>(start_value, stop_value, out);
336 }
337 }
338 return out;
339 }
340
Meshgrid(const std::vector<Array> & arrays,MeshgridIndexingMode mode)341 std::vector<Array> Meshgrid(const std::vector<Array>& arrays, MeshgridIndexingMode mode) {
342 Shape shape;
343 Shape broadcast_shape;
344 std::vector<Shape> broadcasted_array_shapes;
345 std::vector<Array> grid_arrays;
346
347 // special cases
348 // similar behavior to numpy.
349 if (arrays.empty()) {
350 return grid_arrays;
351 }
352
353 if (arrays.size() == 1) {
354 grid_arrays.emplace_back(arrays[0].Flatten());
355 return grid_arrays;
356 }
357
358 grid_arrays.reserve(arrays.size());
359 broadcasted_array_shapes.reserve(arrays.size());
360
361 // Algo
362 //
363 // Step 1: Reshape/View each array as broadcastable based
364 // on number of input vectors.
365 // Eg. For tuple of vectors (n1, n2, n3)
366 // where ni is length of that vector.
367 // After this step for Vector 1 , we will reshape it as
368 // (n1, 1, 1) , Vector 2 as (1, n2, 1)
369 //
370 // Step 2: Broadcast each vector to the shape
371 // if (indexing == "ij") -> (n1, n2, n3)
372 // else if (indexing == "xy") -> (n2, n1, n3)
373 // Note : For "xy" only n1 and n2 swap their places
374 // all others are same as "ij"
375
376 // Step 1
377 for (const Array& array : arrays) {
378 shape.emplace_back(1);
379 broadcast_shape.emplace_back(array.GetTotalSize());
380 }
381
382 // Shape for each array based on number of arrays.
383 for (size_t i = 0; i < arrays.size(); ++i) {
384 Shape temp_shape{shape.begin(), shape.end()};
385 temp_shape[i] = arrays[i].GetTotalSize();
386 broadcasted_array_shapes.emplace_back(temp_shape);
387 }
388
389 // Referred from numpy documentation and source.
390 if (mode == MeshgridIndexingMode::kCartesian) {
391 std::swap(broadcasted_array_shapes[0][0], broadcasted_array_shapes[0][1]);
392 std::swap(broadcasted_array_shapes[1][0], broadcasted_array_shapes[1][1]);
393 std::swap(broadcast_shape[0], broadcast_shape[1]);
394 }
395
396 std::vector<Array> reshaped_arrays;
397 reshaped_arrays.reserve(arrays.size());
398 for (size_t i = 0; i < arrays.size(); ++i) {
399 reshaped_arrays.emplace_back(arrays[i].Reshape(broadcasted_array_shapes[i]));
400 }
401
402 // Step 2
403 for (const Array& reshaped_array : reshaped_arrays) {
404 grid_arrays.emplace_back(reshaped_array.BroadcastTo(broadcast_shape));
405 }
406
407 return grid_arrays;
408 }
409
Tri(int64_t n,absl::optional<int64_t> m,absl::optional<int64_t> k,absl::optional<Dtype> dtype,Device & device)410 Array Tri(int64_t n, absl::optional<int64_t> m, absl::optional<int64_t> k, absl::optional<Dtype> dtype, Device& device) {
411 if (!m.has_value()) {
412 m = n;
413 }
414 if (!k.has_value()) {
415 k = 0;
416 }
417 if (!dtype.has_value()) {
418 dtype = Dtype::kFloat32;
419 }
420 // NumPy returns 0-sized array for the input with negative dimensions.
421 // This is a flaw in NumPy's implementation. Other array creation routines raise an error for negative dimensions.
422 if (n < 0 || m < 0) {
423 throw DimensionError{"Negative dimensions are not allowed"};
424 }
425
426 Array out = Empty({n, m.value()}, dtype.value(), device);
427 {
428 NoBackpropModeScope scope{};
429 device.backend().CallKernel<TriKernel>(k.value(), out);
430 }
431 return out;
432 }
433
Tril(const Array & m,int64_t k=0)434 Array Tril(const Array& m, int64_t k = 0) {
435 Array out = Empty(m.shape(), m.dtype(), m.device());
436 {
437 NoBackpropModeScope scope{};
438 Array mask{};
439 if (m.ndim() >= 2) {
440 mask = Tri(m.shape()[m.ndim() - 2], m.shape()[m.ndim() - 1], k, Dtype::kBool, m.device());
441 } else {
442 mask = Tri(m.shape()[0], m.shape()[0], k, Dtype::kBool, m.device());
443 }
444 out = Where(mask, m, 0);
445 }
446
447 BackwardBuilder bb{"tril", m, out};
448 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
449 bt.Define([ndim = m.ndim(), k](BackwardContext& bctx) {
450 if (ndim == 1) {
451 throw DimensionError{"ChainerX Tril backward is not implemented for 1-dimensional arrays."};
452 }
453 const Array& gout = *bctx.output_grad();
454 bctx.input_grad() = Tril(gout, k);
455 });
456 }
457 bb.Finalize();
458
459 return out;
460 }
461
Triu(const Array & m,int64_t k=0)462 Array Triu(const Array& m, int64_t k = 0) {
463 Array out = Empty(m.shape(), m.dtype(), m.device());
464 {
465 NoBackpropModeScope scope{};
466 Array mask{};
467 if (m.ndim() >= 2) {
468 mask = Tri(m.shape()[m.ndim() - 2], m.shape()[m.ndim() - 1], k - 1, Dtype::kBool, m.device());
469 } else {
470 mask = Tri(m.shape()[0], m.shape()[0], k - 1, Dtype::kBool, m.device());
471 }
472 out = Where(mask, 0, m);
473 }
474
475 BackwardBuilder bb{"triu", m, out};
476 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
477 bt.Define([ndim = m.ndim(), k](BackwardContext& bctx) {
478 if (ndim == 1) {
479 throw DimensionError{"ChainerX Triu backward is not implemented for 1-dimensional arrays."};
480 }
481 const Array& gout = *bctx.output_grad();
482 bctx.input_grad() = Triu(gout, k);
483 });
484 }
485 bb.Finalize();
486
487 return out;
488 }
489
490 } // namespace chainerx
491