1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 #include "blas_subgroup.h"
18 #include <stdio.h>
19 #include <clblas_stddef.h>
20 
21 #include <matrix_props.h>
22 #include <matrix_dims.h>
23 #include <dis_warning.h>
24 
25 #include "blas_kgen.h"
26 #include "gen_helper.h"
27 #include "tile_iter.h"
28 #include "kerngen.h"
29 
30 static int
31 calcMergeStepSubgrN(
32     const BlasGenSettings* pGSet,
33     DataType dtype);
34 
35 static int declareSubgrLDS(
36     struct KgenContext* pCtx,
37     const BlasGenSettings* pGSet,
38     DataType dtype);
39 
40 //-----------------------------------------------------------------------------
41 // calculates best number of subgroups to be engaged in each merge step
42 // simultaneously
43 // Calculation is based on the register usage estimation
44 // in order not to limit
45 // the number of workgroups scheduled on the SIMD engine
46 static int
calcMergeStepSubgrN(const BlasGenSettings * pGSet,DataType dtype)47 calcMergeStepSubgrN(
48     const BlasGenSettings* pGSet,
49     DataType dtype)
50 {
51     // hardware-specific options
52     const int deviceLDS = 32768;
53     const unsigned int gprsPerUnit = 240;
54 
55     int vecLenA = 0;
56     int vecLenB = 0;
57     int vecLenC = 0;
58 
59     int vecNumA = 0;
60     int vecNumB = 0;
61     int vecNumC = 0;
62 
63     int subgPerStep = 0;
64     int bestLDS = 0;
65     int gprsUsed = 0;
66     int subgNum = 0;
67 
68     int itemsPerSubgroup = 0;
69 
70     if( NULL == pGSet || NULL == pGSet->pgran ){
71         return -EINVAL;
72     }
73 
74     itemsPerSubgroup = pGSet->subdims[0].bwidth/
75         pGSet->subdims[1].bwidth;
76 
77     subgNum = (pGSet->subdims[0].x/pGSet->subdims[1].x)*
78         (pGSet->subdims[0].y/pGSet->subdims[1].y);
79 
80     vecLenA = pGSet->tileA.vecLen;
81     vecLenB = pGSet->tileBX.vecLen;
82     vecLenC = pGSet->tileCY.vecLen;
83 
84     vecNumA = tileVectorsNum( &pGSet->tileA );
85     vecNumB = tileVectorsNum( &pGSet->tileBX );
86     vecNumC = tileVectorsNum( &pGSet->tileCY );
87 
88     // registers hold 4-vectors of 32-bit floats or 2-vectors of doubles
89     switch(dtype){
90 
91         case TYPE_FLOAT:
92 
93             // each register holds 4 4-byte float values
94             // 10 registers are used address, etc
95             gprsUsed =  vecNumA * (vecLenA/4) +
96                         vecNumB * (vecLenB/4) +
97                         vecNumC * (vecLenC/4) + 10;
98 
99             bestLDS = deviceLDS/(gprsPerUnit/gprsUsed);
100 
101             subgPerStep = bestLDS/(itemsPerSubgroup *
102                                    vecNumC *
103                                    vecLenC * 4 );//4-byte floats
104             break;
105 
106         case TYPE_DOUBLE:
107 
108             // each register can hold 2 double values
109             // 10 registers are used for address, etc
110             gprsUsed =  vecNumA * (vecLenA/2) +
111                         vecNumB * (vecLenB/2) +
112                         vecNumC * (vecLenC/2) + 10;
113 
114             bestLDS = deviceLDS/(gprsPerUnit/gprsUsed);
115 
116             subgPerStep = bestLDS/(itemsPerSubgroup *
117                                    vecNumC *
118                                    vecLenC * 8 );//8-byte doubles
119             break;
120 
121         case TYPE_COMPLEX_FLOAT:
122 
123             // each register holds 2 4-byte float-based complex values
124             // 10 registers are used address, etc
125             gprsUsed =  vecNumA * (vecLenA/2) +
126                         vecNumB * (vecLenB/2) +
127                         vecNumC * (vecLenC/2) + 10;
128 
129             bestLDS = deviceLDS/(gprsPerUnit/gprsUsed);
130 
131             subgPerStep = bestLDS/(itemsPerSubgroup *
132                                    vecNumC *
133                                    vecLenC * 8 );//2x4-byte floats
134             break;
135 
136         case TYPE_COMPLEX_DOUBLE:
137 
138             // each register can hold 1 double-based complex value
139             // 10 registers are used for address, etc
140             gprsUsed =  vecNumA * (vecLenA) +
141                         vecNumB * (vecLenB) +
142                         vecNumC * (vecLenC) + 10;
143 
144             bestLDS = deviceLDS/(gprsPerUnit/gprsUsed);
145 
146             subgPerStep = bestLDS/(itemsPerSubgroup *
147                                    vecNumC *
148                                    vecLenC * 16 );//2x8-byte double
149             break;
150 
151         default:
152           break ;
153     }
154 
155     if( 0==subgPerStep ){
156         subgPerStep = 1;
157     }
158 
159     // do not exceed physical number of subgroups in workgroup
160     if( subgPerStep > subgNum ){
161         subgPerStep = subgNum;
162     }
163 
164     return subgPerStep;
165 }
166 
167 //-----------------------------------------------------------------------------
168 // Add LDS array declaration(based on C matrix parameters) to the context
169 // each row of C Matrix block may be splitted into separate vectors
170 
declareSubgrLDS(struct KgenContext * pCtx,const BlasGenSettings * pGSet,DataType dtype)171 static int declareSubgrLDS(
172     struct KgenContext* pCtx,
173     const BlasGenSettings* pGSet,
174     DataType dtype)
175 {
176     int vecLenC = 0;
177     int vecNumC = 0;
178     const char* typeName;
179     const KernelVarNames *vnames = NULL;
180     char tmp[512];
181     int itemsPerSubgroup = 0;
182     int subgrPerStep = 0;
183 
184     if( NULL == pCtx || NULL == pGSet ){
185         return -EINVAL;
186     }
187 
188     itemsPerSubgroup = pGSet->subdims[0].bwidth / pGSet->subdims[1].bwidth;
189     subgrPerStep = calcMergeStepSubgrN(pGSet, dtype);
190 
191     vecLenC = pGSet->tileCY.vecLen;
192     vecNumC = tileVectorsNum( &pGSet->tileCY );
193     typeName = dtypeBuiltinType(dtype);
194     vnames = &pGSet->varNames;
195 
196     switch(dtype){
197 
198         case TYPE_FLOAT:
199         case TYPE_DOUBLE:
200 
201             if( vecLenC > 1){
202                 sprintf(
203                     tmp,
204                     "__local %s%d a%s[%d*%d*%d];\n"
205                     "__local %s%d *%s = a%s;\n",
206                     typeName,
207                     vecLenC,
208                     vnames->LDS,
209                     itemsPerSubgroup,
210                     subgrPerStep,
211                     vecNumC,
212                     typeName,
213                     vecLenC,
214                     vnames->LDS,
215                     vnames->LDS);
216             }
217             else{
218                 sprintf(
219                     tmp,
220                     "__local %s a%s[%d*%d*%d];\n"
221                     "__local %s *%s = a%s;\n",
222                     typeName,
223                     vnames->LDS,
224                     itemsPerSubgroup,
225                     subgrPerStep,
226                     vecNumC,
227                     typeName,
228                     vnames->LDS,
229                     vnames->LDS);
230             }
231 
232             break;
233 
234         case TYPE_COMPLEX_FLOAT:
235 
236             sprintf(
237                 tmp,
238                 "__local float%d a%s[%d*%d*%d];\n"
239                 "__local float%d *%s = a%s;\n",
240                 vecLenC*2,
241                 vnames->LDS,
242                 itemsPerSubgroup,
243                 subgrPerStep,
244                 vecNumC,
245                 vecLenC*2,
246                 vnames->LDS,
247                 vnames->LDS);
248 
249             break;
250 
251         case TYPE_COMPLEX_DOUBLE:
252 
253              sprintf(
254                 tmp,
255                 "__local double%d a%s[%d*%d*%d];\n"
256                 "__local double%d *%s = a%s;\n",
257                 vecLenC*2,
258                 vnames->LDS,
259                 itemsPerSubgroup,
260                 subgrPerStep,
261                 vecNumC,
262                 vecLenC*2,
263                 vnames->LDS,
264                 vnames->LDS);
265 
266             break;
267 
268         default:  // to avoid compilation warning
269             break;
270     }
271 
272     kgenAddStmt( pCtx, tmp );
273 
274     return 0;
275 }
276 
277 //-----------------------------------------------------------------------------
278 
279 int
mergeUpdateResult(struct KgenContext * pCtx,BlasFunctionID funcID,struct BlasGenSettings * pGSet,struct SubgVarNames * pSubgVNames,UpdateResultFlags upResFlags,UpresProcPtr upresProcPtr)280 mergeUpdateResult( struct KgenContext* pCtx,
281     BlasFunctionID funcID,
282     struct BlasGenSettings* pGSet,
283     struct SubgVarNames* pSubgVNames,
284     UpdateResultFlags upResFlags,
285     UpresProcPtr upresProcPtr )
286 {
287     char tmp[2048];
288     int subgN = 0;
289     int subgItems = 0;
290     int aBlkH = 0;
291     DataType dtype;
292     Tile tileC;
293     Tile tileScratch;
294     KernelVarNames* pVNames;
295     unsigned int vecLenC;
296     unsigned int vecNumC;
297 
298     int subgPerStep = 0;
299 
300     if( NULL == pCtx || NULL == pGSet ){
301         return -EINVAL;
302     }
303 
304     dtype = pGSet->kextra->dtype;
305     subgN = ( pGSet->subdims[0].x/pGSet->subdims[1].x ) *
306         ( pGSet->subdims[0].y/pGSet->subdims[1].y );
307 
308     subgItems = pGSet->subdims[0].bwidth/
309         pGSet->subdims[1].bwidth;
310 
311     aBlkH = pGSet->subdims[1].y;
312     pVNames = &pGSet->varNames;
313 
314     // calculate best number of subgroups to be engaged in each merge step
315     subgPerStep = calcMergeStepSubgrN( pGSet, dtype );
316 
317     vecLenC = pGSet->tileCY.vecLen;
318     vecNumC = tileVectorsNum( &pGSet->tileCY );
319 
320     kgenAddStmt(pCtx,"//-----MergeUpdateResult\n");
321     kgenAddBlankLine(pCtx);
322 
323     // declare local data storage array
324     kgenAddStmt( pCtx, "// veclenC scratch[SUBG_ITEMS*MSTEP_SUBG*vecNumC]\n");
325     declareSubgrLDS( pCtx,
326         pGSet,
327         dtype);
328 
329     kgenAddBlankLine( pCtx );
330 
331     kgenAddStmt(pCtx,
332                 "//LDS block has the same vectorization as C matrix block\n");
333     kgenAddStmt(
334         pCtx,
335         "//VNUM_C*((get_local_id(1)%MSTEP_SUBG)*SUBG_ITEMS"
336         " +get_local_id(0) );\n");
337 
338     sprintf(tmp,
339         "scratch += "
340             "%d*("
341                 "(%s.y%%%d)*%d +"
342                 "%s.x );\n",
343             vecNumC,
344             pSubgVNames->itemId,
345             subgPerStep,
346             subgItems,
347             pSubgVNames->itemId );
348     kgenAddStmt(pCtx, tmp);
349 
350 
351     sprintf(
352         tmp,
353         "\nfor( uint mstep = 0; mstep < %d; mstep += %d )",
354         subgN,
355         subgPerStep);
356     kgenBeginBranch(pCtx,tmp);
357     kgenAddBlankLine(pCtx);
358 
359     sprintf(
360         tmp,
361         "if( (%s.y >= mstep)&&(%s.y < (mstep+%d)) )",
362         pSubgVNames->itemId,
363         pSubgVNames->itemId,
364         subgPerStep);
365     kgenBeginBranch(pCtx,tmp);
366 
367     // the LDS block size is similar to C matrix block size
368     kgenAddBlankLine(pCtx);
369     initTile(&tileC,
370             "c",
371             (unsigned int)pGSet->subdims[1].y,
372             (unsigned int)pGSet->subdims[1].x,
373             vecLenC,
374             dtype,
375             pGSet->tileCY.storType,
376             pGSet->tileCY.trans,
377             pGSet->tileCY.packed);
378 
379     initTile(&tileScratch,
380             "scratch",
381             (unsigned int)pGSet->subdims[1].y,
382             (unsigned int)pGSet->subdims[1].x,
383             vecLenC,
384             dtype,
385             PRIV_STORAGE_ARRAY,
386             pGSet->tileCY.trans,
387             pGSet->tileCY.packed);
388 
389     genTileCopy(pCtx,
390                 &tileScratch,
391                 &tileC,
392                 TILECOPY_ASSIGN);
393 
394     genZeroTile(pCtx,
395                 &tileC);
396 
397     // split merge if
398     kgenEndBranch( pCtx, NULL ); // merge step if
399     kgenAddBlankLine( pCtx );
400 
401     //splitting if on two, to prevent barrier issue
402     kgenAddBarrier( pCtx, CLK_LOCAL_MEM_FENCE );
403     kgenAddBlankLine( pCtx );
404     //----------------------------------------------
405 
406     sprintf( tmp,
407         "if( (%s.y >= mstep)&&(%s.y < (mstep+%d)) )",
408         pSubgVNames->itemId,
409         pSubgVNames->itemId,
410         subgPerStep);
411     kgenBeginBranch(pCtx,tmp);
412 
413     sprintf( tmp,
414         "if ( 0 == %s.x )",
415         pSubgVNames->itemId );
416     kgenBeginBranch( pCtx, tmp );
417 
418     kgenAddBlankLine(pCtx);
419 
420     // Zero element of each subgroup also performs LDS merge
421     sprintf(
422         tmp,
423         "for(uint k = 0; k < %d * %d; k += %d)",
424         subgItems,
425         aBlkH,
426         aBlkH);
427 
428     kgenBeginBranch(pCtx, tmp);
429     kgenAddBlankLine(pCtx);
430 
431     genTileCopy(pCtx,
432                 &tileC,
433                 &tileScratch,
434                 TILECOPY_ADD_ASSIGN );
435     kgenAddStmt(pCtx,
436                 "//Adding the LDS block size in vectors\n");
437     sprintf(tmp,
438             "%s += %d;",
439             pVNames->LDS,
440             vecNumC);
441     kgenAddStmt(pCtx, tmp);
442     kgenAddBlankLine(pCtx);
443 
444     kgenEndBranch( pCtx, NULL ); // merge for()
445     kgenAddBlankLine( pCtx );
446 
447     // Write into global memory -------------------------------
448     if ( NULL != upresProcPtr ) {
449 
450         (*upresProcPtr)( pCtx,
451             funcID,
452             pGSet,
453             upResFlags /*| UPRES_INDEXING_WITH_CONSTANTS*/,
454             NULL,
455             NULL,
456             NULL );
457     }
458 
459     kgenAddBlankLine(pCtx);
460 
461     kgenEndBranch(pCtx, NULL); // merge and global write if
462     kgenEndBranch(pCtx, NULL); // LDS write if
463 
464     kgenAddBarrier(pCtx, CLK_LOCAL_MEM_FENCE);
465     //LDS write for
466     kgenEndBranch(pCtx, NULL);
467 
468 
469     return 0;
470 }
471 
472 //-----------------------------------------------------------------------------
473 
474 int
subgGetDefaultDecomp(PGranularity * pgran,SubproblemDim * subdims,void * pArgs)475 subgGetDefaultDecomp(
476     PGranularity *pgran,
477     SubproblemDim *subdims,
478     void* pArgs )
479 {
480     int itemsPerSubg = 8;
481     int subgA = 4;
482     int subgB = 2;
483 
484     int bw1 = 8;
485     int x1 = 4;
486     int y1 = 4;
487     CLBlasKargs *kargs;
488 
489     if ( NULL == pArgs ) {
490         return -EINVAL;
491     }
492 
493     kargs = (CLBlasKargs *)pArgs;
494 
495     if( isComplexType(kargs->dtype) ){
496         bw1 /= 2;
497     }
498     if( isDoubleBasedType(kargs->dtype) ){
499         bw1 /= 2;
500     }
501 
502     subdims[1].bwidth = bw1;
503     subdims[1].x = subdims[1].itemX = x1;
504     subdims[1].y = subdims[1].itemY = y1;
505 
506     subdims[0].bwidth = bw1 * itemsPerSubg;
507     subdims[0].itemX = x1 * subgB;
508     subdims[0].x = x1*subgB;
509 
510     subdims[0].itemY = y1*subgA;
511     subdims[0].y = y1*subgA;
512 
513     switch ( pgran->wgDim ) {
514 
515         case 1:
516             pgran->wgSize[0] = 64;
517             pgran->wgSize[1] = 1;
518             break;
519 
520         case 2:
521             pgran->wgSize[0] = itemsPerSubg;
522             pgran->wgSize[1] = 64/itemsPerSubg;
523             break;
524 
525         default:
526             pgran->wgSize[0] = 64;
527             pgran->wgSize[1] = 1;
528             break;
529     }
530 
531     return 0;
532 }
533