1/*
2    This file is part of Leela Zero.
3    Copyright (C) 2017-2019 Gian-Carlo Pascutto and contributors
4
5    Leela Zero is free software: you can redistribute it and/or modify
6    it under the terms of the GNU General Public License as published by
7    the Free Software Foundation, either version 3 of the License, or
8    (at your option) any later version.
9
10    Leela Zero is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13    GNU General Public License for more details.
14
15    You should have received a copy of the GNU General Public License
16    along with Leela Zero.  If not, see <http://www.gnu.org/licenses/>.
17
18    Additional permission under GNU GPL version 3 section 7
19
20    If you modify this Program, or any covered work, by linking or
21    combining it with NVIDIA Corporation's libraries from the
22    NVIDIA CUDA Toolkit and/or the NVIDIA CUDA Deep Neural
23    Network library and/or the NVIDIA TensorRT inference library
24    (or a modified version of those libraries), containing parts covered
25    by the terms of the respective license agreement, the licensors of
26    this Program grant you additional permission to convey the resulting
27    work.
28*/
29
30// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
31// literal). Comment-out this line for syntax-highlighting when developing.
32
33R"(
34
35#ifndef OUTIN_KWG
36#define OUTIN_KWG 2
37#endif
38
39#ifndef OUT_KWG
40#define OUT_KWG 32
41#endif
42
43#ifndef OUT_BWG
44#define OUT_BWG 2
45#endif
46
47__constant real Bt[WINOGRAD_ALPHA * WINOGRAD_ALPHA] = \
48                   {1.0f,  0.0f,     -5.0f/2.0f,  0.0f,      1.0f, 0.0f,
49                    0.0f, -SQ2,      -2.0f,       SQ2/2.0f,  1.0f, 0.0f,
50                    0.0f,  SQ2,      -2.0f,      -SQ2/2.0f,  1.0f, 0.0f,
51                    0.0f, -SQ2/2.0f, -1.0f/2.0f,  SQ2,       1.0f, 0.0f,
52                    0.0f,  SQ2/2.0f, -1.0f/2.0f, -SQ2,       1.0f, 0.0f,
53                    0.0f,  1.0f,      0.0f,      -5.0f/2.0f, 0.0f, 1.0f};
54void multiply_bt(
55    real * o0, real * o1, real * o2, real * o3, real * o4, real * o5,
56    real i0, real i1, real i2, real i3, real i4, real i5
57) {
58    real i3m1 = i1 * -SQ2 + i3 * (SQ2 / 2.0f);
59    real i4m2 = i2 * -2.0f + i4 * 1.0f;
60
61    *o0 = i0 + i2 * (-5.0f/2.0f) + i4;
62    *o1 = i3m1 + i4m2;
63    *o2 = -i3m1 + i4m2;
64
65    real i3m1_2 = i3 * (SQ2) + i1 * (-SQ2/2.0f);
66    real i4m2_2 = i2 * (-1.0f/2.0f) + i4;
67
68    *o3 = i3m1_2 + i4m2_2;
69    *o4 = -i3m1_2 + i4m2_2;
70
71    *o5 = i1 + i3 * (-5.0f/2.0f) + i5;
72}
73
74
75__constant real At[WINOGRAD_M * WINOGRAD_ALPHA] = \
76                   {1.0f, 1.0f,      1.0f,       1.0f,      1.0f,     0.0f,
77                    0.0f, SQ2/2.0f, -SQ2/2.0f,   SQ2,      -SQ2,      0.0f,
78                    0.0f, 1.0f/2.0f, 1.0f/2.0f,  2.0f,      2.0f,     0.0f,
79                    0.0f, SQ2/4.0f, -SQ2/4.0f,   2.0f*SQ2, -2.0f*SQ2, 1.0f};
80void multiply_atv(
81    real4 * o,
82    real i0, real i1, real i2, real i3, real i4, real i5
83) {
84    real t1p2 = (i1 + i2) * (1.0f / 2.0f);
85    real t1m2 = (i1 - i2) * (SQ2/4.0f);
86    real t3p4 = i3 + i4;
87    real t3m4 = (i3 - i4) * (SQ2);
88
89    (*o).x = i0 + t1p2 + t1p2 + t3p4;
90    (*o).y = t1m2 + t1m2 + t3m4;
91    (*o).z = t1p2 + t3p4 + t3p4;
92    (*o).w = t1m2 + t3m4 + t3m4 + i5;
93}
94
95
96void multiply_at(
97    real * o0, real * o1, real * o2, real * o3,
98    real i0, real i1, real i2, real i3, real i4, real i5
99) {
100    real4 o;
101    multiply_atv(&o, i0, i1, i2, i3, i4, i5);
102
103    *o0 = o.x;
104    *o1 = o.y;
105    *o2 = o.z;
106    *o3 = o.w;
107}
108
109void __in_transform_eq(real x[WINOGRAD_ALPHA][WINOGRAD_ALPHA], __global net_t * restrict V, int offset, int CPpad) {
110
111    const int W = BOARD_SIZE;
112    const int H = BOARD_SIZE;
113    const int P = WTILES * WTILES;
114
115    real T1[WINOGRAD_ALPHA][WINOGRAD_ALPHA];
116    real T2[WINOGRAD_ALPHA][WINOGRAD_ALPHA];
117
118    // Calculates transpose(B).x.B
119#ifdef WINOGRAD_SIMD
120    for (int i = 0; i < WINOGRAD_ALPHA; i++){
121        for (int j = 0; j < WINOGRAD_ALPHA; j++) {
122            real2 acc = {ZERO, ZERO};
123            real2 *x2 = (real2 *)&x[j][0];
124            for (int k = 0; k < WINOGRAD_ALPHA/2; k++) {
125                real2 x1;
126                x1.x = Bt[i * WINOGRAD_ALPHA + 2*k];
127                x1.y = Bt[i * WINOGRAD_ALPHA + 2*k + 1];
128                acc += x1 * x2[k];
129            }
130            T1[i][j] = acc.x + acc.y;
131        }
132    }
133#else
134    for (int j = 0; j < WINOGRAD_ALPHA; j++) {
135        multiply_bt(
136            &(T1[0][j]), &(T1[1][j]), &(T1[2][j]), &(T1[3][j]), &(T1[4][j]), &(T1[5][j]),
137            x[j][0], x[j][1], x[j][2], x[j][3], x[j][4], x[j][5]
138        );
139    }
140#endif
141
142#ifdef WINOGRAD_SIMD
143    for (int i = 0; i < WINOGRAD_ALPHA; i++){
144        for (int j = 0; j < WINOGRAD_ALPHA; j++) {
145            real2 acc = {ZERO, ZERO};
146            real2 *x1 = (real2 *)&T1[i][0];
147            for (int k = 0; k < WINOGRAD_ALPHA/2; k++) {
148                real2 x2;
149                x2.x = Bt[j * WINOGRAD_ALPHA + 2*k];
150                x2.y = Bt[j * WINOGRAD_ALPHA + 2*k + 1];
151                acc += x1[k] * x2;
152            }
153            T2[i][j] = acc.x + acc.y;
154        }
155    }
156#else
157    for (int i = 0; i < WINOGRAD_ALPHA; i++){
158        multiply_bt(
159            &(T2[i][0]),  &(T2[i][1]),  &(T2[i][2]),  &(T2[i][3]),  &(T2[i][4]),  &(T2[i][5]),
160            T1[i][0], T1[i][1], T1[i][2], T1[i][3], T1[i][4], T1[i][5]
161        );
162    }
163#endif
164
165    // Scatter each sub element in tile to separate matrices
166    for (int i = 0; i < WINOGRAD_ALPHA; i++) {
167        for (int j = 0; j < WINOGRAD_ALPHA; j++) {
168            vstore_net_t(T2[i][j], (i*WINOGRAD_ALPHA + j)*CPpad + offset, V);
169        }
170    }
171}
172
173__kernel void in_transform(__global net_t * restrict in, __global net_t * restrict V,
174                           const int C, const int Cpad,
175                           const int Ppad, const int batch_size) {
176    const int W = BOARD_SIZE;
177    const int H = BOARD_SIZE;
178    const int P = WTILES * WTILES;
179    const int CPpad = Ppad * Cpad;
180
181    const int block = get_global_id(0);
182    const int ch = get_global_id(1);
183
184    const int batch = block / P;
185    const int block_x = (block - P * batch) % WTILES;
186    const int block_y = (block - P * batch) / WTILES;
187
188    // 6x6 tiles overlap by 2
189    const int yin = WINOGRAD_M * block_y - 1;
190    const int xin = WINOGRAD_M * block_x - 1;
191
192    if (block < batch_size * P && ch < C) {
193        // Cache input tile and handle zero padding
194        real x[WINOGRAD_ALPHA][WINOGRAD_ALPHA];
195        for (int i = 0; i < WINOGRAD_ALPHA; i++) {
196            for (int j = 0; j < WINOGRAD_ALPHA; j++) {
197                int a = xin + j;
198                int b = yin + i;
199                // x is transposed here for better layout later
200                if (b >= 0 && a >= 0 && b < H && a < W) {
201                    x[j][i] = vload_net_t(batch * C * NUM_INTERSECTIONS +
202                        ch * NUM_INTERSECTIONS + b * W + a, in);
203                } else {
204                    x[j][i] = ZERO;
205                }
206            }
207        }
208
209        // V dimensions are [36, input_channels, batch_size * tiles].
210        // Padded with zeros as necessary for SGEMM
211        // = [36, Cpad, Ppad]
212
213        const int offset = ch * Ppad + block;
214        __in_transform_eq(x, V, offset, CPpad);
215    }
216}
217
218__kernel __attribute__((reqd_work_group_size(OUT_KWG, OUT_BWG, 1)))
219void out_transform_fused_bn(__global const net_t * restrict M,
220                                     __global net_t * restrict Y,
221                                     const int K,
222                                     const int Kpad, const int Ppad,
223                                     const int batch_size,
224                                     __global const net_t * restrict residual,
225                                     __constant const net_t * restrict means,
226                                     __constant const net_t * restrict stddivs) {
227
228    const int W = BOARD_SIZE;
229    const int H = BOARD_SIZE;
230    const int P = WTILES * WTILES;
231
232    const int k = get_global_id(0);
233    const int block = get_global_id(1);
234
235    // Adding some padding decreases bank conflicts
236    __local real out_buf[OUT_KWG][OUT_BWG][WINOGRAD_M][WINOGRAD_M + 1];
237
238    volatile int kid = get_local_id(0);
239    volatile int bid = get_local_id(1);
240
241    if (k < K && block < batch_size * P) {
242        const real mean = vload_net_t(k, means);
243        const real scale_stddiv = vload_net_t(k, stddivs);
244
245        real temp[WINOGRAD_M][WINOGRAD_ALPHA];
246
247        // M dimensions are [36, outputs, batch_size * tiles].
248        // Plus zero padding from SGEMM.
249        const int offset = block * Kpad + k;
250
251        // Calculates transpose(A).temp_m
252        for (int xn = 0; xn < WINOGRAD_ALPHA; xn++) {
253            real temp_m0 = vload_net_t((0 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
254            real temp_m1 = vload_net_t((1 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
255            real temp_m2 = vload_net_t((2 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
256            real temp_m3 = vload_net_t((3 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
257            real temp_m4 = vload_net_t((4 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
258            real temp_m5 = vload_net_t((5 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
259            multiply_at(
260                &(temp[0][xn]), &(temp[1][xn]), &(temp[2][xn]), &(temp[3][xn]),
261                temp_m0, temp_m1, temp_m2, temp_m3, temp_m4, temp_m5
262            );
263        }
264
265        // Calculates temp.A
266        for (int i = 0; i < WINOGRAD_M; i++){
267            real4 r;
268            multiply_atv(
269                &r,
270                temp[i][0], temp[i][1], temp[i][2], temp[i][3], temp[i][4], temp[i][5]
271            );
272
273            r = (r - mean) * scale_stddiv;
274            out_buf[kid][bid][i][0] = r.x;
275            out_buf[kid][bid][i][1] = r.y;
276            out_buf[kid][bid][i][2] = r.z;
277            out_buf[kid][bid][i][3] = r.w;
278        }
279    }
280
281    barrier(CLK_LOCAL_MEM_FENCE);
282
283    for (int idx = get_local_id(0) + get_local_size(0) * get_local_id(1); idx < OUT_BWG * OUT_KWG * WINOGRAD_M * WINOGRAD_M; idx += get_local_size(0) * get_local_size(1)) {
284        // Calculate indexing for coalesced memory access.
285        // This should be simplified somehow.
286        const int k_local = idx / (OUT_BWG * WINOGRAD_M * WINOGRAD_M);
287
288        const int idx_block = (idx - k_local * OUT_BWG * WINOGRAD_M * WINOGRAD_M);
289
290        const int row = idx_block / (WINOGRAD_M * OUT_BWG);
291        const int col = (idx_block - row * WINOGRAD_M * OUT_BWG);
292        const int block_local = col / WINOGRAD_M;
293
294        const int j = col % WINOGRAD_M;
295        const int i = row % WINOGRAD_M;
296
297        const int blockt = get_group_id(1) * get_local_size(1) + block_local;
298        const int kt = get_group_id(0) * get_local_size(0) + k_local;
299
300        const int batch = blockt / P;
301        const int blockt_x = (blockt - P * batch) % WTILES;
302        const int blockt_y = (blockt - P * batch) / WTILES;
303
304        const int x = WINOGRAD_M * blockt_x;
305        const int y = WINOGRAD_M * blockt_y;
306        const int out_idx = batch * K * NUM_INTERSECTIONS + kt * NUM_INTERSECTIONS + (y + i) * W + (x + j);
307
308        if (kt < K && blockt < batch_size * P && y + i < H && x + j < W) {
309            real acc = out_buf[k_local][block_local][i][j];
310            if (residual) {
311                acc += vload_net_t(out_idx, residual);
312            }
313            acc = acc > ZERO ? acc : ZERO;
314
315            vstore_net_t(acc, out_idx, Y);
316        }
317    }
318}
319
320__kernel void out_transform_fused_bn_in(
321                                     __global const net_t * restrict M,
322                                     __global net_t * restrict Y,
323                                     __global net_t * restrict V,
324                                     const int K,
325                                     const int Kpad, const int Ppad, const int Cpad,
326                                     __global const net_t * restrict residual,
327                                     __constant const net_t * restrict means,
328                                     __constant const net_t * restrict stddivs) {
329
330    const int W = BOARD_SIZE;
331    const int H = BOARD_SIZE;
332    const int P = WTILES * WTILES;
333
334    const int k = get_global_id(0);
335    const int kg = get_local_id(0);
336    const int block = get_global_id(1);
337    const int batch = get_global_id(2);
338
339    const int block_x = block % WTILES;
340    const int block_y = block / WTILES;
341
342    const int x = WINOGRAD_M * block_x;
343    const int y = WINOGRAD_M * block_y;
344
345    const int kHW = batch * K * NUM_INTERSECTIONS + k * NUM_INTERSECTIONS;
346
347    __local real ybuf[OUTIN_KWG * NUM_INTERSECTIONS];
348
349    if (k < K && block < P) {
350
351        const real mean = vload_net_t(k, means);
352        const real scale_stddiv = vload_net_t(k, stddivs);
353
354        real temp[WINOGRAD_M][WINOGRAD_ALPHA];
355
356        // M dimensions are [36, outputs, batch_size * tiles].
357        // Plus zero padding from SGEMM.
358
359        const int offset = block * Kpad + k;
360
361        // Calculates transpose(A).temp_m
362        for (int xn = 0; xn < WINOGRAD_ALPHA; xn++) {
363            real temp_m0 = vload_net_t((0 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
364            real temp_m1 = vload_net_t((1 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
365            real temp_m2 = vload_net_t((2 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
366            real temp_m3 = vload_net_t((3 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
367            real temp_m4 = vload_net_t((4 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
368            real temp_m5 = vload_net_t((5 * WINOGRAD_ALPHA + xn) * Kpad * Ppad + offset, M);
369
370            multiply_at(
371                &(temp[0][xn]), &(temp[1][xn]), &(temp[2][xn]), &(temp[3][xn]),
372                temp_m0, temp_m1, temp_m2, temp_m3, temp_m4, temp_m5
373            );
374        }
375
376        // Calculates temp.A
377        for (int i = 0; i < WINOGRAD_M; i++){
378            real4 r;
379            multiply_atv(
380                &r,
381                temp[i][0], temp[i][1], temp[i][2], temp[i][3], temp[i][4], temp[i][5]
382            );
383
384            r = scale_stddiv * (r - mean);
385            if (y + i < H && x + 0 < W) {
386                const int out_idx = (y + i) * W + (x + 0);
387                ybuf[kg * NUM_INTERSECTIONS + out_idx] = r.x;
388            }
389            if (y + i < H && x + 1 < W) {
390                const int out_idx = (y + i) * W + (x + 1);
391                ybuf[kg * NUM_INTERSECTIONS + out_idx] = r.y;
392            }
393            if (y + i < H && x + 2 < W) {
394                const int out_idx = (y + i) * W + (x + 2);
395                ybuf[kg * NUM_INTERSECTIONS + out_idx] = r.z;
396            }
397            if (y + i < H && x + 3 < W) {
398                const int out_idx = (y + i) * W + (x + 3);
399                ybuf[kg * NUM_INTERSECTIONS + out_idx] = r.w;
400            }
401        }
402    }
403
404    barrier(CLK_LOCAL_MEM_FENCE);
405
406    const int ks = get_local_size(0);
407    const int k0 = get_group_id(0) * get_local_size(0);
408
409    for (int x = get_local_id(0) + ks * get_local_id(1); x < ks * NUM_INTERSECTIONS; x += get_local_size(1) * get_local_size(0)) {
410        const int kx = x / NUM_INTERSECTIONS;
411        const int idx = x - kx * NUM_INTERSECTIONS;
412
413        const int kHWx = batch * K * NUM_INTERSECTIONS + (k0 + kx) * NUM_INTERSECTIONS;
414
415        real acc = ybuf[kx * NUM_INTERSECTIONS + idx];
416        if (residual) {
417            acc += vload_net_t(kHWx + idx, residual);
418        }
419        acc = acc > ZERO ? acc : ZERO;
420
421        if (Y) {
422            vstore_net_t(acc, kHWx + idx, Y);
423        }
424        ybuf[kx * NUM_INTERSECTIONS + idx] = acc;
425    }
426
427    barrier(CLK_LOCAL_MEM_FENCE);
428
429    const int yin = WINOGRAD_M * block_y - 1;
430    const int xin = WINOGRAD_M * block_x - 1;
431
432    if (block < P && k < K) {
433        const int CPpad = Ppad * Cpad;
434        // Cache input tile and handle zero padding
435        real xx[WINOGRAD_ALPHA][WINOGRAD_ALPHA];
436        for (int i = 0; i < WINOGRAD_ALPHA; i++) {
437            int b = yin + i;
438            for (int j = 0; j < WINOGRAD_ALPHA; j++) {
439                int a = xin + j;
440                // x is transposed here for better layout later
441                if (b >= 0 && a >= 0 && b < H && a < W) {
442                    xx[j][i] = ybuf[kg * NUM_INTERSECTIONS + b * W + a];
443                } else {
444                    xx[j][i] = ZERO;
445                }
446            }
447        }
448
449        const int offset = k * Ppad + P * batch + block;
450        __in_transform_eq(xx, V, offset, CPpad);
451    }
452}
453
454// End of the C++11 raw string literal
455)"
456