1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8*     http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ocl_post_ops.h"
18#include "gpu/ocl/ocl_types.h"
19
20#if DT_F32 != 1
21#error "Only f32 implemented."
22#endif
23
24#define DO_FMA_NN(hh, i_mod_16, i_div_16, i_mod_4, i_div_4) \
25    do { \
26        c[i_div_4].s##i_mod_4 \
27                = mad(sub_group_broadcast(a[hh].s##i_div_16, i_mod_16), \
28                        b.s##hh, c[i_div_4].s##i_mod_4); \
29    } while (0)
30
31#define DO_FMA_NT(hh, i_mod_16, i_div_16, i_mod_4, i_div_4) \
32    do { \
33        c[i_div_4].s##i_mod_4 \
34                = mad(sub_group_broadcast(a[hh].s##i_div_16, i_mod_16), b[hh], \
35                        c[i_div_4].s##i_mod_4); \
36    } while (0)
37
38#if !defined(TRANS_A)
39#if !defined(TRANS_B)
40#define NN
41#define DO_FMA DO_FMA_NN
42#else
43#define NT
44#define DO_FMA DO_FMA_NT
45#endif
46#else
47#error "No superkernel implementation."
48#endif
49
50#define FMA_I_LOOP_32_ROW(hh) \
51    do { \
52        DO_FMA(hh, 0, 0, 0, 0); \
53        DO_FMA(hh, 1, 0, 1, 0); \
54        DO_FMA(hh, 2, 0, 2, 0); \
55        DO_FMA(hh, 3, 0, 3, 0); \
56        DO_FMA(hh, 4, 0, 0, 1); \
57        DO_FMA(hh, 5, 0, 1, 1); \
58        DO_FMA(hh, 6, 0, 2, 1); \
59        DO_FMA(hh, 7, 0, 3, 1); \
60        DO_FMA(hh, 8, 0, 0, 2); \
61        DO_FMA(hh, 9, 0, 1, 2); \
62        DO_FMA(hh, 10, 0, 2, 2); \
63        DO_FMA(hh, 11, 0, 3, 2); \
64        DO_FMA(hh, 12, 0, 0, 3); \
65        DO_FMA(hh, 13, 0, 1, 3); \
66        DO_FMA(hh, 14, 0, 2, 3); \
67        DO_FMA(hh, 15, 0, 3, 3); \
68        DO_FMA(hh, 16, 1, 0, 4); \
69        DO_FMA(hh, 17, 1, 1, 4); \
70        DO_FMA(hh, 18, 1, 2, 4); \
71        DO_FMA(hh, 19, 1, 3, 4); \
72        DO_FMA(hh, 20, 1, 0, 5); \
73        DO_FMA(hh, 21, 1, 1, 5); \
74        DO_FMA(hh, 22, 1, 2, 5); \
75        DO_FMA(hh, 23, 1, 3, 5); \
76        DO_FMA(hh, 24, 1, 0, 6); \
77        DO_FMA(hh, 25, 1, 1, 6); \
78        DO_FMA(hh, 26, 1, 2, 6); \
79        DO_FMA(hh, 27, 1, 3, 6); \
80        DO_FMA(hh, 28, 1, 0, 7); \
81        DO_FMA(hh, 29, 1, 1, 7); \
82        DO_FMA(hh, 30, 1, 2, 7); \
83        DO_FMA(hh, 31, 1, 3, 7); \
84    } while (0)
85
86#define FMA_I_LOOP_16_ROW(hh) \
87    do { \
88        DO_FMA(hh, 0, 0, 0, 0); \
89        DO_FMA(hh, 1, 0, 1, 0); \
90        DO_FMA(hh, 2, 0, 2, 0); \
91        DO_FMA(hh, 3, 0, 3, 0); \
92        DO_FMA(hh, 4, 0, 0, 1); \
93        DO_FMA(hh, 5, 0, 1, 1); \
94        DO_FMA(hh, 6, 0, 2, 1); \
95        DO_FMA(hh, 7, 0, 3, 1); \
96        DO_FMA(hh, 8, 0, 0, 2); \
97        DO_FMA(hh, 9, 0, 1, 2); \
98        DO_FMA(hh, 10, 0, 2, 2); \
99        DO_FMA(hh, 11, 0, 3, 2); \
100        DO_FMA(hh, 12, 0, 0, 3); \
101        DO_FMA(hh, 13, 0, 1, 3); \
102        DO_FMA(hh, 14, 0, 2, 3); \
103        DO_FMA(hh, 15, 0, 3, 3); \
104    } while (0)
105
106#if WITH_ELTWISE == 1
107#define POST_OP(val) \
108    do { \
109        if (last_k_block) \
110            val = fwd_eltwise( \
111                    val, eltwise_alpha, eltwise_beta, eltwise_scale); \
112    } while (0)
113#else
114#define POST_OP(val)
115#endif
116
117#define UPDATE_C_ROW(i, ii, betaZero) \
118    do { \
119        if (jrem > 0) \
120            if (irem > i) { \
121                float val = alpha * c[i / 4].s##ii \
122                        + ((betaZero) ? 0 : beta * *C); \
123                POST_OP(val); \
124                *C = val; \
125            } \
126        C++; \
127    } while (0)
128
129#define UPDATE_C_32_ROW(betaZero) \
130    do { \
131        UPDATE_C_ROW(0, 0, betaZero); \
132        UPDATE_C_ROW(1, 1, betaZero); \
133        UPDATE_C_ROW(2, 2, betaZero); \
134        UPDATE_C_ROW(3, 3, betaZero); \
135        UPDATE_C_ROW(4, 0, betaZero); \
136        UPDATE_C_ROW(5, 1, betaZero); \
137        UPDATE_C_ROW(6, 2, betaZero); \
138        UPDATE_C_ROW(7, 3, betaZero); \
139        UPDATE_C_ROW(8, 0, betaZero); \
140        UPDATE_C_ROW(9, 1, betaZero); \
141        UPDATE_C_ROW(10, 2, betaZero); \
142        UPDATE_C_ROW(11, 3, betaZero); \
143        UPDATE_C_ROW(12, 0, betaZero); \
144        UPDATE_C_ROW(13, 1, betaZero); \
145        UPDATE_C_ROW(14, 2, betaZero); \
146        UPDATE_C_ROW(15, 3, betaZero); \
147        UPDATE_C_ROW(16, 0, betaZero); \
148        UPDATE_C_ROW(17, 1, betaZero); \
149        UPDATE_C_ROW(18, 2, betaZero); \
150        UPDATE_C_ROW(19, 3, betaZero); \
151        UPDATE_C_ROW(20, 0, betaZero); \
152        UPDATE_C_ROW(21, 1, betaZero); \
153        UPDATE_C_ROW(22, 2, betaZero); \
154        UPDATE_C_ROW(23, 3, betaZero); \
155        UPDATE_C_ROW(24, 0, betaZero); \
156        UPDATE_C_ROW(25, 1, betaZero); \
157        UPDATE_C_ROW(26, 2, betaZero); \
158        UPDATE_C_ROW(27, 3, betaZero); \
159        UPDATE_C_ROW(28, 0, betaZero); \
160        UPDATE_C_ROW(29, 1, betaZero); \
161        UPDATE_C_ROW(30, 2, betaZero); \
162        UPDATE_C_ROW(31, 3, betaZero); \
163    } while (0)
164
165#define SUPERKERNEL_PROLOGUE \
166    global volatile int *p = plan; \
167    int id = get_group_id(0); \
168\
169    A0 += offsetA; \
170    B0 += offsetB; \
171    C0 += offsetC; \
172\
173    while (id < threads) { \
174        uint i0, j0; \
175        uint kid0, kid1; \
176\
177        i0 = plan[2 * id + 2]; \
178        j0 = plan[2 * id + 3]; \
179        kid0 = (i0 >> 31); \
180        kid1 = (j0 >> 31); \
181        i0 &= ~(1 << 31); \
182        j0 &= ~(1 << 31); \
183        j0 += get_local_id(0);
184
185#define SUPERKERNEL_EPILOGUE \
186    if (get_sub_group_local_id() == 0) id = atomic_inc(plan); \
187\
188    sub_group_barrier(0); \
189    id = sub_group_broadcast(id, 0); \
190    } \
191    if (get_sub_group_local_id() == 0) { \
192        if (atomic_inc(plan + 1) == (get_num_groups(0) - 1)) { \
193            mem_fence(CLK_GLOBAL_MEM_FENCE); \
194            plan[0] = get_num_groups(0); \
195            plan[1] = 0; \
196        } \
197    }
198
199#ifdef NN
200__attribute__((intel_reqd_sub_group_size(16))) // attr:no-format
201kernel void
202gen9_gemm_nocopy_superkernel_f32(global int *plan, int threads,
203        global float *A0, global float *B0, global float *C0, long offsetA,
204        long offsetB, long offsetC, int lda, int ldb, int ldc, int m, int n,
205        int k, float alpha, float beta, int last_k_block, float eltwise_alpha,
206        float eltwise_beta, float eltwise_scale) {
207    SUPERKERNEL_PROLOGUE
208
209    float2 a[4]; // 32 x 4  block of A, 4x 32x1 block accesses
210    float4 b; // 4  x 16 block of B, 1x 4x16 scattered access
211    float4 c[8]; // 32 x 16 block of C, 8x 4x16 scattered access
212
213    int irem = m - i0;
214    int jrem = n - j0;
215    if (irem < 0) irem = 0;
216    if (jrem < 0) jrem = 0;
217
218    global float *A = A0 + i0;
219    global float *B = B0 + j0 * ldb;
220    global float *C = C0 + i0 + j0 * ldc;
221
222    global float *A_cols[4] = {A, A + lda, A + 2 * lda, A + 3 * lda};
223
224    int ldax4 = lda << 2;
225    int ldbx4 = ldb << 2;
226
227    if (kid0 == 0) {
228        for (int z = 0; z < 8; z++)
229            c[z] = 0.f;
230
231        for (int h = 0; h < (k >> 2); h++) {
232            // Load A
233            for (int j = 0; j < 4; j++) {
234                a[j] = as_float2(
235                        intel_sub_group_block_read2((global uint *)A_cols[j]));
236                A_cols[j] += ldax4;
237            }
238
239            // Load B
240            b = vload4(0, B);
241            B += 4;
242
243            // FMAs
244            FMA_I_LOOP_32_ROW(0);
245            FMA_I_LOOP_32_ROW(1);
246            FMA_I_LOOP_32_ROW(2);
247            FMA_I_LOOP_32_ROW(3);
248        }
249
250        int krem = k & 3;
251        if (krem > 0) {
252            for (int j = 0; j < 4; j++)
253                a[j] = as_float2(
254                        intel_sub_group_block_read2((global uint *)A_cols[j]));
255
256            b = vload4(0, B);
257
258            FMA_I_LOOP_32_ROW(0);
259            if (krem > 1) FMA_I_LOOP_32_ROW(1);
260            if (krem > 2) FMA_I_LOOP_32_ROW(2);
261        }
262    } else {
263        if (irem > 16) irem = 16;
264
265        for (int z = 0; z < 4; z++)
266            c[z] = 0.f;
267
268        for (int h = 0; h < (k >> 2); h++) {
269            for (int j = 0; j < 4; j++) {
270                a[j].s0 = as_float(
271                        intel_sub_group_block_read((global uint *)A_cols[j]));
272                A_cols[j] += ldax4;
273            }
274
275            b = vload4(0, B);
276            B += 4;
277
278            FMA_I_LOOP_16_ROW(0);
279            FMA_I_LOOP_16_ROW(1);
280            FMA_I_LOOP_16_ROW(2);
281            FMA_I_LOOP_16_ROW(3);
282        }
283
284        int krem = k & 3;
285        if (krem > 0) {
286            for (int j = 0; j < 4; j++)
287                a[j].s0 = as_float(
288                        intel_sub_group_block_read((global uint *)A_cols[j]));
289
290            b = vload4(0, B);
291
292            FMA_I_LOOP_16_ROW(0);
293            if (krem > 1) FMA_I_LOOP_16_ROW(1);
294            if (krem > 2) FMA_I_LOOP_16_ROW(2);
295        }
296    }
297
298    if (beta == 0)
299        UPDATE_C_32_ROW(1);
300    else
301        UPDATE_C_32_ROW(0);
302
303    SUPERKERNEL_EPILOGUE
304}
305#endif
306
307#ifdef NT
308__attribute__((intel_reqd_sub_group_size(16))) // attr:no-format
309kernel void
310gen9_gemm_nocopy_superkernel_f32(global int *plan, int threads,
311        global float *A0, global float *B0, global float *C0, long offsetA,
312        long offsetB, long offsetC, int lda, int ldb, int ldc, int m, int n,
313        int k, float alpha, float beta, int last_k_block, float eltwise_alpha,
314        float eltwise_beta, float eltwise_scale) {
315    SUPERKERNEL_PROLOGUE
316
317    float2 a[2]; // 32 x 2  block of A, 2x 32x1 block accesses
318    float b[2]; // 2  x 16 block of B, 2x 1x16 block accesses
319    float4 c[8]; // 32 x 16 block of C, 8x 4x16 scattered access
320
321    int irem = m - i0;
322    int jrem = n - j0;
323    if (irem < 0) irem = 0;
324    if (jrem < 0) jrem = 0;
325
326    global float *A = A0 + i0;
327    global float *B = B0 + j0;
328    global float *C = C0 + i0 + j0 * ldc;
329
330    global float *A_cols[2] = {A, A + lda};
331    global float *B_rows[2] = {B, B + ldb};
332
333    int ldax2 = lda << 1;
334    int ldbx2 = ldb << 1;
335
336    if (kid0 == 0) {
337        for (int z = 0; z < 8; z++)
338            c[z] = 0.f;
339
340        for (int h = 0; h < (k >> 1); h++) {
341            // Load A
342            for (int j = 0; j < 2; j++) {
343                a[j] = as_float2(
344                        intel_sub_group_block_read2((global uint *)A_cols[j]));
345                A_cols[j] += ldax2;
346            }
347
348            // Load B
349            for (int i = 0; i < 2; i++) {
350                b[i] = as_float(
351                        intel_sub_group_block_read((global uint *)B_rows[i]));
352                B_rows[i] += ldbx2;
353            }
354
355            // FMAs
356            FMA_I_LOOP_32_ROW(0);
357            FMA_I_LOOP_32_ROW(1);
358        }
359
360        int krem = k & 1;
361        if (krem > 0) {
362            a[0] = as_float2(
363                    intel_sub_group_block_read2((global uint *)A_cols[0]));
364
365            b[0] = as_float(
366                    intel_sub_group_block_read((global uint *)B_rows[0]));
367
368            FMA_I_LOOP_32_ROW(0);
369        }
370    } else {
371        if (irem > 16) irem = 16;
372
373        for (int z = 0; z < 4; z++)
374            c[z] = 0.f;
375
376        for (int h = 0; h < (k >> 1); h++) {
377            for (int j = 0; j < 2; j++) {
378                a[j].s0 = as_float(
379                        intel_sub_group_block_read((global uint *)A_cols[j]));
380                A_cols[j] += ldax2;
381            }
382
383            for (int i = 0; i < 2; i++) {
384                b[i] = as_float(
385                        intel_sub_group_block_read((global uint *)B_rows[i]));
386                B_rows[i] += ldbx2;
387            }
388
389            FMA_I_LOOP_16_ROW(0);
390            FMA_I_LOOP_16_ROW(1);
391        }
392
393        int krem = k & 1;
394        if (krem > 0) {
395            a[0].s0 = as_float(
396                    intel_sub_group_block_read((global uint *)A_cols[0]));
397            b[0] = as_float(
398                    intel_sub_group_block_read((global uint *)B_rows[0]));
399
400            FMA_I_LOOP_16_ROW(0);
401        }
402    }
403
404    if (beta == 0)
405        UPDATE_C_32_ROW(1);
406    else
407        UPDATE_C_32_ROW(0);
408
409    SUPERKERNEL_EPILOGUE
410}
411#endif
412