1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Pooling op constructions
22  * \file nn/pooling.h
23  */
24 #ifndef TVM_TOPI_NN_POOLING_H_
25 #define TVM_TOPI_NN_POOLING_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/topi/detail/pad_utils.h>
29 #include <tvm/topi/nn.h>
30 #include <tvm/topi/reduction.h>
31 #include <tvm/topi/tags.h>
32 
33 #include <algorithm>
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace topi {
39 namespace nn {
40 
41 using namespace tvm::te;
42 
43 /*! \brief Pooling type */
44 enum PoolType : int {
45   kAvgPool,
46   kMaxPool,
47 };
48 
49 /*!
50  * \brief Perform pooling on height and width dimension of data.
51  *
52  * \param x The input tensor
53  * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
54  * \param stride_size Vector of two ints: {stride_height, stride_width}
55  * \param padding_size Vector of two ints: {padding_height, padding_width}
56  * \param pool_type The type of pooling operator
57  * \param ceil_mode Whether to use ceil when calculating the output size
58  * \param height_axis index of the height dimension
59  * \param width_axis index of the width dimension
60  * \param count_include_pad Whether include padding in the calculation
61  *
62  * \return The output tensor in same layout order
63  */
pool_impl(const Tensor & x,const Array<PrimExpr> & kernel_size,const Array<PrimExpr> & stride_size,const Array<PrimExpr> & padding_size,PoolType pool_type,bool ceil_mode,const size_t height_axis,const size_t width_axis,bool count_include_pad)64 inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
65                         const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
66                         PoolType pool_type, bool ceil_mode, const size_t height_axis,
67                         const size_t width_axis, bool count_include_pad) {
68   CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
69   CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
70   CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
71   CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
72 
73   auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
74   auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
75   auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
76   auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
77 
78   auto height = x->shape[height_axis];
79   auto width = x->shape[width_axis];
80 
81   auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
82   auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
83   auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
84   auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
85 
86   if (ceil_mode) {
87     // Additional padding to ensure we do ceil instead of floor when
88     // dividing by stride.
89     pad_bottom += stride_height - 1;
90     pad_right += stride_width - 1;
91   }
92 
93   Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
94   pad_before.Set(height_axis, pad_top);
95   pad_before.Set(width_axis, pad_left);
96 
97   Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
98   pad_after.Set(height_axis, pad_bottom);
99   pad_after.Set(width_axis, pad_right);
100   arith::Analyzer analyzer;
101   auto out_height =
102       analyzer.Simplify(indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1);
103   auto out_width =
104       analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);
105 
106   auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
107   auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
108 
109   Array<PrimExpr> out_shape = x->shape;
110   out_shape.Set(height_axis, out_height);
111   out_shape.Set(width_axis, out_width);
112 
113   const int64_t* padding_h0 = as_const_int(pad_top);
114   const int64_t* padding_w0 = as_const_int(pad_left);
115   const int64_t* padding_h1 = as_const_int(pad_bottom);
116   const int64_t* padding_w1 = as_const_int(pad_right);
117   const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
118                       ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
119 
120   if (pool_type == kMaxPool) {
121     auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
122     return tvm::te::compute(
123         out_shape,
124         [&](const Array<Var>& output) {
125           Array<PrimExpr> indices;
126           for (const Var& var : output) indices.push_back(var);
127           indices.Set(height_axis, output[height_axis] * stride_height + dheight);
128           indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
129           return tvm::max(temp(indices), {dheight, dwidth});
130         },
131         "tensor", "pool_max");
132   } else if (pool_type == kAvgPool) {
133     // Pad the inputs
134     auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
135 
136     // TVM compute for summing the pooling window.
137     auto pool_sum = tvm::te::compute(
138         out_shape,
139         [&](const Array<Var>& output) {
140           Array<PrimExpr> indices;
141           for (const Var& var : output) indices.push_back(var);
142           indices.Set(height_axis, output[height_axis] * stride_height + dheight);
143           indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
144           return tvm::sum(temp(indices), {dheight, dwidth});
145         },
146         "tensor", "pool_sum");
147 
148     // TVM compute for dividing the reduced window sum by kernel size.
149     return tvm::te::compute(
150         out_shape,
151         [&](const Array<Var>& output) {
152           Array<PrimExpr> indices;
153           for (const Var& var : output) indices.push_back(var);
154           if (count_include_pad) {
155             return div(pool_sum(indices), (kernel_height * kernel_width));
156           } else {
157             PrimExpr h_start = output[height_axis] * stride_height - pad_top;
158             PrimExpr w_start = output[width_axis] * stride_width - pad_left;
159 
160             PrimExpr h_end = min(h_start + kernel_height, height);
161             PrimExpr w_end = min(w_start + kernel_width, width);
162             h_start = max(h_start, make_const(DataType::DataType::Int(32), 0));
163             w_start = max(w_start, make_const(DataType::DataType::Int(32), 0));
164             PrimExpr divide_factor = max((h_end - h_start) * (w_end - w_start),
165                                          make_const(DataType::DataType::Int(32), 1));
166             return div(pool_sum(indices), divide_factor);
167           }
168         },
169         "tensor", kElementWise);
170   } else {
171     LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
172     return x;
173   }
174 }
175 
pool_grad_impl(const Tensor & out_grad,const Tensor & x,const Array<PrimExpr> & kernel_size,const Array<PrimExpr> & stride_size,const Array<PrimExpr> & padding_size,PoolType pool_type,bool ceil_mode,const size_t height_axis,const size_t width_axis,bool count_include_pad)176 inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
177                              const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
178                              const Array<PrimExpr>& padding_size, PoolType pool_type,
179                              bool ceil_mode, const size_t height_axis, const size_t width_axis,
180                              bool count_include_pad) {
181   CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
182   CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
183   CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
184   CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
185   CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
186 
187   auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
188   auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
189   auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
190   auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
191 
192   auto height = x->shape[height_axis];
193   auto width = x->shape[width_axis];
194 
195   auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
196   auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
197   auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
198   auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
199 
200   if (ceil_mode) {
201     // Additional padding to ensure we do ceil instead of floor when
202     // dividing by stride.
203     pad_bottom += stride_height - 1;
204     pad_right += stride_width - 1;
205   }
206 
207   Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
208   pad_before.Set(height_axis, pad_top);
209   pad_before.Set(width_axis, pad_left);
210 
211   Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
212   pad_after.Set(height_axis, pad_bottom);
213   pad_after.Set(width_axis, pad_right);
214   arith::Analyzer analyzer;
215   auto out_height =
216       analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
217   auto out_width =
218       analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
219 
220   auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
221   auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
222 
223   Array<PrimExpr> out_shape = x->shape;
224   out_shape.Set(height_axis, out_height);
225   out_shape.Set(width_axis, out_width);
226 
227   const int64_t* padding_h0 = as_const_int(pad_top);
228   const int64_t* padding_w0 = as_const_int(pad_left);
229   const int64_t* padding_h1 = as_const_int(pad_bottom);
230   const int64_t* padding_w1 = as_const_int(pad_right);
231   const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
232                       ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
233 
234   if (pool_type == kMaxPool) {
235     Array<PrimExpr> ravel_shape{x->shape.begin(), x->shape.end()};
236     ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
237     ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
238 
239     auto windowh =
240         tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
241     auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
242 
243     auto argmax = MakeArgmaxReducer();
244     auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
245 
246     auto mp_argmax = tvm::te::compute(
247         out_shape,
248         [&](const Array<Var>& inds) {
249           Array<PrimExpr> window_inds{inds.begin(), inds.end()};
250           window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
251           window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
252           auto idx = detail::RavelIndex(window_inds, ravel_shape);
253           return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
254         },
255         "maxpool_grad_argmax", kCommReduceIdx);
256 
257     auto mp_inds = mp_argmax[0];
258 
259     return tvm::te::compute(
260         x->shape,
261         [&](const Array<Var>& inds) {
262           Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
263           pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
264           pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
265           auto idx = detail::RavelIndex(pad_inds, ravel_shape);
266 
267           Array<PrimExpr> out_idx{inds.begin(), inds.end()};
268           out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
269           out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
270 
271           PrimExpr out_idx_lower_h = tir::Select(
272               pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
273               (pad_inds[height_axis] - kernel_height) / stride_height + 1);
274           PrimExpr out_idx_lower_w = tir::Select(
275               pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
276               (pad_inds[width_axis] - kernel_width) / stride_width + 1);
277 
278           return tvm::sum(
279               tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
280                                                   out_idx[width_axis] >= out_idx_lower_w),
281                                          mp_inds(out_idx) == idx),
282                                 out_grad(out_idx), make_const(x->dtype, 0)),
283               {windowh, windoww});
284         },
285         "T_pool_grad", "pool_grad_max");
286   } else if (pool_type == kAvgPool) {
287     auto windowh =
288         tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
289     auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
290     return tvm::te::compute(
291         x->shape,
292         [&](const Array<Var>& inds) {
293           PrimExpr pad_h_idx = inds[height_axis] + pad_top;
294           PrimExpr pad_w_idx = inds[width_axis] + pad_left;
295 
296           // output indices whose pooling windows cover current input element (can be out-of-bound)
297           Array<PrimExpr> out_idx{inds.begin(), inds.end()};
298           out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
299           out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
300 
301           PrimExpr out_idx_lower_h =
302               tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
303                           (pad_h_idx - kernel_height) / stride_height + 1);
304           PrimExpr out_idx_lower_w =
305               tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
306                           (pad_w_idx - kernel_width) / stride_width + 1);
307 
308           PrimExpr divide_factor;  // number of pooled elements
309           if (count_include_pad) {
310             divide_factor = kernel_height * kernel_width;
311           } else {
312             PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
313             PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
314 
315             PrimExpr h_end = min(h_start + kernel_height, height);
316             PrimExpr w_end = min(w_start + kernel_width, width);
317             h_start = max(h_start, make_const(DataType::Int(32), 0));
318             w_start = max(w_start, make_const(DataType::Int(32), 0));
319             divide_factor =
320                 max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1));
321           }
322           return tvm::sum(
323               tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
324                                                   out_idx[height_axis] < out_height),
325                                          tir::And(out_idx[width_axis] >= out_idx_lower_w,
326                                                   out_idx[width_axis] < out_width)),
327                                 out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
328               {windowh, windoww});
329         },
330         "T_pool_grad", "pool_grad_avg");
331   } else {
332     LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
333     return Tensor();
334   }
335 }
336 
find_depth_height_width(const std::string & layout,int * depth_axis,int * height_axis,int * width_axis)337 inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis,
338                                     int* width_axis) {
339   *depth_axis = -1;
340   *height_axis = -1;
341   *width_axis = -1;
342   int curr_idx = 0;
343   for (size_t i = 0; i < layout.size(); ++i) {
344     if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) {
345       if (layout[i] == 'D') {
346         if (*depth_axis != -1) return false;
347         *depth_axis = curr_idx;
348       } else if (layout[i] == 'H') {
349         if (*height_axis != -1) return false;
350         *height_axis = curr_idx;
351       } else if (layout[i] == 'W') {
352         if (*width_axis != -1) return false;
353         *width_axis = curr_idx;
354       } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') {
355         // do not support split on height or width, e.g., NCHW16w
356         return false;
357       }
358       ++curr_idx;
359     }
360   }
361   if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1) return false;
362   return true;
363 }
364 
find_height_width(const std::string & layout,int * height_axis,int * width_axis)365 inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) {
366   int dummy;
367   CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false);
368   if (*height_axis != -1 && *width_axis != -1) {
369     return true;
370   }
371   return false;
372 }
373 
find_width(const std::string & layout,int * width_axis)374 inline bool find_width(const std::string& layout, int* width_axis) {
375   int dummy;
376   CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false);
377   if (*width_axis != -1) {
378     return true;
379   }
380   return false;
381 }
382 
383 /*!
384  * \brief Perform pooling on height and width dimension of data.
385  *        It decides the height and width dimension according to the layout string,
386  *        in which 'W' and 'H' means width and height respectively.
387  *        Width and height dimension cannot be split.
388  *        For example, NCHW, NCHW16c, etc. are valid for pool,
389  *        while NCHW16w, NCHW16h are not.
390  *        See \a layout for more information of the layout string convention.
391  * \param x The input tensor.
392  * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
393  * \param stride_size Vector of two ints: {stride_height, stride_width}
394  * \param padding_size Vector of two ints: {padding_height, padding_width}
395  * \param pool_type The type of pooling operator
396  * \param ceil_mode Whether to use ceil when calculating the output size
397  * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
398  *        The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
399  *        where upper case indicates a dimension and
400  *        the corresponding lower case (with factor size) indicates the split dimension.
401  *        For example, NCHW16c can describe a 5-D tensor of
402  *        [batch_size, channel, height, width, channel_block].
403  *        (in which factor size `16` will not be used in pooling but for other operators,
404  *        it can be used to decide the output shape).
405  *        Since pooling does not care about the factor size of dimensions
406  *        other than `H` and `W`, one can pass `NCHWc` as well.
407  * \param  count_include_pad Whether include padding in the calculation when pool_type is 'avg'
408  *
409  *
410  * \return The output tensor in the same layout
411  */
412 inline Tensor pool(const Tensor& x, const Array<PrimExpr>& kernel_size,
413                    const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
414                    PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
415                    bool count_include_pad = true) {
416   int height_axis = -1, width_axis = -1;
417   CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
418   return pool_impl(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, height_axis,
419                    width_axis, count_include_pad);
420 }
421 
422 /*!
423  * \brief Calculate gradient of pooling on height and width dimension of data.
424  *        It decides the height and width dimension according to the layout string,
425  *        in which 'W' and 'H' means width and height respectively.
426  *        Width and height dimension cannot be split.
427  *        For example, NCHW, NCHW16c, etc. are valid for pool,
428  *        while NCHW16w, NCHW16h are not.
429  *        See \a layout for more information of the layout string convention.
430  * \param out_grad The output gradient tensor.
431  * \param x The input tensor.
432  * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
433  * \param stride_size Vector of two ints: {stride_height, stride_width}
434  * \param padding_size Vector of two ints: {padding_height, padding_width}
435  * \param pool_type The type of pooling operator
436  * \param ceil_mode Whether to use ceil when calculating the output size
437  * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
438  *        The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
439  *        where upper case indicates a dimension and
440  *        the corresponding lower case (with factor size) indicates the split dimension.
441  *        For example, NCHW16c can describe a 5-D tensor of
442  *        [batch_size, channel, height, width, channel_block].
443  *        (in which factor size `16` will not be used in pooling but for other operators,
444  *        it can be used to decide the output shape).
445  *        Since pooling does not care about the factor size of dimensions
446  *        other than `H` and `W`, one can pass `NCHWc` as well.
447  * \param  count_include_pad Whether include padding in the calculation when pool_type is 'avg'
448  *
449  *
450  * \return The output tensor in the same layout
451  */
452 inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size,
453                         const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
454                         PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
455                         bool count_include_pad = true) {
456   int height_axis = -1, width_axis = -1;
457   CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
458   return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
459                         height_axis, width_axis, count_include_pad);
460 }
461 
start_index(const Var & out_index,const PrimExpr & odim,const PrimExpr & idim)462 inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
463   return indexdiv(out_index * idim, odim);
464 }
465 
end_index(const Var & out_index,const PrimExpr & odim,const PrimExpr & idim)466 inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
467   PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
468   return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1);
469 }
470 
471 /*!
472  * \brief Perform adaptive pooling on N dimensional data
473  *
474  * \param x The input tensor
475  * \param output_size int vector of size in each dimension
476  * \param pool_type The type of pooling operator
477  * \param axes indices of each dimension
478  *
479  * \return The output tensor in same layout order
480  */
adaptive_pool_impl(const Tensor & x,const Array<PrimExpr> & output_size,PoolType pool_type,const std::vector<int> & axes)481 inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_size,
482                                  PoolType pool_type, const std::vector<int>& axes) {
483   const auto n_dim = output_size.size();
484   CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
485 
486   Array<PrimExpr> out_shape = x->shape;
487   Array<PrimExpr> in_size, out_size;
488   for (size_t i = 0; i < n_dim; ++i) {
489     in_size.push_back(x->shape[axes[i]]);
490     out_size.push_back(cast(DataType::Int(32), output_size[i]));
491     out_shape.Set(axes[i], out_size[i]);
492   }
493 
494   auto get_iter_vars = [=](const Array<Var>& output, bool reduce_indices) {
495     Array<PrimExpr> indices;
496     for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
497     Array<tir::IterVar> reduce_axes;
498     for (size_t i = 0; i < n_dim; ++i) {
499       auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]);
500       auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]);
501       auto rv_name = "rv" + std::to_string(i);
502       auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name);
503       reduce_axes.push_back(rv_axis);
504       if (reduce_indices) {
505         indices.Set(axes[i], i_start + rv_axis);
506       }
507     }
508     return std::make_tuple(indices, reduce_axes);
509   };
510 
511   if (pool_type == kMaxPool) {
512     return tvm::te::compute(
513         out_shape,
514         [&](const Array<Var>& output) {
515           Array<PrimExpr> indices;
516           Array<tir::IterVar> reduce_axes;
517           std::tie(indices, reduce_axes) = get_iter_vars(output, true);
518           return tvm::max(x(indices), reduce_axes);  // NOLINT(*)
519         },
520         "tensor", "adaptive_pool_max");
521   } else if (pool_type == kAvgPool) {
522     auto pool_sum = tvm::te::compute(
523         out_shape,
524         [&](const Array<Var>& output) {
525           Array<PrimExpr> indices;
526           Array<tir::IterVar> reduce_axes;
527           std::tie(indices, reduce_axes) = get_iter_vars(output, true);
528           return tvm::sum(x(indices), reduce_axes);
529         },
530         "tensor", "adaptive_pool_sum");
531 
532     return tvm::te::compute(
533         out_shape,
534         [&](const Array<Var>& output) {
535           Array<PrimExpr> indices;
536           Array<tir::IterVar> reduce_axes;
537           std::tie(indices, reduce_axes) = get_iter_vars(output, false);
538 
539           PrimExpr divide_factor = tvm::cast(x->dtype, 1);
540           for (size_t i = 0; i < n_dim; ++i) {
541             divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
542           }
543 
544           return div(pool_sum(indices), divide_factor);
545         },
546         "tensor", kElementWise);
547   } else {
548     LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
549     return x;
550   }
551 }
552 
553 /*!
554  * \brief Adaptively perform pooling on height and width dimension of data.
555  *        The pooling kernel and stride sizes are automatically chosen for desired output sizes.
556  *        It decides the height and width dimension according to the layout string,
557  *        in which 'W' and 'H' means width and height respectively.
558  *        Width and height dimension cannot be split.
559  *        For example, NCHW, NCHW16c, etc. are valid for pool,
560  *        while NCHW16w, NCHW16h are not.
561  *        See \a layout for more information of the layout string convention.
562  *
563  * \param x The input tensor
564  * \param output_size Vector of two ints: {output_height, output_width}
565  * \param pool_type The type of pooling operator
566  * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
567  *        The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
568  *        where upper case indicates a dimension and
569  *        the corresponding lower case (with factor size) indicates the split dimension.
570  *        For example, NCHW16c can describe a 5-D tensor of
571  *        [batch_size, channel, height, width, channel_block].
572  *        (in which factor size `16` will not be used in pooling but for other operators,
573  *        it can be used to decide the output shape).
574  *        Since pooling does not care about the factor size of dimensions
575  *        other than `H` and `W`, one can pass `NCHWc` as well.
576  *
577  * \return The output tensor in same layout order
578  */
579 inline Tensor adaptive_pool(const Tensor& x, const Array<PrimExpr>& output_size, PoolType pool_type,
580                             const std::string& layout = "NCHW") {
581   int height_axis = -1, width_axis = -1;
582   CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
583   return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
584 }
585 
586 /*!
587  * \brief Adaptively perform pooling on three dimensional data.
588  *        See the two dimensional version above for details.
589  * \param x The input tensor
590  * \param output_size Vector of three ints: {output_depth, output_height, output_width}
591  * \param pool_type The type of pooling operator
592  * \param layout The input layout. The default is "NCDHW".
593  */
594 inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_size,
595                               PoolType pool_type, const std::string& layout = "NCDHW") {
596   int depth_axis = -1, height_axis = -1, width_axis = -1;
597   CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
598       << "Unsupported layout " << layout;
599   return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
600 }
601 
602 /*!
603  * \brief Perform global pooling on height and width dimension of data.
604  *        It decides the height and width dimension according to the layout string,
605  *        in which 'W' and 'H' means width and height respectively.
606  *        Width and height dimension cannot be split.
607  *        For example, NCHW, NCHW16c, ... are valid for global_pool,
608  *        while NCHW16w, NCHW16h are not.
609  *        See \a layout for more information of the layout string convention.
610  *
611  * \param x The input tensor represent as layout
612  * \param pool_type The type of pooling operator
613  * \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear.
614  *        The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
615  *        where upper case indicates a dimension and
616  *        the corresponding lower case (with factor size) indicates the sub-dimension.
617  *        For example, `NCHW16c` can describe a 5-D tensor of
618  *        [batch_size, channel, height, width, channel_block].
619  *        (in which factor size `16` will not be used in pooling but for other operators,
620  *        it can be used to decide the output shape).
621  *        Since pooling does not care about the factor size of
622  *        dimensions other than `H` and `W`, one can pass `NCHWc` as well.
623  *
624  * \return The output tensor in same layout with height and width dimension size of 1.
625  *         e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
626  */
627 inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") {
628   return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
629 }
630 
631 /*!
632  * \brief Perform pooling on N-dimension of data.
633  *
634  * \param x The input tensor
635  * \param kernel_size Vector of N ints
636  * \param stride_size Vector of N ints
637  * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ...,
638  *        head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN]
639  * \param pool_type The type of pooling operator
640  * \param ceil_mode Whether to use ceil when calculating the output size
641  * \param axis Vector of indices for the N dimensions
642  * \param count_include_pad Whether include padding in the calculation
643  *
644  * \return The output tensor in same layout order
645  */
pool_impl_nd(const Tensor & x,const Array<PrimExpr> & kernel_size,const Array<PrimExpr> & stride_size,const Array<PrimExpr> & padding_size,PoolType pool_type,bool ceil_mode,const std::vector<int> & axis,bool count_include_pad)646 inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
647                            const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
648                            PoolType pool_type, bool ceil_mode, const std::vector<int>& axis,
649                            bool count_include_pad) {
650   int k_size = kernel_size.size();
651   int x_size = x->shape.size();
652   CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
653   CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
654                                                " kernel";
655   CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
656 
657   Array<IterVar> daxis;
658   std::vector<PrimExpr> kernel(k_size);
659   std::vector<PrimExpr> stride(k_size);
660   std::vector<PrimExpr> pad_head(k_size);
661   std::vector<PrimExpr> pad_tail(k_size);
662   Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
663   Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
664   Array<PrimExpr> out_shape = x->shape;
665 
666   bool do_pad = false;
667   for (int i = 0; i < k_size; i++) {
668     int ii = axis[i];
669     kernel[i] = cast(DataType::Int(32), kernel_size[i]);
670     stride[i] = cast(DataType::Int(32), stride_size[i]);
671     pad_head[i] = cast(DataType::Int(32), padding_size[i]);
672     pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);
673     const int64_t* padding0 = as_const_int(pad_head[i]);
674     const int64_t* padding1 = as_const_int(pad_tail[i]);
675     do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1));
676 
677     if (ceil_mode) {
678       // Additional padding to ensure we do ceil instead of floor when
679       // dividing by stride.
680       pad_tail[i] += stride[i] - 1;
681     }
682 
683     daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i])));
684 
685     pad_before.Set(ii, pad_head[i]);
686     pad_after.Set(ii, pad_tail[i]);
687 
688     arith::Analyzer analyzer;
689     auto out_dim = analyzer.Simplify(
690         indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
691 
692     out_shape.Set(ii, out_dim);
693   }
694 
695   if (pool_type == kMaxPool) {
696     auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
697     return tvm::te::compute(
698         out_shape,
699         [&](const Array<Var>& output) {
700           Array<PrimExpr> indices;
701           for (const Var& var : output) indices.push_back(var);
702 
703           for (int i = 0; i < k_size; i++) {
704             int ii = axis[i];
705             indices.Set(ii, output[ii] * stride[i] + daxis[i]);
706           }
707 
708           return tvm::max(temp(indices), daxis);
709         },
710         "tensor", "pool_max");
711   } else if (pool_type == kAvgPool) {
712     // Pad the inputs
713     auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
714 
715     // TVM compute for summing the pooling window.
716     auto pool_sum = tvm::te::compute(
717         out_shape,
718         [&](const Array<Var>& output) {
719           Array<PrimExpr> indices;
720           for (const Var& var : output) indices.push_back(var);
721 
722           for (int i = 0; i < k_size; i++) {
723             int ii = axis[i];
724             indices.Set(ii, output[ii] * stride[i] + daxis[i]);
725           }
726           return tvm::sum(temp(indices), daxis);
727         },
728         "tensor", "pool_sum");
729 
730     // TVM compute for dividing the reduced window sum by kernel size.
731     return tvm::te::compute(
732         out_shape,
733         [&](const Array<Var>& output) {
734           Array<PrimExpr> indices;
735           for (const Var& var : output) indices.push_back(var);
736           if (count_include_pad) {
737             auto kernel_size = make_const(DataType::Int(32), 1);
738             for (int i = 0; i < k_size; i++) {
739               kernel_size *= kernel[i];
740             }
741             return div(pool_sum(indices), kernel_size);
742           } else {
743             std::vector<PrimExpr> start(k_size);
744             std::vector<PrimExpr> end(k_size);
745             auto kernel_size = make_const(DataType::Int(32), 1);
746             for (int i = 0; i < k_size; i++) {
747               int ii = axis[i];
748               start[i] = output[ii] * stride[i] - pad_head[i];
749               end[i] = min(start[i] + kernel[i], x->shape[ii]);
750               start[i] = max(start[i], make_const(DataType::Int(32), 0));
751               kernel_size *= (end[i] - start[i]);
752             }
753 
754             PrimExpr divide_factor = max(kernel_size, make_const(DataType::Int(32), 1));
755             return div(pool_sum(indices), divide_factor);
756           }
757         },
758         "tensor", kElementWise);
759   } else {
760     LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
761     return x;
762   }
763 }
764 
765 /*!
766  * \brief Perform pooling on the width dimension of data.
767  *        Width axis is determined by the layout string
768  *        in which 'W' means width.
769  *        Width dimension cannot be split.
770  *        For example, NCW, NCW16c, etc. are valid for pool,
771  *        while NCW16w is not.
772  *        See \a layout for more information of the layout string convention.
773  * \param x The input tensor.
774  * \param kernel_size Vector of three ints: {kernel_width}
775  * \param stride_size Vector of three ints: {stride_width}
776  * \param padding_size Vector of six ints: {head_pad_width, tail_pad_width}
777  * \param pool_type The type of pooling operator
778  * \param ceil_mode Whether to use ceil when calculating the output size
779  * \param layout The input layout. Pooling supports any layout as long as 'W' appears.
780  *        The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
781  *        where upper case indicates a dimension and
782  *        the corresponding lower case (with factor size) indicates the split dimension.
783  *        For example, NCW16c can describe a 4-D tensor of
784  *        [batch_size, channel, width, channel_block].
785  *        (in which factor size `16` will not be used in pooling but for other operators,
786  *        it can be used to decide the output shape).
787  *        Since pooling does not care about the factor size of dimensions
788  *        other than `W`, one can pass `NCWc` as well.
789  * \param  count_include_pad Whether include padding in the calculation when pool_type is 'avg'
790  *
791  *
792  * \return The output tensor in the same layout
793  */
794 inline Tensor pool1d(const Tensor& x, const Array<PrimExpr>& kernel_size,
795                      const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
796                      PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW",
797                      bool count_include_pad = true) {
798   int width_axis = -1;
799   CHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
800   std::vector<int> axis = {width_axis};
801   return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis,
802                       count_include_pad);
803 }
804 
805 /*!
806  * \brief Perform pooling on depth, height and width dimension of data.
807  *        It decides the depth, height and width dimension according to the layout string,
808  *        in which 'D', 'W' and 'H' means depth, width and height respectively.
809  *        Depth, Width and height dimension cannot be split.
810  *        For example, NCDHW, NCDHW16c, etc. are valid for pool,
811  *        while NCDHW16d, NCDHW16w or NCDHW16h are not.
812  *        See \a layout for more information of the layout string convention.
813  * \param x The input tensor.
814  * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width}
815  * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width}
816  * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width,
817  *        tail_pad_depth, tail_pad_height, tail_pad_width}
818  * \param pool_type The type of pooling operator
819  * \param ceil_mode Whether to use ceil when calculating the output size
820  * \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear.
821  *        The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
822  *        where upper case indicates a dimension and
823  *        the corresponding lower case (with factor size) indicates the split dimension.
824  *        For example, NCDHW16c can describe a 6-D tensor of
825  *        [batch_size, channel, depth, height, width, channel_block].
826  *        (in which factor size `16` will not be used in pooling but for other operators,
827  *        it can be used to decide the output shape).
828  *        Since pooling does not care about the factor size of dimensions
829  *        other than `D`, `H` and `W`, one can pass `NCDHWc` as well.
830  * \param  count_include_pad Whether include padding in the calculation when pool_type is 'avg'
831  *
832  *
833  * \return The output tensor in the same layout
834  */
835 inline Tensor pool3d(const Tensor& x, const Array<PrimExpr>& kernel_size,
836                      const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
837                      PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW",
838                      bool count_include_pad = true) {
839   int depth_axis = -1, height_axis = -1, width_axis = -1;
840   CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
841       << "Unsupported layout " << layout;
842   std::vector<int> axis = {depth_axis, height_axis, width_axis};
843   return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis,
844                       count_include_pad);
845 }
846 
847 }  // namespace nn
848 }  // namespace topi
849 }  // namespace tvm
850 #endif  // TVM_TOPI_NN_POOLING_H_
851