1 #include "chainerx/routines/manipulation.h"
2
3 #include <algorithm>
4 #include <cstddef>
5 #include <cstdint>
6 #include <functional>
7 #include <numeric>
8 #include <set>
9 #include <sstream>
10 #include <string>
11 #include <utility>
12 #include <vector>
13
14 #include <absl/types/optional.h>
15
16 #include "chainerx/array.h"
17 #include "chainerx/axes.h"
18 #include "chainerx/backend.h"
19 #include "chainerx/backprop_mode.h"
20 #include "chainerx/backward_builder.h"
21 #include "chainerx/backward_context.h"
22 #include "chainerx/device.h"
23 #include "chainerx/dtype.h"
24 #include "chainerx/error.h"
25 #include "chainerx/graph.h"
26 #include "chainerx/kernels/creation.h"
27 #include "chainerx/kernels/indexing.h"
28 #include "chainerx/kernels/misc.h"
29 #include "chainerx/macro.h"
30 #include "chainerx/routines/creation.h"
31 #include "chainerx/routines/indexing.h"
32 #include "chainerx/routines/routines_util.h"
33 #include "chainerx/routines/type_util.h"
34 #include "chainerx/shape.h"
35 #include "chainerx/strides.h"
36
37 namespace chainerx {
38
AsScalar(const Array & a)39 Scalar AsScalar(const Array& a) {
40 if (a.GetTotalSize() != 1) {
41 throw DimensionError{"Cannot convert an array of size ", a.GetTotalSize(), " to a scalar, size must be 1."};
42 }
43
44 // Copy to the native device
45 Array native_copy = a.ToNative();
46
47 // Retrieve the value
48 return VisitDtype(a.dtype(), [&native_copy](auto pt) -> Scalar {
49 using T = typename decltype(pt)::type;
50 const uint8_t* ptr = static_cast<const uint8_t*>(native_copy.data().get()) + native_copy.offset();
51 auto typed_ptr = reinterpret_cast<const T*>(ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
52 return Scalar{*typed_ptr};
53 });
54 }
55
RollAxis(const Array & a,int8_t axis,int8_t start)56 Array RollAxis(const Array& a, int8_t axis, int8_t start) {
57 // TODO(hvy): Optimize the implementation.
58 axis = internal::NormalizeAxis(axis, a.ndim());
59
60 // start can be a.ndim() so we cannot use NormalizeAxis here.
61 if (start < -a.ndim() || start > a.ndim()) {
62 throw DimensionError{"start arg out of bounds. start: ", start, ", ndim: ", a.ndim()};
63 }
64 if (start < 0) {
65 start += a.ndim();
66 }
67
68 Axes axes;
69 for (int8_t i = 0; i < a.ndim(); ++i) {
70 if (i == start) {
71 axes.emplace_back(axis);
72 }
73 if (i != axis) {
74 axes.emplace_back(i);
75 }
76 }
77 if (start == a.ndim()) {
78 axes.emplace_back(axis);
79 }
80 return Transpose(a, axes);
81 }
82
Transpose(const Array & a,const OptionalAxes & axes)83 Array Transpose(const Array& a, const OptionalAxes& axes) {
84 Axes real_axes;
85 if (axes.has_value()) {
86 if (axes->ndim() != a.ndim()) {
87 throw DimensionError{"Axes do not match, input array dimensions: ", a.ndim(), " but axes: ", axes->ndim()};
88 }
89 real_axes = internal::GetNormalizedAxes(*axes, a.ndim());
90 } else {
91 for (int8_t i = 0; i < a.ndim(); ++i) {
92 real_axes.emplace_back(a.ndim() - i - 1);
93 }
94 }
95 CHAINERX_ASSERT(real_axes.ndim() == a.ndim());
96
97 Shape out_shape;
98 Strides out_strides;
99 for (int8_t axis : real_axes) {
100 out_shape.emplace_back(a.shape()[axis]);
101 out_strides.emplace_back(a.strides()[axis]);
102 }
103
104 Array out = internal::MakeArray(out_shape, out_strides, a.dtype(), a.device(), a.data(), a.offset());
105
106 BackwardBuilder bb{"transpose", a, out};
107 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
108 bt.Define([real_axes](BackwardContext& bctx) {
109 Axes backward_axes;
110 backward_axes.resize(real_axes.ndim());
111 for (int8_t i = 0; i < real_axes.ndim(); ++i) {
112 backward_axes[real_axes[i]] = i;
113 }
114 bctx.input_grad() = bctx.output_grad()->Transpose(backward_axes);
115 });
116 }
117 bb.Finalize();
118
119 return out;
120 }
121
122 namespace {
123
124 // Returns a shape where the length of at most one dimension is inferred from the total size and the remaining dimensions.
125 // Such a dimension is given by a negative length, i.e. Shape{2, 3, -1}.
126 // If the given shape does not contain such a dimension, this function will return a copy of the given shape.
127 // If there exists multiple negative lengths or if the negative length dimension cannot be inferred due to non divisbility, an
128 // DimensionError is thrown.
GetInferredShape(const Shape & shape,int64_t total_size)129 Shape GetInferredShape(const Shape& shape, int64_t total_size) {
130 Shape inferred_shape = shape;
131
132 auto it = std::find_if(inferred_shape.begin(), inferred_shape.end(), [](int64_t dim) { return dim < 0; });
133 if (it != inferred_shape.end()) {
134 if (std::find_if(std::next(it), inferred_shape.end(), [](int64_t dim) { return dim < 0; }) != inferred_shape.end()) {
135 throw DimensionError{"Can only specify one unknown dimension"};
136 }
137 int64_t rest_size = std::accumulate(inferred_shape.begin(), it, int64_t{1}, std::multiplies<>()) *
138 std::accumulate(std::next(it), inferred_shape.end(), int64_t{1}, std::multiplies<>());
139 if (rest_size == 0) {
140 throw DimensionError{"Cannot reshape array of size ", total_size, " into an ambiguous shape ", shape};
141 }
142 *it = total_size / rest_size;
143 }
144
145 if (total_size != inferred_shape.GetTotalSize()) {
146 throw DimensionError{"Cannot reshape array of size ", total_size, " into shape ", shape};
147 }
148 return inferred_shape;
149 }
150
151 } // namespace
152
Reshape(const Array & a,const Shape & newshape)153 Array Reshape(const Array& a, const Shape& newshape) {
154 const Shape& in_shape = a.shape();
155 const Strides& in_strides = a.strides();
156
157 // If the shape is unchanged, just return a view.
158 if (in_shape == newshape) {
159 return a.MakeView();
160 }
161
162 // Check for invalid shape.
163 int64_t total_size = in_shape.GetTotalSize();
164 Shape out_shape = GetInferredShape(newshape, total_size);
165 int64_t item_size = a.GetItemSize();
166 Strides strides{};
167 if (total_size == 0) {
168 // Calculate the strides for 0-sized array.
169 strides.resize(out_shape.ndim());
170 strides.back() = item_size;
171 for (int8_t i = out_shape.ndim() - 1; i >= 1; --i) {
172 strides[i - 1] = strides[i] * std::max(int64_t{1}, out_shape[i]);
173 }
174 } else {
175 // Calculate the strides for non-0-sized array.
176
177 // reduced_shape and reduced_strides are the shortest shape and strides which can be convertible from input shape and strides
178 // without copy.
179 Shape reduced_shape{};
180 Strides reduced_strides{};
181 if (total_size == 1) {
182 reduced_shape.emplace_back(int64_t{1});
183 reduced_strides.emplace_back(item_size);
184 } else {
185 int8_t i = 0;
186 // Ignore preceding 1-length dimensions
187 while (i < in_shape.ndim() && in_shape[i] == 1) {
188 ++i;
189 }
190 // Add the first pair
191 reduced_shape.emplace_back(in_shape[i]);
192 reduced_strides.emplace_back(in_strides[i]);
193 ++i;
194 // Reduce the remaining
195 for (; i < in_shape.ndim(); ++i) {
196 int64_t dim = in_shape[i];
197 int64_t st = in_strides[i];
198 CHAINERX_ASSERT(dim > 0);
199 if (dim == 1) {
200 // If the axis has unit-length, skip this dimension.
201 } else if (dim * st == reduced_strides.back()) {
202 // If the pair is compatible with the previous stride, reduce the pair to it.
203 reduced_shape.back() *= dim;
204 reduced_strides.back() = st;
205 } else {
206 // Otherwise, add a new shape and stride.
207 reduced_shape.emplace_back(dim);
208 reduced_strides.emplace_back(st);
209 }
210 }
211 }
212 CHAINERX_ASSERT(reduced_shape.size() == reduced_strides.size());
213 CHAINERX_ASSERT(!reduced_shape.empty());
214
215 // Construct the strides for no-copy reshape.
216 // If it's not possible, can_reshape_without_copy will be false.
217 bool can_reshape_without_copy = true;
218 if (out_shape.ndim() > 0) {
219 int64_t last_stride = reduced_shape[0] * reduced_strides[0];
220 size_t i_dim = 0;
221 for (int64_t dim : out_shape) {
222 if (dim <= 1) {
223 strides.emplace_back(last_stride);
224 continue;
225 }
226 if (i_dim >= reduced_shape.size() || reduced_shape[i_dim] % dim != 0) {
227 strides.clear();
228 can_reshape_without_copy = false;
229 break;
230 }
231 reduced_shape[i_dim] /= dim;
232 last_stride = reduced_shape[i_dim] * reduced_strides[i_dim];
233 strides.emplace_back(last_stride);
234 if (reduced_shape[i_dim] == 1) {
235 ++i_dim;
236 }
237 }
238 }
239
240 if (!can_reshape_without_copy) {
241 // Copy is required.
242 return a.Copy().Reshape(out_shape);
243 }
244 CHAINERX_ASSERT(strides.size() == out_shape.size());
245 }
246
247 Array out = internal::MakeArray(out_shape, strides, a.dtype(), a.device(), a.data(), a.offset());
248
249 BackwardBuilder bb{"reshape", a, out};
250 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
251 bt.Define([in_shape](BackwardContext& bctx) { bctx.input_grad() = bctx.output_grad()->Reshape(in_shape); });
252 }
253 bb.Finalize();
254
255 CHAINERX_ASSERT(out.shape() == out_shape);
256 CHAINERX_ASSERT(out.strides().size() == out_shape.size());
257 return out;
258 }
259
Squeeze(const Array & a,const OptionalAxes & axis)260 Array Squeeze(const Array& a, const OptionalAxes& axis) {
261 const Shape& in_shape = a.shape();
262 const Strides& in_strides = a.strides();
263
264 Shape out_shape{};
265 Strides out_strides{};
266
267 if (axis.has_value()) {
268 const Axes sorted_axis = internal::GetSortedAxes(*axis, in_shape.ndim());
269
270 int64_t i_axis = 0;
271 for (int64_t i = 0; i < in_shape.ndim(); ++i) {
272 if (i_axis < static_cast<int64_t>(sorted_axis.size()) && sorted_axis[i_axis] == i) {
273 ++i_axis;
274 if (in_shape[i] != 1) {
275 std::ostringstream os;
276 os << "Cannot squeeze out non-unit-length axes, where shape was " << in_shape.ToString();
277 os << " and axes were (";
278 for (auto it = axis->begin(); it != axis->end(); ++it) {
279 if (it != axis->begin()) {
280 os << ", ";
281 }
282 os << *it;
283 }
284 os << (axis->size() == 1 ? ",)." : ").");
285 throw DimensionError{os.str()};
286 }
287 } else {
288 out_shape.emplace_back(in_shape[i]);
289 out_strides.emplace_back(in_strides[i]);
290 }
291 }
292 } else { // All axes are candidates for removal if none are given.
293 for (int64_t i = 0; i < in_shape.ndim(); ++i) {
294 if (in_shape[i] != 1) {
295 out_shape.emplace_back(in_shape[i]);
296 out_strides.emplace_back(in_strides[i]);
297 }
298 }
299 }
300
301 if (in_shape.size() == out_shape.size()) {
302 return a;
303 }
304
305 Array out = internal::MakeArray(out_shape, out_strides, a.dtype(), a.device(), a.data(), a.offset());
306
307 BackwardBuilder bb{"squeeze", a, out};
308 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
309 bt.Define([in_shape](BackwardContext& bctx) { bctx.input_grad() = bctx.output_grad()->Reshape(in_shape); });
310 }
311 bb.Finalize();
312
313 return out;
314 }
315
BroadcastTo(const Array & array,const Shape & shape)316 Array BroadcastTo(const Array& array, const Shape& shape) {
317 const Shape& in_shape = array.shape();
318 const Strides& in_strides = array.strides();
319
320 if (in_shape.size() > shape.size()) {
321 throw DimensionError{"Cannot broadcast to smaller dimensions from ", in_shape, " to ", shape, "."};
322 }
323
324 // Compute the new set of strides after broadcastining.
325 Strides strides;
326 strides.resize(shape.ndim());
327 int8_t i_in = in_shape.ndim() - 1;
328 for (int8_t i_out = shape.ndim() - 1; i_out >= 0; --i_out) {
329 int64_t out_dim = shape[i_out];
330 // If this dimension is to be broadcasted, nonbroadcast_stride is unset.
331 // Otherwise, it holds the new stride.
332 absl::optional<int64_t> nonbroadcast_stride{};
333 if (i_in >= 0) {
334 int64_t in_dim = in_shape[i_in];
335 if (in_dim == 1) {
336 // do nothing; broadcast
337 } else if (in_dim == out_dim) {
338 nonbroadcast_stride = in_strides[i_in];
339 } else {
340 throw DimensionError{"Invalid broadcast from ", in_shape, " to ", shape};
341 }
342 --i_in;
343 } else {
344 // do nothing; broadcast
345 }
346
347 if (nonbroadcast_stride.has_value()) {
348 // non-broadcast dimension
349 strides[i_out] = nonbroadcast_stride.value();
350 } else {
351 // broadcast dimension
352 strides[i_out] = int64_t{0};
353 }
354 }
355 CHAINERX_ASSERT(i_in == -1);
356 CHAINERX_ASSERT(strides.ndim() == shape.ndim());
357
358 Array out = internal::MakeArray(shape, strides, array.dtype(), array.device(), array.data(), array.offset());
359
360 BackwardBuilder bb{"broadcast_to", array, out};
361 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
362 bt.Define([in_shape](BackwardContext& bctx) {
363 const Array& gout = *bctx.output_grad();
364 if (gout.shape() == in_shape) {
365 bctx.input_grad() = gout;
366 return;
367 }
368
369 int8_t lead = gout.ndim() - in_shape.ndim();
370 Axes lead_axis{};
371 lead_axis.resize(lead);
372 std::iota(lead_axis.begin(), lead_axis.end(), int8_t{0});
373
374 Axes axis{lead_axis};
375 for (int8_t i = 0; i < in_shape.ndim(); ++i) {
376 if (in_shape[i] == 1) {
377 axis.emplace_back(i + lead);
378 }
379 }
380 axis.erase(std::unique(axis.begin(), axis.end()), axis.end()); // Sum does not accept axis with duplicate elements
381
382 Array gin = gout.Sum(axis, true);
383 if (lead > 0) {
384 bctx.input_grad() = gin.Squeeze(lead_axis);
385 } else {
386 bctx.input_grad() = std::move(gin);
387 }
388 });
389 }
390 bb.Finalize();
391
392 return out;
393 }
394
395 namespace {
396
ConcatenateImpl(const std::vector<Array> & arrays,int8_t axis)397 Array ConcatenateImpl(const std::vector<Array>& arrays, int8_t axis) {
398 if (arrays.empty()) {
399 throw DimensionError{"Need at least one array to concatenate"};
400 }
401
402 Shape shape = arrays.front().shape();
403 Dtype out_dtype = ResultType(arrays);
404 Device& device = arrays.front().device();
405 int8_t ndim = arrays.front().ndim();
406 axis = internal::NormalizeAxis(axis, ndim);
407 shape[axis] = 0;
408 std::vector<int64_t> indices;
409 indices.reserve(arrays.size() - 1);
410
411 for (const Array& array : arrays) {
412 const Shape& s = array.shape();
413 if (ndim != array.ndim()) {
414 throw DimensionError{"All the input arrays must have same number of dimensions"};
415 }
416 for (int8_t i = 0; i < ndim; ++i) {
417 if (axis == i) {
418 shape[i] += s[i];
419 } else if (shape[i] != s[i]) {
420 throw DimensionError{"All the input array dimensions except for the concatenation axis must match exactly"};
421 }
422 }
423 if (indices.size() < arrays.size() - 1) {
424 indices.emplace_back(shape[axis]);
425 }
426 }
427
428 Strides strides{shape, out_dtype};
429
430 // Aligning with NumPy strides behavior
431 auto last_zero_it = std::find(shape.rbegin(), shape.rend(), int64_t{0});
432 if (last_zero_it != shape.rend()) {
433 std::fill(strides.rbegin() + (last_zero_it - shape.rbegin() + 1), strides.rend(), int64_t{0});
434 }
435
436 Array out = internal::Empty(shape, out_dtype, strides, device);
437
438 size_t in_size = arrays.size();
439
440 // If input dtypes are mixed, elements in the input arrays are casted to the resulting dtype.
441 // Their original dtypes must therefore be remembered in order to cast the computed gradients back in the backward pass.
442 std::vector<Dtype> in_dtypes;
443 in_dtypes.reserve(in_size);
444
445 std::vector<ConstArrayRef> array_refs;
446 array_refs.reserve(in_size);
447
448 {
449 NoBackpropModeScope scope{};
450 int64_t out_offset = 0;
451 for (const Array& array : arrays) {
452 const Shape& shape = array.shape();
453 Array sliced_out = internal::MakeArray(shape, strides, out_dtype, device, out.data(), out_offset);
454 Dtype in_dtype = array.dtype();
455 in_dtypes.emplace_back(in_dtype);
456 // Note: In CopyKernel, Input Array Elements are casted to the type of Output Array.
457 device.backend().CallKernel<CopyKernel>(array, sliced_out);
458 array_refs.emplace_back(ConstArrayRef{array});
459 out_offset += strides[axis] * shape[axis];
460 }
461 }
462
463 {
464 BackwardBuilder bb{"concatenate", array_refs, out};
465 if (BackwardBuilder::Target bt = bb.CreateTarget()) {
466 bt.Define([indices = std::move(indices), axis, in_dtypes = std::move(in_dtypes)](BackwardContext& bctx) {
467 const Array& gy = *bctx.output_grad();
468 Dtype out_dtype = gy.dtype();
469 std::vector<Array> gxs = Split(gy, indices, axis);
470 for (size_t i = 0; i < gxs.size(); ++i) {
471 Dtype in_dtype = in_dtypes[i];
472 if (out_dtype != in_dtype) {
473 bctx.input_grad(i) = gxs[i].AsType(in_dtype);
474 } else {
475 bctx.input_grad(i) = std::move(gxs[i]);
476 }
477 }
478 });
479 }
480 bb.Finalize();
481 }
482
483 return out;
484 }
485
486 } // namespace
487
Concatenate(const std::vector<Array> & arrays)488 Array Concatenate(const std::vector<Array>& arrays) { return ConcatenateImpl(arrays, 0); }
489
Concatenate(const std::vector<Array> & arrays,absl::optional<int8_t> axis)490 Array Concatenate(const std::vector<Array>& arrays, absl::optional<int8_t> axis) {
491 if (!axis.has_value()) {
492 // Special case, making input arrays 1-dimensional and concatenating along the first axis.
493 std::vector<Array> raveled_arrays;
494 raveled_arrays.reserve(arrays.size());
495 std::transform(arrays.begin(), arrays.end(), std::back_inserter(raveled_arrays), [](const Array& array) {
496 Shape shape{array.GetTotalSize()};
497 return array.Reshape(shape);
498 });
499 return ConcatenateImpl(raveled_arrays, 0);
500 }
501 return ConcatenateImpl(arrays, *axis);
502 }
503
Stack(const std::vector<Array> & arrays,int8_t axis)504 Array Stack(const std::vector<Array>& arrays, int8_t axis) {
505 std::vector<Array> reshaped_arrays;
506 reshaped_arrays.reserve(arrays.size());
507 std::transform(arrays.begin(), arrays.end(), std::back_inserter(reshaped_arrays), [axis](const Array& array) {
508 return ExpandDims(array, axis);
509 });
510 return ConcatenateImpl(reshaped_arrays, axis);
511 }
512
513 namespace {
514
515 // Defines the backward pass for Split, for both by-sections and by-indices.
DefineSplitBackward(const Array & ary,const std::vector<Array> & out,int8_t axis_norm)516 void DefineSplitBackward(const Array& ary, const std::vector<Array>& out, int8_t axis_norm) {
517 // TODO(hvy): Avoid creating an intermediate vector of reference when BackwardBuilder accepts std::vector<Array>.
518 std::vector<ConstArrayRef> out_refs{};
519 out_refs.reserve(out.size());
520 std::transform(out.begin(), out.end(), std::back_inserter(out_refs), [](const Array& array) { return ConstArrayRef{array}; });
521
522 // TODO(imanishi): Avoid creating shapes of forward outputs;
523 std::vector<Shape> shapes;
524 shapes.reserve(out.size());
525 std::transform(out.begin(), out.end(), std::back_inserter(shapes), [](const Array& array) { return array.shape(); });
526
527 BackwardBuilder bb{"split", ary, out_refs};
528 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
529 bt.Define([axis_norm, shapes = std::move(shapes), dtype = ary.dtype(), &device = ary.device()](BackwardContext& bctx) {
530 std::vector<Array> output_grads;
531 output_grads.reserve(bctx.output_count());
532 for (size_t i = 0; i < bctx.output_count(); ++i) {
533 const absl::optional<Array>& gy = bctx.output_grad(i);
534 output_grads.emplace_back(gy.has_value() ? *gy : Zeros(shapes[i], dtype, device));
535 }
536 bctx.input_grad() = ConcatenateImpl(output_grads, axis_norm);
537 });
538 }
539 bb.Finalize();
540 }
541
542 } // namespace
543
Split(const Array & ary,int64_t sections,int8_t axis)544 std::vector<Array> Split(const Array& ary, int64_t sections, int8_t axis) {
545 if (sections < 1) {
546 throw DimensionError("Number of sections must be larger than 0.");
547 }
548
549 const Shape& in_shape = ary.shape();
550 int8_t axis_norm = internal::NormalizeAxis(axis, ary.ndim());
551 int64_t in_dim = in_shape[axis_norm];
552
553 if (in_dim % sections != 0) {
554 throw DimensionError("Array split does not result in an equal division.");
555 }
556
557 Shape out_shape = in_shape;
558 int64_t out_dim = in_dim / sections;
559 out_shape[axis_norm] = out_dim;
560 int64_t out_stride = ary.strides()[axis_norm];
561 int64_t out_offset = ary.offset();
562 bool is_empty = ary.GetTotalSize() == 0;
563
564 std::vector<Array> out{};
565 out.reserve(sections);
566
567 for (int64_t i = 0; i < sections; ++i) {
568 out.emplace_back(internal::MakeArray(out_shape, ary.strides(), ary.dtype(), ary.device(), ary.data(), out_offset));
569
570 // Empty arrays should all have offsets of 0 to e.g. avoid out-of-memory errors.
571 if (!is_empty) {
572 out_offset += out_stride * out_dim;
573 }
574 }
575
576 DefineSplitBackward(ary, out, axis_norm);
577
578 return out;
579 }
580
Split(const Array & ary,std::vector<int64_t> indices,int8_t axis)581 std::vector<Array> Split(const Array& ary, std::vector<int64_t> indices, int8_t axis) {
582 const Shape& in_shape = ary.shape();
583 int8_t axis_norm = internal::NormalizeAxis(axis, ary.ndim());
584 int64_t in_dim = in_shape[axis_norm];
585
586 // Wrap negative indices.
587 std::transform(
588 indices.begin(), indices.end(), indices.begin(), [in_dim](int64_t index) { return index >= 0 ? index : index + in_dim; });
589 indices.emplace_back(in_dim);
590
591 Shape out_shape = in_shape;
592 int64_t out_stride = ary.strides()[axis_norm];
593 int64_t out_offset = ary.offset();
594 int64_t slice_start = 0;
595 bool is_empty = ary.GetTotalSize() == 0;
596
597 std::vector<Array> out{};
598 out.reserve(indices.size());
599
600 for (int64_t index : indices) {
601 int64_t slice_stop = std::min(in_dim, std::max(int64_t{0}, index));
602 int64_t slice_step = slice_stop - slice_start;
603
604 // Update the dimension of interest in the output shape.
605 out_shape[axis_norm] = std::max(int64_t{0}, slice_step);
606
607 out.emplace_back(internal::MakeArray(out_shape, ary.strides(), ary.dtype(), ary.device(), ary.data(), out_offset));
608
609 // Empty arrays should all have offsets of 0 to e.g. avoid out-of-memory errors.
610 if (!is_empty) {
611 out_offset += out_stride * slice_step;
612 }
613
614 slice_start = slice_stop;
615 }
616
617 DefineSplitBackward(ary, out, axis_norm);
618
619 return out;
620 }
621
DSplit(const Array & ary,int64_t sections)622 std::vector<Array> DSplit(const Array& ary, int64_t sections) {
623 if (sections < 1) {
624 throw DimensionError("Number of sections must be larger than 0.");
625 }
626
627 if (ary.ndim() < 3) {
628 throw DimensionError("dsplit only works on arrays of 3 or more dimensions.");
629 }
630
631 return Split(ary, sections, 2);
632 }
633
DSplit(const Array & ary,std::vector<int64_t> indices)634 std::vector<Array> DSplit(const Array& ary, std::vector<int64_t> indices) {
635 if (ary.ndim() < 3) {
636 throw DimensionError("dsplit only works on arrays of 3 or more dimensions.");
637 }
638
639 return Split(ary, std::move(indices), 2);
640 }
641
VSplit(const Array & ary,int64_t sections)642 std::vector<Array> VSplit(const Array& ary, int64_t sections) {
643 if (sections < 1) {
644 throw DimensionError("Number of sections must be larger than 0.");
645 }
646
647 if (ary.ndim() < 2) {
648 throw DimensionError("vsplit only works on arrays of 2 or more dimensions.");
649 }
650
651 return Split(ary, sections, 0);
652 }
653
VSplit(const Array & ary,std::vector<int64_t> indices)654 std::vector<Array> VSplit(const Array& ary, std::vector<int64_t> indices) {
655 if (ary.ndim() < 2) {
656 throw DimensionError("vsplit only works on arrays of 2 or more dimensions.");
657 }
658
659 return Split(ary, std::move(indices), 0);
660 }
661
HSplit(const Array & ary,int64_t sections)662 std::vector<Array> HSplit(const Array& ary, int64_t sections) {
663 if (sections < 1) {
664 throw DimensionError("Number of sections must be larger than 0.");
665 }
666
667 if (ary.ndim() == 0) {
668 throw DimensionError("hsplit only works on arrays of 1 or more dimensions.");
669 }
670
671 if (ary.ndim() > 1) {
672 return Split(ary, sections, 1);
673 }
674
675 return Split(ary, sections, 0);
676 }
677
HSplit(const Array & ary,std::vector<int64_t> indices)678 std::vector<Array> HSplit(const Array& ary, std::vector<int64_t> indices) {
679 if (ary.ndim() == 0) {
680 throw DimensionError("hsplit only works on arrays of 1 or more dimensions.");
681 }
682
683 if (ary.ndim() > 1) {
684 return Split(ary, std::move(indices), 1);
685 }
686
687 return Split(ary, std::move(indices), 0);
688 }
689
Swapaxes(const Array & a,int8_t axis1,int8_t axis2)690 Array Swapaxes(const Array& a, int8_t axis1, int8_t axis2) {
691 Shape shape = a.shape();
692 Strides strides = a.strides();
693
694 axis1 = internal::NormalizeAxis(axis1, a.ndim());
695 axis2 = internal::NormalizeAxis(axis2, a.ndim());
696
697 std::iter_swap(shape.begin() + axis1, shape.begin() + axis2);
698 std::iter_swap(strides.begin() + axis1, strides.begin() + axis2);
699 Array out = internal::MakeArray(shape, strides, a.dtype(), a.device(), a.data(), a.offset());
700
701 BackwardBuilder bb{"swapaxes", a, out};
702 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
703 bt.Define([axis1, axis2](BackwardContext& bctx) {
704 const Array& gout = *bctx.output_grad();
705 bctx.input_grad() = Swapaxes(gout, axis1, axis2);
706 });
707 }
708 bb.Finalize();
709
710 return out;
711 }
712
Ravel(const Array & a)713 Array Ravel(const Array& a) { return a.Reshape({a.GetTotalSize()}); }
714
Repeat(const Array & a,int64_t repeats,absl::optional<int8_t> axis)715 Array Repeat(const Array& a, int64_t repeats, absl::optional<int8_t> axis) {
716 if (repeats < 0) {
717 throw DimensionError("repeats must be larger than 0.");
718 }
719
720 int8_t target_axis = 0;
721 Array target_array;
722
723 if (axis.has_value()) {
724 target_axis = internal::NormalizeAxis(*axis, a.ndim());
725 target_array = a;
726 } else {
727 target_array = Reshape(a, Shape({a.shape().GetTotalSize()}));
728 }
729
730 Shape broadcast_shape = target_array.shape();
731 broadcast_shape.insert(broadcast_shape.begin() + target_axis + 1, repeats);
732
733 Shape reshape_shape = target_array.shape();
734 reshape_shape[target_axis] *= repeats;
735
736 Array expanded_array = ExpandDims(target_array, target_axis + 1);
737 Array broadcasted_array = BroadcastTo(expanded_array, broadcast_shape);
738 Array reshaped_array = Reshape(broadcasted_array, reshape_shape);
739 return AsContiguousArray(reshaped_array);
740 }
741
Repeat(const Array & a,const std::vector<int64_t> & repeats,absl::optional<int8_t> axis)742 Array Repeat(const Array& a, const std::vector<int64_t>& repeats, absl::optional<int8_t> axis) {
743 if (repeats.size() == 1) {
744 return Repeat(a, repeats[0], axis);
745 }
746
747 if (axis.has_value()) {
748 int8_t target_axis = internal::NormalizeAxis(*axis, a.ndim());
749
750 if (repeats.size() != static_cast<size_t>(a.shape()[target_axis])) {
751 throw DimensionError("The number of repeats must be same with a shape in the axis direction.");
752 }
753
754 if (std::any_of(repeats.begin(), repeats.end(), [](int64_t x) -> bool { return x < 0; })) {
755 throw DimensionError("repeats must be larger than 0.");
756 }
757
758 // TODO(durswd) : should be optimized
759 std::vector<Array> output_elements;
760 std::vector<Array> splitted = Split(a, a.shape()[target_axis], target_axis);
761
762 for (size_t i = 0; i < splitted.size(); ++i) {
763 for (int32_t j = 0; j < repeats[i]; ++j) {
764 output_elements.push_back(splitted[i]);
765 }
766 }
767
768 Array out = Concatenate(output_elements, target_axis);
769
770 return AsContiguousArray(out);
771 }
772
773 if (repeats.size() != static_cast<size_t>(a.shape().GetTotalSize())) {
774 throw DimensionError("The number of repeats must be same with a shape.");
775 }
776
777 Array reshaped = Reshape(a, Shape({a.shape().GetTotalSize()}));
778 return Repeat(reshaped, repeats, 0);
779 }
780
ExpandDims(const Array & a,int8_t axis)781 Array ExpandDims(const Array& a, int8_t axis) {
782 Shape shape = a.shape();
783
784 axis = internal::NormalizeAxis(axis, a.ndim() + 1);
785
786 shape.insert(shape.begin() + axis, 1);
787
788 Array out = a.Reshape(shape);
789
790 // A trivial reshape of adding a new axis should just return a view of the input.
791 CHAINERX_ASSERT(out.raw_data() == a.raw_data());
792
793 return out;
794 }
795
Flip(const Array & m,const OptionalAxes & axes)796 Array Flip(const Array& m, const OptionalAxes& axes) {
797 Axes real_axes;
798 if (axes.has_value()) {
799 real_axes = internal::GetNormalizedAxes(*axes, m.ndim());
800 } else {
801 for (int8_t i = 0; i < m.ndim(); ++i) {
802 real_axes.emplace_back(m.ndim() - i - 1);
803 }
804 }
805
806 Strides strides = m.strides();
807 Shape shape = m.shape();
808 int64_t offset = m.offset();
809 for (auto axis : real_axes) {
810 // last element of that dimension.
811 offset += std::max<int64_t>(shape[axis] - 1, 0) * strides[axis];
812 if (shape[axis] != 0) {
813 strides[axis] = -strides[axis];
814 }
815 }
816
817 auto is_zero = std::find(shape.begin(), shape.end(), 0);
818 if (is_zero != shape.end()) {
819 offset = 0;
820 }
821
822 Array out = internal::MakeArray(m.shape(), strides, m.dtype(), m.device(), m.data(), offset);
823
824 BackwardBuilder bb{"flip", m, out};
825 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
826 bt.Define([real_axes](BackwardContext& bctx) {
827 const Array& gout = *bctx.output_grad();
828 bctx.input_grad() = Flip(gout, real_axes);
829 });
830 }
831 bb.Finalize();
832
833 return out;
834 }
835
Fliplr(const Array & m)836 Array Fliplr(const Array& m) {
837 if (m.ndim() < 2) {
838 throw DimensionError{"Input must be >= 2-d."};
839 }
840 return Flip(m, Axes{1});
841 }
842
Flipud(const Array & m)843 Array Flipud(const Array& m) {
844 if (m.ndim() < 1) {
845 throw DimensionError{"Input must be >= 1-d."};
846 }
847 return Flip(m, Axes{0});
848 }
849
AtLeast2D(const Array & x)850 Array AtLeast2D(const Array& x) {
851 Array out;
852
853 {
854 NoBackpropModeScope scope;
855
856 switch (x.ndim()) {
857 case 0:
858 out = x.Reshape({1, 1});
859 break;
860 case 1: {
861 Shape shape = x.shape();
862 Strides strides = x.strides();
863 shape.insert(shape.begin(), 1);
864 strides.insert(strides.begin(), 0);
865 out = internal::MakeArray(shape, strides, x.dtype(), x.device(), x.data());
866 } break;
867 default:
868 out = x.MakeView();
869 break;
870 }
871 }
872
873 BackwardBuilder bb{"atleast_2d", x, out};
874 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
875 bt.Define([in_shape = x.shape(), ndim = x.ndim()](BackwardContext& bctx) {
876 if (ndim <= 1) {
877 bctx.input_grad() = bctx.output_grad()->Reshape(in_shape);
878 } else {
879 bctx.input_grad() = *bctx.output_grad();
880 }
881 });
882 }
883 bb.Finalize();
884
885 return out;
886 }
887
AtLeast3D(const Array & x)888 Array AtLeast3D(const Array& x) {
889 Array out;
890
891 {
892 NoBackpropModeScope scope;
893
894 switch (x.ndim()) {
895 case 0:
896 out = x.Reshape({1, 1, 1});
897 break;
898 case 1: {
899 Shape shape = x.shape();
900 Strides strides = x.strides();
901 shape.insert(shape.begin(), 1);
902 shape.insert(shape.end(), 1);
903 strides.insert(strides.begin(), 0);
904 strides.insert(strides.end(), 0);
905 out = internal::MakeArray(shape, strides, x.dtype(), x.device(), x.data());
906 } break;
907 case 2: {
908 Shape shape = x.shape();
909 Strides strides = x.strides();
910 shape.insert(shape.end(), 1);
911 strides.insert(strides.end(), 0);
912 out = internal::MakeArray(shape, strides, x.dtype(), x.device(), x.data());
913 } break;
914 default:
915 out = x.MakeView();
916 break;
917 }
918 }
919
920 BackwardBuilder bb{"atleast_3d", x, out};
921 if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
922 bt.Define([in_shape = x.shape(), ndim = x.ndim()](BackwardContext& bctx) {
923 if (ndim <= 2) {
924 bctx.input_grad() = bctx.output_grad()->Reshape(in_shape);
925 } else {
926 bctx.input_grad() = *bctx.output_grad();
927 }
928 });
929 }
930 bb.Finalize();
931
932 return out;
933 }
934
HStack(const std::vector<Array> & arrays)935 Array HStack(const std::vector<Array>& arrays) {
936 if (arrays.empty()) {
937 throw DimensionError{"Need at least one array to stack"};
938 }
939
940 if (arrays.front().ndim() <= 1) {
941 return Concatenate(arrays, 0);
942 }
943 return Concatenate(arrays, 1);
944 }
945
VStack(const std::vector<Array> & arrays)946 Array VStack(const std::vector<Array>& arrays) {
947 if (arrays.empty()) {
948 throw DimensionError{"Need at least one array to stack"};
949 }
950
951 std::vector<Array> reshaped_arrays(arrays.size());
952 std::transform(arrays.begin(), arrays.end(), reshaped_arrays.begin(), AtLeast2D);
953
954 return Concatenate(reshaped_arrays, 0);
955 }
956
DStack(const std::vector<Array> & arrays)957 Array DStack(const std::vector<Array>& arrays) {
958 if (arrays.empty()) {
959 throw DimensionError{"Need at least one array to stack"};
960 }
961
962 std::vector<Array> reshaped_arrays(arrays.size());
963 std::transform(arrays.begin(), arrays.end(), reshaped_arrays.begin(), AtLeast3D);
964 return Concatenate(reshaped_arrays, 2);
965 }
966
Moveaxis(const Array & a,const Axes & source,const Axes & destination)967 Array Moveaxis(const Array& a, const Axes& source, const Axes& destination) {
968 if (source.size() != destination.size()) {
969 throw DimensionError{"Invalid Source or Destination Axes"};
970 }
971
972 if (source.empty()) {
973 return a;
974 }
975
976 const Axes& normalized_source = internal::GetNormalizedAxes(source, a.ndim());
977 const Axes& normalized_destination = internal::GetNormalizedAxes(destination, a.ndim());
978
979 Axes order, source_axes, destination_axes;
980 order.resize(a.ndim());
981 source_axes.resize(a.ndim());
982 destination_axes.resize(a.ndim());
983
984 std::iota(source_axes.begin(), source_axes.end(), 0);
985 std::iota(destination_axes.begin(), destination_axes.end(), 0);
986
987 for (int8_t i = 0; i < source.ndim(); ++i) {
988 order[normalized_destination[i]] = normalized_source[i];
989 source_axes[normalized_source[i]] = -1;
990 destination_axes[normalized_destination[i]] = -1;
991 }
992
993 auto source_iter = std::remove(source_axes.begin(), source_axes.end(), -1);
994 auto destination_iter = std::remove(destination_axes.begin(), destination_axes.end(), -1);
995
996 int8_t rest_dim = a.ndim() - source.ndim();
997 CHAINERX_ASSERT(a.ndim() - destination.ndim() == rest_dim);
998 CHAINERX_ASSERT(static_cast<int8_t>(source_iter - source_axes.begin()) == rest_dim);
999 CHAINERX_ASSERT(static_cast<int8_t>(destination_iter - destination_axes.begin()) == rest_dim);
1000
1001 for (int8_t i = 0; i < rest_dim; ++i) {
1002 order[destination_axes[i]] = source_axes[i];
1003 }
1004
1005 return a.Transpose(order);
1006 }
1007
CopyTo(const Array & dst,const Array & src,CastingMode casting,const Array & where)1008 void CopyTo(const Array& dst, const Array& src, CastingMode casting, const Array& where) {
1009 internal::CheckNoUnsafeInplace(dst, {dst, src, where});
1010
1011 switch (casting) {
1012 case CastingMode::kNo:
1013 if (dst.dtype() != src.dtype()) {
1014 throw DtypeError{"Source and destination must have same dtype."};
1015 }
1016 break;
1017 default:
1018 CHAINERX_NEVER_REACH();
1019 }
1020
1021 const Array& src_b = src.shape() != dst.shape() ? src.BroadcastTo(dst.shape()) : src;
1022 const Array& where_b = where.shape() != dst.shape() ? where.BroadcastTo(dst.shape()) : where;
1023
1024 {
1025 NoBackpropModeScope scope;
1026 dst.device().backend().CallKernel<WhereKernel>(where_b, src_b, dst, dst);
1027 }
1028 }
1029
1030 } // namespace chainerx
1031