1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file histogram-inl.h
22  * \brief Function definition of histogram operator
23 */
24 #ifndef MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
25 #define MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
26 
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <mxnet/operator.h>
30 #include <mxnet/operator_util.h>
31 #include <dmlc/optional.h>
32 #include <mshadow/tensor.h>
33 #include <nnvm/op.h>
34 #include <nnvm/node.h>
35 #include <nnvm/op_attr_types.h>
36 #include <vector>
37 #include <type_traits>
38 #include "./util/tensor_util-inl.h"
39 #include "../elemwise_op_common.h"
40 #include "../mshadow_op.h"
41 #include "../mxnet_op.h"
42 #include "../operator_common.h"
43 
44 namespace mxnet {
45 namespace op {
46 
47 struct HistogramParam : public dmlc::Parameter<HistogramParam> {
48     dmlc::optional<int> bin_cnt;
49     dmlc::optional<mxnet::Tuple<double>> range;
DMLC_DECLARE_PARAMETERHistogramParam50     DMLC_DECLARE_PARAMETER(HistogramParam) {
51       DMLC_DECLARE_FIELD(bin_cnt)
52         .set_default(dmlc::optional<int>())
53         .describe("Number of bins for uniform case");
54       DMLC_DECLARE_FIELD(range)
55         .set_default(dmlc::optional<mxnet::Tuple<double>>())
56         .describe("The lower and upper range of the bins. if not provided, "
57                   "range is simply (a.min(), a.max()). values outside the "
58                   "range are ignored. the first element of the range must be "
59                   "less than or equal to the second. range affects the automatic "
60                   "bin computation as well. while bin width is computed to be "
61                   "optimal based on the actual data within range, the bin count "
62                   "will fill the entire range including portions containing no data.");
63     }
64 };
65 
66 struct FillBinBoundsKernel {
67   template<typename DType>
MapFillBinBoundsKernel68   static MSHADOW_XINLINE void Map(int i, DType* bin_bounds, int bin_cnt, double min, double max) {
69     if (i <= bin_cnt) {
70       bin_bounds[i] = DType((max * i + (bin_cnt - i) * min) / bin_cnt);
71     }
72   }
73 };
74 
HistogramOpShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)75 inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs,
76                              mxnet::ShapeVector* in_attrs,
77                              mxnet::ShapeVector* out_attrs) {
78   HistogramParam param = nnvm::get<HistogramParam>(attrs.parsed);
79   const bool has_cnt = param.bin_cnt.has_value();
80   const bool has_range = param.range.has_value();
81   const bool legal_param = (has_cnt && has_range) || (!has_cnt && !has_range);
82   CHECK_EQ(in_attrs->size(), has_cnt ? 1U : 2U);
83   CHECK_EQ(out_attrs->size(), 2U);
84   CHECK(legal_param) << "cnt and range should both or neither specified";
85 
86   if (has_cnt) {
87     // if cnt is specified, the output histogram has shape (cnt,)
88     // while output bins has shape (cnt+1,)
89     const dim_t bin_cnt = param.bin_cnt.value();
90     SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, bin_cnt));
91     SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, bin_cnt + 1));
92   } else {
93     // if cnt is not specified, the output histogram has shape (bins.Size() - 1)
94     // while output bins has same shape as input bins
95     mxnet::TShape oshape = (*in_attrs)[1];
96 
97     CHECK_EQ(oshape.ndim(), 1U) << "bins argument should be an 1D vector";
98     CHECK_GE(oshape.Size(), 2U) << "number of bounds should be >= 2";
99 
100     SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, oshape[0] - 1));
101     SHAPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(1));
102   }
103 
104   return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1)) &&
105          out_attrs->at(0).Size() == out_attrs->at(1).Size() - 1;
106 }
107 
HistogramOpType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)108 inline bool HistogramOpType(const nnvm::NodeAttrs& attrs,
109                             std::vector<int>* in_attrs,
110                             std::vector<int>* out_attrs) {
111   CHECK_EQ(out_attrs->size(), 2U);
112 
113   TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
114   TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));
115   return !type_is_none(out_attrs->at(0)) && !type_is_none(out_attrs->at(1));
116 }
117 
118 template<typename xpu>
119 void HistogramForwardImpl(const OpContext& ctx,
120                           const TBlob& in_data,
121                           const TBlob& bin_bounds,
122                           const TBlob& out_data,
123                           const TBlob& out_bins);
124 
125 template<typename xpu>
126 void HistogramForwardImpl(const OpContext& ctx,
127                           const TBlob& in_data,
128                           const TBlob& out_data,
129                           const TBlob& out_bins,
130                           const int bin_cnt,
131                           const double min,
132                           const double max);
133 
134 template<typename xpu>
HistogramOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)135 void HistogramOpForward(const nnvm::NodeAttrs& attrs,
136                         const OpContext& ctx,
137                         const std::vector<TBlob>& inputs,
138                         const std::vector<OpReqType>& req,
139                         const std::vector<TBlob>& outputs) {
140   CHECK_EQ(req.size(), 2U);
141   CHECK_EQ(req[0], kWriteTo);
142   CHECK_EQ(req[1], kWriteTo);
143   const HistogramParam& param = nnvm::get<HistogramParam>(attrs.parsed);
144   const bool has_cnt = param.bin_cnt.has_value();
145   const bool has_range = param.range.has_value();
146   const bool legal_params = (has_cnt && has_range) || (!has_cnt && !has_range);
147   CHECK(legal_params) << "width and range should both or neither be specified";
148 
149   const TBlob& in_data = inputs[0];
150   const TBlob& out_data = outputs[0];
151   const TBlob& out_bins = outputs[1];
152 
153   if (has_cnt) {
154     CHECK((param.range.value().ndim() == 2U)) << "range should be a tuple with only 2 elements";
155     CHECK(param.range.value()[0] <= param.range.value()[1])
156       << "left hand side of range(" << param.range.value()[0]
157       << ")should be less than or equal to right hand side(" << param.range.value()[1] << ")";
158     double max = param.range.value()[1];
159     double min = param.range.value()[0];
160     const int bin_cnt = param.bin_cnt.value();
161     if (min == max) {
162       min -= 0.5f;
163       max += 0.5f;
164       LOG(INFO) << min << " " << max;
165     }
166     HistogramForwardImpl<xpu>(ctx, in_data, out_data, out_bins, bin_cnt, min, max);
167   } else {
168     const TBlob& bin_bounds = inputs[1];
169     HistogramForwardImpl<xpu>(ctx, in_data, bin_bounds, out_data, out_bins);
170   }
171 }
172 
173 }   // namespace op
174 }   // namespace mxnet
175 
176 #endif  // MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
177