1 #pragma once
2 
3 #include <cstdint>
4 
5 #include "chainerx/array.h"
6 #include "chainerx/kernel.h"
7 #include "chainerx/scalar.h"
8 
9 namespace chainerx {
10 
11 class ArangeKernel : public Kernel {
12 public:
13     virtual void Call(Scalar start, Scalar step, const Array& out) = 0;
14 };
15 
16 class CopyKernel : public Kernel {
17 public:
18     // Copies the elements from one array to the other.
19     //
20     // The arrays must match in shape and dtype and need to reside on this device.
21     virtual void Call(const Array& a, const Array& out) = 0;
22 };
23 
24 class IdentityKernel : public Kernel {
25 public:
26     // Creates the identity array.
27     // out must be a square 2-dim array.
28     virtual void Call(const Array& out) = 0;
29 };
30 
31 class EyeKernel : public Kernel {
32 public:
33     // Creates a 2-dimensional array with ones along the k-th diagonal and zeros elsewhere.
34     // out must be a square 2-dim array.
35     virtual void Call(int64_t k, const Array& out) = 0;
36 };
37 
38 class DiagflatKernel : public Kernel {
39 public:
40     virtual void Call(const Array& v, int64_t k, const Array& out) = 0;
41 };
42 
43 class LinspaceKernel : public Kernel {
44 public:
45     // Creates an evenly spaced 1-d array.
46     // `out.ndim()` must be 1 with at least 1 elements.
47     virtual void Call(double start, double stop, const Array& out) = 0;
48 };
49 
50 class TriKernel : public Kernel {
51 public:
52     // Creates a 2-dimensional array with ones at and below the given diagonal.
53     // out must be a 2-dim array.
54     virtual void Call(int64_t k, const Array& out) = 0;
55 };
56 
57 }  // namespace chainerx
58