1/* 2 This file is part of Leela Zero. 3 Copyright (C) 2017-2019 Gian-Carlo Pascutto and contributors 4 5 Leela Zero is free software: you can redistribute it and/or modify 6 it under the terms of the GNU General Public License as published by 7 the Free Software Foundation, either version 3 of the License, or 8 (at your option) any later version. 9 10 Leela Zero is distributed in the hope that it will be useful, 11 but WITHOUT ANY WARRANTY; without even the implied warranty of 12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 GNU General Public License for more details. 14 15 You should have received a copy of the GNU General Public License 16 along with Leela Zero. If not, see <http://www.gnu.org/licenses/>. 17 18 Additional permission under GNU GPL version 3 section 7 19 20 If you modify this Program, or any covered work, by linking or 21 combining it with NVIDIA Corporation's libraries from the 22 NVIDIA CUDA Toolkit and/or the NVIDIA CUDA Deep Neural 23 Network library and/or the NVIDIA TensorRT inference library 24 (or a modified version of those libraries), containing parts covered 25 by the terms of the respective license agreement, the licensors of 26 this Program grant you additional permission to convey the resulting 27 work. 28*/ 29 30// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string 31// literal). Comment-out this line for syntax-highlighting when developing. 32 33R"( 34 35#ifndef OUTIN_KWG 36#define OUTIN_KWG 2 37#endif 38 39#ifndef OUT_KWG 40#define OUT_KWG 32 41#endif 42 43#ifndef OUT_BWG 44#define OUT_BWG 2 45#endif 46 47__constant real Bt[WINOGRAD_ALPHA * WINOGRAD_ALPHA] = \ 48 {1.0f, 0.0f, -5.0f/2.0f, 0.0f, 1.0f, 0.0f, 49 0.0f, -SQ2, -2.0f, SQ2/2.0f, 1.0f, 0.0f, 50 0.0f, SQ2, -2.0f, -SQ2/2.0f, 1.0f, 0.0f, 51 0.0f, -SQ2/2.0f, -1.0f/2.0f, SQ2, 1.0f, 0.0f, 52 0.0f, SQ2/2.0f, -1.0f/2.0f, -SQ2, 1.0f, 0.0f, 53 0.0f, 1.0f, 0.0f, -5.0f/2.0f, 0.0f, 1.0f}; 54void multiply_bt( 55 real * o0, real * o1, real * o2, real * o3, real * o4, real * o5, 56 real i0, real i1, real i2, real i3, real i4, real i5 57) { 58 real i3m1 = i1 * -SQ2 + i3 * (SQ2 / 2.0f); 59 real i4m2 = i2 * -2.0f + i4 * 1.0f; 60 61 *o0 = i0 + i2 * (-5.0f/2.0f) + i4; 62 *o1 = i3m1 + i4m2; 63 *o2 = -i3m1 + i4m2; 64 65 real i3m1_2 = i3 * (SQ2) + i1 * (-SQ2/2.0f); 66 real i4m2_2 = i2 * (-1.0f/2.0f) + i4; 67 68 *o3 = i3m1_2 + i4m2_2; 69 *o4 = -i3m1_2 + i4m2_2; 70 71 *o5 = i1 + i3 * (-5.0f/2.0f) + i5; 72} 73 74 75__constant real At[WINOGRAD_M * WINOGRAD_ALPHA] = \ 76 {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 77 0.0f, SQ2/2.0f, -SQ2/2.0f, SQ2, -SQ2, 0.0f, 78 0.0f, 1.0f/2.0f, 1.0f/2.0f, 2.0f, 2.0f, 0.0f, 79 0.0f, SQ2/4.0f, -SQ2/4.0f, 2.0f*SQ2, -2.0f*SQ2, 1.0f}; 80void multiply_atv( 81 real4 * o, 82 real i0, real i1, real i2, real i3, real i4, real i5 83) { 84 real t1p2 = (i1 + i2) * (1.0f / 2.0f); 85 real t1m2 = (i1 - i2) * (SQ2/4.0f); 86 real t3p4 = i3 + i4; 87 real t3m4 = (i3 - i4) * (SQ2); 88 89 (*o).x = i0 + t1p2 + t1p2 + t3p4; 90 (*o).y = t1m2 + t1m2 + t3m4; 91 (*o).z = t1p2 + t3p4 + t3p4; 92 (*o).w = t1m2 + t3m4 + t3m4 + i5; 93} 94 95 96void multiply_at( 97 real * o0, real * o1, real * o2, real * o3, 98 real i0, real i1, real i2, real i3, real i4, real i5 99) { 100 real4 o; 101 multiply_atv(&o, i0, i1, i2, i3, i4, i5); 102 103 *o0 = o.x; 104 *o1 = o.y; 105 *o2 = o.z; 106 *o3 = o.w; 107} 108 109void __in_transform_eq(real x[WINOGRAD_ALPHA][WINOGRAD_ALPHA], __global net_t * restrict V, int offset, int CPpad) { 110 111 const int W = BOARD_SIZE; 112 const int H = BOARD_SIZE; 113 const int P = WTILES * WTILES; 114 115 real T1[WINOGRAD_ALPHA][WINOGRAD_ALPHA]; 116 real T2[WINOGRAD_ALPHA][WINOGRAD_ALPHA]; 117 118 // Calculates transpose(B).x.B 119#ifdef WINOGRAD_SIMD 120 for (int i = 0; i < WINOGRAD_ALPHA; i++){ 121 for (int j = 0; j < WINOGRAD_ALPHA; j++) { 122 real2 acc = {ZERO, ZERO}; 123 real2 *x2 = (real2 *)&x[j][0]; 124 for (int k = 0; k < WINOGRAD_ALPHA/2; k++) { 125 real2 x1; 126 x1.x = Bt[i * WINOGRAD_ALPHA + 2*k]; 127 x1.y = Bt[i * WINOGRAD_ALPHA + 2*k + 1]; 128 acc += x1 * x2[k]; 129 } 130 T1[i][j] = acc.x + acc.y; 131 } 132 } 133#else 134 for (int j = 0; j < WINOGRAD_ALPHA; j++) { 135 multiply_bt( 136 &(T1[0][j]), &(T1[1][j]), &(T1[2][j]), &(T1[3][j]), &(T1[4][j]), &(T1[5][j]), 137 x[j][0], x[j][1], x[j][2], x[j][3], x[j][4], x[j][5] 138 ); 139 } 140#endif 141 142#ifdef WINOGRAD_SIMD 143 for (int i = 0; i < WINOGRAD_ALPHA; i++){ 144 for (int j = 0; j < WINOGRAD_ALPHA; j++) { 145 real2 acc = {ZERO, ZERO}; 146 real2 *x1 = (real2 *)&T1[i][0]; 147 for (int k = 0; k < WINOGRAD_ALPHA/2; k++) { 148 real2 x2; 149 x2.x = Bt[j * WINOGRAD_ALPHA + 2*k]; 150 x2.y = Bt[j * WINOGRAD_ALPHA + 2*k + 1]; 151 acc += x1[k] * x2; 152 } 153 T2[i][j] = acc.x + acc.y; 154 } 155 } 156#else 157 for (int i = 0; i < WINOGRAD_ALPHA; i++){ 158 multiply_bt( 159 &(T2[i][0]), &(T2[i][1]), &(T2[i][2]), &(T2[i][3]), &(T2[i][4]), &(T2[i][5]), 160 T1[i][0], T1[i][1], T1[i][2], T1[i][3], T1[i][4], T1[i][5] 161 ); 162 } 163#endif 164 165 // Scatter each sub element in tile to separate matrices 166 for (int i = 0; i < WINOGRAD_ALPHA; i++) { 167 for (int j = 0; j < WINOGRAD_ALPHA; j++) { 168 vstore_net_t(T2[i][j], (i*WINOGRAD_ALPHA + j)*CPpad + offset, V); 169 } 170 } 171} 172 173__kernel void in_transform(__global net_t * restrict in, __global net_t * restrict V, 174 const int C, const int Cpad, 175 const int Ppad, const int batch_size) { 176 const int W = BOARD_SIZE; 177 const int H = BOARD_SIZE; 178 const int P = WTILES * WTILES; 179 const int CPpad = Ppad * Cpad; 180 181 const int block = get_global_id(0); 182 const int ch = get_global_id(1); 183 184 const int batch = block / P; 185 const int block_x = (block - P * batch) % WTILES; 186 const int block_y = (block - P * batch) / WTILES; 187 188 // 6x6 tiles overlap by 2 189 const int yin = WINOGRAD_M * block_y - 1; 190 const int xin = WINOGRAD_M * block_x - 1; 191 192 if (block < batch_size * P && ch < C) { 193 // Cache input tile and handle zero padding 194 real x[WINOGRAD_ALPHA][WINOGRAD_ALPHA]; 195 for (int i = 0; i < WINOGRAD_ALPHA; i++) { 196 for (int j = 0; j < WINOGRAD_ALPHA; j++) { 197 int a = xin + j; 198 int b = yin + i; 199 // x is transposed here for better layout later 200 if (b >= 0 && a >= 0 && b < H && a < W) { 201 x[j][i] = vload_net_t(batch * C * NUM_INTERSECTIONS + 202 ch * NUM_INTERSECTIONS + b * W + a, in); 203 } else { 204 x[j][i] = ZERO; 205 } 206 } 207 } 208 209 // V dimensions are [36, input_channels, batch_size * tiles]. 210 // Padded with zeros as necessary for SGEMM 211 // = [36, Cpad, Ppad] 212 213 const int offset = ch * Ppad + block; 214 __in_transform_eq(x, V, offset, CPpad); 215 } 216} 217 218__kernel __attribute__((reqd_work_group_size(OUT_KWG, OUT_BWG, 1))) 219void out_transform_fused_bn(__global const net_t * restrict M, 220 __global net_t * restrict Y, 221 const int K, 222 const int Kpad, const int Ppad, 223 const int batch_size, 224 __global const net_t * restrict residual, 225 __constant const net_t * restrict means, 226 __constant const net_t * restrict stddivs) { 227 228 const int W = BOARD_SIZE; 229 const int H = BOARD_SIZE; 230 const int P = WTILES * WTILES; 231 232 const int k = get_global_id(0); 233 const int block = get_global_id(1); 234 235 // Adding some padding decreases bank conflicts 236 __local real out_buf[OUT_KWG][OUT_BWG][WINOGRAD_M][WINOGRAD_M + 1]; 237 238 volatile int kid = get_local_id(0); 239 volatile int bid = get_local_id(1); 240 241 if (k < K && block < batch_size * P) { 242 const real mean = vload_net_t(k, means); 243 const real scale_stddiv = vload_net_t(k, stddivs); 244 245 real temp[WINOGRAD_M][WINOGRAD_ALPHA]; 246 247 // M dimensions are [36, outputs, batch_size * tiles]. 248 // Plus zero padding from SGEMM. 249 const int offset = block * Kpad + k; 250 251 // Calculates transpose(A).temp_m 252 for (int xn = 0; xn < WINOGRAD_ALPHA; xn++) { 253 real temp_m0 = vload_net_t((0 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 254 real temp_m1 = vload_net_t((1 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 255 real temp_m2 = vload_net_t((2 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 256 real temp_m3 = vload_net_t((3 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 257 real temp_m4 = vload_net_t((4 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 258 real temp_m5 = vload_net_t((5 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 259 multiply_at( 260 &(temp[0][xn]), &(temp[1][xn]), &(temp[2][xn]), &(temp[3][xn]), 261 temp_m0, temp_m1, temp_m2, temp_m3, temp_m4, temp_m5 262 ); 263 } 264 265 // Calculates temp.A 266 for (int i = 0; i < WINOGRAD_M; i++){ 267 real4 r; 268 multiply_atv( 269 &r, 270 temp[i][0], temp[i][1], temp[i][2], temp[i][3], temp[i][4], temp[i][5] 271 ); 272 273 r = (r - mean) * scale_stddiv; 274 out_buf[kid][bid][i][0] = r.x; 275 out_buf[kid][bid][i][1] = r.y; 276 out_buf[kid][bid][i][2] = r.z; 277 out_buf[kid][bid][i][3] = r.w; 278 } 279 } 280 281 barrier(CLK_LOCAL_MEM_FENCE); 282 283 for (int idx = get_local_id(0) + get_local_size(0) * get_local_id(1); idx < OUT_BWG * OUT_KWG * WINOGRAD_M * WINOGRAD_M; idx += get_local_size(0) * get_local_size(1)) { 284 // Calculate indexing for coalesced memory access. 285 // This should be simplified somehow. 286 const int k_local = idx / (OUT_BWG * WINOGRAD_M * WINOGRAD_M); 287 288 const int idx_block = (idx - k_local * OUT_BWG * WINOGRAD_M * WINOGRAD_M); 289 290 const int row = idx_block / (WINOGRAD_M * OUT_BWG); 291 const int col = (idx_block - row * WINOGRAD_M * OUT_BWG); 292 const int block_local = col / WINOGRAD_M; 293 294 const int j = col % WINOGRAD_M; 295 const int i = row % WINOGRAD_M; 296 297 const int blockt = get_group_id(1) * get_local_size(1) + block_local; 298 const int kt = get_group_id(0) * get_local_size(0) + k_local; 299 300 const int batch = blockt / P; 301 const int blockt_x = (blockt - P * batch) % WTILES; 302 const int blockt_y = (blockt - P * batch) / WTILES; 303 304 const int x = WINOGRAD_M * blockt_x; 305 const int y = WINOGRAD_M * blockt_y; 306 const int out_idx = batch * K * NUM_INTERSECTIONS + kt * NUM_INTERSECTIONS + (y + i) * W + (x + j); 307 308 if (kt < K && blockt < batch_size * P && y + i < H && x + j < W) { 309 real acc = out_buf[k_local][block_local][i][j]; 310 if (residual) { 311 acc += vload_net_t(out_idx, residual); 312 } 313 acc = acc > ZERO ? acc : ZERO; 314 315 vstore_net_t(acc, out_idx, Y); 316 } 317 } 318} 319 320__kernel void out_transform_fused_bn_in( 321 __global const net_t * restrict M, 322 __global net_t * restrict Y, 323 __global net_t * restrict V, 324 const int K, 325 const int Kpad, const int Ppad, const int Cpad, 326 __global const net_t * restrict residual, 327 __constant const net_t * restrict means, 328 __constant const net_t * restrict stddivs) { 329 330 const int W = BOARD_SIZE; 331 const int H = BOARD_SIZE; 332 const int P = WTILES * WTILES; 333 334 const int k = get_global_id(0); 335 const int kg = get_local_id(0); 336 const int block = get_global_id(1); 337 const int batch = get_global_id(2); 338 339 const int block_x = block % WTILES; 340 const int block_y = block / WTILES; 341 342 const int x = WINOGRAD_M * block_x; 343 const int y = WINOGRAD_M * block_y; 344 345 const int kHW = batch * K * NUM_INTERSECTIONS + k * NUM_INTERSECTIONS; 346 347 __local real ybuf[OUTIN_KWG * NUM_INTERSECTIONS]; 348 349 if (k < K && block < P) { 350 351 const real mean = vload_net_t(k, means); 352 const real scale_stddiv = vload_net_t(k, stddivs); 353 354 real temp[WINOGRAD_M][WINOGRAD_ALPHA]; 355 356 // M dimensions are [36, outputs, batch_size * tiles]. 357 // Plus zero padding from SGEMM. 358 359 const int offset = block * Kpad + k; 360 361 // Calculates transpose(A).temp_m 362 for (int xn = 0; xn < WINOGRAD_ALPHA; xn++) { 363 real temp_m0 = vload_net_t((0 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 364 real temp_m1 = vload_net_t((1 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 365 real temp_m2 = vload_net_t((2 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 366 real temp_m3 = vload_net_t((3 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 367 real temp_m4 = vload_net_t((4 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 368 real temp_m5 = vload_net_t((5 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M); 369 370 multiply_at( 371 &(temp[0][xn]), &(temp[1][xn]), &(temp[2][xn]), &(temp[3][xn]), 372 temp_m0, temp_m1, temp_m2, temp_m3, temp_m4, temp_m5 373 ); 374 } 375 376 // Calculates temp.A 377 for (int i = 0; i < WINOGRAD_M; i++){ 378 real4 r; 379 multiply_atv( 380 &r, 381 temp[i][0], temp[i][1], temp[i][2], temp[i][3], temp[i][4], temp[i][5] 382 ); 383 384 r = scale_stddiv * (r - mean); 385 if (y + i < H && x + 0 < W) { 386 const int out_idx = (y + i) * W + (x + 0); 387 ybuf[kg * NUM_INTERSECTIONS + out_idx] = r.x; 388 } 389 if (y + i < H && x + 1 < W) { 390 const int out_idx = (y + i) * W + (x + 1); 391 ybuf[kg * NUM_INTERSECTIONS + out_idx] = r.y; 392 } 393 if (y + i < H && x + 2 < W) { 394 const int out_idx = (y + i) * W + (x + 2); 395 ybuf[kg * NUM_INTERSECTIONS + out_idx] = r.z; 396 } 397 if (y + i < H && x + 3 < W) { 398 const int out_idx = (y + i) * W + (x + 3); 399 ybuf[kg * NUM_INTERSECTIONS + out_idx] = r.w; 400 } 401 } 402 } 403 404 barrier(CLK_LOCAL_MEM_FENCE); 405 406 const int ks = get_local_size(0); 407 const int k0 = get_group_id(0) * get_local_size(0); 408 409 for (int x = get_local_id(0) + ks * get_local_id(1); x < ks * NUM_INTERSECTIONS; x += get_local_size(1) * get_local_size(0)) { 410 const int kx = x / NUM_INTERSECTIONS; 411 const int idx = x - kx * NUM_INTERSECTIONS; 412 413 const int kHWx = batch * K * NUM_INTERSECTIONS + (k0 + kx) * NUM_INTERSECTIONS; 414 415 real acc = ybuf[kx * NUM_INTERSECTIONS + idx]; 416 if (residual) { 417 acc += vload_net_t(kHWx + idx, residual); 418 } 419 acc = acc > ZERO ? acc : ZERO; 420 421 if (Y) { 422 vstore_net_t(acc, kHWx + idx, Y); 423 } 424 ybuf[kx * NUM_INTERSECTIONS + idx] = acc; 425 } 426 427 barrier(CLK_LOCAL_MEM_FENCE); 428 429 const int yin = WINOGRAD_M * block_y - 1; 430 const int xin = WINOGRAD_M * block_x - 1; 431 432 if (block < P && k < K) { 433 const int CPpad = Ppad * Cpad; 434 // Cache input tile and handle zero padding 435 real xx[WINOGRAD_ALPHA][WINOGRAD_ALPHA]; 436 for (int i = 0; i < WINOGRAD_ALPHA; i++) { 437 int b = yin + i; 438 for (int j = 0; j < WINOGRAD_ALPHA; j++) { 439 int a = xin + j; 440 // x is transposed here for better layout later 441 if (b >= 0 && a >= 0 && b < H && a < W) { 442 xx[j][i] = ybuf[kg * NUM_INTERSECTIONS + b * W + a]; 443 } else { 444 xx[j][i] = ZERO; 445 } 446 } 447 } 448 449 const int offset = k * Ppad + P * batch + block; 450 __in_transform_eq(xx, V, offset, CPpad); 451 } 452} 453 454// End of the C++11 raw string literal 455)" 456