1 #include "chainerx/routines/manipulation.h"
2 
3 #include <algorithm>
4 #include <cstddef>
5 #include <cstdint>
6 #include <functional>
7 #include <numeric>
8 #include <set>
9 #include <sstream>
10 #include <string>
11 #include <utility>
12 #include <vector>
13 
14 #include <absl/types/optional.h>
15 
16 #include "chainerx/array.h"
17 #include "chainerx/axes.h"
18 #include "chainerx/backend.h"
19 #include "chainerx/backprop_mode.h"
20 #include "chainerx/backward_builder.h"
21 #include "chainerx/backward_context.h"
22 #include "chainerx/device.h"
23 #include "chainerx/dtype.h"
24 #include "chainerx/error.h"
25 #include "chainerx/graph.h"
26 #include "chainerx/kernels/creation.h"
27 #include "chainerx/kernels/indexing.h"
28 #include "chainerx/kernels/misc.h"
29 #include "chainerx/macro.h"
30 #include "chainerx/routines/creation.h"
31 #include "chainerx/routines/indexing.h"
32 #include "chainerx/routines/routines_util.h"
33 #include "chainerx/routines/type_util.h"
34 #include "chainerx/shape.h"
35 #include "chainerx/strides.h"
36 
37 namespace chainerx {
38 
AsScalar(const Array & a)39 Scalar AsScalar(const Array& a) {
40     if (a.GetTotalSize() != 1) {
41         throw DimensionError{"Cannot convert an array of size ", a.GetTotalSize(), " to a scalar, size must be 1."};
42     }
43 
44     // Copy to the native device
45     Array native_copy = a.ToNative();
46 
47     // Retrieve the value
48     return VisitDtype(a.dtype(), [&native_copy](auto pt) -> Scalar {
49         using T = typename decltype(pt)::type;
50         const uint8_t* ptr = static_cast<const uint8_t*>(native_copy.data().get()) + native_copy.offset();
51         auto typed_ptr = reinterpret_cast<const T*>(ptr);  // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
52         return Scalar{*typed_ptr};
53     });
54 }
55 
RollAxis(const Array & a,int8_t axis,int8_t start)56 Array RollAxis(const Array& a, int8_t axis, int8_t start) {
57     // TODO(hvy): Optimize the implementation.
58     axis = internal::NormalizeAxis(axis, a.ndim());
59 
60     // start can be a.ndim() so we cannot use NormalizeAxis here.
61     if (start < -a.ndim() || start > a.ndim()) {
62         throw DimensionError{"start arg out of bounds. start: ", start, ", ndim: ", a.ndim()};
63     }
64     if (start < 0) {
65         start += a.ndim();
66     }
67 
68     Axes axes;
69     for (int8_t i = 0; i < a.ndim(); ++i) {
70         if (i == start) {
71             axes.emplace_back(axis);
72         }
73         if (i != axis) {
74             axes.emplace_back(i);
75         }
76     }
77     if (start == a.ndim()) {
78         axes.emplace_back(axis);
79     }
80     return Transpose(a, axes);
81 }
82 
Transpose(const Array & a,const OptionalAxes & axes)83 Array Transpose(const Array& a, const OptionalAxes& axes) {
84     Axes real_axes;
85     if (axes.has_value()) {
86         if (axes->ndim() != a.ndim()) {
87             throw DimensionError{"Axes do not match, input array dimensions: ", a.ndim(), " but axes: ", axes->ndim()};
88         }
89         real_axes = internal::GetNormalizedAxes(*axes, a.ndim());
90     } else {
91         for (int8_t i = 0; i < a.ndim(); ++i) {
92             real_axes.emplace_back(a.ndim() - i - 1);
93         }
94     }
95     CHAINERX_ASSERT(real_axes.ndim() == a.ndim());
96 
97     Shape out_shape;
98     Strides out_strides;
99     for (int8_t axis : real_axes) {
100         out_shape.emplace_back(a.shape()[axis]);
101         out_strides.emplace_back(a.strides()[axis]);
102     }
103 
104     Array out = internal::MakeArray(out_shape, out_strides, a.dtype(), a.device(), a.data(), a.offset());
105 
106     BackwardBuilder bb{"transpose", a, out};
107     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
108         bt.Define([real_axes](BackwardContext& bctx) {
109             Axes backward_axes;
110             backward_axes.resize(real_axes.ndim());
111             for (int8_t i = 0; i < real_axes.ndim(); ++i) {
112                 backward_axes[real_axes[i]] = i;
113             }
114             bctx.input_grad() = bctx.output_grad()->Transpose(backward_axes);
115         });
116     }
117     bb.Finalize();
118 
119     return out;
120 }
121 
122 namespace {
123 
124 // Returns a shape where the length of at most one dimension is inferred from the total size and the remaining dimensions.
125 // Such a dimension is given by a negative length, i.e. Shape{2, 3, -1}.
126 // If the given shape does not contain such a dimension, this function will return a copy of the given shape.
127 // If there exists multiple negative lengths or if the negative length dimension cannot be inferred due to non divisbility, an
128 // DimensionError is thrown.
GetInferredShape(const Shape & shape,int64_t total_size)129 Shape GetInferredShape(const Shape& shape, int64_t total_size) {
130     Shape inferred_shape = shape;
131 
132     auto it = std::find_if(inferred_shape.begin(), inferred_shape.end(), [](int64_t dim) { return dim < 0; });
133     if (it != inferred_shape.end()) {
134         if (std::find_if(std::next(it), inferred_shape.end(), [](int64_t dim) { return dim < 0; }) != inferred_shape.end()) {
135             throw DimensionError{"Can only specify one unknown dimension"};
136         }
137         int64_t rest_size = std::accumulate(inferred_shape.begin(), it, int64_t{1}, std::multiplies<>()) *
138                             std::accumulate(std::next(it), inferred_shape.end(), int64_t{1}, std::multiplies<>());
139         if (rest_size == 0) {
140             throw DimensionError{"Cannot reshape array of size ", total_size, " into an ambiguous shape ", shape};
141         }
142         *it = total_size / rest_size;
143     }
144 
145     if (total_size != inferred_shape.GetTotalSize()) {
146         throw DimensionError{"Cannot reshape array of size ", total_size, " into shape ", shape};
147     }
148     return inferred_shape;
149 }
150 
151 }  // namespace
152 
Reshape(const Array & a,const Shape & newshape)153 Array Reshape(const Array& a, const Shape& newshape) {
154     const Shape& in_shape = a.shape();
155     const Strides& in_strides = a.strides();
156 
157     // If the shape is unchanged, just return a view.
158     if (in_shape == newshape) {
159         return a.MakeView();
160     }
161 
162     // Check for invalid shape.
163     int64_t total_size = in_shape.GetTotalSize();
164     Shape out_shape = GetInferredShape(newshape, total_size);
165     int64_t item_size = a.GetItemSize();
166     Strides strides{};
167     if (total_size == 0) {
168         // Calculate the strides for 0-sized array.
169         strides.resize(out_shape.ndim());
170         strides.back() = item_size;
171         for (int8_t i = out_shape.ndim() - 1; i >= 1; --i) {
172             strides[i - 1] = strides[i] * std::max(int64_t{1}, out_shape[i]);
173         }
174     } else {
175         // Calculate the strides for non-0-sized array.
176 
177         // reduced_shape and reduced_strides are the shortest shape and strides which can be convertible from input shape and strides
178         // without copy.
179         Shape reduced_shape{};
180         Strides reduced_strides{};
181         if (total_size == 1) {
182             reduced_shape.emplace_back(int64_t{1});
183             reduced_strides.emplace_back(item_size);
184         } else {
185             int8_t i = 0;
186             // Ignore preceding 1-length dimensions
187             while (i < in_shape.ndim() && in_shape[i] == 1) {
188                 ++i;
189             }
190             // Add the first pair
191             reduced_shape.emplace_back(in_shape[i]);
192             reduced_strides.emplace_back(in_strides[i]);
193             ++i;
194             // Reduce the remaining
195             for (; i < in_shape.ndim(); ++i) {
196                 int64_t dim = in_shape[i];
197                 int64_t st = in_strides[i];
198                 CHAINERX_ASSERT(dim > 0);
199                 if (dim == 1) {
200                     // If the axis has unit-length, skip this dimension.
201                 } else if (dim * st == reduced_strides.back()) {
202                     // If the pair is compatible with the previous stride, reduce the pair to it.
203                     reduced_shape.back() *= dim;
204                     reduced_strides.back() = st;
205                 } else {
206                     // Otherwise, add a new shape and stride.
207                     reduced_shape.emplace_back(dim);
208                     reduced_strides.emplace_back(st);
209                 }
210             }
211         }
212         CHAINERX_ASSERT(reduced_shape.size() == reduced_strides.size());
213         CHAINERX_ASSERT(!reduced_shape.empty());
214 
215         // Construct the strides for no-copy reshape.
216         // If it's not possible, can_reshape_without_copy will be false.
217         bool can_reshape_without_copy = true;
218         if (out_shape.ndim() > 0) {
219             int64_t last_stride = reduced_shape[0] * reduced_strides[0];
220             size_t i_dim = 0;
221             for (int64_t dim : out_shape) {
222                 if (dim <= 1) {
223                     strides.emplace_back(last_stride);
224                     continue;
225                 }
226                 if (i_dim >= reduced_shape.size() || reduced_shape[i_dim] % dim != 0) {
227                     strides.clear();
228                     can_reshape_without_copy = false;
229                     break;
230                 }
231                 reduced_shape[i_dim] /= dim;
232                 last_stride = reduced_shape[i_dim] * reduced_strides[i_dim];
233                 strides.emplace_back(last_stride);
234                 if (reduced_shape[i_dim] == 1) {
235                     ++i_dim;
236                 }
237             }
238         }
239 
240         if (!can_reshape_without_copy) {
241             // Copy is required.
242             return a.Copy().Reshape(out_shape);
243         }
244         CHAINERX_ASSERT(strides.size() == out_shape.size());
245     }
246 
247     Array out = internal::MakeArray(out_shape, strides, a.dtype(), a.device(), a.data(), a.offset());
248 
249     BackwardBuilder bb{"reshape", a, out};
250     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
251         bt.Define([in_shape](BackwardContext& bctx) { bctx.input_grad() = bctx.output_grad()->Reshape(in_shape); });
252     }
253     bb.Finalize();
254 
255     CHAINERX_ASSERT(out.shape() == out_shape);
256     CHAINERX_ASSERT(out.strides().size() == out_shape.size());
257     return out;
258 }
259 
Squeeze(const Array & a,const OptionalAxes & axis)260 Array Squeeze(const Array& a, const OptionalAxes& axis) {
261     const Shape& in_shape = a.shape();
262     const Strides& in_strides = a.strides();
263 
264     Shape out_shape{};
265     Strides out_strides{};
266 
267     if (axis.has_value()) {
268         const Axes sorted_axis = internal::GetSortedAxes(*axis, in_shape.ndim());
269 
270         int64_t i_axis = 0;
271         for (int64_t i = 0; i < in_shape.ndim(); ++i) {
272             if (i_axis < static_cast<int64_t>(sorted_axis.size()) && sorted_axis[i_axis] == i) {
273                 ++i_axis;
274                 if (in_shape[i] != 1) {
275                     std::ostringstream os;
276                     os << "Cannot squeeze out non-unit-length axes, where shape was " << in_shape.ToString();
277                     os << " and axes were (";
278                     for (auto it = axis->begin(); it != axis->end(); ++it) {
279                         if (it != axis->begin()) {
280                             os << ", ";
281                         }
282                         os << *it;
283                     }
284                     os << (axis->size() == 1 ? ",)." : ").");
285                     throw DimensionError{os.str()};
286                 }
287             } else {
288                 out_shape.emplace_back(in_shape[i]);
289                 out_strides.emplace_back(in_strides[i]);
290             }
291         }
292     } else {  // All axes are candidates for removal if none are given.
293         for (int64_t i = 0; i < in_shape.ndim(); ++i) {
294             if (in_shape[i] != 1) {
295                 out_shape.emplace_back(in_shape[i]);
296                 out_strides.emplace_back(in_strides[i]);
297             }
298         }
299     }
300 
301     if (in_shape.size() == out_shape.size()) {
302         return a;
303     }
304 
305     Array out = internal::MakeArray(out_shape, out_strides, a.dtype(), a.device(), a.data(), a.offset());
306 
307     BackwardBuilder bb{"squeeze", a, out};
308     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
309         bt.Define([in_shape](BackwardContext& bctx) { bctx.input_grad() = bctx.output_grad()->Reshape(in_shape); });
310     }
311     bb.Finalize();
312 
313     return out;
314 }
315 
BroadcastTo(const Array & array,const Shape & shape)316 Array BroadcastTo(const Array& array, const Shape& shape) {
317     const Shape& in_shape = array.shape();
318     const Strides& in_strides = array.strides();
319 
320     if (in_shape.size() > shape.size()) {
321         throw DimensionError{"Cannot broadcast to smaller dimensions from ", in_shape, " to ", shape, "."};
322     }
323 
324     // Compute the new set of strides after broadcastining.
325     Strides strides;
326     strides.resize(shape.ndim());
327     int8_t i_in = in_shape.ndim() - 1;
328     for (int8_t i_out = shape.ndim() - 1; i_out >= 0; --i_out) {
329         int64_t out_dim = shape[i_out];
330         // If this dimension is to be broadcasted, nonbroadcast_stride is unset.
331         // Otherwise, it holds the new stride.
332         absl::optional<int64_t> nonbroadcast_stride{};
333         if (i_in >= 0) {
334             int64_t in_dim = in_shape[i_in];
335             if (in_dim == 1) {
336                 // do nothing; broadcast
337             } else if (in_dim == out_dim) {
338                 nonbroadcast_stride = in_strides[i_in];
339             } else {
340                 throw DimensionError{"Invalid broadcast from ", in_shape, " to ", shape};
341             }
342             --i_in;
343         } else {
344             // do nothing; broadcast
345         }
346 
347         if (nonbroadcast_stride.has_value()) {
348             // non-broadcast dimension
349             strides[i_out] = nonbroadcast_stride.value();
350         } else {
351             // broadcast dimension
352             strides[i_out] = int64_t{0};
353         }
354     }
355     CHAINERX_ASSERT(i_in == -1);
356     CHAINERX_ASSERT(strides.ndim() == shape.ndim());
357 
358     Array out = internal::MakeArray(shape, strides, array.dtype(), array.device(), array.data(), array.offset());
359 
360     BackwardBuilder bb{"broadcast_to", array, out};
361     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
362         bt.Define([in_shape](BackwardContext& bctx) {
363             const Array& gout = *bctx.output_grad();
364             if (gout.shape() == in_shape) {
365                 bctx.input_grad() = gout;
366                 return;
367             }
368 
369             int8_t lead = gout.ndim() - in_shape.ndim();
370             Axes lead_axis{};
371             lead_axis.resize(lead);
372             std::iota(lead_axis.begin(), lead_axis.end(), int8_t{0});
373 
374             Axes axis{lead_axis};
375             for (int8_t i = 0; i < in_shape.ndim(); ++i) {
376                 if (in_shape[i] == 1) {
377                     axis.emplace_back(i + lead);
378                 }
379             }
380             axis.erase(std::unique(axis.begin(), axis.end()), axis.end());  // Sum does not accept axis with duplicate elements
381 
382             Array gin = gout.Sum(axis, true);
383             if (lead > 0) {
384                 bctx.input_grad() = gin.Squeeze(lead_axis);
385             } else {
386                 bctx.input_grad() = std::move(gin);
387             }
388         });
389     }
390     bb.Finalize();
391 
392     return out;
393 }
394 
395 namespace {
396 
ConcatenateImpl(const std::vector<Array> & arrays,int8_t axis)397 Array ConcatenateImpl(const std::vector<Array>& arrays, int8_t axis) {
398     if (arrays.empty()) {
399         throw DimensionError{"Need at least one array to concatenate"};
400     }
401 
402     Shape shape = arrays.front().shape();
403     Dtype out_dtype = ResultType(arrays);
404     Device& device = arrays.front().device();
405     int8_t ndim = arrays.front().ndim();
406     axis = internal::NormalizeAxis(axis, ndim);
407     shape[axis] = 0;
408     std::vector<int64_t> indices;
409     indices.reserve(arrays.size() - 1);
410 
411     for (const Array& array : arrays) {
412         const Shape& s = array.shape();
413         if (ndim != array.ndim()) {
414             throw DimensionError{"All the input arrays must have same number of dimensions"};
415         }
416         for (int8_t i = 0; i < ndim; ++i) {
417             if (axis == i) {
418                 shape[i] += s[i];
419             } else if (shape[i] != s[i]) {
420                 throw DimensionError{"All the input array dimensions except for the concatenation axis must match exactly"};
421             }
422         }
423         if (indices.size() < arrays.size() - 1) {
424             indices.emplace_back(shape[axis]);
425         }
426     }
427 
428     Strides strides{shape, out_dtype};
429 
430     // Aligning with NumPy strides behavior
431     auto last_zero_it = std::find(shape.rbegin(), shape.rend(), int64_t{0});
432     if (last_zero_it != shape.rend()) {
433         std::fill(strides.rbegin() + (last_zero_it - shape.rbegin() + 1), strides.rend(), int64_t{0});
434     }
435 
436     Array out = internal::Empty(shape, out_dtype, strides, device);
437 
438     size_t in_size = arrays.size();
439 
440     // If input dtypes are mixed, elements in the input arrays are casted to the resulting dtype.
441     // Their original dtypes must therefore be remembered in order to cast the computed gradients back in the backward pass.
442     std::vector<Dtype> in_dtypes;
443     in_dtypes.reserve(in_size);
444 
445     std::vector<ConstArrayRef> array_refs;
446     array_refs.reserve(in_size);
447 
448     {
449         NoBackpropModeScope scope{};
450         int64_t out_offset = 0;
451         for (const Array& array : arrays) {
452             const Shape& shape = array.shape();
453             Array sliced_out = internal::MakeArray(shape, strides, out_dtype, device, out.data(), out_offset);
454             Dtype in_dtype = array.dtype();
455             in_dtypes.emplace_back(in_dtype);
456             // Note: In CopyKernel, Input Array Elements are casted to the type of Output Array.
457             device.backend().CallKernel<CopyKernel>(array, sliced_out);
458             array_refs.emplace_back(ConstArrayRef{array});
459             out_offset += strides[axis] * shape[axis];
460         }
461     }
462 
463     {
464         BackwardBuilder bb{"concatenate", array_refs, out};
465         if (BackwardBuilder::Target bt = bb.CreateTarget()) {
466             bt.Define([indices = std::move(indices), axis, in_dtypes = std::move(in_dtypes)](BackwardContext& bctx) {
467                 const Array& gy = *bctx.output_grad();
468                 Dtype out_dtype = gy.dtype();
469                 std::vector<Array> gxs = Split(gy, indices, axis);
470                 for (size_t i = 0; i < gxs.size(); ++i) {
471                     Dtype in_dtype = in_dtypes[i];
472                     if (out_dtype != in_dtype) {
473                         bctx.input_grad(i) = gxs[i].AsType(in_dtype);
474                     } else {
475                         bctx.input_grad(i) = std::move(gxs[i]);
476                     }
477                 }
478             });
479         }
480         bb.Finalize();
481     }
482 
483     return out;
484 }
485 
486 }  // namespace
487 
Concatenate(const std::vector<Array> & arrays)488 Array Concatenate(const std::vector<Array>& arrays) { return ConcatenateImpl(arrays, 0); }
489 
Concatenate(const std::vector<Array> & arrays,absl::optional<int8_t> axis)490 Array Concatenate(const std::vector<Array>& arrays, absl::optional<int8_t> axis) {
491     if (!axis.has_value()) {
492         // Special case, making input arrays 1-dimensional and concatenating along the first axis.
493         std::vector<Array> raveled_arrays;
494         raveled_arrays.reserve(arrays.size());
495         std::transform(arrays.begin(), arrays.end(), std::back_inserter(raveled_arrays), [](const Array& array) {
496             Shape shape{array.GetTotalSize()};
497             return array.Reshape(shape);
498         });
499         return ConcatenateImpl(raveled_arrays, 0);
500     }
501     return ConcatenateImpl(arrays, *axis);
502 }
503 
Stack(const std::vector<Array> & arrays,int8_t axis)504 Array Stack(const std::vector<Array>& arrays, int8_t axis) {
505     std::vector<Array> reshaped_arrays;
506     reshaped_arrays.reserve(arrays.size());
507     std::transform(arrays.begin(), arrays.end(), std::back_inserter(reshaped_arrays), [axis](const Array& array) {
508         return ExpandDims(array, axis);
509     });
510     return ConcatenateImpl(reshaped_arrays, axis);
511 }
512 
513 namespace {
514 
515 // Defines the backward pass for Split, for both by-sections and by-indices.
DefineSplitBackward(const Array & ary,const std::vector<Array> & out,int8_t axis_norm)516 void DefineSplitBackward(const Array& ary, const std::vector<Array>& out, int8_t axis_norm) {
517     // TODO(hvy): Avoid creating an intermediate vector of reference when BackwardBuilder accepts std::vector<Array>.
518     std::vector<ConstArrayRef> out_refs{};
519     out_refs.reserve(out.size());
520     std::transform(out.begin(), out.end(), std::back_inserter(out_refs), [](const Array& array) { return ConstArrayRef{array}; });
521 
522     // TODO(imanishi): Avoid creating shapes of forward outputs;
523     std::vector<Shape> shapes;
524     shapes.reserve(out.size());
525     std::transform(out.begin(), out.end(), std::back_inserter(shapes), [](const Array& array) { return array.shape(); });
526 
527     BackwardBuilder bb{"split", ary, out_refs};
528     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
529         bt.Define([axis_norm, shapes = std::move(shapes), dtype = ary.dtype(), &device = ary.device()](BackwardContext& bctx) {
530             std::vector<Array> output_grads;
531             output_grads.reserve(bctx.output_count());
532             for (size_t i = 0; i < bctx.output_count(); ++i) {
533                 const absl::optional<Array>& gy = bctx.output_grad(i);
534                 output_grads.emplace_back(gy.has_value() ? *gy : Zeros(shapes[i], dtype, device));
535             }
536             bctx.input_grad() = ConcatenateImpl(output_grads, axis_norm);
537         });
538     }
539     bb.Finalize();
540 }
541 
542 }  // namespace
543 
Split(const Array & ary,int64_t sections,int8_t axis)544 std::vector<Array> Split(const Array& ary, int64_t sections, int8_t axis) {
545     if (sections < 1) {
546         throw DimensionError("Number of sections must be larger than 0.");
547     }
548 
549     const Shape& in_shape = ary.shape();
550     int8_t axis_norm = internal::NormalizeAxis(axis, ary.ndim());
551     int64_t in_dim = in_shape[axis_norm];
552 
553     if (in_dim % sections != 0) {
554         throw DimensionError("Array split does not result in an equal division.");
555     }
556 
557     Shape out_shape = in_shape;
558     int64_t out_dim = in_dim / sections;
559     out_shape[axis_norm] = out_dim;
560     int64_t out_stride = ary.strides()[axis_norm];
561     int64_t out_offset = ary.offset();
562     bool is_empty = ary.GetTotalSize() == 0;
563 
564     std::vector<Array> out{};
565     out.reserve(sections);
566 
567     for (int64_t i = 0; i < sections; ++i) {
568         out.emplace_back(internal::MakeArray(out_shape, ary.strides(), ary.dtype(), ary.device(), ary.data(), out_offset));
569 
570         // Empty arrays should all have offsets of 0 to e.g. avoid out-of-memory errors.
571         if (!is_empty) {
572             out_offset += out_stride * out_dim;
573         }
574     }
575 
576     DefineSplitBackward(ary, out, axis_norm);
577 
578     return out;
579 }
580 
Split(const Array & ary,std::vector<int64_t> indices,int8_t axis)581 std::vector<Array> Split(const Array& ary, std::vector<int64_t> indices, int8_t axis) {
582     const Shape& in_shape = ary.shape();
583     int8_t axis_norm = internal::NormalizeAxis(axis, ary.ndim());
584     int64_t in_dim = in_shape[axis_norm];
585 
586     // Wrap negative indices.
587     std::transform(
588             indices.begin(), indices.end(), indices.begin(), [in_dim](int64_t index) { return index >= 0 ? index : index + in_dim; });
589     indices.emplace_back(in_dim);
590 
591     Shape out_shape = in_shape;
592     int64_t out_stride = ary.strides()[axis_norm];
593     int64_t out_offset = ary.offset();
594     int64_t slice_start = 0;
595     bool is_empty = ary.GetTotalSize() == 0;
596 
597     std::vector<Array> out{};
598     out.reserve(indices.size());
599 
600     for (int64_t index : indices) {
601         int64_t slice_stop = std::min(in_dim, std::max(int64_t{0}, index));
602         int64_t slice_step = slice_stop - slice_start;
603 
604         // Update the dimension of interest in the output shape.
605         out_shape[axis_norm] = std::max(int64_t{0}, slice_step);
606 
607         out.emplace_back(internal::MakeArray(out_shape, ary.strides(), ary.dtype(), ary.device(), ary.data(), out_offset));
608 
609         // Empty arrays should all have offsets of 0 to e.g. avoid out-of-memory errors.
610         if (!is_empty) {
611             out_offset += out_stride * slice_step;
612         }
613 
614         slice_start = slice_stop;
615     }
616 
617     DefineSplitBackward(ary, out, axis_norm);
618 
619     return out;
620 }
621 
DSplit(const Array & ary,int64_t sections)622 std::vector<Array> DSplit(const Array& ary, int64_t sections) {
623     if (sections < 1) {
624         throw DimensionError("Number of sections must be larger than 0.");
625     }
626 
627     if (ary.ndim() < 3) {
628         throw DimensionError("dsplit only works on arrays of 3 or more dimensions.");
629     }
630 
631     return Split(ary, sections, 2);
632 }
633 
DSplit(const Array & ary,std::vector<int64_t> indices)634 std::vector<Array> DSplit(const Array& ary, std::vector<int64_t> indices) {
635     if (ary.ndim() < 3) {
636         throw DimensionError("dsplit only works on arrays of 3 or more dimensions.");
637     }
638 
639     return Split(ary, std::move(indices), 2);
640 }
641 
VSplit(const Array & ary,int64_t sections)642 std::vector<Array> VSplit(const Array& ary, int64_t sections) {
643     if (sections < 1) {
644         throw DimensionError("Number of sections must be larger than 0.");
645     }
646 
647     if (ary.ndim() < 2) {
648         throw DimensionError("vsplit only works on arrays of 2 or more dimensions.");
649     }
650 
651     return Split(ary, sections, 0);
652 }
653 
VSplit(const Array & ary,std::vector<int64_t> indices)654 std::vector<Array> VSplit(const Array& ary, std::vector<int64_t> indices) {
655     if (ary.ndim() < 2) {
656         throw DimensionError("vsplit only works on arrays of 2 or more dimensions.");
657     }
658 
659     return Split(ary, std::move(indices), 0);
660 }
661 
HSplit(const Array & ary,int64_t sections)662 std::vector<Array> HSplit(const Array& ary, int64_t sections) {
663     if (sections < 1) {
664         throw DimensionError("Number of sections must be larger than 0.");
665     }
666 
667     if (ary.ndim() == 0) {
668         throw DimensionError("hsplit only works on arrays of 1 or more dimensions.");
669     }
670 
671     if (ary.ndim() > 1) {
672         return Split(ary, sections, 1);
673     }
674 
675     return Split(ary, sections, 0);
676 }
677 
HSplit(const Array & ary,std::vector<int64_t> indices)678 std::vector<Array> HSplit(const Array& ary, std::vector<int64_t> indices) {
679     if (ary.ndim() == 0) {
680         throw DimensionError("hsplit only works on arrays of 1 or more dimensions.");
681     }
682 
683     if (ary.ndim() > 1) {
684         return Split(ary, std::move(indices), 1);
685     }
686 
687     return Split(ary, std::move(indices), 0);
688 }
689 
Swapaxes(const Array & a,int8_t axis1,int8_t axis2)690 Array Swapaxes(const Array& a, int8_t axis1, int8_t axis2) {
691     Shape shape = a.shape();
692     Strides strides = a.strides();
693 
694     axis1 = internal::NormalizeAxis(axis1, a.ndim());
695     axis2 = internal::NormalizeAxis(axis2, a.ndim());
696 
697     std::iter_swap(shape.begin() + axis1, shape.begin() + axis2);
698     std::iter_swap(strides.begin() + axis1, strides.begin() + axis2);
699     Array out = internal::MakeArray(shape, strides, a.dtype(), a.device(), a.data(), a.offset());
700 
701     BackwardBuilder bb{"swapaxes", a, out};
702     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
703         bt.Define([axis1, axis2](BackwardContext& bctx) {
704             const Array& gout = *bctx.output_grad();
705             bctx.input_grad() = Swapaxes(gout, axis1, axis2);
706         });
707     }
708     bb.Finalize();
709 
710     return out;
711 }
712 
Ravel(const Array & a)713 Array Ravel(const Array& a) { return a.Reshape({a.GetTotalSize()}); }
714 
Repeat(const Array & a,int64_t repeats,absl::optional<int8_t> axis)715 Array Repeat(const Array& a, int64_t repeats, absl::optional<int8_t> axis) {
716     if (repeats < 0) {
717         throw DimensionError("repeats must be larger than 0.");
718     }
719 
720     int8_t target_axis = 0;
721     Array target_array;
722 
723     if (axis.has_value()) {
724         target_axis = internal::NormalizeAxis(*axis, a.ndim());
725         target_array = a;
726     } else {
727         target_array = Reshape(a, Shape({a.shape().GetTotalSize()}));
728     }
729 
730     Shape broadcast_shape = target_array.shape();
731     broadcast_shape.insert(broadcast_shape.begin() + target_axis + 1, repeats);
732 
733     Shape reshape_shape = target_array.shape();
734     reshape_shape[target_axis] *= repeats;
735 
736     Array expanded_array = ExpandDims(target_array, target_axis + 1);
737     Array broadcasted_array = BroadcastTo(expanded_array, broadcast_shape);
738     Array reshaped_array = Reshape(broadcasted_array, reshape_shape);
739     return AsContiguousArray(reshaped_array);
740 }
741 
Repeat(const Array & a,const std::vector<int64_t> & repeats,absl::optional<int8_t> axis)742 Array Repeat(const Array& a, const std::vector<int64_t>& repeats, absl::optional<int8_t> axis) {
743     if (repeats.size() == 1) {
744         return Repeat(a, repeats[0], axis);
745     }
746 
747     if (axis.has_value()) {
748         int8_t target_axis = internal::NormalizeAxis(*axis, a.ndim());
749 
750         if (repeats.size() != static_cast<size_t>(a.shape()[target_axis])) {
751             throw DimensionError("The number of repeats must be same with a shape in the axis direction.");
752         }
753 
754         if (std::any_of(repeats.begin(), repeats.end(), [](int64_t x) -> bool { return x < 0; })) {
755             throw DimensionError("repeats must be larger than 0.");
756         }
757 
758         // TODO(durswd) : should be optimized
759         std::vector<Array> output_elements;
760         std::vector<Array> splitted = Split(a, a.shape()[target_axis], target_axis);
761 
762         for (size_t i = 0; i < splitted.size(); ++i) {
763             for (int32_t j = 0; j < repeats[i]; ++j) {
764                 output_elements.push_back(splitted[i]);
765             }
766         }
767 
768         Array out = Concatenate(output_elements, target_axis);
769 
770         return AsContiguousArray(out);
771     }
772 
773     if (repeats.size() != static_cast<size_t>(a.shape().GetTotalSize())) {
774         throw DimensionError("The number of repeats must be same with a shape.");
775     }
776 
777     Array reshaped = Reshape(a, Shape({a.shape().GetTotalSize()}));
778     return Repeat(reshaped, repeats, 0);
779 }
780 
ExpandDims(const Array & a,int8_t axis)781 Array ExpandDims(const Array& a, int8_t axis) {
782     Shape shape = a.shape();
783 
784     axis = internal::NormalizeAxis(axis, a.ndim() + 1);
785 
786     shape.insert(shape.begin() + axis, 1);
787 
788     Array out = a.Reshape(shape);
789 
790     // A trivial reshape of adding a new axis should just return a view of the input.
791     CHAINERX_ASSERT(out.raw_data() == a.raw_data());
792 
793     return out;
794 }
795 
Flip(const Array & m,const OptionalAxes & axes)796 Array Flip(const Array& m, const OptionalAxes& axes) {
797     Axes real_axes;
798     if (axes.has_value()) {
799         real_axes = internal::GetNormalizedAxes(*axes, m.ndim());
800     } else {
801         for (int8_t i = 0; i < m.ndim(); ++i) {
802             real_axes.emplace_back(m.ndim() - i - 1);
803         }
804     }
805 
806     Strides strides = m.strides();
807     Shape shape = m.shape();
808     int64_t offset = m.offset();
809     for (auto axis : real_axes) {
810         // last element of that dimension.
811         offset += std::max<int64_t>(shape[axis] - 1, 0) * strides[axis];
812         if (shape[axis] != 0) {
813             strides[axis] = -strides[axis];
814         }
815     }
816 
817     auto is_zero = std::find(shape.begin(), shape.end(), 0);
818     if (is_zero != shape.end()) {
819         offset = 0;
820     }
821 
822     Array out = internal::MakeArray(m.shape(), strides, m.dtype(), m.device(), m.data(), offset);
823 
824     BackwardBuilder bb{"flip", m, out};
825     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
826         bt.Define([real_axes](BackwardContext& bctx) {
827             const Array& gout = *bctx.output_grad();
828             bctx.input_grad() = Flip(gout, real_axes);
829         });
830     }
831     bb.Finalize();
832 
833     return out;
834 }
835 
Fliplr(const Array & m)836 Array Fliplr(const Array& m) {
837     if (m.ndim() < 2) {
838         throw DimensionError{"Input must be >= 2-d."};
839     }
840     return Flip(m, Axes{1});
841 }
842 
Flipud(const Array & m)843 Array Flipud(const Array& m) {
844     if (m.ndim() < 1) {
845         throw DimensionError{"Input must be >= 1-d."};
846     }
847     return Flip(m, Axes{0});
848 }
849 
AtLeast2D(const Array & x)850 Array AtLeast2D(const Array& x) {
851     Array out;
852 
853     {
854         NoBackpropModeScope scope;
855 
856         switch (x.ndim()) {
857             case 0:
858                 out = x.Reshape({1, 1});
859                 break;
860             case 1: {
861                 Shape shape = x.shape();
862                 Strides strides = x.strides();
863                 shape.insert(shape.begin(), 1);
864                 strides.insert(strides.begin(), 0);
865                 out = internal::MakeArray(shape, strides, x.dtype(), x.device(), x.data());
866             } break;
867             default:
868                 out = x.MakeView();
869                 break;
870         }
871     }
872 
873     BackwardBuilder bb{"atleast_2d", x, out};
874     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
875         bt.Define([in_shape = x.shape(), ndim = x.ndim()](BackwardContext& bctx) {
876             if (ndim <= 1) {
877                 bctx.input_grad() = bctx.output_grad()->Reshape(in_shape);
878             } else {
879                 bctx.input_grad() = *bctx.output_grad();
880             }
881         });
882     }
883     bb.Finalize();
884 
885     return out;
886 }
887 
AtLeast3D(const Array & x)888 Array AtLeast3D(const Array& x) {
889     Array out;
890 
891     {
892         NoBackpropModeScope scope;
893 
894         switch (x.ndim()) {
895             case 0:
896                 out = x.Reshape({1, 1, 1});
897                 break;
898             case 1: {
899                 Shape shape = x.shape();
900                 Strides strides = x.strides();
901                 shape.insert(shape.begin(), 1);
902                 shape.insert(shape.end(), 1);
903                 strides.insert(strides.begin(), 0);
904                 strides.insert(strides.end(), 0);
905                 out = internal::MakeArray(shape, strides, x.dtype(), x.device(), x.data());
906             } break;
907             case 2: {
908                 Shape shape = x.shape();
909                 Strides strides = x.strides();
910                 shape.insert(shape.end(), 1);
911                 strides.insert(strides.end(), 0);
912                 out = internal::MakeArray(shape, strides, x.dtype(), x.device(), x.data());
913             } break;
914             default:
915                 out = x.MakeView();
916                 break;
917         }
918     }
919 
920     BackwardBuilder bb{"atleast_3d", x, out};
921     if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
922         bt.Define([in_shape = x.shape(), ndim = x.ndim()](BackwardContext& bctx) {
923             if (ndim <= 2) {
924                 bctx.input_grad() = bctx.output_grad()->Reshape(in_shape);
925             } else {
926                 bctx.input_grad() = *bctx.output_grad();
927             }
928         });
929     }
930     bb.Finalize();
931 
932     return out;
933 }
934 
HStack(const std::vector<Array> & arrays)935 Array HStack(const std::vector<Array>& arrays) {
936     if (arrays.empty()) {
937         throw DimensionError{"Need at least one array to stack"};
938     }
939 
940     if (arrays.front().ndim() <= 1) {
941         return Concatenate(arrays, 0);
942     }
943     return Concatenate(arrays, 1);
944 }
945 
VStack(const std::vector<Array> & arrays)946 Array VStack(const std::vector<Array>& arrays) {
947     if (arrays.empty()) {
948         throw DimensionError{"Need at least one array to stack"};
949     }
950 
951     std::vector<Array> reshaped_arrays(arrays.size());
952     std::transform(arrays.begin(), arrays.end(), reshaped_arrays.begin(), AtLeast2D);
953 
954     return Concatenate(reshaped_arrays, 0);
955 }
956 
DStack(const std::vector<Array> & arrays)957 Array DStack(const std::vector<Array>& arrays) {
958     if (arrays.empty()) {
959         throw DimensionError{"Need at least one array to stack"};
960     }
961 
962     std::vector<Array> reshaped_arrays(arrays.size());
963     std::transform(arrays.begin(), arrays.end(), reshaped_arrays.begin(), AtLeast3D);
964     return Concatenate(reshaped_arrays, 2);
965 }
966 
Moveaxis(const Array & a,const Axes & source,const Axes & destination)967 Array Moveaxis(const Array& a, const Axes& source, const Axes& destination) {
968     if (source.size() != destination.size()) {
969         throw DimensionError{"Invalid Source or Destination Axes"};
970     }
971 
972     if (source.empty()) {
973         return a;
974     }
975 
976     const Axes& normalized_source = internal::GetNormalizedAxes(source, a.ndim());
977     const Axes& normalized_destination = internal::GetNormalizedAxes(destination, a.ndim());
978 
979     Axes order, source_axes, destination_axes;
980     order.resize(a.ndim());
981     source_axes.resize(a.ndim());
982     destination_axes.resize(a.ndim());
983 
984     std::iota(source_axes.begin(), source_axes.end(), 0);
985     std::iota(destination_axes.begin(), destination_axes.end(), 0);
986 
987     for (int8_t i = 0; i < source.ndim(); ++i) {
988         order[normalized_destination[i]] = normalized_source[i];
989         source_axes[normalized_source[i]] = -1;
990         destination_axes[normalized_destination[i]] = -1;
991     }
992 
993     auto source_iter = std::remove(source_axes.begin(), source_axes.end(), -1);
994     auto destination_iter = std::remove(destination_axes.begin(), destination_axes.end(), -1);
995 
996     int8_t rest_dim = a.ndim() - source.ndim();
997     CHAINERX_ASSERT(a.ndim() - destination.ndim() == rest_dim);
998     CHAINERX_ASSERT(static_cast<int8_t>(source_iter - source_axes.begin()) == rest_dim);
999     CHAINERX_ASSERT(static_cast<int8_t>(destination_iter - destination_axes.begin()) == rest_dim);
1000 
1001     for (int8_t i = 0; i < rest_dim; ++i) {
1002         order[destination_axes[i]] = source_axes[i];
1003     }
1004 
1005     return a.Transpose(order);
1006 }
1007 
CopyTo(const Array & dst,const Array & src,CastingMode casting,const Array & where)1008 void CopyTo(const Array& dst, const Array& src, CastingMode casting, const Array& where) {
1009     internal::CheckNoUnsafeInplace(dst, {dst, src, where});
1010 
1011     switch (casting) {
1012         case CastingMode::kNo:
1013             if (dst.dtype() != src.dtype()) {
1014                 throw DtypeError{"Source and destination must have same dtype."};
1015             }
1016             break;
1017         default:
1018             CHAINERX_NEVER_REACH();
1019     }
1020 
1021     const Array& src_b = src.shape() != dst.shape() ? src.BroadcastTo(dst.shape()) : src;
1022     const Array& where_b = where.shape() != dst.shape() ? where.BroadcastTo(dst.shape()) : where;
1023 
1024     {
1025         NoBackpropModeScope scope;
1026         dst.device().backend().CallKernel<WhereKernel>(where_b, src_b, dst, dst);
1027     }
1028 }
1029 
1030 }  // namespace chainerx
1031