1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file init_op.h
22  * \brief Function definition of initialization op
23  */
24 #ifndef MXNET_OPERATOR_TENSOR_INIT_OP_H_
25 #define MXNET_OPERATOR_TENSOR_INIT_OP_H_
26 
27 #include <mxnet/base.h>
28 #include <mxnet/operator_util.h>
29 #include <mxnet/op_attr_types.h>
30 #include <mxnet/imperative.h>
31 #include <dmlc/parameter.h>
32 #include <dmlc/optional.h>
33 #include <vector>
34 #include <string>
35 #include <algorithm>
36 #include <limits>
37 #include "../mshadow_op.h"
38 #include "../elemwise_op_common.h"
39 #include "../mxnet_op.h"
40 #include "../mshadow_op.h"
41 
42 
43 namespace mxnet {
44 namespace op {
45 
46 struct InitOpParam : public dmlc::Parameter<InitOpParam> {
47   mxnet::TShape shape;
48   std::string ctx;
49   int dtype;
DMLC_DECLARE_PARAMETERInitOpParam50   DMLC_DECLARE_PARAMETER(InitOpParam) {
51     DMLC_DECLARE_FIELD(shape)
52     .set_default(mxnet::TShape(0, 1))
53     .describe("The shape of the output");
54     DMLC_DECLARE_FIELD(ctx)
55     .set_default("")
56     .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
57               "Only used for imperative calls.");
58     DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
59     MXNET_ADD_ALL_TYPES_WITH_BOOL
60     .describe("Target data type.");
61   }
SetAttrDictInitOpParam62   void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
63     std::ostringstream shape_s, dtype_s;
64     shape_s << shape;
65     dtype_s << dtype;
66     (*dict)["shape"] = shape_s.str();
67     (*dict)["dtype"] = dtype_s.str();
68     // We do not set ctx, because ctx has been set in dict instead of InitOpParam.
69     // Setting ctx here results in an error.
70   }
71 };
72 
73 struct InitOpWithoutDTypeParam : public dmlc::Parameter<InitOpWithoutDTypeParam> {
74   mxnet::TShape shape;
75   std::string ctx;
76   int dtype;
DMLC_DECLARE_PARAMETERInitOpWithoutDTypeParam77   DMLC_DECLARE_PARAMETER(InitOpWithoutDTypeParam) {
78     DMLC_DECLARE_FIELD(shape)
79     .set_default(mxnet::TShape())
80     .describe("The shape of the output");
81     DMLC_DECLARE_FIELD(ctx)
82     .set_default("")
83     .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
84               "Only used for imperative calls.");
85     DMLC_DECLARE_FIELD(dtype)
86     .set_default(-1)
87     .describe("Target data type.");
88   }
89 };
90 
91 struct FullLikeOpParam : public dmlc::Parameter<FullLikeOpParam> {
92   double fill_value;
93   std::string ctx;
94   dmlc::optional<int> dtype;
DMLC_DECLARE_PARAMETERFullLikeOpParam95   DMLC_DECLARE_PARAMETER(FullLikeOpParam) {
96     DMLC_DECLARE_FIELD(fill_value)
97       .describe("Value with which to fill newly created tensor");
98     DMLC_DECLARE_FIELD(ctx)
99       .set_default("")
100       .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
101                 "Only used for imperative calls.");
102     DMLC_DECLARE_FIELD(dtype)
103       .set_default(dmlc::optional<int>())
104       MXNET_ADD_ALL_TYPES_WITH_BOOL
105       .describe("Target data type.");
106   }
107 };
108 
109 /*! \brief Infer type of FullLikeOpCompute*/
110 template<typename ParamType>
FullLikeOpType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)111 inline bool FullLikeOpType(const nnvm::NodeAttrs& attrs,
112                            std::vector<int> *in_attrs,
113                            std::vector<int> *out_attrs) {
114   const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
115   CHECK_EQ(in_attrs->size(), 1U);
116   CHECK_EQ(out_attrs->size(), 1U);
117   if (param.dtype.has_value()) {
118     TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
119   } else {
120     TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
121   }
122   return out_attrs->at(0) != -1;;
123 }
124 
125 struct EyeParam : public dmlc::Parameter<EyeParam> {
126   nnvm::dim_t N;
127   nnvm::dim_t M;
128   nnvm::dim_t k;
129   std::string ctx;
130   int dtype;
131 
DMLC_DECLARE_PARAMETEREyeParam132   DMLC_DECLARE_PARAMETER(EyeParam) {
133     DMLC_DECLARE_FIELD(N)
134     .describe("Number of rows in the output.");
135     DMLC_DECLARE_FIELD(M)
136     .set_default(0)
137     .describe("Number of columns in the output. If 0, defaults to N");
138     DMLC_DECLARE_FIELD(k)
139     .set_default(0)
140     .describe("Index of the diagonal. 0 (the default) refers to the main diagonal."
141               "A positive value refers to an upper diagonal."
142               "A negative value to a lower diagonal.");
143     DMLC_DECLARE_FIELD(ctx)
144     .set_default("")
145     .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
146               "Only used for imperative calls.");
147     DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
148     .add_enum("float32", mshadow::kFloat32)
149     .add_enum("float64", mshadow::kFloat64)
150     .add_enum("float16", mshadow::kFloat16)
151     .add_enum("uint8", mshadow::kUint8)
152     .add_enum("int8", mshadow::kInt8)
153     .add_enum("int32", mshadow::kInt32)
154     .add_enum("int64", mshadow::kInt64)
155     .describe("Target data type.");
156   }
157 };
158 
159 template<typename ParamType>
InitEyeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)160 inline bool InitEyeShape(const nnvm::NodeAttrs& attrs,
161                          mxnet::ShapeVector *in_attrs,
162                          mxnet::ShapeVector *out_attrs) {
163   const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
164   CHECK_EQ(in_attrs->size(), 0U);
165   CHECK_EQ(out_attrs->size(), 1U);
166   SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, param.M > 0 ? param.M : param.N));
167   return true;
168 }
169 
170 template<int req>
171 struct eye_dns_fill {
172   template<typename DType>
Mapeye_dns_fill173   MSHADOW_XINLINE static void Map(int i, DType* out_data,
174                                   const nnvm::dim_t init_col,
175                                   const nnvm::dim_t k,
176                                   const nnvm::dim_t num_cols) {
177     KERNEL_ASSIGN(out_data[(i+init_col-k)*num_cols+i+init_col], req, static_cast<DType>(1));
178   }
179 };
180 
181 
182 struct RangeParam : public dmlc::Parameter<RangeParam> {
183   double start;
184   dmlc::optional<double> stop;
185   double step;
186   int repeat;
187   bool infer_range;
188   std::string ctx;
189   int dtype;
DMLC_DECLARE_PARAMETERRangeParam190   DMLC_DECLARE_PARAMETER(RangeParam) {
191     DMLC_DECLARE_FIELD(start)
192     .describe("Start of interval. The interval includes this value. The default start value is 0.");
193     DMLC_DECLARE_FIELD(stop)
194     .set_default(dmlc::optional<double>())
195     .describe("End of interval. The interval does not include this value,"
196               " except in some cases where step is not an integer and"
197               " floating point round-off affects the length of out.");
198     DMLC_DECLARE_FIELD(step)
199     .set_default(1)
200     .describe("Spacing between values.");
201     DMLC_DECLARE_FIELD(repeat)
202     .set_default(1)
203     .describe("The repeating time of all elements."
204               " E.g repeat=3, the element a will be repeated three times --> a, a, a.");
205     DMLC_DECLARE_FIELD(infer_range)
206     .set_default(false)
207     .describe("When set to True, infer the stop position from the start, step, "
208               "repeat, and output tensor size.");
209     DMLC_DECLARE_FIELD(ctx)
210     .set_default("")
211     .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
212               "Only used for imperative calls.");
213     DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
214     MXNET_ADD_ALL_TYPES
215     .describe("Target data type.");
216   }
217 };
218 
219 struct RangeLikeParam : public dmlc::Parameter<RangeLikeParam> {
220   double start;
221   double step;
222   int repeat;
223   std::string ctx;
224   dmlc::optional<int> axis;
225 
DMLC_DECLARE_PARAMETERRangeLikeParam226   DMLC_DECLARE_PARAMETER(RangeLikeParam) {
227     DMLC_DECLARE_FIELD(start)
228     .set_default(0)
229     .describe("Start of interval. The interval includes this value. The default start value is 0.");
230     DMLC_DECLARE_FIELD(step)
231     .set_default(1)
232     .describe("Spacing between values.");
233     DMLC_DECLARE_FIELD(repeat)
234     .set_default(1)
235     .describe("The repeating time of all elements."
236               " E.g repeat=3, the element a will be repeated three times --> a, a, a.");
237     DMLC_DECLARE_FIELD(ctx)
238     .set_default("")
239     .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
240               "Only used for imperative calls.");
241     DMLC_DECLARE_FIELD(axis)
242     .set_default(dmlc::optional<int>())
243     .describe("Arange elements according to the size of a certain axis of input array."
244               " The negative numbers are interpreted counting from the backward."
245               " If not provided, will arange elements according to the input shape.");
246   }
247 };
248 
249 /*! \brief Initialize and fill output with an arbitrary value */
250 struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
251   mxnet::TShape shape;
252   std::string ctx;
253   int dtype;
254   double value;
DMLC_DECLARE_PARAMETERInitOpWithScalarParam255   DMLC_DECLARE_PARAMETER(InitOpWithScalarParam) {
256     DMLC_DECLARE_FIELD(shape)
257       .set_default(mxnet::TShape())
258       .describe("The shape of the output");
259     DMLC_DECLARE_FIELD(ctx)
260       .set_default("")
261       .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
262                   "Only used for imperative calls.");
263     DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
264       MXNET_ADD_ALL_TYPES_WITH_BOOL
265       .describe("Target data type.");
266     DMLC_DECLARE_FIELD(value)
267       .describe("Value with which to fill newly created tensor");
268   }
269 };
270 
271 /*! \brief Parse keyword arguments as PType arguments and save to parsed */
RangeParamParser(nnvm::NodeAttrs * attrs)272 inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
273   RangeParam param;
274   param.Init(attrs->dict);
275   if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
276     param.stop = param.start;
277     param.start = 0;
278   }
279   attrs->parsed = std::move(param);
280 }
281 
282 struct LinspaceParam : public dmlc::Parameter<LinspaceParam> {
283   double start;
284   double stop;
285   int num;
286   bool endpoint;
287   std::string ctx;
288   int dtype;
DMLC_DECLARE_PARAMETERLinspaceParam289   DMLC_DECLARE_PARAMETER(LinspaceParam) {
290     DMLC_DECLARE_FIELD(start)
291     .describe("The starting value of the sequence.");
292     DMLC_DECLARE_FIELD(stop)
293     .describe("The ending value of the sequence");
294     DMLC_DECLARE_FIELD(num)
295     .describe("Number of samples to generate. Must be non-negative.");
296     DMLC_DECLARE_FIELD(endpoint)
297     .set_default(true)
298     .describe("If True, stop is the last sample. Otherwise, it is not included.");
299     DMLC_DECLARE_FIELD(ctx)
300     .set_default("")
301     .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
302               "Only used for imperative calls.");
303     DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
304     MXNET_ADD_ALL_TYPES
305     .describe("Target data type.");
306   }
307 };
308 
309 template<typename ParamType>
InitShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)310 inline bool InitShape(const nnvm::NodeAttrs& attrs,
311                       mxnet::ShapeVector *in_attrs,
312                       mxnet::ShapeVector *out_attrs) {
313   const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
314   CHECK_EQ(in_attrs->size(), 0U);
315   CHECK_EQ(out_attrs->size(), 1U);
316   mxnet::TShape param_shape = param.shape;
317   if (shape_is_known(param_shape) && !features::is_enabled(features::INT64_TENSOR_SIZE)) {
318     CHECK_LT(param_shape.Size(), (int64_t{1} << 31) - 1) <<
319               "[InitShape-input] Size of tensor you are trying to allocate is larger than "
320               "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
321   }
322   if (!Imperative::Get()->is_np_shape()) {
323     common::ConvertToNumpyShape(&param_shape);
324   }
325   if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) {
326     if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
327       CHECK_LT(out_attrs->at(0).Size() , (int64_t{1} << 31) - 1) <<
328                 "[InitShape-output] Size of tensor you are trying to allocate is larger than "
329                 "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
330     }
331     return true;
332   }
333   SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_shape);
334   return shape_is_known(out_attrs->at(0));
335 }
336 
337 template<typename ParamType, int num_in = 0U>
InitType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)338 inline bool InitType(const nnvm::NodeAttrs& attrs,
339                        std::vector<int> *in_attrs,
340                        std::vector<int> *out_attrs) {
341   const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
342   CHECK_EQ(in_attrs->size(), num_in);
343   CHECK_EQ(out_attrs->size(), 1U);
344   TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
345   return true;
346 }
347 
348 template<typename ParamType, bool rsp, bool csr>
InitStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)349 inline bool InitStorageType(const nnvm::NodeAttrs& attrs,
350                             const int dev_mask,
351                             DispatchMode* dispatch_mode,
352                             std::vector<int> *in_attrs,
353                             std::vector<int> *out_attrs) {
354   CHECK_EQ(in_attrs->size(), 0U);
355   CHECK_EQ(out_attrs->size(), 1U);
356   auto &out_stype = out_attrs->at(0);
357   bool dispatched = false;
358   type_assign(&out_stype, kDefaultStorage);
359   if (!dispatched && out_stype == kDefaultStorage) {
360     // default
361     dispatched = storage_type_assign(out_attrs, kDefaultStorage,
362                                      dispatch_mode, DispatchMode::kFCompute);
363   }
364   if (!dispatched && rsp && out_stype == kRowSparseStorage) {
365     // rsp
366     dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
367                                      dispatch_mode, DispatchMode::kFComputeEx);
368   }
369   if (!dispatched && csr && out_stype == kCSRStorage) {
370     // csr
371     dispatched = storage_type_assign(out_attrs, kCSRStorage,
372                                      dispatch_mode, DispatchMode::kFComputeEx);
373   }
374   if (!dispatched) {
375     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
376   }
377   return dispatched;
378 }
379 
380 /*!
381  * \brief General-purpose blob value-filling function
382  * \tparam xpu cpu or gpu
383  * \tparam ValueType Data type of supplied value
384  * \tparam is_integer Whether to optimize for an integer value
385  * \param s Stream
386  * \param b The blob to fill with a value
387  * \param req Request type (kNullOp, kWriteTo, etc)
388  * \param val The value to use for the filling operation
389  */
390 template <bool is_integer = false, typename ValueType, typename xpu>
Fill(mshadow::Stream<xpu> * s,const TBlob & b,const OpReqType req,ValueType val)391 void Fill(mshadow::Stream<xpu> *s, const TBlob& b, const OpReqType req, ValueType val) {
392   // If b is a zero-size tensor, do nothing.
393   if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
394     CHECK_LT(b.Size(), (int64_t{1} << 31) - 1) <<
395               "[Fill] Size of tensor you are trying to allocate is larger than "
396               "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
397   }
398   if (b.Size() == 0) return;
399   if (req != kNullOp) {
400     const size_t size = b.Size();
401     if (val == 0) {
402       if (req != kAddTo) {
403         if (b.dev_mask() == cpu::kDevMask && size < 50000) {
404           MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, {
405             memset(b.dptr_, 0, size * sizeof(DType));
406           });
407         } else {
408           // Optimize common use-case of filling with ones
409           MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, {
410             MXNET_ASSIGN_REQ_SWITCH(req, Req, {
411               mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_to_int<0>, Req>, xpu>::Launch(
412                 s, b.Size(), b.dptr<DType>());
413             });
414           });
415         }
416       }
417     } else if (is_integer && val == 1) {
418       // Optimize common use-case of filling with ones
419       MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, {
420         MXNET_ASSIGN_REQ_SWITCH(req, Req, {
421           mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_one, Req>, xpu>::Launch(
422             s, b.Size(), b.dptr<DType>());
423         });
424       });
425     } else {
426       // Generic fill kernel from variable
427       MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, {
428         MXNET_ASSIGN_REQ_SWITCH(req, Req, {
429           mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch(
430             s, b.Size(), b.dptr<DType>(), static_cast<DType>(val));
431         });
432       });
433     }
434   }
435 }
436 
437 /*! \brief Fill output with a scalar integer value */
438 template<typename xpu, int value>
FillCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)439 void FillCompute(const nnvm::NodeAttrs& attrs,
440                  const OpContext& ctx,
441                  const std::vector<TBlob>& inputs,
442                  const std::vector<OpReqType>& req,
443                  const std::vector<TBlob>& outputs) {
444   Fill<true>(ctx.get_stream<xpu>(), outputs[0], req[0], value);
445 }
446 
447 /*! \brief Fill output with a scalar integer value */
448 template<typename xpu>
FullLikeOpCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)449 void FullLikeOpCompute(const nnvm::NodeAttrs& attrs,
450                        const OpContext& ctx,
451                        const std::vector<TBlob>& inputs,
452                        const std::vector<OpReqType>& req,
453                        const std::vector<TBlob>& outputs) {
454   CHECK_EQ(inputs.size(), 1U);
455   CHECK_EQ(outputs.size(), 1U);
456   const auto& param = nnvm::get<FullLikeOpParam>(attrs.parsed);
457   Fill<false>(ctx.get_stream<xpu>(), outputs[0], req[0], param.fill_value);
458 }
459 
460 /*! \brief Fill output with an arbitrary value */
461 template<typename xpu>
InitFillWithScalarCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)462 void InitFillWithScalarCompute(const nnvm::NodeAttrs &attrs,
463                                const OpContext &ctx,
464                                const std::vector<TBlob> &inputs,
465                                const std::vector<OpReqType> &req,
466                                const std::vector<TBlob> &outputs) {
467   CHECK_EQ(inputs.size(), 0);
468   CHECK_EQ(outputs.size(), 1U);
469   const auto& param = nnvm::get<InitOpWithScalarParam>(attrs.parsed);
470   Fill<false>(ctx.get_stream<xpu>(), outputs[0], req[0], param.value);
471 }
472 
473 struct PopulateFullIdxRspKernel : public mxnet_op::tunable {
474   template<typename IType>
MapPopulateFullIdxRspKernel475   MSHADOW_XINLINE static void Map(int i, IType* out) {
476     KERNEL_ASSIGN(out[i], kWriteTo, i);
477   }
478 };
479 
480 // Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
481 // instead of the usual compact representation.
482 template<typename xpu>
FillDnsZerosRspImpl(mshadow::Stream<xpu> * s,NDArray * dst)483 inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
484   using namespace rowsparse;
485   using namespace mshadow::expr;
486   using namespace mshadow;
487   using namespace mxnet_op;
488   CHECK_EQ(dst->storage_type(), kRowSparseStorage);
489   MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
490     const index_t num_rows = dst->shape()[0];
491     dst->CheckAndAlloc({Shape1(num_rows)});
492     Fill<true>(s, dst->data(), kWriteTo, 0);
493     auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
494     Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, num_rows, idx.dptr_);
495   });
496 }
497 
498 /*!
499  * \brief Fill a rsp NDArray with zeros by updating the aux shape.
500  * \tparam xpu - cpu or gpu
501  * \param s - The device stream
502  * \param dst - NDArray which is to be set to "all zeroes"
503  */
504 template<typename xpu>
FillZerosRspImpl(mshadow::Stream<xpu> *,const NDArray & dst)505 void FillZerosRspImpl(mshadow::Stream<xpu> *, const NDArray& dst) {
506   CHECK_EQ(dst.storage_type(), kRowSparseStorage) << "dst should be an RSP NDArray";
507   if (dst.storage_initialized()) {
508     // reset the shapes if it's not zeros (set_aux_shape() will set storage_shape to zero as well)
509     dst.set_aux_shape(rowsparse::kIdx, mxnet::TShape(mshadow::Shape1(0)));
510   }
511 }
512 
513 /*!
514  * \brief Fill a CSR NDArray with zeros by updating the aux shape
515  * \param s - The device stream
516  * \param dst - NDArray which is to be set to "all zeroes"
517  */
FillZerosCsrImpl(mshadow::Stream<mshadow::cpu> * s,const NDArray & dst)518 inline void FillZerosCsrImpl(mshadow::Stream<mshadow::cpu> *s, const NDArray& dst) {
519   CHECK_EQ(dst.storage_type(), kCSRStorage) << "dst is not a CSR NDArray";
520   dst.set_aux_shape(csr::kIdx, mshadow::Shape1(0));
521   dst.CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(dst.shape()[0] + 1));
522   TBlob indptr_data = dst.aux_data(csr::kIndPtr);
523   Fill<true>(s, dst.aux_data(csr::kIndPtr), kWriteTo, 0);
524 }
525 void FillZerosCsrImpl(mshadow::Stream<mshadow::gpu> *s, const NDArray& dst);
526 
527 /*!
528  * \brief Fill an NDArray with zeros
529  * \tparam xpu - cpu or gpu
530  * \param attrs  - node attributes (unused)
531  * \param ctx - Device context
532  * \param inputs - NDArray inputs (unused)
533  * \param req - Request type (i.e. kWrite, kNullOp, etc.)
534  * \param outputs - Array which contains at position zero (0) the array to be set to zeros
535  */
536 template<typename xpu>
FillComputeZerosEx(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)537 void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
538                         const OpContext& ctx,
539                         const std::vector<NDArray>& inputs,
540                         const std::vector<OpReqType>& req,
541                         const std::vector<NDArray>& outputs) {
542   using namespace mshadow;
543   using namespace mshadow::expr;
544   Stream<xpu> *s = ctx.get_stream<xpu>();
545   CHECK_EQ(outputs.size(), 1);
546   auto stype = outputs[0].storage_type();
547   // x + 0 == x
548   if (req[0] == kNullOp || req[0] == kAddTo) return;
549   if (stype == kRowSparseStorage) {
550     FillZerosRspImpl(s, outputs[0]);
551   } else if (stype == kCSRStorage) {
552     FillZerosCsrImpl(s, outputs[0]);
553   } else {
554     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
555   }
556 }
557 
558 template<typename xpu>
EyeFillImpl(const TBlob & out_data,const OpContext & ctx,const std::vector<OpReqType> & req,const nnvm::dim_t num_cols,const nnvm::dim_t N,const nnvm::dim_t k)559 inline void EyeFillImpl(const TBlob& out_data,
560                         const OpContext& ctx,
561                         const std::vector<OpReqType>& req,
562                         const nnvm::dim_t num_cols,
563                         const nnvm::dim_t N,
564                         const nnvm::dim_t k) {
565   using namespace mxnet_op;
566   const nnvm::dim_t cnnz = std::max(num_cols - std::abs(k), (nnvm::dim_t)0);
567   const nnvm::dim_t rnnz = std::max(N - std::abs(k), (nnvm::dim_t)0);
568   const nnvm::dim_t nnz = k > 0 ? std::min(cnnz, N) :
569                           std::min(rnnz, num_cols);
570   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
571   MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
572       MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
573         Fill(s, out_data, req[0], static_cast<DType>(0));
574         if (nnz > 0) {
575           Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
576             std::max(static_cast<nnvm::dim_t>(0), k), k, num_cols);
577         }
578       });
579   });
580 }
581 
582 template<typename xpu>
EyeFill(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)583 void EyeFill(const nnvm::NodeAttrs& attrs,
584              const OpContext& ctx,
585              const std::vector<TBlob>& inputs,
586              const std::vector<OpReqType>& req,
587              const std::vector<TBlob>& outputs) {
588   CHECK_EQ(inputs.size(), 0U);
589   CHECK_EQ(outputs.size(), 1U);
590   CHECK_EQ(req.size(), 1U);
591   const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
592   const TBlob& out_data = outputs[0];
593   const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;
594   EyeFillImpl<xpu>(out_data, ctx, req, num_cols, param.N, param.k);
595 }
596 
597 
598 struct range_fwd {
599   template<typename DType>
Maprange_fwd600   MSHADOW_XINLINE static void Map(index_t i, index_t repeat, DType start, DType step,
601                                   int req, DType* out) {
602     KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step);
603   }
604 };
605 
606 template<typename xpu, typename ParamType>
RangeCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)607 void RangeCompute(const nnvm::NodeAttrs& attrs,
608                   const OpContext& ctx,
609                   const std::vector<TBlob>& inputs,
610                   const std::vector<OpReqType>& req,
611                   const std::vector<TBlob>& outputs) {
612   using namespace mxnet_op;
613   Stream<xpu> *s = ctx.get_stream<xpu>();
614   const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
615   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
616       // Force unsigned params to take two's complement form on ARM to ensure consistency with x86
617       // results.  Casting negative floats to unsigned types is undefined in the CPP standard.
618       auto step = std::is_signed<DType>() ? param.step : static_cast<index_t>(param.step);
619       auto start = std::is_signed<DType>() ? param.start : static_cast<index_t>(param.start);
620       Kernel<range_fwd, xpu>::Launch(s,
621                                      outputs[0].Size(),
622                                      static_cast<int>(param.repeat),
623                                      static_cast<DType>(start),
624                                      static_cast<DType>(step),
625                                      req[0],
626                                      outputs[0].dptr<DType>());
627   });
628 }
629 
630 
RangeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)631 inline bool RangeShape(const nnvm::NodeAttrs& attrs,
632                        mxnet::ShapeVector *in_attrs,
633                        mxnet::ShapeVector *out_attrs) {
634   const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
635   CHECK_EQ(in_attrs->size(), 0U);
636   CHECK_EQ(out_attrs->size(), 1U);
637   CHECK_NE(param.step, 0)
638     << "Range does not support step=0, received " << param.step;
639   CHECK(param.repeat > 0)
640     << "Range only supports repeat > 0, received " << param.repeat;
641   if (param.infer_range && !param.stop.has_value()) {
642     return false;
643   }
644   if (param.step > 0) {
645     CHECK(param.start < param.stop.value())
646       << "Invalid range (start, stop, step) = "
647       << "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
648   } else {
649     CHECK(param.start > param.stop.value())
650       << "Invalid range (start, stop, step)= "
651       << "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
652   }
653   const double out_size = std::ceil((param.stop.value() - param.start) / param.step)
654                           * param.repeat;
655   mxnet::TShape output_shape = mxnet::TShape({static_cast<nnvm::dim_t>(out_size)});
656   if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
657     CHECK_LT(output_shape.Size(), (int64_t{1} << 31) - 1) <<
658               "[RangeShape] Size of tensor you are trying to allocate is larger than "
659               "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
660   }
661   SHAPE_ASSIGN_CHECK(*out_attrs, 0, output_shape);
662   return true;
663 }
664 
665 struct linspace_fwd {
666   template<typename DType>
Maplinspace_fwd667   MSHADOW_XINLINE static void Map(index_t i, double start, double stop, double step,
668                                   int req, DType* out) {
669     KERNEL_ASSIGN(out[i], req, static_cast<DType>(start + step * i));
670   }
671 };
672 
673 template<typename xpu>
LinspaceCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)674 void LinspaceCompute(const nnvm::NodeAttrs& attrs,
675                      const OpContext& ctx,
676                      const std::vector<TBlob>& inputs,
677                      const std::vector<OpReqType>& req,
678                      const std::vector<TBlob>& outputs) {
679   using namespace mxnet_op;
680   Stream<xpu> *s = ctx.get_stream<xpu>();
681   const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
682   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
683       int step_num = param.endpoint ? param.num - 1 : param.num;
684       double step = step_num > 0 ? (param.stop - param.start) / step_num : 0.0f;
685       Kernel<linspace_fwd, xpu>::Launch(s,
686                                         outputs[0].Size(),
687                                         param.start,
688                                         param.stop,
689                                         step,
690                                         req[0],
691                                         outputs[0].dptr<DType>());
692   });
693 }
694 
LinspaceShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)695 inline bool LinspaceShape(const nnvm::NodeAttrs& attrs,
696                        mxnet::ShapeVector *in_attrs,
697                        mxnet::ShapeVector *out_attrs) {
698   const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
699   CHECK_EQ(in_attrs->size(), 0U);
700   CHECK_EQ(out_attrs->size(), 1U);
701   CHECK_GE(param.num, 0)
702     << "Number of sequence should be non-negative, received " << param.num;
703   mxnet::TShape shape = mxnet::TShape({static_cast<nnvm::dim_t>(param.num)});
704   SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape);
705   return true;
706 }
707 
RangeLikeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)708 inline bool RangeLikeShape(const nnvm::NodeAttrs& attrs,
709                            mxnet::ShapeVector *in_attrs,
710                            mxnet::ShapeVector *out_attrs) {
711   const RangeLikeParam& param = nnvm::get<RangeLikeParam>(attrs.parsed);
712   CHECK_EQ(in_attrs->size(), 1U);
713   CHECK_EQ(out_attrs->size(), 1U);
714   int real_axis = -1;
715   if (param.axis.has_value()) {
716     real_axis = param.axis.value() < 0 ?
717         (param.axis.value() + (*in_attrs)[0].ndim()) : param.axis.value();
718     CHECK(real_axis >=0 && real_axis < (*in_attrs)[0].ndim())
719         << "cannot handle param.axis " << param.axis.value() << ".";
720   }
721   if (real_axis == -1) {
722     SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
723   } else {
724     const index_t out_size = (*in_attrs)[0][real_axis];
725     SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
726   }
727   return true;
728 }
729 
730 }  // namespace op
731 }  // namespace mxnet
732 
733 #endif  // MXNET_OPERATOR_TENSOR_INIT_OP_H_
734