1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*!
20  * \file boolean_mask.cc
21 */
22 
23 #include "./boolean_mask-inl.h"
24 
25 namespace mxnet {
26 namespace op {
27 
28 DMLC_REGISTER_PARAMETER(BooleanMaskParam);
29 
BooleanMaskType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)30 bool BooleanMaskType(const nnvm::NodeAttrs& attrs,
31                      std::vector<int> *in_attrs,
32                      std::vector<int> *out_attrs) {
33   CHECK_EQ(in_attrs->size(), 2);
34   CHECK_EQ(out_attrs->size(), 1);
35   TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
36   TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
37   return in_attrs->at(0) != -1 && in_attrs->at(1) != -1 && out_attrs->at(0) != -1;
38 }
39 
BooleanMaskStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)40 bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs,
41                             const int dev_mask,
42                             DispatchMode* dispatch_mode,
43                             std::vector<int> *in_attrs,
44                             std::vector<int> *out_attrs) {
45   CHECK_EQ(in_attrs->size(), 2);
46   CHECK_EQ(out_attrs->size(), 1);
47   for (int &attr : *in_attrs) {
48     CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
49   }
50   for (int &attr : *out_attrs) {
51     attr = kDefaultStorage;
52   }
53   *dispatch_mode = DispatchMode::kFComputeEx;
54   return true;
55 }
56 
BooleanMaskBackStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)57 bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs,
58                                 const int dev_mask,
59                                 DispatchMode* dispatch_mode,
60                                 std::vector<int> *in_attrs,
61                                 std::vector<int> *out_attrs) {
62   CHECK_EQ(in_attrs->size(), 3);
63   CHECK_EQ(out_attrs->size(), 2);
64   for (int &attr : *in_attrs) {
65     CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
66   }
67   for (int &attr : *out_attrs) {
68     attr = kDefaultStorage;
69   }
70   for (size_t i = 0; i < out_attrs->size(); i++)
71     out_attrs->at(i) = kDefaultStorage;
72   *dispatch_mode = DispatchMode::kFComputeEx;
73   return true;
74 }
75 
76 struct BooleanMaskBackwardCPUWriteKernel {
77   template<typename DType>
Mapmxnet::op::BooleanMaskBackwardCPUWriteKernel78   static void Map(int i,
79                   DType* igrad,
80                   const OpReqType /*req*/,
81                   const DType* ograd,
82                   const int32_t* idx,
83                   const size_t col_size) {
84     // i is row id already
85     int32_t prev = (i == 0) ? 0 : idx[i - 1];
86     int32_t curr = idx[i];
87     if (prev != curr) {
88       std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * sizeof(DType));
89     } else {
90       std::memset(igrad + i * col_size, 0, col_size * sizeof(DType));
91     }
92   }
93 };
94 
95 template<>
BooleanMaskForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)96 inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
97                                     const OpContext &ctx,
98                                     const std::vector<NDArray> &inputs,
99                                     const std::vector<OpReqType> &req,
100                                     const std::vector<NDArray> &outputs) {
101   CHECK_EQ(inputs.size(), 2U);
102   CHECK_EQ(outputs.size(), 1U);
103   CHECK(req[0] == kWriteTo || req[0] == kWriteInplace);
104   const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
105   const int axis = param.axis;
106   const NDArray &data = inputs[0];
107   const NDArray &idx = inputs[1];
108   const NDArray &out = outputs[0];
109   CHECK_EQ(axis, 0) << "Not supported yet";
110   CHECK_EQ(data.shape()[axis], idx.shape()[0]);
111   CHECK_EQ(idx.shape().ndim(), 1U);  // idx is required to be 1-d.
112   // count the number of 1s in `idx`, so that we could know the output dimension
113   size_t idx_size = idx.shape()[0];
114   std::vector<int32_t> prefix_sum(idx_size, 0);
115   size_t valid_num = 0;
116   // Calculate prefix sum
117   MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), DType, {
118     DType* idx_dptr = idx.data().dptr<DType>();
119     for (size_t i = 0; i < idx_size; i++) {
120       prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
121       prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
122     }
123     valid_num = prefix_sum[idx_size - 1];
124   });
125   // set the output shape forcefully
126   mxnet::TShape s = data.shape();
127   s[axis] = valid_num;
128 
129   const_cast<NDArray &>(out).Init(s);
130   // do the copy
131   MSHADOW_TYPE_SWITCH_WITH_BOOL(data.dtype(), DType, {
132     size_t input_size = data.shape().Size();
133     size_t col_size = input_size / idx_size;
134     mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
135     mxnet_op::Kernel<BooleanMaskForwardCPUKernel, cpu>::Launch(
136       stream, idx_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
137       prefix_sum.data(), col_size);
138   });
139 }
140 
141 template<>
BooleanMaskBackward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)142 inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
143                                      const OpContext &ctx,
144                                      const std::vector<NDArray> &inputs,
145                                      const std::vector<OpReqType> &req,
146                                      const std::vector<NDArray> &outputs) {
147   CHECK_EQ(inputs.size(), 3U);
148   CHECK_EQ(outputs.size(), 2U);
149   if (req[0] == kNullOp) return;
150   // inputs: {ograd, data, idx}
151   // outputs: {igrad_data, igrad_idx}
152   const NDArray& ograd = inputs[0];
153   const NDArray& idx = inputs[2];
154   const NDArray& igrad_data = outputs[0];
155   MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
156     MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, {
157       size_t input_size = igrad_data.shape().Size();
158       size_t idx_size = idx.shape()[0];
159       size_t col_size = input_size / idx_size;
160       std::vector<int32_t> prefix_sum(idx_size, 0);
161       IType* idx_dptr = idx.data().dptr<IType>();
162       for (size_t i = 0; i < idx_size; i++) {
163         prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
164         prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
165       }
166       mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
167       if (req[0] == kAddTo) {
168         mxnet_op::Kernel<BooleanMaskBackwardKernel, cpu>::Launch(
169           stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
170           ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
171       } else {
172         mxnet_op::Kernel<BooleanMaskBackwardCPUWriteKernel, cpu>::Launch(
173           stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
174           ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
175       }
176     });
177   });
178 }
179 
180 NNVM_REGISTER_OP(_contrib_boolean_mask)
181 .add_alias("_npi_boolean_mask")
182 .describe(R"code(
183 Given an n-d NDArray data, and a 1-d NDArray index,
184 the operator produces an un-predeterminable shaped n-d NDArray out,
185 which stands for the rows in x where the corresonding element in index is non-zero.
186 
187 >>> data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
188 >>> index = mx.nd.array([0, 1, 0])
189 >>> out = mx.nd.contrib.boolean_mask(data, index)
190 >>> out
191 
192 [[4. 5. 6.]]
193 <NDArray 1x3 @cpu(0)>
194 
195 )code" ADD_FILELINE)
196 .set_attr_parser(ParamParser<BooleanMaskParam>)
197 .set_num_inputs(2)
198 .set_num_outputs(1)
199 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anon79815e990102(const NodeAttrs& attrs) 200   [](const NodeAttrs& attrs) {
201     return std::vector<std::string>{"data", "index"};
202   })
203 .set_attr<nnvm::FInferType>("FInferType", BooleanMaskType)
204 .set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>)
205 .set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType)
206 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"})
207 .add_argument("data", "NDArray-or-Symbol", "Data")
208 .add_argument("index", "NDArray-or-Symbol", "Mask")
209 .add_arguments(BooleanMaskParam::__FIELDS__());
210 
211 NNVM_REGISTER_OP(_backward_contrib_boolean_mask)
212 .set_num_inputs(3)
213 .set_num_outputs(2)
214 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
215 .set_attr<FInferStorageType>("FInferStorageType", BooleanMaskBackStorageType)
216 .set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskBackward<cpu>)
217 .add_arguments(BooleanMaskParam::__FIELDS__());
218 
219 }  // namespace op
220 }  // namespace mxnet
221