1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file np_bicount_op.cc
22 * \brief numpy compatible bincount operator CPU registration
23 */
24
25 #include "./np_bincount_op-inl.h"
26
27 namespace mxnet {
28 namespace op {
29
BinNumberCount(const NDArray & data,const int & minlength,const NDArray & out,const size_t & N)30 void BinNumberCount(const NDArray& data, const int& minlength,
31 const NDArray& out, const size_t& N) {
32 int bin = minlength;
33 MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
34 DType* data_ptr = data.data().dptr<DType>();
35 for (size_t i = 0; i < N; i++) {
36 CHECK_GE(data_ptr[i], 0) << "input should be nonnegative number";
37 if (data_ptr[i] + 1 > bin) {
38 bin = data_ptr[i] + 1;
39 }
40 }
41 }); // bin number = max(max(data) + 1, minlength)
42 mxnet::TShape s(1, bin);
43 const_cast<NDArray &>(out).Init(s); // set the output shape forcefully
44 }
45
46 template<typename DType, typename OType>
BincountCpuWeights(const DType * data,const OType * weights,OType * out,const size_t & data_n)47 void BincountCpuWeights(const DType* data, const OType* weights,
48 OType* out, const size_t& data_n) {
49 for (size_t i = 0; i < data_n; i++) {
50 int target = data[i];
51 out[target] += weights[i];
52 }
53 }
54
55 template<typename DType, typename OType>
BincountCpu(const DType * data,OType * out,const size_t & data_n)56 void BincountCpu(const DType* data, OType* out, const size_t& data_n) {
57 for (size_t i = 0; i < data_n; i++) {
58 int target = data[i];
59 out[target] += 1;
60 }
61 }
62
63 template<>
NumpyBincountForwardImpl(const OpContext & ctx,const NDArray & data,const NDArray & weights,const NDArray & out,const size_t & data_n,const int & minlength)64 void NumpyBincountForwardImpl<cpu>(const OpContext &ctx,
65 const NDArray &data,
66 const NDArray &weights,
67 const NDArray &out,
68 const size_t &data_n,
69 const int &minlength) {
70 using namespace mxnet_op;
71 BinNumberCount(data, minlength, out, data_n);
72 mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
73 MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
74 MSHADOW_TYPE_SWITCH(weights.dtype(), OType, {
75 size_t out_size = out.shape()[0];
76 Kernel<set_zero, cpu>::Launch(s, out_size, out.data().dptr<OType>());
77 BincountCpuWeights(data.data().dptr<DType>(), weights.data().dptr<OType>(),
78 out.data().dptr<OType>(), data_n);
79 });
80 });
81 }
82
83 template<>
NumpyBincountForwardImpl(const OpContext & ctx,const NDArray & data,const NDArray & out,const size_t & data_n,const int & minlength)84 void NumpyBincountForwardImpl<cpu>(const OpContext &ctx,
85 const NDArray &data,
86 const NDArray &out,
87 const size_t &data_n,
88 const int &minlength) {
89 using namespace mxnet_op;
90 BinNumberCount(data, minlength, out, data_n);
91 mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
92 MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
93 MSHADOW_TYPE_SWITCH(out.dtype(), OType, {
94 size_t out_size = out.shape()[0];
95 Kernel<set_zero, cpu>::Launch(s, out_size, out.data().dptr<OType>());
96 BincountCpu(data.data().dptr<DType>(), out.data().dptr<OType>(), data_n);
97 });
98 });
99 }
100
101 DMLC_REGISTER_PARAMETER(NumpyBincountParam);
102
103 NNVM_REGISTER_OP(_npi_bincount)
104 .set_attr_parser(ParamParser<NumpyBincountParam>)
__anone35d09bd0102(const NodeAttrs& attrs) 105 .set_num_inputs([](const NodeAttrs& attrs) {
106 const NumpyBincountParam& params =
107 nnvm::get<NumpyBincountParam>(attrs.parsed);
108 return params.has_weights? 2 : 1;
109 })
110 .set_num_outputs(1)
111 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anone35d09bd0202(const NodeAttrs& attrs) 112 [](const NodeAttrs& attrs) {
113 const NumpyBincountParam& params =
114 nnvm::get<NumpyBincountParam>(attrs.parsed);
115 return params.has_weights ?
116 std::vector<std::string>{"data", "weights"} :
117 std::vector<std::string>{"data"};
118 })
119 .set_attr<FResourceRequest>("FResourceRequest",
__anone35d09bd0302(const NodeAttrs& attrs) 120 [](const NodeAttrs& attrs) {
121 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
122 })
123 .set_attr<nnvm::FInferType>("FInferType", NumpyBincountType)
124 .set_attr<FInferStorageType>("FInferStorageType", NumpyBincountStorageType)
125 .set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBincountForward<cpu>)
126 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
127 .add_argument("data", "NDArray-or-Symbol", "Data")
128 .add_argument("weights", "NDArray-or-Symbol", "Weights")
129 .add_arguments(NumpyBincountParam::__FIELDS__());
130
131 } // namespace op
132 } // namespace mxnet
133