1//
2//  MetalConvolutionWinograd.metal
3//  MNN
4//
5//  Created by MNN on 2019/02/01.
6//  Copyright © 2018, Alibaba Group Holding Limited
7//
8
9#include <metal_stdlib>
10#include "MetalConvolutionActivation.metal"
11
12using namespace metal;
13
14struct winograd_constants {
15    int4 input_shape;
16    int4 output_shape;
17    int pad_x;
18    int pad_y;
19    int unit_width;
20    int unit_height;
21    int unit;
22    conv_activation_type activation;
23};
24
25static inline ftype4 get_input(const device ftype4 *input, int x, int y, constant winograd_constants &cst) {
26    return x < cst.input_shape.x && y < cst.input_shape.y && x >= 0 && y >= 0 ? input[x + y * cst.input_shape.x] : 0;
27}
28
29kernel void winograd_transform_source2_5_1(const device ftype4 *in          [[buffer(0)]],
30                                           device ftype4 *out               [[buffer(1)]],
31                                           constant winograd_constants &cst [[buffer(2)]],
32                                           uint3 gid                        [[thread_position_in_grid]]) {
33    auto pos = int3(gid);
34    if (pos.x < cst.unit_width && pos.y < cst.unit_height) {
35        int ix = pos.x * cst.unit - cst.pad_x;
36        int iy = pos.y * cst.unit - cst.pad_y;
37
38        auto z_in = in + pos.z * cst.input_shape.x * cst.input_shape.y;
39        auto S00 = get_input(z_in, ix + 0, iy + 0, cst);
40        auto S10 = get_input(z_in, ix + 1, iy + 0, cst);
41        auto S20 = get_input(z_in, ix + 2, iy + 0, cst);
42        auto S30 = get_input(z_in, ix + 3, iy + 0, cst);
43        auto S40 = get_input(z_in, ix + 4, iy + 0, cst);
44        auto S50 = get_input(z_in, ix + 5, iy + 0, cst);
45        auto S01 = get_input(z_in, ix + 0, iy + 1, cst);
46        auto S11 = get_input(z_in, ix + 1, iy + 1, cst);
47        auto S21 = get_input(z_in, ix + 2, iy + 1, cst);
48        auto S31 = get_input(z_in, ix + 3, iy + 1, cst);
49        auto S41 = get_input(z_in, ix + 4, iy + 1, cst);
50        auto S51 = get_input(z_in, ix + 5, iy + 1, cst);
51        auto S02 = get_input(z_in, ix + 0, iy + 2, cst);
52        auto S12 = get_input(z_in, ix + 1, iy + 2, cst);
53        auto S22 = get_input(z_in, ix + 2, iy + 2, cst);
54        auto S32 = get_input(z_in, ix + 3, iy + 2, cst);
55        auto S42 = get_input(z_in, ix + 4, iy + 2, cst);
56        auto S52 = get_input(z_in, ix + 5, iy + 2, cst);
57        auto S03 = get_input(z_in, ix + 0, iy + 3, cst);
58        auto S13 = get_input(z_in, ix + 1, iy + 3, cst);
59        auto S23 = get_input(z_in, ix + 2, iy + 3, cst);
60        auto S33 = get_input(z_in, ix + 3, iy + 3, cst);
61        auto S43 = get_input(z_in, ix + 4, iy + 3, cst);
62        auto S53 = get_input(z_in, ix + 5, iy + 3, cst);
63        auto S04 = get_input(z_in, ix + 0, iy + 4, cst);
64        auto S14 = get_input(z_in, ix + 1, iy + 4, cst);
65        auto S24 = get_input(z_in, ix + 2, iy + 4, cst);
66        auto S34 = get_input(z_in, ix + 3, iy + 4, cst);
67        auto S44 = get_input(z_in, ix + 4, iy + 4, cst);
68        auto S54 = get_input(z_in, ix + 5, iy + 4, cst);
69        auto S05 = get_input(z_in, ix + 0, iy + 5, cst);
70        auto S15 = get_input(z_in, ix + 1, iy + 5, cst);
71        auto S25 = get_input(z_in, ix + 2, iy + 5, cst);
72        auto S35 = get_input(z_in, ix + 3, iy + 5, cst);
73        auto S45 = get_input(z_in, ix + 4, iy + 5, cst);
74        auto S55 = get_input(z_in, ix + 5, iy + 5, cst);
75
76        auto m00 = +S00 - 1.25 * S02 + 0.25 * S04;
77        auto m10 = +S10 - 1.25 * S12 + 0.25 * S14;
78        auto m20 = +S20 - 1.25 * S22 + 0.25 * S24;
79        auto m30 = +S30 - 1.25 * S32 + 0.25 * S34;
80        auto m40 = +S40 - 1.25 * S42 + 0.25 * S44;
81        auto m50 = +S50 - 1.25 * S52 + 0.25 * S54;
82        auto m01 = +0.666667 * S01 + 0.666667 * S02 - 0.166667 * S03 - 0.166667 * S04;
83        auto m11 = +0.666667 * S11 + 0.666667 * S12 - 0.166667 * S13 - 0.166667 * S14;
84        auto m21 = +0.666667 * S21 + 0.666667 * S22 - 0.166667 * S23 - 0.166667 * S24;
85        auto m31 = +0.666667 * S31 + 0.666667 * S32 - 0.166667 * S33 - 0.166667 * S34;
86        auto m41 = +0.666667 * S41 + 0.666667 * S42 - 0.166667 * S43 - 0.166667 * S44;
87        auto m51 = +0.666667 * S51 + 0.666667 * S52 - 0.166667 * S53 - 0.166667 * S54;
88        auto m02 = -0.666667 * S01 + 0.666667 * S02 + 0.166667 * S03 - 0.166667 * S04;
89        auto m12 = -0.666667 * S11 + 0.666667 * S12 + 0.166667 * S13 - 0.166667 * S14;
90        auto m22 = -0.666667 * S21 + 0.666667 * S22 + 0.166667 * S23 - 0.166667 * S24;
91        auto m32 = -0.666667 * S31 + 0.666667 * S32 + 0.166667 * S33 - 0.166667 * S34;
92        auto m42 = -0.666667 * S41 + 0.666667 * S42 + 0.166667 * S43 - 0.166667 * S44;
93        auto m52 = -0.666667 * S51 + 0.666667 * S52 + 0.166667 * S53 - 0.166667 * S54;
94        auto m03 = -0.0833333 * S01 - 0.0416667 * S02 + 0.0833333 * S03 + 0.0416667 * S04;
95        auto m13 = -0.0833333 * S11 - 0.0416667 * S12 + 0.0833333 * S13 + 0.0416667 * S14;
96        auto m23 = -0.0833333 * S21 - 0.0416667 * S22 + 0.0833333 * S23 + 0.0416667 * S24;
97        auto m33 = -0.0833333 * S31 - 0.0416667 * S32 + 0.0833333 * S33 + 0.0416667 * S34;
98        auto m43 = -0.0833333 * S41 - 0.0416667 * S42 + 0.0833333 * S43 + 0.0416667 * S44;
99        auto m53 = -0.0833333 * S51 - 0.0416667 * S52 + 0.0833333 * S53 + 0.0416667 * S54;
100        auto m04 = +0.0833333 * S01 - 0.0416667 * S02 - 0.0833333 * S03 + 0.0416667 * S04;
101        auto m14 = +0.0833333 * S11 - 0.0416667 * S12 - 0.0833333 * S13 + 0.0416667 * S14;
102        auto m24 = +0.0833333 * S21 - 0.0416667 * S22 - 0.0833333 * S23 + 0.0416667 * S24;
103        auto m34 = +0.0833333 * S31 - 0.0416667 * S32 - 0.0833333 * S33 + 0.0416667 * S34;
104        auto m44 = +0.0833333 * S41 - 0.0416667 * S42 - 0.0833333 * S43 + 0.0416667 * S44;
105        auto m54 = +0.0833333 * S51 - 0.0416667 * S52 - 0.0833333 * S53 + 0.0416667 * S54;
106        auto m05 = +4.0 * S01 - 5.0 * S03 + S05;
107        auto m15 = +4.0 * S11 - 5.0 * S13 + S15;
108        auto m25 = +4.0 * S21 - 5.0 * S23 + S25;
109        auto m35 = +4.0 * S31 - 5.0 * S33 + S35;
110        auto m45 = +4.0 * S41 - 5.0 * S43 + S45;
111        auto m55 = +4.0 * S51 - 5.0 * S53 + S55;
112
113        int dst_x_origin = pos.z;
114        int dst_y_origin = cst.unit_width * pos.y + pos.x;
115        int dst_y_stride = cst.input_shape.z * 4;
116        int dst_y        = dst_y_origin / 4;
117        int dst_x        = dst_y_origin % 4 + 4 * dst_x_origin;
118        int src_height   = UP_DIV(cst.unit_width * cst.unit_height, 4);
119        int stride       = src_height * dst_y_stride;
120        auto xy_out = out + dst_y * dst_y_stride + dst_x;
121                          *xy_out = +m00 - 1.25 * m20 + 0.25 * m40;
122        xy_out += stride; *xy_out = +0.666667 * m10 + 0.666667 * m20 - 0.166667 * m30 - 0.166667 * m40;
123        xy_out += stride; *xy_out = -0.666667 * m10 + 0.666667 * m20 + 0.166667 * m30 - 0.166667 * m40;
124        xy_out += stride; *xy_out = -0.0833333 * m10 - 0.0416667 * m20 + 0.0833333 * m30 + 0.0416667 * m40;
125        xy_out += stride; *xy_out = +0.0833333 * m10 - 0.0416667 * m20 - 0.0833333 * m30 + 0.0416667 * m40;
126        xy_out += stride; *xy_out = +4.0 * m10 - 5.0 * m30 + m50;
127        xy_out += stride; *xy_out = +m01 - 1.25 * m21 + 0.25 * m41;
128        xy_out += stride; *xy_out = +0.666667 * m11 + 0.666667 * m21 - 0.166667 * m31 - 0.166667 * m41;
129        xy_out += stride; *xy_out = -0.666667 * m11 + 0.666667 * m21 + 0.166667 * m31 - 0.166667 * m41;
130        xy_out += stride; *xy_out = -0.0833333 * m11 - 0.0416667 * m21 + 0.0833333 * m31 + 0.0416667 * m41;
131        xy_out += stride; *xy_out = +0.0833333 * m11 - 0.0416667 * m21 - 0.0833333 * m31 + 0.0416667 * m41;
132        xy_out += stride; *xy_out = +4.0 * m11 - 5.0 * m31 + m51;
133        xy_out += stride; *xy_out = +m02 - 1.25 * m22 + 0.25 * m42;
134        xy_out += stride; *xy_out = +0.666667 * m12 + 0.666667 * m22 - 0.166667 * m32 - 0.166667 * m42;
135        xy_out += stride; *xy_out = -0.666667 * m12 + 0.666667 * m22 + 0.166667 * m32 - 0.166667 * m42;
136        xy_out += stride; *xy_out = -0.0833333 * m12 - 0.0416667 * m22 + 0.0833333 * m32 + 0.0416667 * m42;
137        xy_out += stride; *xy_out = +0.0833333 * m12 - 0.0416667 * m22 - 0.0833333 * m32 + 0.0416667 * m42;
138        xy_out += stride; *xy_out = +4.0 * m12 - 5.0 * m32 + m52;
139        xy_out += stride; *xy_out = +m03 - 1.25 * m23 + 0.25 * m43;
140        xy_out += stride; *xy_out = +0.666667 * m13 + 0.666667 * m23 - 0.166667 * m33 - 0.166667 * m43;
141        xy_out += stride; *xy_out = -0.666667 * m13 + 0.666667 * m23 + 0.166667 * m33 - 0.166667 * m43;
142        xy_out += stride; *xy_out = -0.0833333 * m13 - 0.0416667 * m23 + 0.0833333 * m33 + 0.0416667 * m43;
143        xy_out += stride; *xy_out = +0.0833333 * m13 - 0.0416667 * m23 - 0.0833333 * m33 + 0.0416667 * m43;
144        xy_out += stride; *xy_out = +4.0 * m13 - 5.0 * m33 + m53;
145        xy_out += stride; *xy_out = +m04 - 1.25 * m24 + 0.25 * m44;
146        xy_out += stride; *xy_out = +0.666667 * m14 + 0.666667 * m24 - 0.166667 * m34 - 0.166667 * m44;
147        xy_out += stride; *xy_out = -0.666667 * m14 + 0.666667 * m24 + 0.166667 * m34 - 0.166667 * m44;
148        xy_out += stride; *xy_out = -0.0833333 * m14 - 0.0416667 * m24 + 0.0833333 * m34 + 0.0416667 * m44;
149        xy_out += stride; *xy_out = +0.0833333 * m14 - 0.0416667 * m24 - 0.0833333 * m34 + 0.0416667 * m44;
150        xy_out += stride; *xy_out = +4.0 * m14 - 5.0 * m34 + m54;
151        xy_out += stride; *xy_out = +m05 - 1.25 * m25 + 0.25 * m45;
152        xy_out += stride; *xy_out = +0.666667 * m15 + 0.666667 * m25 - 0.166667 * m35 - 0.166667 * m45;
153        xy_out += stride; *xy_out = -0.666667 * m15 + 0.666667 * m25 + 0.166667 * m35 - 0.166667 * m45;
154        xy_out += stride; *xy_out = -0.0833333 * m15 - 0.0416667 * m25 + 0.0833333 * m35 + 0.0416667 * m45;
155        xy_out += stride; *xy_out = +0.0833333 * m15 - 0.0416667 * m25 - 0.0833333 * m35 + 0.0416667 * m45;
156        xy_out += stride; *xy_out = +4.0 * m15 - 5.0 * m35 + m55;
157    }
158}
159
160kernel void winograd_transform_source2_3_1(const device ftype4 *in          [[buffer(0)]],
161                                           device ftype4 *out               [[buffer(1)]],
162                                           constant winograd_constants &cst [[buffer(2)]],
163                                           uint3 gid                        [[thread_position_in_grid]]) {
164    auto pos = int3(gid);
165    if (pos.x < cst.unit_width && pos.y < cst.unit_height) {
166        int ix = pos.x * cst.unit - cst.pad_x;
167        int iy = pos.y * cst.unit - cst.pad_y;
168
169        auto z_in = in + pos.z * cst.input_shape.x * cst.input_shape.y;
170        auto S00 = get_input(z_in, ix + 0, iy + 0, cst);
171        auto S10 = get_input(z_in, ix + 1, iy + 0, cst);
172        auto S20 = get_input(z_in, ix + 2, iy + 0, cst);
173        auto S30 = get_input(z_in, ix + 3, iy + 0, cst);
174        auto S01 = get_input(z_in, ix + 0, iy + 1, cst);
175        auto S11 = get_input(z_in, ix + 1, iy + 1, cst);
176        auto S21 = get_input(z_in, ix + 2, iy + 1, cst);
177        auto S31 = get_input(z_in, ix + 3, iy + 1, cst);
178        auto S02 = get_input(z_in, ix + 0, iy + 2, cst);
179        auto S12 = get_input(z_in, ix + 1, iy + 2, cst);
180        auto S22 = get_input(z_in, ix + 2, iy + 2, cst);
181        auto S32 = get_input(z_in, ix + 3, iy + 2, cst);
182        auto S03 = get_input(z_in, ix + 0, iy + 3, cst);
183        auto S13 = get_input(z_in, ix + 1, iy + 3, cst);
184        auto S23 = get_input(z_in, ix + 2, iy + 3, cst);
185        auto S33 = get_input(z_in, ix + 3, iy + 3, cst);
186
187        auto m00 = +S00 - S02;
188        auto m10 = +S10 - S12;
189        auto m20 = +S20 - S22;
190        auto m30 = +S30 - S32;
191        auto m01 = +0.5 * S01 + 0.5 * S02;
192        auto m11 = +0.5 * S11 + 0.5 * S12;
193        auto m21 = +0.5 * S21 + 0.5 * S22;
194        auto m31 = +0.5 * S31 + 0.5 * S32;
195        auto m02 = -0.5 * S01 + 0.5 * S02;
196        auto m12 = -0.5 * S11 + 0.5 * S12;
197        auto m22 = -0.5 * S21 + 0.5 * S22;
198        auto m32 = -0.5 * S31 + 0.5 * S32;
199        auto m03 = -S01 + S03;
200        auto m13 = -S11 + S13;
201        auto m23 = -S21 + S23;
202        auto m33 = -S31 + S33;
203
204        int dst_x_origin = pos.z;
205        int dst_y_origin = cst.unit_width * pos.y + pos.x;
206        int dst_y_stride = cst.input_shape.z * 4;
207        int dst_y        = dst_y_origin / 4;
208        int dst_x        = dst_y_origin % 4 + 4 * dst_x_origin;
209        int src_height   = UP_DIV(cst.unit_width * cst.unit_height, 4);
210        int stride       = src_height * dst_y_stride;
211        auto xy_out = out + dst_y * dst_y_stride + dst_x;
212                          *xy_out =  +m00 - m20;
213        xy_out += stride; *xy_out =  +0.5 * m10 + 0.5 * m20;
214        xy_out += stride; *xy_out =  -0.5 * m10 + 0.5 * m20;
215        xy_out += stride; *xy_out =  -m10 + m30;
216        xy_out += stride; *xy_out =  +m01 - m21;
217        xy_out += stride; *xy_out =  +0.5 * m11 + 0.5 * m21;
218        xy_out += stride; *xy_out =  -0.5 * m11 + 0.5 * m21;
219        xy_out += stride; *xy_out =  -m11 + m31;
220        xy_out += stride; *xy_out =  +m02 - m22;
221        xy_out += stride; *xy_out=  +0.5 * m12 + 0.5 * m22;
222        xy_out += stride; *xy_out =  -0.5 * m12 + 0.5 * m22;
223        xy_out += stride; *xy_out =  -m12 + m32;
224        xy_out += stride; *xy_out =  +m03 - m23;
225        xy_out += stride; *xy_out =  +0.5 * m13 + 0.5 * m23;
226        xy_out += stride; *xy_out =  -0.5 * m13 + 0.5 * m23;
227        xy_out += stride; *xy_out =  -m13 + m33;
228    }
229}
230
231static inline void set_output(constant winograd_constants &cst, device ftype4 *output, int x, int y, ftype4 value) {
232    output[y * cst.output_shape.x + x] = activate(value, cst.activation);
233}
234
235kernel void winograd_transform_dest2_5_1(const device ftype4 *in            [[buffer(0)]],
236                                         const device ftype4 *biasTerms     [[buffer(1)]],
237                                         device ftype4 *out                 [[buffer(2)]],
238                                         constant winograd_constants &cst   [[buffer(3)]],
239                                         uint3 gid                          [[thread_position_in_grid]]) {
240    auto pos = int3(gid);
241    if (pos.x < cst.unit_width && pos.y < cst.unit_height) {
242        int dst_w        = UP_DIV(cst.unit_width * cst.unit_height, 4);
243        int dst_x_origin = cst.unit_width * pos.y + pos.x;
244        int dst_x        = dst_x_origin / 4;
245        int dst_y        = 4 * pos.z + dst_x_origin % 4;
246        int dst_y_stride = dst_w * 36;
247        auto xy_in = in + dst_y * dst_y_stride + dst_x;
248
249        auto S00 = *xy_in; xy_in += dst_w;
250        auto S10 = *xy_in; xy_in += dst_w;
251        auto S20 = *xy_in; xy_in += dst_w;
252        auto S30 = *xy_in; xy_in += dst_w;
253        auto S40 = *xy_in; xy_in += dst_w;
254        auto S50 = *xy_in; xy_in += dst_w;
255        auto S01 = *xy_in; xy_in += dst_w;
256        auto S11 = *xy_in; xy_in += dst_w;
257        auto S21 = *xy_in; xy_in += dst_w;
258        auto S31 = *xy_in; xy_in += dst_w;
259        auto S41 = *xy_in; xy_in += dst_w;
260        auto S51 = *xy_in; xy_in += dst_w;
261        auto S02 = *xy_in; xy_in += dst_w;
262        auto S12 = *xy_in; xy_in += dst_w;
263        auto S22 = *xy_in; xy_in += dst_w;
264        auto S32 = *xy_in; xy_in += dst_w;
265        auto S42 = *xy_in; xy_in += dst_w;
266        auto S52 = *xy_in; xy_in += dst_w;
267        auto S03 = *xy_in; xy_in += dst_w;
268        auto S13 = *xy_in; xy_in += dst_w;
269        auto S23 = *xy_in; xy_in += dst_w;
270        auto S33 = *xy_in; xy_in += dst_w;
271        auto S43 = *xy_in; xy_in += dst_w;
272        auto S53 = *xy_in; xy_in += dst_w;
273        auto S04 = *xy_in; xy_in += dst_w;
274        auto S14 = *xy_in; xy_in += dst_w;
275        auto S24 = *xy_in; xy_in += dst_w;
276        auto S34 = *xy_in; xy_in += dst_w;
277        auto S44 = *xy_in; xy_in += dst_w;
278        auto S54 = *xy_in; xy_in += dst_w;
279        auto S05 = *xy_in; xy_in += dst_w;
280        auto S15 = *xy_in; xy_in += dst_w;
281        auto S25 = *xy_in; xy_in += dst_w;
282        auto S35 = *xy_in; xy_in += dst_w;
283        auto S45 = *xy_in; xy_in += dst_w;
284        auto S55 = *xy_in;
285
286        auto m00 = +S00 + S01 + S02 + S03 + S04;
287        auto m10 = +S10 + S11 + S12 + S13 + S14;
288        auto m20 = +S20 + S21 + S22 + S23 + S24;
289        auto m30 = +S30 + S31 + S32 + S33 + S34;
290        auto m40 = +S40 + S41 + S42 + S43 + S44;
291        auto m50 = +S50 + S51 + S52 + S53 + S54;
292        auto m01 = +S01 - S02 + 2.0 * S03 - 2.0 * S04 + S05;
293        auto m11 = +S11 - S12 + 2.0 * S13 - 2.0 * S14 + S15;
294        auto m21 = +S21 - S22 + 2.0 * S23 - 2.0 * S24 + S25;
295        auto m31 = +S31 - S32 + 2.0 * S33 - 2.0 * S34 + S35;
296        auto m41 = +S41 - S42 + 2.0 * S43 - 2.0 * S44 + S45;
297        auto m51 = +S51 - S52 + 2.0 * S53 - 2.0 * S54 + S55;
298
299        // write output
300        auto b4 = biasTerms[int(pos.z)];
301        int oy = pos.y * cst.unit;
302        int ox = pos.x * cst.unit;
303        auto z_out = out + pos.z * cst.output_shape.x * cst.output_shape.y;
304
305        /* if true */ {
306            set_output(cst, z_out, ox + 0, oy + 0, b4 + m00 + m10 + m20 + m30 + m40);
307        }
308        if (ox + 1 < cst.output_shape.x) {
309            set_output(cst, z_out, ox + 1, oy + 0, b4 + m10 - m20 + 2.0 * m30 - 2.0 * m40 + m50);
310        }
311        if (oy + 1 < cst.output_shape.y) {
312            set_output(cst, z_out, ox + 0, oy + 1, b4 + m01 + m11 + m21 + m31 + m41);
313        }
314        if (ox + 1 < cst.output_shape.x && oy + 1 < cst.output_shape.y) {
315            set_output(cst, z_out, ox + 1, oy + 1, b4 + m11 - m21 + 2.0 * m31 - 2.0 * m41 + m51);
316        }
317    }
318}
319
320kernel void winograd_transform_dest2_3_1(const device ftype4 *in            [[buffer(0)]],
321                                         const device ftype4 *biasTerms     [[buffer(1)]],
322                                         device ftype4 *out                 [[buffer(2)]],
323                                         constant winograd_constants &cst   [[buffer(3)]],
324                                         uint3 gid                          [[thread_position_in_grid]]) {
325    auto pos = int3(gid);
326    if (pos.x < cst.unit_width && pos.y < cst.unit_height) {
327        int dst_w        = UP_DIV(cst.unit_width * cst.unit_height, 4);
328        int dst_x_origin = cst.unit_width * pos.y + pos.x;
329        int dst_x        = dst_x_origin / 4;
330        int dst_y        = 4 * pos.z + dst_x_origin % 4;
331        int dst_y_stride = dst_w * 16;
332        auto xy_in = in + dst_y * dst_y_stride + dst_x;
333
334        auto S00 = *xy_in; xy_in += dst_w;
335        auto S10 = *xy_in; xy_in += dst_w;
336        auto S20 = *xy_in; xy_in += dst_w;
337        auto S30 = *xy_in; xy_in += dst_w;
338        auto S01 = *xy_in; xy_in += dst_w;
339        auto S11 = *xy_in; xy_in += dst_w;
340        auto S21 = *xy_in; xy_in += dst_w;
341        auto S31 = *xy_in; xy_in += dst_w;
342        auto S02 = *xy_in; xy_in += dst_w;
343        auto S12 = *xy_in; xy_in += dst_w;
344        auto S22 = *xy_in; xy_in += dst_w;
345        auto S32 = *xy_in; xy_in += dst_w;
346        auto S03 = *xy_in; xy_in += dst_w;
347        auto S13 = *xy_in; xy_in += dst_w;
348        auto S23 = *xy_in; xy_in += dst_w;
349        auto S33 = *xy_in;
350
351        auto m00 = +S00 + S01 + S02;
352        auto m10 = +S10 + S11 + S12;
353        auto m20 = +S20 + S21 + S22;
354        auto m30 = +S30 + S31 + S32;
355        auto m01 = +S01 - S02 + S03;
356        auto m11 = +S11 - S12 + S13;
357        auto m21 = +S21 - S22 + S23;
358        auto m31 = +S31 - S32 + S33;
359
360        // write output
361        auto b4 = biasTerms[int(pos.z)];
362        int oy = pos.y * cst.unit;
363        int ox = pos.x * cst.unit;
364        auto z_out = out + pos.z * cst.output_shape.x * cst.output_shape.y;
365
366        /* if true */ {
367            set_output(cst, z_out, ox + 0, oy + 0, b4 + m00 + m10 + m20);
368        }
369        if (ox + 1 < cst.output_shape.x) {
370            set_output(cst, z_out, ox + 1, oy + 0, b4 + m10 - m20 + m30);
371        }
372        if (oy + 1 < cst.output_shape.y) {
373            set_output(cst, z_out, ox + 0, oy + 1, b4 + m01 + m11 + m21);
374        }
375        if (ox + 1 < cst.output_shape.x && oy + 1 < cst.output_shape.y) {
376            set_output(cst, z_out, ox + 1, oy + 1, b4 + m11 - m21 + m31);
377        }
378    }
379}
380