1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*!
20 * \file boolean_mask.cc
21 */
22
23 #include "./boolean_mask-inl.h"
24
25 namespace mxnet {
26 namespace op {
27
28 DMLC_REGISTER_PARAMETER(BooleanMaskParam);
29
BooleanMaskType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)30 bool BooleanMaskType(const nnvm::NodeAttrs& attrs,
31 std::vector<int> *in_attrs,
32 std::vector<int> *out_attrs) {
33 CHECK_EQ(in_attrs->size(), 2);
34 CHECK_EQ(out_attrs->size(), 1);
35 TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
36 TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
37 return in_attrs->at(0) != -1 && in_attrs->at(1) != -1 && out_attrs->at(0) != -1;
38 }
39
BooleanMaskStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)40 bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs,
41 const int dev_mask,
42 DispatchMode* dispatch_mode,
43 std::vector<int> *in_attrs,
44 std::vector<int> *out_attrs) {
45 CHECK_EQ(in_attrs->size(), 2);
46 CHECK_EQ(out_attrs->size(), 1);
47 for (int &attr : *in_attrs) {
48 CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
49 }
50 for (int &attr : *out_attrs) {
51 attr = kDefaultStorage;
52 }
53 *dispatch_mode = DispatchMode::kFComputeEx;
54 return true;
55 }
56
BooleanMaskBackStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)57 bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs,
58 const int dev_mask,
59 DispatchMode* dispatch_mode,
60 std::vector<int> *in_attrs,
61 std::vector<int> *out_attrs) {
62 CHECK_EQ(in_attrs->size(), 3);
63 CHECK_EQ(out_attrs->size(), 2);
64 for (int &attr : *in_attrs) {
65 CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
66 }
67 for (int &attr : *out_attrs) {
68 attr = kDefaultStorage;
69 }
70 for (size_t i = 0; i < out_attrs->size(); i++)
71 out_attrs->at(i) = kDefaultStorage;
72 *dispatch_mode = DispatchMode::kFComputeEx;
73 return true;
74 }
75
76 struct BooleanMaskBackwardCPUWriteKernel {
77 template<typename DType>
Mapmxnet::op::BooleanMaskBackwardCPUWriteKernel78 static void Map(int i,
79 DType* igrad,
80 const OpReqType /*req*/,
81 const DType* ograd,
82 const int32_t* idx,
83 const size_t col_size) {
84 // i is row id already
85 int32_t prev = (i == 0) ? 0 : idx[i - 1];
86 int32_t curr = idx[i];
87 if (prev != curr) {
88 std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * sizeof(DType));
89 } else {
90 std::memset(igrad + i * col_size, 0, col_size * sizeof(DType));
91 }
92 }
93 };
94
95 template<>
BooleanMaskForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)96 inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
97 const OpContext &ctx,
98 const std::vector<NDArray> &inputs,
99 const std::vector<OpReqType> &req,
100 const std::vector<NDArray> &outputs) {
101 CHECK_EQ(inputs.size(), 2U);
102 CHECK_EQ(outputs.size(), 1U);
103 CHECK(req[0] == kWriteTo || req[0] == kWriteInplace);
104 const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
105 const int axis = param.axis;
106 const NDArray &data = inputs[0];
107 const NDArray &idx = inputs[1];
108 const NDArray &out = outputs[0];
109 CHECK_EQ(axis, 0) << "Not supported yet";
110 CHECK_EQ(data.shape()[axis], idx.shape()[0]);
111 CHECK_EQ(idx.shape().ndim(), 1U); // idx is required to be 1-d.
112 // count the number of 1s in `idx`, so that we could know the output dimension
113 size_t idx_size = idx.shape()[0];
114 std::vector<int32_t> prefix_sum(idx_size, 0);
115 size_t valid_num = 0;
116 // Calculate prefix sum
117 MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), DType, {
118 DType* idx_dptr = idx.data().dptr<DType>();
119 for (size_t i = 0; i < idx_size; i++) {
120 prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
121 prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
122 }
123 valid_num = prefix_sum[idx_size - 1];
124 });
125 // set the output shape forcefully
126 mxnet::TShape s = data.shape();
127 s[axis] = valid_num;
128
129 const_cast<NDArray &>(out).Init(s);
130 // do the copy
131 MSHADOW_TYPE_SWITCH_WITH_BOOL(data.dtype(), DType, {
132 size_t input_size = data.shape().Size();
133 size_t col_size = input_size / idx_size;
134 mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
135 mxnet_op::Kernel<BooleanMaskForwardCPUKernel, cpu>::Launch(
136 stream, idx_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
137 prefix_sum.data(), col_size);
138 });
139 }
140
141 template<>
BooleanMaskBackward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)142 inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
143 const OpContext &ctx,
144 const std::vector<NDArray> &inputs,
145 const std::vector<OpReqType> &req,
146 const std::vector<NDArray> &outputs) {
147 CHECK_EQ(inputs.size(), 3U);
148 CHECK_EQ(outputs.size(), 2U);
149 if (req[0] == kNullOp) return;
150 // inputs: {ograd, data, idx}
151 // outputs: {igrad_data, igrad_idx}
152 const NDArray& ograd = inputs[0];
153 const NDArray& idx = inputs[2];
154 const NDArray& igrad_data = outputs[0];
155 MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
156 MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, {
157 size_t input_size = igrad_data.shape().Size();
158 size_t idx_size = idx.shape()[0];
159 size_t col_size = input_size / idx_size;
160 std::vector<int32_t> prefix_sum(idx_size, 0);
161 IType* idx_dptr = idx.data().dptr<IType>();
162 for (size_t i = 0; i < idx_size; i++) {
163 prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
164 prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
165 }
166 mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
167 if (req[0] == kAddTo) {
168 mxnet_op::Kernel<BooleanMaskBackwardKernel, cpu>::Launch(
169 stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
170 ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
171 } else {
172 mxnet_op::Kernel<BooleanMaskBackwardCPUWriteKernel, cpu>::Launch(
173 stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
174 ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
175 }
176 });
177 });
178 }
179
180 NNVM_REGISTER_OP(_contrib_boolean_mask)
181 .add_alias("_npi_boolean_mask")
182 .describe(R"code(
183 Given an n-d NDArray data, and a 1-d NDArray index,
184 the operator produces an un-predeterminable shaped n-d NDArray out,
185 which stands for the rows in x where the corresonding element in index is non-zero.
186
187 >>> data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
188 >>> index = mx.nd.array([0, 1, 0])
189 >>> out = mx.nd.contrib.boolean_mask(data, index)
190 >>> out
191
192 [[4. 5. 6.]]
193 <NDArray 1x3 @cpu(0)>
194
195 )code" ADD_FILELINE)
196 .set_attr_parser(ParamParser<BooleanMaskParam>)
197 .set_num_inputs(2)
198 .set_num_outputs(1)
199 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anon79815e990102(const NodeAttrs& attrs) 200 [](const NodeAttrs& attrs) {
201 return std::vector<std::string>{"data", "index"};
202 })
203 .set_attr<nnvm::FInferType>("FInferType", BooleanMaskType)
204 .set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>)
205 .set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType)
206 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"})
207 .add_argument("data", "NDArray-or-Symbol", "Data")
208 .add_argument("index", "NDArray-or-Symbol", "Mask")
209 .add_arguments(BooleanMaskParam::__FIELDS__());
210
211 NNVM_REGISTER_OP(_backward_contrib_boolean_mask)
212 .set_num_inputs(3)
213 .set_num_outputs(2)
214 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
215 .set_attr<FInferStorageType>("FInferStorageType", BooleanMaskBackStorageType)
216 .set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskBackward<cpu>)
217 .add_arguments(BooleanMaskParam::__FIELDS__());
218
219 } // namespace op
220 } // namespace mxnet
221