1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*!
20 * \file np_nonzero_op.cc
21 */
22 #include "np_nonzero_op-inl.h"
23
24 namespace mxnet {
25 namespace op {
26
NonzeroType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)27 bool NonzeroType(const nnvm::NodeAttrs& attrs,
28 std::vector<int> *in_attrs,
29 std::vector<int> *out_attrs) {
30 CHECK_EQ(in_attrs->size(), 1);
31 CHECK_EQ(out_attrs->size(), 1);
32 // Output must be int64.
33 TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
34 return out_attrs->at(0) != -1;
35 }
36
37 #define MAXDIM 5
38
NonzeroStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)39 bool NonzeroStorageType(const nnvm::NodeAttrs& attrs,
40 const int dev_mask,
41 DispatchMode* dispatch_mode,
42 std::vector<int> *in_attrs,
43 std::vector<int> *out_attrs) {
44 CHECK_EQ(in_attrs->size(), 1);
45 CHECK_EQ(out_attrs->size(), 1);
46 for (int &attr : *in_attrs) {
47 CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
48 }
49 for (int &attr : *out_attrs) {
50 attr = kDefaultStorage;
51 }
52 *dispatch_mode = DispatchMode::kFComputeEx;
53 return true;
54 }
55
NonzeroForwardCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)56 void NonzeroForwardCPU(const nnvm::NodeAttrs& attrs,
57 const OpContext &ctx,
58 const std::vector<NDArray> &inputs,
59 const std::vector<OpReqType> &req,
60 const std::vector<NDArray> &outputs) {
61 CHECK_EQ(inputs.size(), 1U);
62 CHECK_EQ(outputs.size(), 1U);
63 const NDArray &in = inputs[0];
64 const NDArray &out = outputs[0];
65 CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
66 // 0-dim
67 if (0 == in.shape().ndim()) {
68 MSHADOW_TYPE_SWITCH_WITH_BOOL(in.dtype(), DType, {
69 DType* in_dptr = in.data().dptr<DType>();
70 if (*in_dptr) {
71 mxnet::TShape s(2, 1);
72 const_cast<NDArray &>(out).Init(s);
73 *(out.data().dptr<int64_t>()) = 0;
74 } else {
75 mxnet::TShape s(2, 1);
76 s[0] = 0;
77 const_cast<NDArray &>(out).Init(s);
78 }
79 });
80 return;
81 }
82 size_t in_size = in.shape().Size();
83 // 0-shape
84 if (0 == in_size) {
85 mxnet::TShape s(2, in.shape().ndim());
86 s[0] = 0;
87 const_cast<NDArray &>(out).Init(s);
88 return;
89 }
90 std::vector<int32_t> prefix_sum(in_size, 0);
91 size_t valid_num = 0;
92 // Calculate prefix sum
93 MSHADOW_TYPE_SWITCH_WITH_BOOL(in.dtype(), DType, {
94 DType* in_dptr = in.data().dptr<DType>();
95 for (size_t i = 0; i < in_size; i++) {
96 prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
97 prefix_sum[i] += (in_dptr[i]) ? 1 : 0;
98 }
99 });
100 valid_num = prefix_sum[in_size - 1];
101 // set the output shape forcefully
102 mxnet::TShape s(2, in.shape().ndim());
103 s[0] = valid_num;
104 const_cast<NDArray &>(out).Init(s);
105 // get the shape from the input
106 MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
107 mshadow::Shape<ndim> shape = in.shape().get<ndim>();
108 mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
109 mxnet_op::Kernel<NonzeroForwardKernel, cpu>::Launch(
110 stream, in_size, out.data().dptr<int64_t>(), prefix_sum.data(), shape);
111 })
112 }
113
114 NNVM_REGISTER_OP(_npx_nonzero)
115 .add_alias("_npi_nonzero")
116 .set_num_inputs(1)
117 .set_num_outputs(1)
118 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonb17d171c0102(const NodeAttrs& attrs) 119 [](const NodeAttrs& attrs) {
120 return std::vector<std::string>{"x"};
121 })
122 .set_attr<nnvm::FInferType>("FInferType", NonzeroType)
123 .set_attr<FComputeEx>("FComputeEx<cpu>", NonzeroForwardCPU)
124 .set_attr<FInferStorageType>("FInferStorageType", NonzeroStorageType)
125 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
126 .add_argument("x", "NDArray-or-Symbol", "The input array.");
127
128 } // namespace op
129 } // namespace mxnet
130