1// 2// MetalConvolutionWinograd.metal 3// MNN 4// 5// Created by MNN on 2019/02/01. 6// Copyright © 2018, Alibaba Group Holding Limited 7// 8 9#include <metal_stdlib> 10#include "MetalConvolutionActivation.metal" 11 12using namespace metal; 13 14struct winograd_constants { 15 int4 input_shape; 16 int4 output_shape; 17 int pad_x; 18 int pad_y; 19 int unit_width; 20 int unit_height; 21 int unit; 22 conv_activation_type activation; 23}; 24 25static inline ftype4 get_input(const device ftype4 *input, int x, int y, constant winograd_constants &cst) { 26 return x < cst.input_shape.x && y < cst.input_shape.y && x >= 0 && y >= 0 ? input[x + y * cst.input_shape.x] : 0; 27} 28 29kernel void winograd_transform_source2_5_1(const device ftype4 *in [[buffer(0)]], 30 device ftype4 *out [[buffer(1)]], 31 constant winograd_constants &cst [[buffer(2)]], 32 uint3 gid [[thread_position_in_grid]]) { 33 auto pos = int3(gid); 34 if (pos.x < cst.unit_width && pos.y < cst.unit_height) { 35 int ix = pos.x * cst.unit - cst.pad_x; 36 int iy = pos.y * cst.unit - cst.pad_y; 37 38 auto z_in = in + pos.z * cst.input_shape.x * cst.input_shape.y; 39 auto S00 = get_input(z_in, ix + 0, iy + 0, cst); 40 auto S10 = get_input(z_in, ix + 1, iy + 0, cst); 41 auto S20 = get_input(z_in, ix + 2, iy + 0, cst); 42 auto S30 = get_input(z_in, ix + 3, iy + 0, cst); 43 auto S40 = get_input(z_in, ix + 4, iy + 0, cst); 44 auto S50 = get_input(z_in, ix + 5, iy + 0, cst); 45 auto S01 = get_input(z_in, ix + 0, iy + 1, cst); 46 auto S11 = get_input(z_in, ix + 1, iy + 1, cst); 47 auto S21 = get_input(z_in, ix + 2, iy + 1, cst); 48 auto S31 = get_input(z_in, ix + 3, iy + 1, cst); 49 auto S41 = get_input(z_in, ix + 4, iy + 1, cst); 50 auto S51 = get_input(z_in, ix + 5, iy + 1, cst); 51 auto S02 = get_input(z_in, ix + 0, iy + 2, cst); 52 auto S12 = get_input(z_in, ix + 1, iy + 2, cst); 53 auto S22 = get_input(z_in, ix + 2, iy + 2, cst); 54 auto S32 = get_input(z_in, ix + 3, iy + 2, cst); 55 auto S42 = get_input(z_in, ix + 4, iy + 2, cst); 56 auto S52 = get_input(z_in, ix + 5, iy + 2, cst); 57 auto S03 = get_input(z_in, ix + 0, iy + 3, cst); 58 auto S13 = get_input(z_in, ix + 1, iy + 3, cst); 59 auto S23 = get_input(z_in, ix + 2, iy + 3, cst); 60 auto S33 = get_input(z_in, ix + 3, iy + 3, cst); 61 auto S43 = get_input(z_in, ix + 4, iy + 3, cst); 62 auto S53 = get_input(z_in, ix + 5, iy + 3, cst); 63 auto S04 = get_input(z_in, ix + 0, iy + 4, cst); 64 auto S14 = get_input(z_in, ix + 1, iy + 4, cst); 65 auto S24 = get_input(z_in, ix + 2, iy + 4, cst); 66 auto S34 = get_input(z_in, ix + 3, iy + 4, cst); 67 auto S44 = get_input(z_in, ix + 4, iy + 4, cst); 68 auto S54 = get_input(z_in, ix + 5, iy + 4, cst); 69 auto S05 = get_input(z_in, ix + 0, iy + 5, cst); 70 auto S15 = get_input(z_in, ix + 1, iy + 5, cst); 71 auto S25 = get_input(z_in, ix + 2, iy + 5, cst); 72 auto S35 = get_input(z_in, ix + 3, iy + 5, cst); 73 auto S45 = get_input(z_in, ix + 4, iy + 5, cst); 74 auto S55 = get_input(z_in, ix + 5, iy + 5, cst); 75 76 auto m00 = +S00 - 1.25 * S02 + 0.25 * S04; 77 auto m10 = +S10 - 1.25 * S12 + 0.25 * S14; 78 auto m20 = +S20 - 1.25 * S22 + 0.25 * S24; 79 auto m30 = +S30 - 1.25 * S32 + 0.25 * S34; 80 auto m40 = +S40 - 1.25 * S42 + 0.25 * S44; 81 auto m50 = +S50 - 1.25 * S52 + 0.25 * S54; 82 auto m01 = +0.666667 * S01 + 0.666667 * S02 - 0.166667 * S03 - 0.166667 * S04; 83 auto m11 = +0.666667 * S11 + 0.666667 * S12 - 0.166667 * S13 - 0.166667 * S14; 84 auto m21 = +0.666667 * S21 + 0.666667 * S22 - 0.166667 * S23 - 0.166667 * S24; 85 auto m31 = +0.666667 * S31 + 0.666667 * S32 - 0.166667 * S33 - 0.166667 * S34; 86 auto m41 = +0.666667 * S41 + 0.666667 * S42 - 0.166667 * S43 - 0.166667 * S44; 87 auto m51 = +0.666667 * S51 + 0.666667 * S52 - 0.166667 * S53 - 0.166667 * S54; 88 auto m02 = -0.666667 * S01 + 0.666667 * S02 + 0.166667 * S03 - 0.166667 * S04; 89 auto m12 = -0.666667 * S11 + 0.666667 * S12 + 0.166667 * S13 - 0.166667 * S14; 90 auto m22 = -0.666667 * S21 + 0.666667 * S22 + 0.166667 * S23 - 0.166667 * S24; 91 auto m32 = -0.666667 * S31 + 0.666667 * S32 + 0.166667 * S33 - 0.166667 * S34; 92 auto m42 = -0.666667 * S41 + 0.666667 * S42 + 0.166667 * S43 - 0.166667 * S44; 93 auto m52 = -0.666667 * S51 + 0.666667 * S52 + 0.166667 * S53 - 0.166667 * S54; 94 auto m03 = -0.0833333 * S01 - 0.0416667 * S02 + 0.0833333 * S03 + 0.0416667 * S04; 95 auto m13 = -0.0833333 * S11 - 0.0416667 * S12 + 0.0833333 * S13 + 0.0416667 * S14; 96 auto m23 = -0.0833333 * S21 - 0.0416667 * S22 + 0.0833333 * S23 + 0.0416667 * S24; 97 auto m33 = -0.0833333 * S31 - 0.0416667 * S32 + 0.0833333 * S33 + 0.0416667 * S34; 98 auto m43 = -0.0833333 * S41 - 0.0416667 * S42 + 0.0833333 * S43 + 0.0416667 * S44; 99 auto m53 = -0.0833333 * S51 - 0.0416667 * S52 + 0.0833333 * S53 + 0.0416667 * S54; 100 auto m04 = +0.0833333 * S01 - 0.0416667 * S02 - 0.0833333 * S03 + 0.0416667 * S04; 101 auto m14 = +0.0833333 * S11 - 0.0416667 * S12 - 0.0833333 * S13 + 0.0416667 * S14; 102 auto m24 = +0.0833333 * S21 - 0.0416667 * S22 - 0.0833333 * S23 + 0.0416667 * S24; 103 auto m34 = +0.0833333 * S31 - 0.0416667 * S32 - 0.0833333 * S33 + 0.0416667 * S34; 104 auto m44 = +0.0833333 * S41 - 0.0416667 * S42 - 0.0833333 * S43 + 0.0416667 * S44; 105 auto m54 = +0.0833333 * S51 - 0.0416667 * S52 - 0.0833333 * S53 + 0.0416667 * S54; 106 auto m05 = +4.0 * S01 - 5.0 * S03 + S05; 107 auto m15 = +4.0 * S11 - 5.0 * S13 + S15; 108 auto m25 = +4.0 * S21 - 5.0 * S23 + S25; 109 auto m35 = +4.0 * S31 - 5.0 * S33 + S35; 110 auto m45 = +4.0 * S41 - 5.0 * S43 + S45; 111 auto m55 = +4.0 * S51 - 5.0 * S53 + S55; 112 113 int dst_x_origin = pos.z; 114 int dst_y_origin = cst.unit_width * pos.y + pos.x; 115 int dst_y_stride = cst.input_shape.z * 4; 116 int dst_y = dst_y_origin / 4; 117 int dst_x = dst_y_origin % 4 + 4 * dst_x_origin; 118 int src_height = UP_DIV(cst.unit_width * cst.unit_height, 4); 119 int stride = src_height * dst_y_stride; 120 auto xy_out = out + dst_y * dst_y_stride + dst_x; 121 *xy_out = +m00 - 1.25 * m20 + 0.25 * m40; 122 xy_out += stride; *xy_out = +0.666667 * m10 + 0.666667 * m20 - 0.166667 * m30 - 0.166667 * m40; 123 xy_out += stride; *xy_out = -0.666667 * m10 + 0.666667 * m20 + 0.166667 * m30 - 0.166667 * m40; 124 xy_out += stride; *xy_out = -0.0833333 * m10 - 0.0416667 * m20 + 0.0833333 * m30 + 0.0416667 * m40; 125 xy_out += stride; *xy_out = +0.0833333 * m10 - 0.0416667 * m20 - 0.0833333 * m30 + 0.0416667 * m40; 126 xy_out += stride; *xy_out = +4.0 * m10 - 5.0 * m30 + m50; 127 xy_out += stride; *xy_out = +m01 - 1.25 * m21 + 0.25 * m41; 128 xy_out += stride; *xy_out = +0.666667 * m11 + 0.666667 * m21 - 0.166667 * m31 - 0.166667 * m41; 129 xy_out += stride; *xy_out = -0.666667 * m11 + 0.666667 * m21 + 0.166667 * m31 - 0.166667 * m41; 130 xy_out += stride; *xy_out = -0.0833333 * m11 - 0.0416667 * m21 + 0.0833333 * m31 + 0.0416667 * m41; 131 xy_out += stride; *xy_out = +0.0833333 * m11 - 0.0416667 * m21 - 0.0833333 * m31 + 0.0416667 * m41; 132 xy_out += stride; *xy_out = +4.0 * m11 - 5.0 * m31 + m51; 133 xy_out += stride; *xy_out = +m02 - 1.25 * m22 + 0.25 * m42; 134 xy_out += stride; *xy_out = +0.666667 * m12 + 0.666667 * m22 - 0.166667 * m32 - 0.166667 * m42; 135 xy_out += stride; *xy_out = -0.666667 * m12 + 0.666667 * m22 + 0.166667 * m32 - 0.166667 * m42; 136 xy_out += stride; *xy_out = -0.0833333 * m12 - 0.0416667 * m22 + 0.0833333 * m32 + 0.0416667 * m42; 137 xy_out += stride; *xy_out = +0.0833333 * m12 - 0.0416667 * m22 - 0.0833333 * m32 + 0.0416667 * m42; 138 xy_out += stride; *xy_out = +4.0 * m12 - 5.0 * m32 + m52; 139 xy_out += stride; *xy_out = +m03 - 1.25 * m23 + 0.25 * m43; 140 xy_out += stride; *xy_out = +0.666667 * m13 + 0.666667 * m23 - 0.166667 * m33 - 0.166667 * m43; 141 xy_out += stride; *xy_out = -0.666667 * m13 + 0.666667 * m23 + 0.166667 * m33 - 0.166667 * m43; 142 xy_out += stride; *xy_out = -0.0833333 * m13 - 0.0416667 * m23 + 0.0833333 * m33 + 0.0416667 * m43; 143 xy_out += stride; *xy_out = +0.0833333 * m13 - 0.0416667 * m23 - 0.0833333 * m33 + 0.0416667 * m43; 144 xy_out += stride; *xy_out = +4.0 * m13 - 5.0 * m33 + m53; 145 xy_out += stride; *xy_out = +m04 - 1.25 * m24 + 0.25 * m44; 146 xy_out += stride; *xy_out = +0.666667 * m14 + 0.666667 * m24 - 0.166667 * m34 - 0.166667 * m44; 147 xy_out += stride; *xy_out = -0.666667 * m14 + 0.666667 * m24 + 0.166667 * m34 - 0.166667 * m44; 148 xy_out += stride; *xy_out = -0.0833333 * m14 - 0.0416667 * m24 + 0.0833333 * m34 + 0.0416667 * m44; 149 xy_out += stride; *xy_out = +0.0833333 * m14 - 0.0416667 * m24 - 0.0833333 * m34 + 0.0416667 * m44; 150 xy_out += stride; *xy_out = +4.0 * m14 - 5.0 * m34 + m54; 151 xy_out += stride; *xy_out = +m05 - 1.25 * m25 + 0.25 * m45; 152 xy_out += stride; *xy_out = +0.666667 * m15 + 0.666667 * m25 - 0.166667 * m35 - 0.166667 * m45; 153 xy_out += stride; *xy_out = -0.666667 * m15 + 0.666667 * m25 + 0.166667 * m35 - 0.166667 * m45; 154 xy_out += stride; *xy_out = -0.0833333 * m15 - 0.0416667 * m25 + 0.0833333 * m35 + 0.0416667 * m45; 155 xy_out += stride; *xy_out = +0.0833333 * m15 - 0.0416667 * m25 - 0.0833333 * m35 + 0.0416667 * m45; 156 xy_out += stride; *xy_out = +4.0 * m15 - 5.0 * m35 + m55; 157 } 158} 159 160kernel void winograd_transform_source2_3_1(const device ftype4 *in [[buffer(0)]], 161 device ftype4 *out [[buffer(1)]], 162 constant winograd_constants &cst [[buffer(2)]], 163 uint3 gid [[thread_position_in_grid]]) { 164 auto pos = int3(gid); 165 if (pos.x < cst.unit_width && pos.y < cst.unit_height) { 166 int ix = pos.x * cst.unit - cst.pad_x; 167 int iy = pos.y * cst.unit - cst.pad_y; 168 169 auto z_in = in + pos.z * cst.input_shape.x * cst.input_shape.y; 170 auto S00 = get_input(z_in, ix + 0, iy + 0, cst); 171 auto S10 = get_input(z_in, ix + 1, iy + 0, cst); 172 auto S20 = get_input(z_in, ix + 2, iy + 0, cst); 173 auto S30 = get_input(z_in, ix + 3, iy + 0, cst); 174 auto S01 = get_input(z_in, ix + 0, iy + 1, cst); 175 auto S11 = get_input(z_in, ix + 1, iy + 1, cst); 176 auto S21 = get_input(z_in, ix + 2, iy + 1, cst); 177 auto S31 = get_input(z_in, ix + 3, iy + 1, cst); 178 auto S02 = get_input(z_in, ix + 0, iy + 2, cst); 179 auto S12 = get_input(z_in, ix + 1, iy + 2, cst); 180 auto S22 = get_input(z_in, ix + 2, iy + 2, cst); 181 auto S32 = get_input(z_in, ix + 3, iy + 2, cst); 182 auto S03 = get_input(z_in, ix + 0, iy + 3, cst); 183 auto S13 = get_input(z_in, ix + 1, iy + 3, cst); 184 auto S23 = get_input(z_in, ix + 2, iy + 3, cst); 185 auto S33 = get_input(z_in, ix + 3, iy + 3, cst); 186 187 auto m00 = +S00 - S02; 188 auto m10 = +S10 - S12; 189 auto m20 = +S20 - S22; 190 auto m30 = +S30 - S32; 191 auto m01 = +0.5 * S01 + 0.5 * S02; 192 auto m11 = +0.5 * S11 + 0.5 * S12; 193 auto m21 = +0.5 * S21 + 0.5 * S22; 194 auto m31 = +0.5 * S31 + 0.5 * S32; 195 auto m02 = -0.5 * S01 + 0.5 * S02; 196 auto m12 = -0.5 * S11 + 0.5 * S12; 197 auto m22 = -0.5 * S21 + 0.5 * S22; 198 auto m32 = -0.5 * S31 + 0.5 * S32; 199 auto m03 = -S01 + S03; 200 auto m13 = -S11 + S13; 201 auto m23 = -S21 + S23; 202 auto m33 = -S31 + S33; 203 204 int dst_x_origin = pos.z; 205 int dst_y_origin = cst.unit_width * pos.y + pos.x; 206 int dst_y_stride = cst.input_shape.z * 4; 207 int dst_y = dst_y_origin / 4; 208 int dst_x = dst_y_origin % 4 + 4 * dst_x_origin; 209 int src_height = UP_DIV(cst.unit_width * cst.unit_height, 4); 210 int stride = src_height * dst_y_stride; 211 auto xy_out = out + dst_y * dst_y_stride + dst_x; 212 *xy_out = +m00 - m20; 213 xy_out += stride; *xy_out = +0.5 * m10 + 0.5 * m20; 214 xy_out += stride; *xy_out = -0.5 * m10 + 0.5 * m20; 215 xy_out += stride; *xy_out = -m10 + m30; 216 xy_out += stride; *xy_out = +m01 - m21; 217 xy_out += stride; *xy_out = +0.5 * m11 + 0.5 * m21; 218 xy_out += stride; *xy_out = -0.5 * m11 + 0.5 * m21; 219 xy_out += stride; *xy_out = -m11 + m31; 220 xy_out += stride; *xy_out = +m02 - m22; 221 xy_out += stride; *xy_out= +0.5 * m12 + 0.5 * m22; 222 xy_out += stride; *xy_out = -0.5 * m12 + 0.5 * m22; 223 xy_out += stride; *xy_out = -m12 + m32; 224 xy_out += stride; *xy_out = +m03 - m23; 225 xy_out += stride; *xy_out = +0.5 * m13 + 0.5 * m23; 226 xy_out += stride; *xy_out = -0.5 * m13 + 0.5 * m23; 227 xy_out += stride; *xy_out = -m13 + m33; 228 } 229} 230 231static inline void set_output(constant winograd_constants &cst, device ftype4 *output, int x, int y, ftype4 value) { 232 output[y * cst.output_shape.x + x] = activate(value, cst.activation); 233} 234 235kernel void winograd_transform_dest2_5_1(const device ftype4 *in [[buffer(0)]], 236 const device ftype4 *biasTerms [[buffer(1)]], 237 device ftype4 *out [[buffer(2)]], 238 constant winograd_constants &cst [[buffer(3)]], 239 uint3 gid [[thread_position_in_grid]]) { 240 auto pos = int3(gid); 241 if (pos.x < cst.unit_width && pos.y < cst.unit_height) { 242 int dst_w = UP_DIV(cst.unit_width * cst.unit_height, 4); 243 int dst_x_origin = cst.unit_width * pos.y + pos.x; 244 int dst_x = dst_x_origin / 4; 245 int dst_y = 4 * pos.z + dst_x_origin % 4; 246 int dst_y_stride = dst_w * 36; 247 auto xy_in = in + dst_y * dst_y_stride + dst_x; 248 249 auto S00 = *xy_in; xy_in += dst_w; 250 auto S10 = *xy_in; xy_in += dst_w; 251 auto S20 = *xy_in; xy_in += dst_w; 252 auto S30 = *xy_in; xy_in += dst_w; 253 auto S40 = *xy_in; xy_in += dst_w; 254 auto S50 = *xy_in; xy_in += dst_w; 255 auto S01 = *xy_in; xy_in += dst_w; 256 auto S11 = *xy_in; xy_in += dst_w; 257 auto S21 = *xy_in; xy_in += dst_w; 258 auto S31 = *xy_in; xy_in += dst_w; 259 auto S41 = *xy_in; xy_in += dst_w; 260 auto S51 = *xy_in; xy_in += dst_w; 261 auto S02 = *xy_in; xy_in += dst_w; 262 auto S12 = *xy_in; xy_in += dst_w; 263 auto S22 = *xy_in; xy_in += dst_w; 264 auto S32 = *xy_in; xy_in += dst_w; 265 auto S42 = *xy_in; xy_in += dst_w; 266 auto S52 = *xy_in; xy_in += dst_w; 267 auto S03 = *xy_in; xy_in += dst_w; 268 auto S13 = *xy_in; xy_in += dst_w; 269 auto S23 = *xy_in; xy_in += dst_w; 270 auto S33 = *xy_in; xy_in += dst_w; 271 auto S43 = *xy_in; xy_in += dst_w; 272 auto S53 = *xy_in; xy_in += dst_w; 273 auto S04 = *xy_in; xy_in += dst_w; 274 auto S14 = *xy_in; xy_in += dst_w; 275 auto S24 = *xy_in; xy_in += dst_w; 276 auto S34 = *xy_in; xy_in += dst_w; 277 auto S44 = *xy_in; xy_in += dst_w; 278 auto S54 = *xy_in; xy_in += dst_w; 279 auto S05 = *xy_in; xy_in += dst_w; 280 auto S15 = *xy_in; xy_in += dst_w; 281 auto S25 = *xy_in; xy_in += dst_w; 282 auto S35 = *xy_in; xy_in += dst_w; 283 auto S45 = *xy_in; xy_in += dst_w; 284 auto S55 = *xy_in; 285 286 auto m00 = +S00 + S01 + S02 + S03 + S04; 287 auto m10 = +S10 + S11 + S12 + S13 + S14; 288 auto m20 = +S20 + S21 + S22 + S23 + S24; 289 auto m30 = +S30 + S31 + S32 + S33 + S34; 290 auto m40 = +S40 + S41 + S42 + S43 + S44; 291 auto m50 = +S50 + S51 + S52 + S53 + S54; 292 auto m01 = +S01 - S02 + 2.0 * S03 - 2.0 * S04 + S05; 293 auto m11 = +S11 - S12 + 2.0 * S13 - 2.0 * S14 + S15; 294 auto m21 = +S21 - S22 + 2.0 * S23 - 2.0 * S24 + S25; 295 auto m31 = +S31 - S32 + 2.0 * S33 - 2.0 * S34 + S35; 296 auto m41 = +S41 - S42 + 2.0 * S43 - 2.0 * S44 + S45; 297 auto m51 = +S51 - S52 + 2.0 * S53 - 2.0 * S54 + S55; 298 299 // write output 300 auto b4 = biasTerms[int(pos.z)]; 301 int oy = pos.y * cst.unit; 302 int ox = pos.x * cst.unit; 303 auto z_out = out + pos.z * cst.output_shape.x * cst.output_shape.y; 304 305 /* if true */ { 306 set_output(cst, z_out, ox + 0, oy + 0, b4 + m00 + m10 + m20 + m30 + m40); 307 } 308 if (ox + 1 < cst.output_shape.x) { 309 set_output(cst, z_out, ox + 1, oy + 0, b4 + m10 - m20 + 2.0 * m30 - 2.0 * m40 + m50); 310 } 311 if (oy + 1 < cst.output_shape.y) { 312 set_output(cst, z_out, ox + 0, oy + 1, b4 + m01 + m11 + m21 + m31 + m41); 313 } 314 if (ox + 1 < cst.output_shape.x && oy + 1 < cst.output_shape.y) { 315 set_output(cst, z_out, ox + 1, oy + 1, b4 + m11 - m21 + 2.0 * m31 - 2.0 * m41 + m51); 316 } 317 } 318} 319 320kernel void winograd_transform_dest2_3_1(const device ftype4 *in [[buffer(0)]], 321 const device ftype4 *biasTerms [[buffer(1)]], 322 device ftype4 *out [[buffer(2)]], 323 constant winograd_constants &cst [[buffer(3)]], 324 uint3 gid [[thread_position_in_grid]]) { 325 auto pos = int3(gid); 326 if (pos.x < cst.unit_width && pos.y < cst.unit_height) { 327 int dst_w = UP_DIV(cst.unit_width * cst.unit_height, 4); 328 int dst_x_origin = cst.unit_width * pos.y + pos.x; 329 int dst_x = dst_x_origin / 4; 330 int dst_y = 4 * pos.z + dst_x_origin % 4; 331 int dst_y_stride = dst_w * 16; 332 auto xy_in = in + dst_y * dst_y_stride + dst_x; 333 334 auto S00 = *xy_in; xy_in += dst_w; 335 auto S10 = *xy_in; xy_in += dst_w; 336 auto S20 = *xy_in; xy_in += dst_w; 337 auto S30 = *xy_in; xy_in += dst_w; 338 auto S01 = *xy_in; xy_in += dst_w; 339 auto S11 = *xy_in; xy_in += dst_w; 340 auto S21 = *xy_in; xy_in += dst_w; 341 auto S31 = *xy_in; xy_in += dst_w; 342 auto S02 = *xy_in; xy_in += dst_w; 343 auto S12 = *xy_in; xy_in += dst_w; 344 auto S22 = *xy_in; xy_in += dst_w; 345 auto S32 = *xy_in; xy_in += dst_w; 346 auto S03 = *xy_in; xy_in += dst_w; 347 auto S13 = *xy_in; xy_in += dst_w; 348 auto S23 = *xy_in; xy_in += dst_w; 349 auto S33 = *xy_in; 350 351 auto m00 = +S00 + S01 + S02; 352 auto m10 = +S10 + S11 + S12; 353 auto m20 = +S20 + S21 + S22; 354 auto m30 = +S30 + S31 + S32; 355 auto m01 = +S01 - S02 + S03; 356 auto m11 = +S11 - S12 + S13; 357 auto m21 = +S21 - S22 + S23; 358 auto m31 = +S31 - S32 + S33; 359 360 // write output 361 auto b4 = biasTerms[int(pos.z)]; 362 int oy = pos.y * cst.unit; 363 int ox = pos.x * cst.unit; 364 auto z_out = out + pos.z * cst.output_shape.x * cst.output_shape.y; 365 366 /* if true */ { 367 set_output(cst, z_out, ox + 0, oy + 0, b4 + m00 + m10 + m20); 368 } 369 if (ox + 1 < cst.output_shape.x) { 370 set_output(cst, z_out, ox + 1, oy + 0, b4 + m10 - m20 + m30); 371 } 372 if (oy + 1 < cst.output_shape.y) { 373 set_output(cst, z_out, ox + 0, oy + 1, b4 + m01 + m11 + m21); 374 } 375 if (ox + 1 < cst.output_shape.x && oy + 1 < cst.output_shape.y) { 376 set_output(cst, z_out, ox + 1, oy + 1, b4 + m11 - m21 + m31); 377 } 378 } 379} 380