1#version 450 2#define LOCAL_SZ_X 256 3layout(push_constant) uniform pushBlock { 4 int channels; 5 int in_h; 6 int in_w; 7 int out_h; 8 int out_w; 9 int padding_h; 10 int padding_w; 11 int filter_h; 12 int filter_w; 13 int stride_h; 14 int stride_w; 15 int total; 16 int padded_area; 17} p; 18 19layout(binding = 0) readonly buffer Input0{ 20 float in_buffer[]; 21}; 22 23layout(binding = 1) writeonly buffer Output{ 24 float out_buffer[]; 25}; 26 27layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in; 28 29void main() 30{ 31 int global_size = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x); 32 int gid = int(gl_GlobalInvocationID.x); 33 for (int index = gid; index < p.total; index += global_size) 34 { 35 const int pw = index % p.out_w; 36 const int ph = (index / p.out_w) % p.out_h; 37 const int c = (index / p.out_w / p.out_h) % p.channels; 38 const int n = index / p.out_w / p.out_h / p.channels; 39 int hstart = ph * p.stride_h - p.padding_h; 40 int wstart = pw * p.stride_w - p.padding_w; 41 int hend = min(hstart + p.filter_h, p.in_h + p.padding_h); 42 int wend = min(wstart + p.filter_w, p.in_w + p.padding_w); 43 int pool_size; 44 if (p.padded_area == 1) 45 { 46 pool_size = (hend - hstart) * (wend - wstart); 47 hstart = max(hstart, 0); 48 wstart = max(wstart, 0); 49 hend = min(hend, p.in_h); 50 wend = min(wend, p.in_w); 51 } 52 else 53 { 54 hstart = max(hstart, 0); 55 wstart = max(wstart, 0); 56 hend = min(hend, p.in_h); 57 wend = min(wend, p.in_w); 58 pool_size = (hend - hstart) * (wend - wstart); 59 } 60 float aveval = 0; 61 int off = (n * p.channels + c) * p.in_h * p.in_w; 62 for (int h = hstart; h < hend; ++h) { 63 for (int w = wstart; w < wend; ++w) { 64 aveval += in_buffer[off + h * p.in_w + w]; 65 } 66 } 67 out_buffer[index] = aveval / pool_size; 68 } 69} 70