1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file init_op.h
22 * \brief Function definition of initialization op
23 */
24 #ifndef MXNET_OPERATOR_TENSOR_INIT_OP_H_
25 #define MXNET_OPERATOR_TENSOR_INIT_OP_H_
26
27 #include <mxnet/base.h>
28 #include <mxnet/operator_util.h>
29 #include <mxnet/op_attr_types.h>
30 #include <mxnet/imperative.h>
31 #include <dmlc/parameter.h>
32 #include <dmlc/optional.h>
33 #include <vector>
34 #include <string>
35 #include <algorithm>
36 #include <limits>
37 #include "../mshadow_op.h"
38 #include "../elemwise_op_common.h"
39 #include "../mxnet_op.h"
40 #include "../mshadow_op.h"
41
42
43 namespace mxnet {
44 namespace op {
45
46 struct InitOpParam : public dmlc::Parameter<InitOpParam> {
47 mxnet::TShape shape;
48 std::string ctx;
49 int dtype;
DMLC_DECLARE_PARAMETERInitOpParam50 DMLC_DECLARE_PARAMETER(InitOpParam) {
51 DMLC_DECLARE_FIELD(shape)
52 .set_default(mxnet::TShape(0, 1))
53 .describe("The shape of the output");
54 DMLC_DECLARE_FIELD(ctx)
55 .set_default("")
56 .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
57 "Only used for imperative calls.");
58 DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
59 MXNET_ADD_ALL_TYPES_WITH_BOOL
60 .describe("Target data type.");
61 }
SetAttrDictInitOpParam62 void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
63 std::ostringstream shape_s, dtype_s;
64 shape_s << shape;
65 dtype_s << dtype;
66 (*dict)["shape"] = shape_s.str();
67 (*dict)["dtype"] = dtype_s.str();
68 // We do not set ctx, because ctx has been set in dict instead of InitOpParam.
69 // Setting ctx here results in an error.
70 }
71 };
72
73 struct InitOpWithoutDTypeParam : public dmlc::Parameter<InitOpWithoutDTypeParam> {
74 mxnet::TShape shape;
75 std::string ctx;
76 int dtype;
DMLC_DECLARE_PARAMETERInitOpWithoutDTypeParam77 DMLC_DECLARE_PARAMETER(InitOpWithoutDTypeParam) {
78 DMLC_DECLARE_FIELD(shape)
79 .set_default(mxnet::TShape())
80 .describe("The shape of the output");
81 DMLC_DECLARE_FIELD(ctx)
82 .set_default("")
83 .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
84 "Only used for imperative calls.");
85 DMLC_DECLARE_FIELD(dtype)
86 .set_default(-1)
87 .describe("Target data type.");
88 }
89 };
90
91 struct FullLikeOpParam : public dmlc::Parameter<FullLikeOpParam> {
92 double fill_value;
93 std::string ctx;
94 dmlc::optional<int> dtype;
DMLC_DECLARE_PARAMETERFullLikeOpParam95 DMLC_DECLARE_PARAMETER(FullLikeOpParam) {
96 DMLC_DECLARE_FIELD(fill_value)
97 .describe("Value with which to fill newly created tensor");
98 DMLC_DECLARE_FIELD(ctx)
99 .set_default("")
100 .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
101 "Only used for imperative calls.");
102 DMLC_DECLARE_FIELD(dtype)
103 .set_default(dmlc::optional<int>())
104 MXNET_ADD_ALL_TYPES_WITH_BOOL
105 .describe("Target data type.");
106 }
107 };
108
109 /*! \brief Infer type of FullLikeOpCompute*/
110 template<typename ParamType>
FullLikeOpType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)111 inline bool FullLikeOpType(const nnvm::NodeAttrs& attrs,
112 std::vector<int> *in_attrs,
113 std::vector<int> *out_attrs) {
114 const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
115 CHECK_EQ(in_attrs->size(), 1U);
116 CHECK_EQ(out_attrs->size(), 1U);
117 if (param.dtype.has_value()) {
118 TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
119 } else {
120 TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
121 }
122 return out_attrs->at(0) != -1;;
123 }
124
125 struct EyeParam : public dmlc::Parameter<EyeParam> {
126 nnvm::dim_t N;
127 nnvm::dim_t M;
128 nnvm::dim_t k;
129 std::string ctx;
130 int dtype;
131
DMLC_DECLARE_PARAMETEREyeParam132 DMLC_DECLARE_PARAMETER(EyeParam) {
133 DMLC_DECLARE_FIELD(N)
134 .describe("Number of rows in the output.");
135 DMLC_DECLARE_FIELD(M)
136 .set_default(0)
137 .describe("Number of columns in the output. If 0, defaults to N");
138 DMLC_DECLARE_FIELD(k)
139 .set_default(0)
140 .describe("Index of the diagonal. 0 (the default) refers to the main diagonal."
141 "A positive value refers to an upper diagonal."
142 "A negative value to a lower diagonal.");
143 DMLC_DECLARE_FIELD(ctx)
144 .set_default("")
145 .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
146 "Only used for imperative calls.");
147 DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
148 .add_enum("float32", mshadow::kFloat32)
149 .add_enum("float64", mshadow::kFloat64)
150 .add_enum("float16", mshadow::kFloat16)
151 .add_enum("uint8", mshadow::kUint8)
152 .add_enum("int8", mshadow::kInt8)
153 .add_enum("int32", mshadow::kInt32)
154 .add_enum("int64", mshadow::kInt64)
155 .describe("Target data type.");
156 }
157 };
158
159 template<typename ParamType>
InitEyeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)160 inline bool InitEyeShape(const nnvm::NodeAttrs& attrs,
161 mxnet::ShapeVector *in_attrs,
162 mxnet::ShapeVector *out_attrs) {
163 const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
164 CHECK_EQ(in_attrs->size(), 0U);
165 CHECK_EQ(out_attrs->size(), 1U);
166 SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, param.M > 0 ? param.M : param.N));
167 return true;
168 }
169
170 template<int req>
171 struct eye_dns_fill {
172 template<typename DType>
Mapeye_dns_fill173 MSHADOW_XINLINE static void Map(int i, DType* out_data,
174 const nnvm::dim_t init_col,
175 const nnvm::dim_t k,
176 const nnvm::dim_t num_cols) {
177 KERNEL_ASSIGN(out_data[(i+init_col-k)*num_cols+i+init_col], req, static_cast<DType>(1));
178 }
179 };
180
181
182 struct RangeParam : public dmlc::Parameter<RangeParam> {
183 double start;
184 dmlc::optional<double> stop;
185 double step;
186 int repeat;
187 bool infer_range;
188 std::string ctx;
189 int dtype;
DMLC_DECLARE_PARAMETERRangeParam190 DMLC_DECLARE_PARAMETER(RangeParam) {
191 DMLC_DECLARE_FIELD(start)
192 .describe("Start of interval. The interval includes this value. The default start value is 0.");
193 DMLC_DECLARE_FIELD(stop)
194 .set_default(dmlc::optional<double>())
195 .describe("End of interval. The interval does not include this value,"
196 " except in some cases where step is not an integer and"
197 " floating point round-off affects the length of out.");
198 DMLC_DECLARE_FIELD(step)
199 .set_default(1)
200 .describe("Spacing between values.");
201 DMLC_DECLARE_FIELD(repeat)
202 .set_default(1)
203 .describe("The repeating time of all elements."
204 " E.g repeat=3, the element a will be repeated three times --> a, a, a.");
205 DMLC_DECLARE_FIELD(infer_range)
206 .set_default(false)
207 .describe("When set to True, infer the stop position from the start, step, "
208 "repeat, and output tensor size.");
209 DMLC_DECLARE_FIELD(ctx)
210 .set_default("")
211 .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
212 "Only used for imperative calls.");
213 DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
214 MXNET_ADD_ALL_TYPES
215 .describe("Target data type.");
216 }
217 };
218
219 struct RangeLikeParam : public dmlc::Parameter<RangeLikeParam> {
220 double start;
221 double step;
222 int repeat;
223 std::string ctx;
224 dmlc::optional<int> axis;
225
DMLC_DECLARE_PARAMETERRangeLikeParam226 DMLC_DECLARE_PARAMETER(RangeLikeParam) {
227 DMLC_DECLARE_FIELD(start)
228 .set_default(0)
229 .describe("Start of interval. The interval includes this value. The default start value is 0.");
230 DMLC_DECLARE_FIELD(step)
231 .set_default(1)
232 .describe("Spacing between values.");
233 DMLC_DECLARE_FIELD(repeat)
234 .set_default(1)
235 .describe("The repeating time of all elements."
236 " E.g repeat=3, the element a will be repeated three times --> a, a, a.");
237 DMLC_DECLARE_FIELD(ctx)
238 .set_default("")
239 .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
240 "Only used for imperative calls.");
241 DMLC_DECLARE_FIELD(axis)
242 .set_default(dmlc::optional<int>())
243 .describe("Arange elements according to the size of a certain axis of input array."
244 " The negative numbers are interpreted counting from the backward."
245 " If not provided, will arange elements according to the input shape.");
246 }
247 };
248
249 /*! \brief Initialize and fill output with an arbitrary value */
250 struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
251 mxnet::TShape shape;
252 std::string ctx;
253 int dtype;
254 double value;
DMLC_DECLARE_PARAMETERInitOpWithScalarParam255 DMLC_DECLARE_PARAMETER(InitOpWithScalarParam) {
256 DMLC_DECLARE_FIELD(shape)
257 .set_default(mxnet::TShape())
258 .describe("The shape of the output");
259 DMLC_DECLARE_FIELD(ctx)
260 .set_default("")
261 .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
262 "Only used for imperative calls.");
263 DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
264 MXNET_ADD_ALL_TYPES_WITH_BOOL
265 .describe("Target data type.");
266 DMLC_DECLARE_FIELD(value)
267 .describe("Value with which to fill newly created tensor");
268 }
269 };
270
271 /*! \brief Parse keyword arguments as PType arguments and save to parsed */
RangeParamParser(nnvm::NodeAttrs * attrs)272 inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
273 RangeParam param;
274 param.Init(attrs->dict);
275 if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
276 param.stop = param.start;
277 param.start = 0;
278 }
279 attrs->parsed = std::move(param);
280 }
281
282 struct LinspaceParam : public dmlc::Parameter<LinspaceParam> {
283 double start;
284 double stop;
285 int num;
286 bool endpoint;
287 std::string ctx;
288 int dtype;
DMLC_DECLARE_PARAMETERLinspaceParam289 DMLC_DECLARE_PARAMETER(LinspaceParam) {
290 DMLC_DECLARE_FIELD(start)
291 .describe("The starting value of the sequence.");
292 DMLC_DECLARE_FIELD(stop)
293 .describe("The ending value of the sequence");
294 DMLC_DECLARE_FIELD(num)
295 .describe("Number of samples to generate. Must be non-negative.");
296 DMLC_DECLARE_FIELD(endpoint)
297 .set_default(true)
298 .describe("If True, stop is the last sample. Otherwise, it is not included.");
299 DMLC_DECLARE_FIELD(ctx)
300 .set_default("")
301 .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
302 "Only used for imperative calls.");
303 DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
304 MXNET_ADD_ALL_TYPES
305 .describe("Target data type.");
306 }
307 };
308
309 template<typename ParamType>
InitShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)310 inline bool InitShape(const nnvm::NodeAttrs& attrs,
311 mxnet::ShapeVector *in_attrs,
312 mxnet::ShapeVector *out_attrs) {
313 const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
314 CHECK_EQ(in_attrs->size(), 0U);
315 CHECK_EQ(out_attrs->size(), 1U);
316 mxnet::TShape param_shape = param.shape;
317 if (shape_is_known(param_shape) && !features::is_enabled(features::INT64_TENSOR_SIZE)) {
318 CHECK_LT(param_shape.Size(), (int64_t{1} << 31) - 1) <<
319 "[InitShape-input] Size of tensor you are trying to allocate is larger than "
320 "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
321 }
322 if (!Imperative::Get()->is_np_shape()) {
323 common::ConvertToNumpyShape(¶m_shape);
324 }
325 if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) {
326 if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
327 CHECK_LT(out_attrs->at(0).Size() , (int64_t{1} << 31) - 1) <<
328 "[InitShape-output] Size of tensor you are trying to allocate is larger than "
329 "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
330 }
331 return true;
332 }
333 SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_shape);
334 return shape_is_known(out_attrs->at(0));
335 }
336
337 template<typename ParamType, int num_in = 0U>
InitType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)338 inline bool InitType(const nnvm::NodeAttrs& attrs,
339 std::vector<int> *in_attrs,
340 std::vector<int> *out_attrs) {
341 const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
342 CHECK_EQ(in_attrs->size(), num_in);
343 CHECK_EQ(out_attrs->size(), 1U);
344 TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
345 return true;
346 }
347
348 template<typename ParamType, bool rsp, bool csr>
InitStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)349 inline bool InitStorageType(const nnvm::NodeAttrs& attrs,
350 const int dev_mask,
351 DispatchMode* dispatch_mode,
352 std::vector<int> *in_attrs,
353 std::vector<int> *out_attrs) {
354 CHECK_EQ(in_attrs->size(), 0U);
355 CHECK_EQ(out_attrs->size(), 1U);
356 auto &out_stype = out_attrs->at(0);
357 bool dispatched = false;
358 type_assign(&out_stype, kDefaultStorage);
359 if (!dispatched && out_stype == kDefaultStorage) {
360 // default
361 dispatched = storage_type_assign(out_attrs, kDefaultStorage,
362 dispatch_mode, DispatchMode::kFCompute);
363 }
364 if (!dispatched && rsp && out_stype == kRowSparseStorage) {
365 // rsp
366 dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
367 dispatch_mode, DispatchMode::kFComputeEx);
368 }
369 if (!dispatched && csr && out_stype == kCSRStorage) {
370 // csr
371 dispatched = storage_type_assign(out_attrs, kCSRStorage,
372 dispatch_mode, DispatchMode::kFComputeEx);
373 }
374 if (!dispatched) {
375 dispatched = dispatch_fallback(out_attrs, dispatch_mode);
376 }
377 return dispatched;
378 }
379
380 /*!
381 * \brief General-purpose blob value-filling function
382 * \tparam xpu cpu or gpu
383 * \tparam ValueType Data type of supplied value
384 * \tparam is_integer Whether to optimize for an integer value
385 * \param s Stream
386 * \param b The blob to fill with a value
387 * \param req Request type (kNullOp, kWriteTo, etc)
388 * \param val The value to use for the filling operation
389 */
390 template <bool is_integer = false, typename ValueType, typename xpu>
Fill(mshadow::Stream<xpu> * s,const TBlob & b,const OpReqType req,ValueType val)391 void Fill(mshadow::Stream<xpu> *s, const TBlob& b, const OpReqType req, ValueType val) {
392 // If b is a zero-size tensor, do nothing.
393 if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
394 CHECK_LT(b.Size(), (int64_t{1} << 31) - 1) <<
395 "[Fill] Size of tensor you are trying to allocate is larger than "
396 "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
397 }
398 if (b.Size() == 0) return;
399 if (req != kNullOp) {
400 const size_t size = b.Size();
401 if (val == 0) {
402 if (req != kAddTo) {
403 if (b.dev_mask() == cpu::kDevMask && size < 50000) {
404 MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, {
405 memset(b.dptr_, 0, size * sizeof(DType));
406 });
407 } else {
408 // Optimize common use-case of filling with ones
409 MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, {
410 MXNET_ASSIGN_REQ_SWITCH(req, Req, {
411 mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_to_int<0>, Req>, xpu>::Launch(
412 s, b.Size(), b.dptr<DType>());
413 });
414 });
415 }
416 }
417 } else if (is_integer && val == 1) {
418 // Optimize common use-case of filling with ones
419 MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, {
420 MXNET_ASSIGN_REQ_SWITCH(req, Req, {
421 mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_one, Req>, xpu>::Launch(
422 s, b.Size(), b.dptr<DType>());
423 });
424 });
425 } else {
426 // Generic fill kernel from variable
427 MSHADOW_TYPE_SWITCH_WITH_BOOL(b.type_flag_, DType, {
428 MXNET_ASSIGN_REQ_SWITCH(req, Req, {
429 mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch(
430 s, b.Size(), b.dptr<DType>(), static_cast<DType>(val));
431 });
432 });
433 }
434 }
435 }
436
437 /*! \brief Fill output with a scalar integer value */
438 template<typename xpu, int value>
FillCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)439 void FillCompute(const nnvm::NodeAttrs& attrs,
440 const OpContext& ctx,
441 const std::vector<TBlob>& inputs,
442 const std::vector<OpReqType>& req,
443 const std::vector<TBlob>& outputs) {
444 Fill<true>(ctx.get_stream<xpu>(), outputs[0], req[0], value);
445 }
446
447 /*! \brief Fill output with a scalar integer value */
448 template<typename xpu>
FullLikeOpCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)449 void FullLikeOpCompute(const nnvm::NodeAttrs& attrs,
450 const OpContext& ctx,
451 const std::vector<TBlob>& inputs,
452 const std::vector<OpReqType>& req,
453 const std::vector<TBlob>& outputs) {
454 CHECK_EQ(inputs.size(), 1U);
455 CHECK_EQ(outputs.size(), 1U);
456 const auto& param = nnvm::get<FullLikeOpParam>(attrs.parsed);
457 Fill<false>(ctx.get_stream<xpu>(), outputs[0], req[0], param.fill_value);
458 }
459
460 /*! \brief Fill output with an arbitrary value */
461 template<typename xpu>
InitFillWithScalarCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)462 void InitFillWithScalarCompute(const nnvm::NodeAttrs &attrs,
463 const OpContext &ctx,
464 const std::vector<TBlob> &inputs,
465 const std::vector<OpReqType> &req,
466 const std::vector<TBlob> &outputs) {
467 CHECK_EQ(inputs.size(), 0);
468 CHECK_EQ(outputs.size(), 1U);
469 const auto& param = nnvm::get<InitOpWithScalarParam>(attrs.parsed);
470 Fill<false>(ctx.get_stream<xpu>(), outputs[0], req[0], param.value);
471 }
472
473 struct PopulateFullIdxRspKernel : public mxnet_op::tunable {
474 template<typename IType>
MapPopulateFullIdxRspKernel475 MSHADOW_XINLINE static void Map(int i, IType* out) {
476 KERNEL_ASSIGN(out[i], kWriteTo, i);
477 }
478 };
479
480 // Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
481 // instead of the usual compact representation.
482 template<typename xpu>
FillDnsZerosRspImpl(mshadow::Stream<xpu> * s,NDArray * dst)483 inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
484 using namespace rowsparse;
485 using namespace mshadow::expr;
486 using namespace mshadow;
487 using namespace mxnet_op;
488 CHECK_EQ(dst->storage_type(), kRowSparseStorage);
489 MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
490 const index_t num_rows = dst->shape()[0];
491 dst->CheckAndAlloc({Shape1(num_rows)});
492 Fill<true>(s, dst->data(), kWriteTo, 0);
493 auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
494 Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, num_rows, idx.dptr_);
495 });
496 }
497
498 /*!
499 * \brief Fill a rsp NDArray with zeros by updating the aux shape.
500 * \tparam xpu - cpu or gpu
501 * \param s - The device stream
502 * \param dst - NDArray which is to be set to "all zeroes"
503 */
504 template<typename xpu>
FillZerosRspImpl(mshadow::Stream<xpu> *,const NDArray & dst)505 void FillZerosRspImpl(mshadow::Stream<xpu> *, const NDArray& dst) {
506 CHECK_EQ(dst.storage_type(), kRowSparseStorage) << "dst should be an RSP NDArray";
507 if (dst.storage_initialized()) {
508 // reset the shapes if it's not zeros (set_aux_shape() will set storage_shape to zero as well)
509 dst.set_aux_shape(rowsparse::kIdx, mxnet::TShape(mshadow::Shape1(0)));
510 }
511 }
512
513 /*!
514 * \brief Fill a CSR NDArray with zeros by updating the aux shape
515 * \param s - The device stream
516 * \param dst - NDArray which is to be set to "all zeroes"
517 */
FillZerosCsrImpl(mshadow::Stream<mshadow::cpu> * s,const NDArray & dst)518 inline void FillZerosCsrImpl(mshadow::Stream<mshadow::cpu> *s, const NDArray& dst) {
519 CHECK_EQ(dst.storage_type(), kCSRStorage) << "dst is not a CSR NDArray";
520 dst.set_aux_shape(csr::kIdx, mshadow::Shape1(0));
521 dst.CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(dst.shape()[0] + 1));
522 TBlob indptr_data = dst.aux_data(csr::kIndPtr);
523 Fill<true>(s, dst.aux_data(csr::kIndPtr), kWriteTo, 0);
524 }
525 void FillZerosCsrImpl(mshadow::Stream<mshadow::gpu> *s, const NDArray& dst);
526
527 /*!
528 * \brief Fill an NDArray with zeros
529 * \tparam xpu - cpu or gpu
530 * \param attrs - node attributes (unused)
531 * \param ctx - Device context
532 * \param inputs - NDArray inputs (unused)
533 * \param req - Request type (i.e. kWrite, kNullOp, etc.)
534 * \param outputs - Array which contains at position zero (0) the array to be set to zeros
535 */
536 template<typename xpu>
FillComputeZerosEx(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)537 void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
538 const OpContext& ctx,
539 const std::vector<NDArray>& inputs,
540 const std::vector<OpReqType>& req,
541 const std::vector<NDArray>& outputs) {
542 using namespace mshadow;
543 using namespace mshadow::expr;
544 Stream<xpu> *s = ctx.get_stream<xpu>();
545 CHECK_EQ(outputs.size(), 1);
546 auto stype = outputs[0].storage_type();
547 // x + 0 == x
548 if (req[0] == kNullOp || req[0] == kAddTo) return;
549 if (stype == kRowSparseStorage) {
550 FillZerosRspImpl(s, outputs[0]);
551 } else if (stype == kCSRStorage) {
552 FillZerosCsrImpl(s, outputs[0]);
553 } else {
554 LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
555 }
556 }
557
558 template<typename xpu>
EyeFillImpl(const TBlob & out_data,const OpContext & ctx,const std::vector<OpReqType> & req,const nnvm::dim_t num_cols,const nnvm::dim_t N,const nnvm::dim_t k)559 inline void EyeFillImpl(const TBlob& out_data,
560 const OpContext& ctx,
561 const std::vector<OpReqType>& req,
562 const nnvm::dim_t num_cols,
563 const nnvm::dim_t N,
564 const nnvm::dim_t k) {
565 using namespace mxnet_op;
566 const nnvm::dim_t cnnz = std::max(num_cols - std::abs(k), (nnvm::dim_t)0);
567 const nnvm::dim_t rnnz = std::max(N - std::abs(k), (nnvm::dim_t)0);
568 const nnvm::dim_t nnz = k > 0 ? std::min(cnnz, N) :
569 std::min(rnnz, num_cols);
570 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
571 MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
572 MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
573 Fill(s, out_data, req[0], static_cast<DType>(0));
574 if (nnz > 0) {
575 Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
576 std::max(static_cast<nnvm::dim_t>(0), k), k, num_cols);
577 }
578 });
579 });
580 }
581
582 template<typename xpu>
EyeFill(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)583 void EyeFill(const nnvm::NodeAttrs& attrs,
584 const OpContext& ctx,
585 const std::vector<TBlob>& inputs,
586 const std::vector<OpReqType>& req,
587 const std::vector<TBlob>& outputs) {
588 CHECK_EQ(inputs.size(), 0U);
589 CHECK_EQ(outputs.size(), 1U);
590 CHECK_EQ(req.size(), 1U);
591 const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
592 const TBlob& out_data = outputs[0];
593 const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;
594 EyeFillImpl<xpu>(out_data, ctx, req, num_cols, param.N, param.k);
595 }
596
597
598 struct range_fwd {
599 template<typename DType>
Maprange_fwd600 MSHADOW_XINLINE static void Map(index_t i, index_t repeat, DType start, DType step,
601 int req, DType* out) {
602 KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step);
603 }
604 };
605
606 template<typename xpu, typename ParamType>
RangeCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)607 void RangeCompute(const nnvm::NodeAttrs& attrs,
608 const OpContext& ctx,
609 const std::vector<TBlob>& inputs,
610 const std::vector<OpReqType>& req,
611 const std::vector<TBlob>& outputs) {
612 using namespace mxnet_op;
613 Stream<xpu> *s = ctx.get_stream<xpu>();
614 const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
615 MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
616 // Force unsigned params to take two's complement form on ARM to ensure consistency with x86
617 // results. Casting negative floats to unsigned types is undefined in the CPP standard.
618 auto step = std::is_signed<DType>() ? param.step : static_cast<index_t>(param.step);
619 auto start = std::is_signed<DType>() ? param.start : static_cast<index_t>(param.start);
620 Kernel<range_fwd, xpu>::Launch(s,
621 outputs[0].Size(),
622 static_cast<int>(param.repeat),
623 static_cast<DType>(start),
624 static_cast<DType>(step),
625 req[0],
626 outputs[0].dptr<DType>());
627 });
628 }
629
630
RangeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)631 inline bool RangeShape(const nnvm::NodeAttrs& attrs,
632 mxnet::ShapeVector *in_attrs,
633 mxnet::ShapeVector *out_attrs) {
634 const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
635 CHECK_EQ(in_attrs->size(), 0U);
636 CHECK_EQ(out_attrs->size(), 1U);
637 CHECK_NE(param.step, 0)
638 << "Range does not support step=0, received " << param.step;
639 CHECK(param.repeat > 0)
640 << "Range only supports repeat > 0, received " << param.repeat;
641 if (param.infer_range && !param.stop.has_value()) {
642 return false;
643 }
644 if (param.step > 0) {
645 CHECK(param.start < param.stop.value())
646 << "Invalid range (start, stop, step) = "
647 << "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
648 } else {
649 CHECK(param.start > param.stop.value())
650 << "Invalid range (start, stop, step)= "
651 << "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
652 }
653 const double out_size = std::ceil((param.stop.value() - param.start) / param.step)
654 * param.repeat;
655 mxnet::TShape output_shape = mxnet::TShape({static_cast<nnvm::dim_t>(out_size)});
656 if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
657 CHECK_LT(output_shape.Size(), (int64_t{1} << 31) - 1) <<
658 "[RangeShape] Size of tensor you are trying to allocate is larger than "
659 "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
660 }
661 SHAPE_ASSIGN_CHECK(*out_attrs, 0, output_shape);
662 return true;
663 }
664
665 struct linspace_fwd {
666 template<typename DType>
Maplinspace_fwd667 MSHADOW_XINLINE static void Map(index_t i, double start, double stop, double step,
668 int req, DType* out) {
669 KERNEL_ASSIGN(out[i], req, static_cast<DType>(start + step * i));
670 }
671 };
672
673 template<typename xpu>
LinspaceCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)674 void LinspaceCompute(const nnvm::NodeAttrs& attrs,
675 const OpContext& ctx,
676 const std::vector<TBlob>& inputs,
677 const std::vector<OpReqType>& req,
678 const std::vector<TBlob>& outputs) {
679 using namespace mxnet_op;
680 Stream<xpu> *s = ctx.get_stream<xpu>();
681 const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
682 MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
683 int step_num = param.endpoint ? param.num - 1 : param.num;
684 double step = step_num > 0 ? (param.stop - param.start) / step_num : 0.0f;
685 Kernel<linspace_fwd, xpu>::Launch(s,
686 outputs[0].Size(),
687 param.start,
688 param.stop,
689 step,
690 req[0],
691 outputs[0].dptr<DType>());
692 });
693 }
694
LinspaceShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)695 inline bool LinspaceShape(const nnvm::NodeAttrs& attrs,
696 mxnet::ShapeVector *in_attrs,
697 mxnet::ShapeVector *out_attrs) {
698 const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
699 CHECK_EQ(in_attrs->size(), 0U);
700 CHECK_EQ(out_attrs->size(), 1U);
701 CHECK_GE(param.num, 0)
702 << "Number of sequence should be non-negative, received " << param.num;
703 mxnet::TShape shape = mxnet::TShape({static_cast<nnvm::dim_t>(param.num)});
704 SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape);
705 return true;
706 }
707
RangeLikeShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)708 inline bool RangeLikeShape(const nnvm::NodeAttrs& attrs,
709 mxnet::ShapeVector *in_attrs,
710 mxnet::ShapeVector *out_attrs) {
711 const RangeLikeParam& param = nnvm::get<RangeLikeParam>(attrs.parsed);
712 CHECK_EQ(in_attrs->size(), 1U);
713 CHECK_EQ(out_attrs->size(), 1U);
714 int real_axis = -1;
715 if (param.axis.has_value()) {
716 real_axis = param.axis.value() < 0 ?
717 (param.axis.value() + (*in_attrs)[0].ndim()) : param.axis.value();
718 CHECK(real_axis >=0 && real_axis < (*in_attrs)[0].ndim())
719 << "cannot handle param.axis " << param.axis.value() << ".";
720 }
721 if (real_axis == -1) {
722 SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
723 } else {
724 const index_t out_size = (*in_attrs)[0][real_axis];
725 SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
726 }
727 return true;
728 }
729
730 } // namespace op
731 } // namespace mxnet
732
733 #endif // MXNET_OPERATOR_TENSOR_INIT_OP_H_
734