1// 2// MetalPooling.metal 3// MNN 4// 5// Created by MNN on 2018/08/24. 6// Copyright © 2018, Alibaba Group Holding Limited 7// 8 9#include <metal_stdlib> 10#include "MetalDefine.metal" 11 12using namespace metal; 13 14struct pooling_sizes { 15 int input_width; 16 int input_height; 17 int output_width; 18 int output_height; 19 int slice; 20 int kernel_width; 21 int kernel_height; 22 int stride_width; 23 int stride_height; 24 int pad_width; 25 int pad_height; 26}; 27 28kernel void pooling_max(const device ftype4 *in [[buffer(0)]], 29 device ftype4 *out [[buffer(1)]], 30 constant pooling_sizes& s [[buffer(2)]], 31 uint3 gid [[thread_position_in_grid]]) { 32 if (any(gid >= uint3(s.output_width, s.output_height, s.slice))) return; 33 34 int off_x = gid.x * s.stride_width - s.pad_width; 35 int off_y = gid.y * s.stride_height - s.pad_height; 36 int x_max = s.input_width - 1; 37 int y_max = s.input_height - 1; 38 int ex = off_x + s.kernel_width; 39 int ey = off_y + s.kernel_height; 40 41 auto z_in = in + (int)gid.z * s.input_width * s.input_height; 42 auto result = ftype4(z_in[clamp(off_y, 0, y_max) * s.input_width + clamp(off_x, 0, x_max)]); 43 for (int y = off_y; y < ey; y++) { 44 auto y_in = z_in + clamp(y, 0, y_max) * s.input_width; 45 for (int x = off_x; x < ex; x++) { 46 result = max(result, y_in[clamp(x, 0, x_max)]); 47 } 48 } 49 out[(int)gid.z * s.output_width * s.output_height + (int)gid.y * s.output_width + (int)gid.x] = result; 50} 51 52kernel void pooling_avg(const device ftype4 *in [[buffer(0)]], 53 device ftype4 *out [[buffer(1)]], 54 constant pooling_sizes& s [[buffer(2)]], 55 uint3 gid [[thread_position_in_grid]]) { 56 if (any(gid >= uint3(s.output_width, s.output_height, s.slice))) return; 57 58 int off_x = gid.x * s.stride_width - s.pad_width; 59 int off_y = gid.y * s.stride_height - s.pad_height; 60 int sx = off_x + max(0, -off_x); 61 int sy = off_y + max(0, -off_y); 62 int ex = off_x + min(s.kernel_width, s.input_width - off_x); 63 int ey = off_y + min(s.kernel_height, s.input_height - off_y); 64 65 float4 result = 0; 66 auto z_in = in + (int)gid.z * s.input_width * s.input_height; 67 for (int y = sy; y < ey; y++) { 68 for (int x = sx; x < ex; x++) { 69 result += float4(z_in[y * s.input_width + x]); 70 } 71 } 72 int count = (ey - sy) * (ex - sx); 73 float4 div = count > 0 ? 1.f / count : 1; 74 out[(int)gid.z * s.output_width * s.output_height + (int)gid.y * s.output_width + (int)gid.x] = ftype4(result * div); 75} 76