1#version 450
2#define LOCAL_SZ_X 256
3layout(push_constant) uniform pushBlock {
4      int channels;
5      int in_h;
6      int in_w;
7      int out_h;
8      int out_w;
9      int padding_h;
10      int padding_w;
11      int filter_h;
12      int filter_w;
13      int stride_h;
14      int stride_w;
15      int total;
16      int padded_area;
17} p;
18
19layout(binding = 0) readonly buffer Input0{
20    float in_buffer[];
21};
22
23layout(binding = 1) writeonly buffer Output{
24    float out_buffer[];
25};
26
27layout(local_size_x = LOCAL_SZ_X, local_size_y = 1, local_size_z = 1) in;
28
29void main()
30{
31  int global_size = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);
32  int gid = int(gl_GlobalInvocationID.x);
33  for (int index = gid; index < p.total; index += global_size)
34  {
35      const int pw = index % p.out_w;
36      const int ph = (index / p.out_w) % p.out_h;
37      const int c = (index / p.out_w / p.out_h) % p.channels;
38      const int n = index / p.out_w / p.out_h / p.channels;
39      int hstart = ph * p.stride_h - p.padding_h;
40      int wstart = pw * p.stride_w - p.padding_w;
41      int hend = min(hstart + p.filter_h, p.in_h + p.padding_h);
42      int wend = min(wstart + p.filter_w, p.in_w + p.padding_w);
43      int pool_size;
44      if (p.padded_area == 1)
45      {
46          pool_size = (hend - hstart) * (wend - wstart);
47          hstart = max(hstart, 0);
48          wstart = max(wstart, 0);
49          hend = min(hend, p.in_h);
50          wend = min(wend, p.in_w);
51      }
52      else
53      {
54          hstart = max(hstart, 0);
55          wstart = max(wstart, 0);
56          hend = min(hend, p.in_h);
57          wend = min(wend, p.in_w);
58          pool_size = (hend - hstart) * (wend - wstart);
59      }
60      float aveval = 0;
61      int off = (n * p.channels + c) * p.in_h * p.in_w;
62      for (int h = hstart; h < hend; ++h) {
63        for (int w = wstart; w < wend; ++w) {
64          aveval += in_buffer[off + h * p.in_w + w];
65        }
66      }
67      out_buffer[index] = aveval / pool_size;
68  }
69}
70