1 #pragma once 2 3 #include <cstdint> 4 5 #include "chainerx/array.h" 6 #include "chainerx/kernel.h" 7 #include "chainerx/scalar.h" 8 9 namespace chainerx { 10 11 class ArangeKernel : public Kernel { 12 public: 13 virtual void Call(Scalar start, Scalar step, const Array& out) = 0; 14 }; 15 16 class CopyKernel : public Kernel { 17 public: 18 // Copies the elements from one array to the other. 19 // 20 // The arrays must match in shape and dtype and need to reside on this device. 21 virtual void Call(const Array& a, const Array& out) = 0; 22 }; 23 24 class IdentityKernel : public Kernel { 25 public: 26 // Creates the identity array. 27 // out must be a square 2-dim array. 28 virtual void Call(const Array& out) = 0; 29 }; 30 31 class EyeKernel : public Kernel { 32 public: 33 // Creates a 2-dimensional array with ones along the k-th diagonal and zeros elsewhere. 34 // out must be a square 2-dim array. 35 virtual void Call(int64_t k, const Array& out) = 0; 36 }; 37 38 class DiagflatKernel : public Kernel { 39 public: 40 virtual void Call(const Array& v, int64_t k, const Array& out) = 0; 41 }; 42 43 class LinspaceKernel : public Kernel { 44 public: 45 // Creates an evenly spaced 1-d array. 46 // `out.ndim()` must be 1 with at least 1 elements. 47 virtual void Call(double start, double stop, const Array& out) = 0; 48 }; 49 50 class TriKernel : public Kernel { 51 public: 52 // Creates a 2-dimensional array with ones at and below the given diagonal. 53 // out must be a 2-dim array. 54 virtual void Call(int64_t k, const Array& out) = 0; 55 }; 56 57 } // namespace chainerx 58