1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \brief Pooling op constructions
22 * \file nn/pooling.h
23 */
24 #ifndef TVM_TOPI_NN_POOLING_H_
25 #define TVM_TOPI_NN_POOLING_H_
26
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/topi/detail/pad_utils.h>
29 #include <tvm/topi/nn.h>
30 #include <tvm/topi/reduction.h>
31 #include <tvm/topi/tags.h>
32
33 #include <algorithm>
34 #include <string>
35 #include <vector>
36
37 namespace tvm {
38 namespace topi {
39 namespace nn {
40
41 using namespace tvm::te;
42
43 /*! \brief Pooling type */
44 enum PoolType : int {
45 kAvgPool,
46 kMaxPool,
47 };
48
49 /*!
50 * \brief Perform pooling on height and width dimension of data.
51 *
52 * \param x The input tensor
53 * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
54 * \param stride_size Vector of two ints: {stride_height, stride_width}
55 * \param padding_size Vector of two ints: {padding_height, padding_width}
56 * \param pool_type The type of pooling operator
57 * \param ceil_mode Whether to use ceil when calculating the output size
58 * \param height_axis index of the height dimension
59 * \param width_axis index of the width dimension
60 * \param count_include_pad Whether include padding in the calculation
61 *
62 * \return The output tensor in same layout order
63 */
pool_impl(const Tensor & x,const Array<PrimExpr> & kernel_size,const Array<PrimExpr> & stride_size,const Array<PrimExpr> & padding_size,PoolType pool_type,bool ceil_mode,const size_t height_axis,const size_t width_axis,bool count_include_pad)64 inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
65 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
66 PoolType pool_type, bool ceil_mode, const size_t height_axis,
67 const size_t width_axis, bool count_include_pad) {
68 CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
69 CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
70 CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
71 CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
72
73 auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
74 auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
75 auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
76 auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
77
78 auto height = x->shape[height_axis];
79 auto width = x->shape[width_axis];
80
81 auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
82 auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
83 auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
84 auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
85
86 if (ceil_mode) {
87 // Additional padding to ensure we do ceil instead of floor when
88 // dividing by stride.
89 pad_bottom += stride_height - 1;
90 pad_right += stride_width - 1;
91 }
92
93 Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
94 pad_before.Set(height_axis, pad_top);
95 pad_before.Set(width_axis, pad_left);
96
97 Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
98 pad_after.Set(height_axis, pad_bottom);
99 pad_after.Set(width_axis, pad_right);
100 arith::Analyzer analyzer;
101 auto out_height =
102 analyzer.Simplify(indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1);
103 auto out_width =
104 analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);
105
106 auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
107 auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
108
109 Array<PrimExpr> out_shape = x->shape;
110 out_shape.Set(height_axis, out_height);
111 out_shape.Set(width_axis, out_width);
112
113 const int64_t* padding_h0 = as_const_int(pad_top);
114 const int64_t* padding_w0 = as_const_int(pad_left);
115 const int64_t* padding_h1 = as_const_int(pad_bottom);
116 const int64_t* padding_w1 = as_const_int(pad_right);
117 const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
118 ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
119
120 if (pool_type == kMaxPool) {
121 auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
122 return tvm::te::compute(
123 out_shape,
124 [&](const Array<Var>& output) {
125 Array<PrimExpr> indices;
126 for (const Var& var : output) indices.push_back(var);
127 indices.Set(height_axis, output[height_axis] * stride_height + dheight);
128 indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
129 return tvm::max(temp(indices), {dheight, dwidth});
130 },
131 "tensor", "pool_max");
132 } else if (pool_type == kAvgPool) {
133 // Pad the inputs
134 auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
135
136 // TVM compute for summing the pooling window.
137 auto pool_sum = tvm::te::compute(
138 out_shape,
139 [&](const Array<Var>& output) {
140 Array<PrimExpr> indices;
141 for (const Var& var : output) indices.push_back(var);
142 indices.Set(height_axis, output[height_axis] * stride_height + dheight);
143 indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
144 return tvm::sum(temp(indices), {dheight, dwidth});
145 },
146 "tensor", "pool_sum");
147
148 // TVM compute for dividing the reduced window sum by kernel size.
149 return tvm::te::compute(
150 out_shape,
151 [&](const Array<Var>& output) {
152 Array<PrimExpr> indices;
153 for (const Var& var : output) indices.push_back(var);
154 if (count_include_pad) {
155 return div(pool_sum(indices), (kernel_height * kernel_width));
156 } else {
157 PrimExpr h_start = output[height_axis] * stride_height - pad_top;
158 PrimExpr w_start = output[width_axis] * stride_width - pad_left;
159
160 PrimExpr h_end = min(h_start + kernel_height, height);
161 PrimExpr w_end = min(w_start + kernel_width, width);
162 h_start = max(h_start, make_const(DataType::DataType::Int(32), 0));
163 w_start = max(w_start, make_const(DataType::DataType::Int(32), 0));
164 PrimExpr divide_factor = max((h_end - h_start) * (w_end - w_start),
165 make_const(DataType::DataType::Int(32), 1));
166 return div(pool_sum(indices), divide_factor);
167 }
168 },
169 "tensor", kElementWise);
170 } else {
171 LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
172 return x;
173 }
174 }
175
pool_grad_impl(const Tensor & out_grad,const Tensor & x,const Array<PrimExpr> & kernel_size,const Array<PrimExpr> & stride_size,const Array<PrimExpr> & padding_size,PoolType pool_type,bool ceil_mode,const size_t height_axis,const size_t width_axis,bool count_include_pad)176 inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
177 const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
178 const Array<PrimExpr>& padding_size, PoolType pool_type,
179 bool ceil_mode, const size_t height_axis, const size_t width_axis,
180 bool count_include_pad) {
181 CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
182 CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
183 CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
184 CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
185 CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
186
187 auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
188 auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
189 auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
190 auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
191
192 auto height = x->shape[height_axis];
193 auto width = x->shape[width_axis];
194
195 auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
196 auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
197 auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
198 auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
199
200 if (ceil_mode) {
201 // Additional padding to ensure we do ceil instead of floor when
202 // dividing by stride.
203 pad_bottom += stride_height - 1;
204 pad_right += stride_width - 1;
205 }
206
207 Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
208 pad_before.Set(height_axis, pad_top);
209 pad_before.Set(width_axis, pad_left);
210
211 Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
212 pad_after.Set(height_axis, pad_bottom);
213 pad_after.Set(width_axis, pad_right);
214 arith::Analyzer analyzer;
215 auto out_height =
216 analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
217 auto out_width =
218 analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
219
220 auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
221 auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
222
223 Array<PrimExpr> out_shape = x->shape;
224 out_shape.Set(height_axis, out_height);
225 out_shape.Set(width_axis, out_width);
226
227 const int64_t* padding_h0 = as_const_int(pad_top);
228 const int64_t* padding_w0 = as_const_int(pad_left);
229 const int64_t* padding_h1 = as_const_int(pad_bottom);
230 const int64_t* padding_w1 = as_const_int(pad_right);
231 const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
232 ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
233
234 if (pool_type == kMaxPool) {
235 Array<PrimExpr> ravel_shape{x->shape.begin(), x->shape.end()};
236 ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
237 ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
238
239 auto windowh =
240 tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
241 auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
242
243 auto argmax = MakeArgmaxReducer();
244 auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
245
246 auto mp_argmax = tvm::te::compute(
247 out_shape,
248 [&](const Array<Var>& inds) {
249 Array<PrimExpr> window_inds{inds.begin(), inds.end()};
250 window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
251 window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
252 auto idx = detail::RavelIndex(window_inds, ravel_shape);
253 return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
254 },
255 "maxpool_grad_argmax", kCommReduceIdx);
256
257 auto mp_inds = mp_argmax[0];
258
259 return tvm::te::compute(
260 x->shape,
261 [&](const Array<Var>& inds) {
262 Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
263 pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
264 pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
265 auto idx = detail::RavelIndex(pad_inds, ravel_shape);
266
267 Array<PrimExpr> out_idx{inds.begin(), inds.end()};
268 out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
269 out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
270
271 PrimExpr out_idx_lower_h = tir::Select(
272 pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
273 (pad_inds[height_axis] - kernel_height) / stride_height + 1);
274 PrimExpr out_idx_lower_w = tir::Select(
275 pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
276 (pad_inds[width_axis] - kernel_width) / stride_width + 1);
277
278 return tvm::sum(
279 tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
280 out_idx[width_axis] >= out_idx_lower_w),
281 mp_inds(out_idx) == idx),
282 out_grad(out_idx), make_const(x->dtype, 0)),
283 {windowh, windoww});
284 },
285 "T_pool_grad", "pool_grad_max");
286 } else if (pool_type == kAvgPool) {
287 auto windowh =
288 tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
289 auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
290 return tvm::te::compute(
291 x->shape,
292 [&](const Array<Var>& inds) {
293 PrimExpr pad_h_idx = inds[height_axis] + pad_top;
294 PrimExpr pad_w_idx = inds[width_axis] + pad_left;
295
296 // output indices whose pooling windows cover current input element (can be out-of-bound)
297 Array<PrimExpr> out_idx{inds.begin(), inds.end()};
298 out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
299 out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
300
301 PrimExpr out_idx_lower_h =
302 tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
303 (pad_h_idx - kernel_height) / stride_height + 1);
304 PrimExpr out_idx_lower_w =
305 tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
306 (pad_w_idx - kernel_width) / stride_width + 1);
307
308 PrimExpr divide_factor; // number of pooled elements
309 if (count_include_pad) {
310 divide_factor = kernel_height * kernel_width;
311 } else {
312 PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
313 PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
314
315 PrimExpr h_end = min(h_start + kernel_height, height);
316 PrimExpr w_end = min(w_start + kernel_width, width);
317 h_start = max(h_start, make_const(DataType::Int(32), 0));
318 w_start = max(w_start, make_const(DataType::Int(32), 0));
319 divide_factor =
320 max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1));
321 }
322 return tvm::sum(
323 tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
324 out_idx[height_axis] < out_height),
325 tir::And(out_idx[width_axis] >= out_idx_lower_w,
326 out_idx[width_axis] < out_width)),
327 out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
328 {windowh, windoww});
329 },
330 "T_pool_grad", "pool_grad_avg");
331 } else {
332 LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
333 return Tensor();
334 }
335 }
336
find_depth_height_width(const std::string & layout,int * depth_axis,int * height_axis,int * width_axis)337 inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis,
338 int* width_axis) {
339 *depth_axis = -1;
340 *height_axis = -1;
341 *width_axis = -1;
342 int curr_idx = 0;
343 for (size_t i = 0; i < layout.size(); ++i) {
344 if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) {
345 if (layout[i] == 'D') {
346 if (*depth_axis != -1) return false;
347 *depth_axis = curr_idx;
348 } else if (layout[i] == 'H') {
349 if (*height_axis != -1) return false;
350 *height_axis = curr_idx;
351 } else if (layout[i] == 'W') {
352 if (*width_axis != -1) return false;
353 *width_axis = curr_idx;
354 } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') {
355 // do not support split on height or width, e.g., NCHW16w
356 return false;
357 }
358 ++curr_idx;
359 }
360 }
361 if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1) return false;
362 return true;
363 }
364
find_height_width(const std::string & layout,int * height_axis,int * width_axis)365 inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) {
366 int dummy;
367 CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false);
368 if (*height_axis != -1 && *width_axis != -1) {
369 return true;
370 }
371 return false;
372 }
373
find_width(const std::string & layout,int * width_axis)374 inline bool find_width(const std::string& layout, int* width_axis) {
375 int dummy;
376 CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false);
377 if (*width_axis != -1) {
378 return true;
379 }
380 return false;
381 }
382
383 /*!
384 * \brief Perform pooling on height and width dimension of data.
385 * It decides the height and width dimension according to the layout string,
386 * in which 'W' and 'H' means width and height respectively.
387 * Width and height dimension cannot be split.
388 * For example, NCHW, NCHW16c, etc. are valid for pool,
389 * while NCHW16w, NCHW16h are not.
390 * See \a layout for more information of the layout string convention.
391 * \param x The input tensor.
392 * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
393 * \param stride_size Vector of two ints: {stride_height, stride_width}
394 * \param padding_size Vector of two ints: {padding_height, padding_width}
395 * \param pool_type The type of pooling operator
396 * \param ceil_mode Whether to use ceil when calculating the output size
397 * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
398 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
399 * where upper case indicates a dimension and
400 * the corresponding lower case (with factor size) indicates the split dimension.
401 * For example, NCHW16c can describe a 5-D tensor of
402 * [batch_size, channel, height, width, channel_block].
403 * (in which factor size `16` will not be used in pooling but for other operators,
404 * it can be used to decide the output shape).
405 * Since pooling does not care about the factor size of dimensions
406 * other than `H` and `W`, one can pass `NCHWc` as well.
407 * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
408 *
409 *
410 * \return The output tensor in the same layout
411 */
412 inline Tensor pool(const Tensor& x, const Array<PrimExpr>& kernel_size,
413 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
414 PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
415 bool count_include_pad = true) {
416 int height_axis = -1, width_axis = -1;
417 CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
418 return pool_impl(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, height_axis,
419 width_axis, count_include_pad);
420 }
421
422 /*!
423 * \brief Calculate gradient of pooling on height and width dimension of data.
424 * It decides the height and width dimension according to the layout string,
425 * in which 'W' and 'H' means width and height respectively.
426 * Width and height dimension cannot be split.
427 * For example, NCHW, NCHW16c, etc. are valid for pool,
428 * while NCHW16w, NCHW16h are not.
429 * See \a layout for more information of the layout string convention.
430 * \param out_grad The output gradient tensor.
431 * \param x The input tensor.
432 * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
433 * \param stride_size Vector of two ints: {stride_height, stride_width}
434 * \param padding_size Vector of two ints: {padding_height, padding_width}
435 * \param pool_type The type of pooling operator
436 * \param ceil_mode Whether to use ceil when calculating the output size
437 * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
438 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
439 * where upper case indicates a dimension and
440 * the corresponding lower case (with factor size) indicates the split dimension.
441 * For example, NCHW16c can describe a 5-D tensor of
442 * [batch_size, channel, height, width, channel_block].
443 * (in which factor size `16` will not be used in pooling but for other operators,
444 * it can be used to decide the output shape).
445 * Since pooling does not care about the factor size of dimensions
446 * other than `H` and `W`, one can pass `NCHWc` as well.
447 * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
448 *
449 *
450 * \return The output tensor in the same layout
451 */
452 inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size,
453 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
454 PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
455 bool count_include_pad = true) {
456 int height_axis = -1, width_axis = -1;
457 CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
458 return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
459 height_axis, width_axis, count_include_pad);
460 }
461
start_index(const Var & out_index,const PrimExpr & odim,const PrimExpr & idim)462 inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
463 return indexdiv(out_index * idim, odim);
464 }
465
end_index(const Var & out_index,const PrimExpr & odim,const PrimExpr & idim)466 inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
467 PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
468 return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1);
469 }
470
471 /*!
472 * \brief Perform adaptive pooling on N dimensional data
473 *
474 * \param x The input tensor
475 * \param output_size int vector of size in each dimension
476 * \param pool_type The type of pooling operator
477 * \param axes indices of each dimension
478 *
479 * \return The output tensor in same layout order
480 */
adaptive_pool_impl(const Tensor & x,const Array<PrimExpr> & output_size,PoolType pool_type,const std::vector<int> & axes)481 inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_size,
482 PoolType pool_type, const std::vector<int>& axes) {
483 const auto n_dim = output_size.size();
484 CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
485
486 Array<PrimExpr> out_shape = x->shape;
487 Array<PrimExpr> in_size, out_size;
488 for (size_t i = 0; i < n_dim; ++i) {
489 in_size.push_back(x->shape[axes[i]]);
490 out_size.push_back(cast(DataType::Int(32), output_size[i]));
491 out_shape.Set(axes[i], out_size[i]);
492 }
493
494 auto get_iter_vars = [=](const Array<Var>& output, bool reduce_indices) {
495 Array<PrimExpr> indices;
496 for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
497 Array<tir::IterVar> reduce_axes;
498 for (size_t i = 0; i < n_dim; ++i) {
499 auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]);
500 auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]);
501 auto rv_name = "rv" + std::to_string(i);
502 auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name);
503 reduce_axes.push_back(rv_axis);
504 if (reduce_indices) {
505 indices.Set(axes[i], i_start + rv_axis);
506 }
507 }
508 return std::make_tuple(indices, reduce_axes);
509 };
510
511 if (pool_type == kMaxPool) {
512 return tvm::te::compute(
513 out_shape,
514 [&](const Array<Var>& output) {
515 Array<PrimExpr> indices;
516 Array<tir::IterVar> reduce_axes;
517 std::tie(indices, reduce_axes) = get_iter_vars(output, true);
518 return tvm::max(x(indices), reduce_axes); // NOLINT(*)
519 },
520 "tensor", "adaptive_pool_max");
521 } else if (pool_type == kAvgPool) {
522 auto pool_sum = tvm::te::compute(
523 out_shape,
524 [&](const Array<Var>& output) {
525 Array<PrimExpr> indices;
526 Array<tir::IterVar> reduce_axes;
527 std::tie(indices, reduce_axes) = get_iter_vars(output, true);
528 return tvm::sum(x(indices), reduce_axes);
529 },
530 "tensor", "adaptive_pool_sum");
531
532 return tvm::te::compute(
533 out_shape,
534 [&](const Array<Var>& output) {
535 Array<PrimExpr> indices;
536 Array<tir::IterVar> reduce_axes;
537 std::tie(indices, reduce_axes) = get_iter_vars(output, false);
538
539 PrimExpr divide_factor = tvm::cast(x->dtype, 1);
540 for (size_t i = 0; i < n_dim; ++i) {
541 divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
542 }
543
544 return div(pool_sum(indices), divide_factor);
545 },
546 "tensor", kElementWise);
547 } else {
548 LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
549 return x;
550 }
551 }
552
553 /*!
554 * \brief Adaptively perform pooling on height and width dimension of data.
555 * The pooling kernel and stride sizes are automatically chosen for desired output sizes.
556 * It decides the height and width dimension according to the layout string,
557 * in which 'W' and 'H' means width and height respectively.
558 * Width and height dimension cannot be split.
559 * For example, NCHW, NCHW16c, etc. are valid for pool,
560 * while NCHW16w, NCHW16h are not.
561 * See \a layout for more information of the layout string convention.
562 *
563 * \param x The input tensor
564 * \param output_size Vector of two ints: {output_height, output_width}
565 * \param pool_type The type of pooling operator
566 * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
567 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
568 * where upper case indicates a dimension and
569 * the corresponding lower case (with factor size) indicates the split dimension.
570 * For example, NCHW16c can describe a 5-D tensor of
571 * [batch_size, channel, height, width, channel_block].
572 * (in which factor size `16` will not be used in pooling but for other operators,
573 * it can be used to decide the output shape).
574 * Since pooling does not care about the factor size of dimensions
575 * other than `H` and `W`, one can pass `NCHWc` as well.
576 *
577 * \return The output tensor in same layout order
578 */
579 inline Tensor adaptive_pool(const Tensor& x, const Array<PrimExpr>& output_size, PoolType pool_type,
580 const std::string& layout = "NCHW") {
581 int height_axis = -1, width_axis = -1;
582 CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
583 return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
584 }
585
586 /*!
587 * \brief Adaptively perform pooling on three dimensional data.
588 * See the two dimensional version above for details.
589 * \param x The input tensor
590 * \param output_size Vector of three ints: {output_depth, output_height, output_width}
591 * \param pool_type The type of pooling operator
592 * \param layout The input layout. The default is "NCDHW".
593 */
594 inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_size,
595 PoolType pool_type, const std::string& layout = "NCDHW") {
596 int depth_axis = -1, height_axis = -1, width_axis = -1;
597 CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
598 << "Unsupported layout " << layout;
599 return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
600 }
601
602 /*!
603 * \brief Perform global pooling on height and width dimension of data.
604 * It decides the height and width dimension according to the layout string,
605 * in which 'W' and 'H' means width and height respectively.
606 * Width and height dimension cannot be split.
607 * For example, NCHW, NCHW16c, ... are valid for global_pool,
608 * while NCHW16w, NCHW16h are not.
609 * See \a layout for more information of the layout string convention.
610 *
611 * \param x The input tensor represent as layout
612 * \param pool_type The type of pooling operator
613 * \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear.
614 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
615 * where upper case indicates a dimension and
616 * the corresponding lower case (with factor size) indicates the sub-dimension.
617 * For example, `NCHW16c` can describe a 5-D tensor of
618 * [batch_size, channel, height, width, channel_block].
619 * (in which factor size `16` will not be used in pooling but for other operators,
620 * it can be used to decide the output shape).
621 * Since pooling does not care about the factor size of
622 * dimensions other than `H` and `W`, one can pass `NCHWc` as well.
623 *
624 * \return The output tensor in same layout with height and width dimension size of 1.
625 * e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
626 */
627 inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") {
628 return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
629 }
630
631 /*!
632 * \brief Perform pooling on N-dimension of data.
633 *
634 * \param x The input tensor
635 * \param kernel_size Vector of N ints
636 * \param stride_size Vector of N ints
637 * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ...,
638 * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN]
639 * \param pool_type The type of pooling operator
640 * \param ceil_mode Whether to use ceil when calculating the output size
641 * \param axis Vector of indices for the N dimensions
642 * \param count_include_pad Whether include padding in the calculation
643 *
644 * \return The output tensor in same layout order
645 */
pool_impl_nd(const Tensor & x,const Array<PrimExpr> & kernel_size,const Array<PrimExpr> & stride_size,const Array<PrimExpr> & padding_size,PoolType pool_type,bool ceil_mode,const std::vector<int> & axis,bool count_include_pad)646 inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
647 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
648 PoolType pool_type, bool ceil_mode, const std::vector<int>& axis,
649 bool count_include_pad) {
650 int k_size = kernel_size.size();
651 int x_size = x->shape.size();
652 CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
653 CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
654 " kernel";
655 CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
656
657 Array<IterVar> daxis;
658 std::vector<PrimExpr> kernel(k_size);
659 std::vector<PrimExpr> stride(k_size);
660 std::vector<PrimExpr> pad_head(k_size);
661 std::vector<PrimExpr> pad_tail(k_size);
662 Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
663 Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
664 Array<PrimExpr> out_shape = x->shape;
665
666 bool do_pad = false;
667 for (int i = 0; i < k_size; i++) {
668 int ii = axis[i];
669 kernel[i] = cast(DataType::Int(32), kernel_size[i]);
670 stride[i] = cast(DataType::Int(32), stride_size[i]);
671 pad_head[i] = cast(DataType::Int(32), padding_size[i]);
672 pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);
673 const int64_t* padding0 = as_const_int(pad_head[i]);
674 const int64_t* padding1 = as_const_int(pad_tail[i]);
675 do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1));
676
677 if (ceil_mode) {
678 // Additional padding to ensure we do ceil instead of floor when
679 // dividing by stride.
680 pad_tail[i] += stride[i] - 1;
681 }
682
683 daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i])));
684
685 pad_before.Set(ii, pad_head[i]);
686 pad_after.Set(ii, pad_tail[i]);
687
688 arith::Analyzer analyzer;
689 auto out_dim = analyzer.Simplify(
690 indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
691
692 out_shape.Set(ii, out_dim);
693 }
694
695 if (pool_type == kMaxPool) {
696 auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
697 return tvm::te::compute(
698 out_shape,
699 [&](const Array<Var>& output) {
700 Array<PrimExpr> indices;
701 for (const Var& var : output) indices.push_back(var);
702
703 for (int i = 0; i < k_size; i++) {
704 int ii = axis[i];
705 indices.Set(ii, output[ii] * stride[i] + daxis[i]);
706 }
707
708 return tvm::max(temp(indices), daxis);
709 },
710 "tensor", "pool_max");
711 } else if (pool_type == kAvgPool) {
712 // Pad the inputs
713 auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
714
715 // TVM compute for summing the pooling window.
716 auto pool_sum = tvm::te::compute(
717 out_shape,
718 [&](const Array<Var>& output) {
719 Array<PrimExpr> indices;
720 for (const Var& var : output) indices.push_back(var);
721
722 for (int i = 0; i < k_size; i++) {
723 int ii = axis[i];
724 indices.Set(ii, output[ii] * stride[i] + daxis[i]);
725 }
726 return tvm::sum(temp(indices), daxis);
727 },
728 "tensor", "pool_sum");
729
730 // TVM compute for dividing the reduced window sum by kernel size.
731 return tvm::te::compute(
732 out_shape,
733 [&](const Array<Var>& output) {
734 Array<PrimExpr> indices;
735 for (const Var& var : output) indices.push_back(var);
736 if (count_include_pad) {
737 auto kernel_size = make_const(DataType::Int(32), 1);
738 for (int i = 0; i < k_size; i++) {
739 kernel_size *= kernel[i];
740 }
741 return div(pool_sum(indices), kernel_size);
742 } else {
743 std::vector<PrimExpr> start(k_size);
744 std::vector<PrimExpr> end(k_size);
745 auto kernel_size = make_const(DataType::Int(32), 1);
746 for (int i = 0; i < k_size; i++) {
747 int ii = axis[i];
748 start[i] = output[ii] * stride[i] - pad_head[i];
749 end[i] = min(start[i] + kernel[i], x->shape[ii]);
750 start[i] = max(start[i], make_const(DataType::Int(32), 0));
751 kernel_size *= (end[i] - start[i]);
752 }
753
754 PrimExpr divide_factor = max(kernel_size, make_const(DataType::Int(32), 1));
755 return div(pool_sum(indices), divide_factor);
756 }
757 },
758 "tensor", kElementWise);
759 } else {
760 LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
761 return x;
762 }
763 }
764
765 /*!
766 * \brief Perform pooling on the width dimension of data.
767 * Width axis is determined by the layout string
768 * in which 'W' means width.
769 * Width dimension cannot be split.
770 * For example, NCW, NCW16c, etc. are valid for pool,
771 * while NCW16w is not.
772 * See \a layout for more information of the layout string convention.
773 * \param x The input tensor.
774 * \param kernel_size Vector of three ints: {kernel_width}
775 * \param stride_size Vector of three ints: {stride_width}
776 * \param padding_size Vector of six ints: {head_pad_width, tail_pad_width}
777 * \param pool_type The type of pooling operator
778 * \param ceil_mode Whether to use ceil when calculating the output size
779 * \param layout The input layout. Pooling supports any layout as long as 'W' appears.
780 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
781 * where upper case indicates a dimension and
782 * the corresponding lower case (with factor size) indicates the split dimension.
783 * For example, NCW16c can describe a 4-D tensor of
784 * [batch_size, channel, width, channel_block].
785 * (in which factor size `16` will not be used in pooling but for other operators,
786 * it can be used to decide the output shape).
787 * Since pooling does not care about the factor size of dimensions
788 * other than `W`, one can pass `NCWc` as well.
789 * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
790 *
791 *
792 * \return The output tensor in the same layout
793 */
794 inline Tensor pool1d(const Tensor& x, const Array<PrimExpr>& kernel_size,
795 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
796 PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW",
797 bool count_include_pad = true) {
798 int width_axis = -1;
799 CHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
800 std::vector<int> axis = {width_axis};
801 return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis,
802 count_include_pad);
803 }
804
805 /*!
806 * \brief Perform pooling on depth, height and width dimension of data.
807 * It decides the depth, height and width dimension according to the layout string,
808 * in which 'D', 'W' and 'H' means depth, width and height respectively.
809 * Depth, Width and height dimension cannot be split.
810 * For example, NCDHW, NCDHW16c, etc. are valid for pool,
811 * while NCDHW16d, NCDHW16w or NCDHW16h are not.
812 * See \a layout for more information of the layout string convention.
813 * \param x The input tensor.
814 * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width}
815 * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width}
816 * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width,
817 * tail_pad_depth, tail_pad_height, tail_pad_width}
818 * \param pool_type The type of pooling operator
819 * \param ceil_mode Whether to use ceil when calculating the output size
820 * \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear.
821 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
822 * where upper case indicates a dimension and
823 * the corresponding lower case (with factor size) indicates the split dimension.
824 * For example, NCDHW16c can describe a 6-D tensor of
825 * [batch_size, channel, depth, height, width, channel_block].
826 * (in which factor size `16` will not be used in pooling but for other operators,
827 * it can be used to decide the output shape).
828 * Since pooling does not care about the factor size of dimensions
829 * other than `D`, `H` and `W`, one can pass `NCDHWc` as well.
830 * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
831 *
832 *
833 * \return The output tensor in the same layout
834 */
835 inline Tensor pool3d(const Tensor& x, const Array<PrimExpr>& kernel_size,
836 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
837 PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW",
838 bool count_include_pad = true) {
839 int depth_axis = -1, height_axis = -1, width_axis = -1;
840 CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
841 << "Unsupported layout " << layout;
842 std::vector<int> axis = {depth_axis, height_axis, width_axis};
843 return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis,
844 count_include_pad);
845 }
846
847 } // namespace nn
848 } // namespace topi
849 } // namespace tvm
850 #endif // TVM_TOPI_NN_POOLING_H_
851