1//
2//  MetalPooling.metal
3//  MNN
4//
5//  Created by MNN on 2018/08/24.
6//  Copyright © 2018, Alibaba Group Holding Limited
7//
8
9#include <metal_stdlib>
10#include "MetalDefine.metal"
11
12using namespace metal;
13
14struct pooling_sizes {
15    int input_width;
16    int input_height;
17    int output_width;
18    int output_height;
19    int slice;
20    int kernel_width;
21    int kernel_height;
22    int stride_width;
23    int stride_height;
24    int pad_width;
25    int pad_height;
26};
27
28kernel void pooling_max(const device ftype4 *in     [[buffer(0)]],
29                        device ftype4 *out          [[buffer(1)]],
30                        constant pooling_sizes& s   [[buffer(2)]],
31                        uint3 gid                   [[thread_position_in_grid]]) {
32    if (any(gid >= uint3(s.output_width, s.output_height, s.slice))) return;
33
34    int off_x = gid.x * s.stride_width - s.pad_width;
35    int off_y = gid.y * s.stride_height - s.pad_height;
36    int x_max = s.input_width  - 1;
37    int y_max = s.input_height - 1;
38    int ex = off_x + s.kernel_width;
39    int ey = off_y + s.kernel_height;
40
41    auto z_in = in + (int)gid.z * s.input_width * s.input_height;
42    auto result = ftype4(z_in[clamp(off_y, 0, y_max) * s.input_width + clamp(off_x, 0, x_max)]);
43    for (int y = off_y; y < ey; y++) {
44        auto y_in = z_in + clamp(y, 0, y_max) * s.input_width;
45        for (int x = off_x; x < ex; x++) {
46            result = max(result, y_in[clamp(x, 0, x_max)]);
47        }
48    }
49    out[(int)gid.z * s.output_width * s.output_height + (int)gid.y * s.output_width + (int)gid.x] = result;
50}
51
52kernel void pooling_avg(const device ftype4 *in     [[buffer(0)]],
53                        device ftype4 *out          [[buffer(1)]],
54                        constant pooling_sizes& s   [[buffer(2)]],
55                        uint3 gid                   [[thread_position_in_grid]]) {
56    if (any(gid >= uint3(s.output_width, s.output_height, s.slice))) return;
57
58    int off_x = gid.x * s.stride_width - s.pad_width;
59    int off_y = gid.y * s.stride_height - s.pad_height;
60    int sx = off_x + max(0, -off_x);
61    int sy = off_y + max(0, -off_y);
62    int ex = off_x + min(s.kernel_width, s.input_width - off_x);
63    int ey = off_y + min(s.kernel_height, s.input_height - off_y);
64
65    float4 result = 0;
66    auto z_in = in + (int)gid.z * s.input_width * s.input_height;
67    for (int y = sy; y < ey; y++) {
68        for (int x = sx; x < ex; x++) {
69            result += float4(z_in[y * s.input_width + x]);
70        }
71    }
72    int count = (ey - sy) * (ex - sx);
73    float4 div = count > 0 ? 1.f / count : 1;
74    out[(int)gid.z * s.output_width * s.output_height + (int)gid.y * s.output_width + (int)gid.x] = ftype4(result * div);
75}
76