1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*!
20  * \file np_nonzero_op.cc
21 */
22 #include "np_nonzero_op-inl.h"
23 
24 namespace mxnet {
25 namespace op {
26 
NonzeroType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)27 bool NonzeroType(const nnvm::NodeAttrs& attrs,
28                  std::vector<int> *in_attrs,
29                  std::vector<int> *out_attrs) {
30   CHECK_EQ(in_attrs->size(), 1);
31   CHECK_EQ(out_attrs->size(), 1);
32   // Output must be int64.
33   TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
34   return out_attrs->at(0) != -1;
35 }
36 
37 #define MAXDIM 5
38 
NonzeroStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)39 bool NonzeroStorageType(const nnvm::NodeAttrs& attrs,
40                         const int dev_mask,
41                         DispatchMode* dispatch_mode,
42                         std::vector<int> *in_attrs,
43                         std::vector<int> *out_attrs) {
44   CHECK_EQ(in_attrs->size(), 1);
45   CHECK_EQ(out_attrs->size(), 1);
46   for (int &attr : *in_attrs) {
47     CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
48   }
49   for (int &attr : *out_attrs) {
50     attr = kDefaultStorage;
51   }
52   *dispatch_mode = DispatchMode::kFComputeEx;
53   return true;
54 }
55 
NonzeroForwardCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)56 void NonzeroForwardCPU(const nnvm::NodeAttrs& attrs,
57                        const OpContext &ctx,
58                        const std::vector<NDArray> &inputs,
59                        const std::vector<OpReqType> &req,
60                        const std::vector<NDArray> &outputs) {
61   CHECK_EQ(inputs.size(), 1U);
62   CHECK_EQ(outputs.size(), 1U);
63   const NDArray &in = inputs[0];
64   const NDArray &out = outputs[0];
65   CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
66   // 0-dim
67   if (0 == in.shape().ndim()) {
68     MSHADOW_TYPE_SWITCH_WITH_BOOL(in.dtype(), DType, {
69       DType* in_dptr = in.data().dptr<DType>();
70       if (*in_dptr) {
71         mxnet::TShape s(2, 1);
72         const_cast<NDArray &>(out).Init(s);
73         *(out.data().dptr<int64_t>()) = 0;
74       } else {
75         mxnet::TShape s(2, 1);
76         s[0] = 0;
77         const_cast<NDArray &>(out).Init(s);
78       }
79     });
80     return;
81   }
82   size_t in_size = in.shape().Size();
83   // 0-shape
84   if (0 == in_size) {
85     mxnet::TShape s(2, in.shape().ndim());
86     s[0] = 0;
87     const_cast<NDArray &>(out).Init(s);
88     return;
89   }
90   std::vector<int32_t> prefix_sum(in_size, 0);
91   size_t valid_num = 0;
92   // Calculate prefix sum
93   MSHADOW_TYPE_SWITCH_WITH_BOOL(in.dtype(), DType, {
94     DType* in_dptr = in.data().dptr<DType>();
95     for (size_t i = 0; i < in_size; i++) {
96       prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
97       prefix_sum[i] += (in_dptr[i]) ? 1 : 0;
98     }
99   });
100   valid_num = prefix_sum[in_size - 1];
101   // set the output shape forcefully
102   mxnet::TShape s(2, in.shape().ndim());
103   s[0] = valid_num;
104   const_cast<NDArray &>(out).Init(s);
105   // get the shape from the input
106   MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
107     mshadow::Shape<ndim> shape = in.shape().get<ndim>();
108     mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
109     mxnet_op::Kernel<NonzeroForwardKernel, cpu>::Launch(
110       stream, in_size, out.data().dptr<int64_t>(), prefix_sum.data(), shape);
111   })
112 }
113 
114 NNVM_REGISTER_OP(_npx_nonzero)
115 .add_alias("_npi_nonzero")
116 .set_num_inputs(1)
117 .set_num_outputs(1)
118 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonb17d171c0102(const NodeAttrs& attrs) 119   [](const NodeAttrs& attrs) {
120     return std::vector<std::string>{"x"};
121   })
122 .set_attr<nnvm::FInferType>("FInferType", NonzeroType)
123 .set_attr<FComputeEx>("FComputeEx<cpu>", NonzeroForwardCPU)
124 .set_attr<FInferStorageType>("FInferStorageType", NonzeroStorageType)
125 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
126 .add_argument("x", "NDArray-or-Symbol", "The input array.");
127 
128 }  // namespace op
129 }  // namespace mxnet
130