1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17 #include "blas_subgroup.h"
18 #include <stdio.h>
19 #include <clblas_stddef.h>
20
21 #include <matrix_props.h>
22 #include <matrix_dims.h>
23 #include <dis_warning.h>
24
25 #include "blas_kgen.h"
26 #include "gen_helper.h"
27 #include "tile_iter.h"
28 #include "kerngen.h"
29
30 static int
31 calcMergeStepSubgrN(
32 const BlasGenSettings* pGSet,
33 DataType dtype);
34
35 static int declareSubgrLDS(
36 struct KgenContext* pCtx,
37 const BlasGenSettings* pGSet,
38 DataType dtype);
39
40 //-----------------------------------------------------------------------------
41 // calculates best number of subgroups to be engaged in each merge step
42 // simultaneously
43 // Calculation is based on the register usage estimation
44 // in order not to limit
45 // the number of workgroups scheduled on the SIMD engine
46 static int
calcMergeStepSubgrN(const BlasGenSettings * pGSet,DataType dtype)47 calcMergeStepSubgrN(
48 const BlasGenSettings* pGSet,
49 DataType dtype)
50 {
51 // hardware-specific options
52 const int deviceLDS = 32768;
53 const unsigned int gprsPerUnit = 240;
54
55 int vecLenA = 0;
56 int vecLenB = 0;
57 int vecLenC = 0;
58
59 int vecNumA = 0;
60 int vecNumB = 0;
61 int vecNumC = 0;
62
63 int subgPerStep = 0;
64 int bestLDS = 0;
65 int gprsUsed = 0;
66 int subgNum = 0;
67
68 int itemsPerSubgroup = 0;
69
70 if( NULL == pGSet || NULL == pGSet->pgran ){
71 return -EINVAL;
72 }
73
74 itemsPerSubgroup = pGSet->subdims[0].bwidth/
75 pGSet->subdims[1].bwidth;
76
77 subgNum = (pGSet->subdims[0].x/pGSet->subdims[1].x)*
78 (pGSet->subdims[0].y/pGSet->subdims[1].y);
79
80 vecLenA = pGSet->tileA.vecLen;
81 vecLenB = pGSet->tileBX.vecLen;
82 vecLenC = pGSet->tileCY.vecLen;
83
84 vecNumA = tileVectorsNum( &pGSet->tileA );
85 vecNumB = tileVectorsNum( &pGSet->tileBX );
86 vecNumC = tileVectorsNum( &pGSet->tileCY );
87
88 // registers hold 4-vectors of 32-bit floats or 2-vectors of doubles
89 switch(dtype){
90
91 case TYPE_FLOAT:
92
93 // each register holds 4 4-byte float values
94 // 10 registers are used address, etc
95 gprsUsed = vecNumA * (vecLenA/4) +
96 vecNumB * (vecLenB/4) +
97 vecNumC * (vecLenC/4) + 10;
98
99 bestLDS = deviceLDS/(gprsPerUnit/gprsUsed);
100
101 subgPerStep = bestLDS/(itemsPerSubgroup *
102 vecNumC *
103 vecLenC * 4 );//4-byte floats
104 break;
105
106 case TYPE_DOUBLE:
107
108 // each register can hold 2 double values
109 // 10 registers are used for address, etc
110 gprsUsed = vecNumA * (vecLenA/2) +
111 vecNumB * (vecLenB/2) +
112 vecNumC * (vecLenC/2) + 10;
113
114 bestLDS = deviceLDS/(gprsPerUnit/gprsUsed);
115
116 subgPerStep = bestLDS/(itemsPerSubgroup *
117 vecNumC *
118 vecLenC * 8 );//8-byte doubles
119 break;
120
121 case TYPE_COMPLEX_FLOAT:
122
123 // each register holds 2 4-byte float-based complex values
124 // 10 registers are used address, etc
125 gprsUsed = vecNumA * (vecLenA/2) +
126 vecNumB * (vecLenB/2) +
127 vecNumC * (vecLenC/2) + 10;
128
129 bestLDS = deviceLDS/(gprsPerUnit/gprsUsed);
130
131 subgPerStep = bestLDS/(itemsPerSubgroup *
132 vecNumC *
133 vecLenC * 8 );//2x4-byte floats
134 break;
135
136 case TYPE_COMPLEX_DOUBLE:
137
138 // each register can hold 1 double-based complex value
139 // 10 registers are used for address, etc
140 gprsUsed = vecNumA * (vecLenA) +
141 vecNumB * (vecLenB) +
142 vecNumC * (vecLenC) + 10;
143
144 bestLDS = deviceLDS/(gprsPerUnit/gprsUsed);
145
146 subgPerStep = bestLDS/(itemsPerSubgroup *
147 vecNumC *
148 vecLenC * 16 );//2x8-byte double
149 break;
150
151 default:
152 break ;
153 }
154
155 if( 0==subgPerStep ){
156 subgPerStep = 1;
157 }
158
159 // do not exceed physical number of subgroups in workgroup
160 if( subgPerStep > subgNum ){
161 subgPerStep = subgNum;
162 }
163
164 return subgPerStep;
165 }
166
167 //-----------------------------------------------------------------------------
168 // Add LDS array declaration(based on C matrix parameters) to the context
169 // each row of C Matrix block may be splitted into separate vectors
170
declareSubgrLDS(struct KgenContext * pCtx,const BlasGenSettings * pGSet,DataType dtype)171 static int declareSubgrLDS(
172 struct KgenContext* pCtx,
173 const BlasGenSettings* pGSet,
174 DataType dtype)
175 {
176 int vecLenC = 0;
177 int vecNumC = 0;
178 const char* typeName;
179 const KernelVarNames *vnames = NULL;
180 char tmp[512];
181 int itemsPerSubgroup = 0;
182 int subgrPerStep = 0;
183
184 if( NULL == pCtx || NULL == pGSet ){
185 return -EINVAL;
186 }
187
188 itemsPerSubgroup = pGSet->subdims[0].bwidth / pGSet->subdims[1].bwidth;
189 subgrPerStep = calcMergeStepSubgrN(pGSet, dtype);
190
191 vecLenC = pGSet->tileCY.vecLen;
192 vecNumC = tileVectorsNum( &pGSet->tileCY );
193 typeName = dtypeBuiltinType(dtype);
194 vnames = &pGSet->varNames;
195
196 switch(dtype){
197
198 case TYPE_FLOAT:
199 case TYPE_DOUBLE:
200
201 if( vecLenC > 1){
202 sprintf(
203 tmp,
204 "__local %s%d a%s[%d*%d*%d];\n"
205 "__local %s%d *%s = a%s;\n",
206 typeName,
207 vecLenC,
208 vnames->LDS,
209 itemsPerSubgroup,
210 subgrPerStep,
211 vecNumC,
212 typeName,
213 vecLenC,
214 vnames->LDS,
215 vnames->LDS);
216 }
217 else{
218 sprintf(
219 tmp,
220 "__local %s a%s[%d*%d*%d];\n"
221 "__local %s *%s = a%s;\n",
222 typeName,
223 vnames->LDS,
224 itemsPerSubgroup,
225 subgrPerStep,
226 vecNumC,
227 typeName,
228 vnames->LDS,
229 vnames->LDS);
230 }
231
232 break;
233
234 case TYPE_COMPLEX_FLOAT:
235
236 sprintf(
237 tmp,
238 "__local float%d a%s[%d*%d*%d];\n"
239 "__local float%d *%s = a%s;\n",
240 vecLenC*2,
241 vnames->LDS,
242 itemsPerSubgroup,
243 subgrPerStep,
244 vecNumC,
245 vecLenC*2,
246 vnames->LDS,
247 vnames->LDS);
248
249 break;
250
251 case TYPE_COMPLEX_DOUBLE:
252
253 sprintf(
254 tmp,
255 "__local double%d a%s[%d*%d*%d];\n"
256 "__local double%d *%s = a%s;\n",
257 vecLenC*2,
258 vnames->LDS,
259 itemsPerSubgroup,
260 subgrPerStep,
261 vecNumC,
262 vecLenC*2,
263 vnames->LDS,
264 vnames->LDS);
265
266 break;
267
268 default: // to avoid compilation warning
269 break;
270 }
271
272 kgenAddStmt( pCtx, tmp );
273
274 return 0;
275 }
276
277 //-----------------------------------------------------------------------------
278
279 int
mergeUpdateResult(struct KgenContext * pCtx,BlasFunctionID funcID,struct BlasGenSettings * pGSet,struct SubgVarNames * pSubgVNames,UpdateResultFlags upResFlags,UpresProcPtr upresProcPtr)280 mergeUpdateResult( struct KgenContext* pCtx,
281 BlasFunctionID funcID,
282 struct BlasGenSettings* pGSet,
283 struct SubgVarNames* pSubgVNames,
284 UpdateResultFlags upResFlags,
285 UpresProcPtr upresProcPtr )
286 {
287 char tmp[2048];
288 int subgN = 0;
289 int subgItems = 0;
290 int aBlkH = 0;
291 DataType dtype;
292 Tile tileC;
293 Tile tileScratch;
294 KernelVarNames* pVNames;
295 unsigned int vecLenC;
296 unsigned int vecNumC;
297
298 int subgPerStep = 0;
299
300 if( NULL == pCtx || NULL == pGSet ){
301 return -EINVAL;
302 }
303
304 dtype = pGSet->kextra->dtype;
305 subgN = ( pGSet->subdims[0].x/pGSet->subdims[1].x ) *
306 ( pGSet->subdims[0].y/pGSet->subdims[1].y );
307
308 subgItems = pGSet->subdims[0].bwidth/
309 pGSet->subdims[1].bwidth;
310
311 aBlkH = pGSet->subdims[1].y;
312 pVNames = &pGSet->varNames;
313
314 // calculate best number of subgroups to be engaged in each merge step
315 subgPerStep = calcMergeStepSubgrN( pGSet, dtype );
316
317 vecLenC = pGSet->tileCY.vecLen;
318 vecNumC = tileVectorsNum( &pGSet->tileCY );
319
320 kgenAddStmt(pCtx,"//-----MergeUpdateResult\n");
321 kgenAddBlankLine(pCtx);
322
323 // declare local data storage array
324 kgenAddStmt( pCtx, "// veclenC scratch[SUBG_ITEMS*MSTEP_SUBG*vecNumC]\n");
325 declareSubgrLDS( pCtx,
326 pGSet,
327 dtype);
328
329 kgenAddBlankLine( pCtx );
330
331 kgenAddStmt(pCtx,
332 "//LDS block has the same vectorization as C matrix block\n");
333 kgenAddStmt(
334 pCtx,
335 "//VNUM_C*((get_local_id(1)%MSTEP_SUBG)*SUBG_ITEMS"
336 " +get_local_id(0) );\n");
337
338 sprintf(tmp,
339 "scratch += "
340 "%d*("
341 "(%s.y%%%d)*%d +"
342 "%s.x );\n",
343 vecNumC,
344 pSubgVNames->itemId,
345 subgPerStep,
346 subgItems,
347 pSubgVNames->itemId );
348 kgenAddStmt(pCtx, tmp);
349
350
351 sprintf(
352 tmp,
353 "\nfor( uint mstep = 0; mstep < %d; mstep += %d )",
354 subgN,
355 subgPerStep);
356 kgenBeginBranch(pCtx,tmp);
357 kgenAddBlankLine(pCtx);
358
359 sprintf(
360 tmp,
361 "if( (%s.y >= mstep)&&(%s.y < (mstep+%d)) )",
362 pSubgVNames->itemId,
363 pSubgVNames->itemId,
364 subgPerStep);
365 kgenBeginBranch(pCtx,tmp);
366
367 // the LDS block size is similar to C matrix block size
368 kgenAddBlankLine(pCtx);
369 initTile(&tileC,
370 "c",
371 (unsigned int)pGSet->subdims[1].y,
372 (unsigned int)pGSet->subdims[1].x,
373 vecLenC,
374 dtype,
375 pGSet->tileCY.storType,
376 pGSet->tileCY.trans,
377 pGSet->tileCY.packed);
378
379 initTile(&tileScratch,
380 "scratch",
381 (unsigned int)pGSet->subdims[1].y,
382 (unsigned int)pGSet->subdims[1].x,
383 vecLenC,
384 dtype,
385 PRIV_STORAGE_ARRAY,
386 pGSet->tileCY.trans,
387 pGSet->tileCY.packed);
388
389 genTileCopy(pCtx,
390 &tileScratch,
391 &tileC,
392 TILECOPY_ASSIGN);
393
394 genZeroTile(pCtx,
395 &tileC);
396
397 // split merge if
398 kgenEndBranch( pCtx, NULL ); // merge step if
399 kgenAddBlankLine( pCtx );
400
401 //splitting if on two, to prevent barrier issue
402 kgenAddBarrier( pCtx, CLK_LOCAL_MEM_FENCE );
403 kgenAddBlankLine( pCtx );
404 //----------------------------------------------
405
406 sprintf( tmp,
407 "if( (%s.y >= mstep)&&(%s.y < (mstep+%d)) )",
408 pSubgVNames->itemId,
409 pSubgVNames->itemId,
410 subgPerStep);
411 kgenBeginBranch(pCtx,tmp);
412
413 sprintf( tmp,
414 "if ( 0 == %s.x )",
415 pSubgVNames->itemId );
416 kgenBeginBranch( pCtx, tmp );
417
418 kgenAddBlankLine(pCtx);
419
420 // Zero element of each subgroup also performs LDS merge
421 sprintf(
422 tmp,
423 "for(uint k = 0; k < %d * %d; k += %d)",
424 subgItems,
425 aBlkH,
426 aBlkH);
427
428 kgenBeginBranch(pCtx, tmp);
429 kgenAddBlankLine(pCtx);
430
431 genTileCopy(pCtx,
432 &tileC,
433 &tileScratch,
434 TILECOPY_ADD_ASSIGN );
435 kgenAddStmt(pCtx,
436 "//Adding the LDS block size in vectors\n");
437 sprintf(tmp,
438 "%s += %d;",
439 pVNames->LDS,
440 vecNumC);
441 kgenAddStmt(pCtx, tmp);
442 kgenAddBlankLine(pCtx);
443
444 kgenEndBranch( pCtx, NULL ); // merge for()
445 kgenAddBlankLine( pCtx );
446
447 // Write into global memory -------------------------------
448 if ( NULL != upresProcPtr ) {
449
450 (*upresProcPtr)( pCtx,
451 funcID,
452 pGSet,
453 upResFlags /*| UPRES_INDEXING_WITH_CONSTANTS*/,
454 NULL,
455 NULL,
456 NULL );
457 }
458
459 kgenAddBlankLine(pCtx);
460
461 kgenEndBranch(pCtx, NULL); // merge and global write if
462 kgenEndBranch(pCtx, NULL); // LDS write if
463
464 kgenAddBarrier(pCtx, CLK_LOCAL_MEM_FENCE);
465 //LDS write for
466 kgenEndBranch(pCtx, NULL);
467
468
469 return 0;
470 }
471
472 //-----------------------------------------------------------------------------
473
474 int
subgGetDefaultDecomp(PGranularity * pgran,SubproblemDim * subdims,void * pArgs)475 subgGetDefaultDecomp(
476 PGranularity *pgran,
477 SubproblemDim *subdims,
478 void* pArgs )
479 {
480 int itemsPerSubg = 8;
481 int subgA = 4;
482 int subgB = 2;
483
484 int bw1 = 8;
485 int x1 = 4;
486 int y1 = 4;
487 CLBlasKargs *kargs;
488
489 if ( NULL == pArgs ) {
490 return -EINVAL;
491 }
492
493 kargs = (CLBlasKargs *)pArgs;
494
495 if( isComplexType(kargs->dtype) ){
496 bw1 /= 2;
497 }
498 if( isDoubleBasedType(kargs->dtype) ){
499 bw1 /= 2;
500 }
501
502 subdims[1].bwidth = bw1;
503 subdims[1].x = subdims[1].itemX = x1;
504 subdims[1].y = subdims[1].itemY = y1;
505
506 subdims[0].bwidth = bw1 * itemsPerSubg;
507 subdims[0].itemX = x1 * subgB;
508 subdims[0].x = x1*subgB;
509
510 subdims[0].itemY = y1*subgA;
511 subdims[0].y = y1*subgA;
512
513 switch ( pgran->wgDim ) {
514
515 case 1:
516 pgran->wgSize[0] = 64;
517 pgran->wgSize[1] = 1;
518 break;
519
520 case 2:
521 pgran->wgSize[0] = itemsPerSubg;
522 pgran->wgSize[1] = 64/itemsPerSubg;
523 break;
524
525 default:
526 pgran->wgSize[0] = 64;
527 pgran->wgSize[1] = 1;
528 break;
529 }
530
531 return 0;
532 }
533