1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file histogram-inl.h
22 * \brief Function definition of histogram operator
23 */
24 #ifndef MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
25 #define MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
26
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <mxnet/operator.h>
30 #include <mxnet/operator_util.h>
31 #include <dmlc/optional.h>
32 #include <mshadow/tensor.h>
33 #include <nnvm/op.h>
34 #include <nnvm/node.h>
35 #include <nnvm/op_attr_types.h>
36 #include <vector>
37 #include <type_traits>
38 #include "./util/tensor_util-inl.h"
39 #include "../elemwise_op_common.h"
40 #include "../mshadow_op.h"
41 #include "../mxnet_op.h"
42 #include "../operator_common.h"
43
44 namespace mxnet {
45 namespace op {
46
47 struct HistogramParam : public dmlc::Parameter<HistogramParam> {
48 dmlc::optional<int> bin_cnt;
49 dmlc::optional<mxnet::Tuple<double>> range;
DMLC_DECLARE_PARAMETERHistogramParam50 DMLC_DECLARE_PARAMETER(HistogramParam) {
51 DMLC_DECLARE_FIELD(bin_cnt)
52 .set_default(dmlc::optional<int>())
53 .describe("Number of bins for uniform case");
54 DMLC_DECLARE_FIELD(range)
55 .set_default(dmlc::optional<mxnet::Tuple<double>>())
56 .describe("The lower and upper range of the bins. if not provided, "
57 "range is simply (a.min(), a.max()). values outside the "
58 "range are ignored. the first element of the range must be "
59 "less than or equal to the second. range affects the automatic "
60 "bin computation as well. while bin width is computed to be "
61 "optimal based on the actual data within range, the bin count "
62 "will fill the entire range including portions containing no data.");
63 }
64 };
65
66 struct FillBinBoundsKernel {
67 template<typename DType>
MapFillBinBoundsKernel68 static MSHADOW_XINLINE void Map(int i, DType* bin_bounds, int bin_cnt, double min, double max) {
69 if (i <= bin_cnt) {
70 bin_bounds[i] = DType((max * i + (bin_cnt - i) * min) / bin_cnt);
71 }
72 }
73 };
74
HistogramOpShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)75 inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs,
76 mxnet::ShapeVector* in_attrs,
77 mxnet::ShapeVector* out_attrs) {
78 HistogramParam param = nnvm::get<HistogramParam>(attrs.parsed);
79 const bool has_cnt = param.bin_cnt.has_value();
80 const bool has_range = param.range.has_value();
81 const bool legal_param = (has_cnt && has_range) || (!has_cnt && !has_range);
82 CHECK_EQ(in_attrs->size(), has_cnt ? 1U : 2U);
83 CHECK_EQ(out_attrs->size(), 2U);
84 CHECK(legal_param) << "cnt and range should both or neither specified";
85
86 if (has_cnt) {
87 // if cnt is specified, the output histogram has shape (cnt,)
88 // while output bins has shape (cnt+1,)
89 const dim_t bin_cnt = param.bin_cnt.value();
90 SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, bin_cnt));
91 SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, bin_cnt + 1));
92 } else {
93 // if cnt is not specified, the output histogram has shape (bins.Size() - 1)
94 // while output bins has same shape as input bins
95 mxnet::TShape oshape = (*in_attrs)[1];
96
97 CHECK_EQ(oshape.ndim(), 1U) << "bins argument should be an 1D vector";
98 CHECK_GE(oshape.Size(), 2U) << "number of bounds should be >= 2";
99
100 SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, oshape[0] - 1));
101 SHAPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(1));
102 }
103
104 return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1)) &&
105 out_attrs->at(0).Size() == out_attrs->at(1).Size() - 1;
106 }
107
HistogramOpType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)108 inline bool HistogramOpType(const nnvm::NodeAttrs& attrs,
109 std::vector<int>* in_attrs,
110 std::vector<int>* out_attrs) {
111 CHECK_EQ(out_attrs->size(), 2U);
112
113 TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
114 TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));
115 return !type_is_none(out_attrs->at(0)) && !type_is_none(out_attrs->at(1));
116 }
117
118 template<typename xpu>
119 void HistogramForwardImpl(const OpContext& ctx,
120 const TBlob& in_data,
121 const TBlob& bin_bounds,
122 const TBlob& out_data,
123 const TBlob& out_bins);
124
125 template<typename xpu>
126 void HistogramForwardImpl(const OpContext& ctx,
127 const TBlob& in_data,
128 const TBlob& out_data,
129 const TBlob& out_bins,
130 const int bin_cnt,
131 const double min,
132 const double max);
133
134 template<typename xpu>
HistogramOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)135 void HistogramOpForward(const nnvm::NodeAttrs& attrs,
136 const OpContext& ctx,
137 const std::vector<TBlob>& inputs,
138 const std::vector<OpReqType>& req,
139 const std::vector<TBlob>& outputs) {
140 CHECK_EQ(req.size(), 2U);
141 CHECK_EQ(req[0], kWriteTo);
142 CHECK_EQ(req[1], kWriteTo);
143 const HistogramParam& param = nnvm::get<HistogramParam>(attrs.parsed);
144 const bool has_cnt = param.bin_cnt.has_value();
145 const bool has_range = param.range.has_value();
146 const bool legal_params = (has_cnt && has_range) || (!has_cnt && !has_range);
147 CHECK(legal_params) << "width and range should both or neither be specified";
148
149 const TBlob& in_data = inputs[0];
150 const TBlob& out_data = outputs[0];
151 const TBlob& out_bins = outputs[1];
152
153 if (has_cnt) {
154 CHECK((param.range.value().ndim() == 2U)) << "range should be a tuple with only 2 elements";
155 CHECK(param.range.value()[0] <= param.range.value()[1])
156 << "left hand side of range(" << param.range.value()[0]
157 << ")should be less than or equal to right hand side(" << param.range.value()[1] << ")";
158 double max = param.range.value()[1];
159 double min = param.range.value()[0];
160 const int bin_cnt = param.bin_cnt.value();
161 if (min == max) {
162 min -= 0.5f;
163 max += 0.5f;
164 LOG(INFO) << min << " " << max;
165 }
166 HistogramForwardImpl<xpu>(ctx, in_data, out_data, out_bins, bin_cnt, min, max);
167 } else {
168 const TBlob& bin_bounds = inputs[1];
169 HistogramForwardImpl<xpu>(ctx, in_data, bin_bounds, out_data, out_bins);
170 }
171 }
172
173 } // namespace op
174 } // namespace mxnet
175
176 #endif // MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
177