1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 /*
19  * common stuff for blas related
20  * kernel generators
21  */
22 
23 #include <string.h>
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <assert.h>
27 
28 #include <list.h>
29 #include <clblas_stddef.h>
30 
31 #include <matrix_props.h>
32 #include <matrix_dims.h>
33 #include <dis_warning.h>
34 
35 #include "blas_kgen.h"
36 #include "gen_helper.h"
37 #include "tile_iter.h"
38 #include "kerngen.h"
39 
40 #define IDX_INVAL ((unsigned int)-1)
41 
42 enum {
43     COORD_STRLEN = 64
44 };
45 
46 static unsigned int
getTmpVecLen(const BlasGenSettings * gset,UpdateResultFlags uflags,const char ** vecName)47 getTmpVecLen(
48     const BlasGenSettings *gset,
49     UpdateResultFlags uflags,
50     const char **vecName)
51 {
52     const CLBLASKernExtra *kextra = gset->kextra;
53     unsigned int vecLen;
54 
55     if (isComplexType(kextra->dtype) || (uflags & (UPRES_GENERIC |
56                                          UPRES_NO_VECTORIZATION))) {
57         vecLen = 1;
58     }
59     else {
60         vecLen = (gset->flags & BGF_DISTINCT_VECLEN) ? kextra->vecLenC :
61                                                        kextra->vecLen;
62         getVectorTypeName(kextra->dtype, vecLen, vecName, NULL);
63     }
64 
65     return vecLen;
66 }
67 
68 /*
69  * Try to transform kernel string to integer.
70  * Return -1. If this is not a number.
71  */
72 static int
stringToInt(const char * str,unsigned int * num)73 stringToInt(const char *str, unsigned int *num)
74 {
75     char *end;
76     unsigned int n;
77     int ret = -1;
78 
79     n = (unsigned int)strtol(str, &end, 10);
80     // believe it is a number if the string has been parsed completely
81     if ((end != str) && (*end == '\0')) {
82         *num = n;
83         ret = 0;
84     }
85 
86     return ret;
87 }
88 
89 void
sprintfVecChunk(char * chunk,unsigned int vecLen,unsigned int clen,unsigned int vecOff)90 sprintfVecChunk(
91     char *chunk,
92     unsigned int vecLen,
93     unsigned int clen,
94     unsigned int vecOff)
95 {
96     const char *vect = "0123456789abcdef";
97 
98     if (clen == vecLen) {
99         chunk[0] = '\0';
100     }
101     else {
102         snprintf(chunk, clen + 3, ".s%s", vect + vecOff);
103         chunk[clen + 2] = '\0';
104     }
105 }
106 
107 unsigned int
getVecLen(const BlasGenSettings * gset,BlasFunctionID funcID,MatrixRole mrole)108 getVecLen(const BlasGenSettings *gset, BlasFunctionID funcID, MatrixRole mrole)
109 {
110     unsigned int vecLen = 0;
111     const CLBLASKernExtra *kextra = gset->kextra;
112 
113     DUMMY_ARG_USAGE(funcID);
114 
115     if (!(gset->flags & BGF_DISTINCT_VECLEN)) {
116         vecLen = umin(kextra->vecLenA, kextra->vecLenB);
117         vecLen = umin(vecLen, kextra->vecLenC);
118     }
119     else {
120         switch (mrole) {
121         case MATRIX_A:
122             vecLen = kextra->vecLenA;
123             break;
124         case MATRIX_B:
125             vecLen = kextra->vecLenB;
126             break;
127         case MATRIX_C:
128             vecLen = kextra->vecLenC;
129             break;
130         default:
131             break;
132         }
133     }
134 
135     return vecLen;
136 }
137 
138 void
genScaleLeadingDimensions(struct KgenContext * ctx,const BlasGenSettings * gset)139 genScaleLeadingDimensions(struct KgenContext *ctx, const BlasGenSettings *gset)
140 {
141     const KernelVarNames *kvars;
142     unsigned int vecLen;
143     bool done = false;
144 
145     if (!(gset->flags & BGF_LD_IN_VECTORS)) {
146         return;
147     }
148 
149     kvars = &gset->varNames;
150 
151     vecLen = getVecLen(gset, CLBLAS_GEMM, MATRIX_A);
152     if ((kvars->lda != NULL) && (vecLen > 1)) {
153         kgenPrintf(ctx, "%s /= %u;\n", kvars->lda, vecLen);
154         done = true;
155     }
156 
157     vecLen = getVecLen(gset, CLBLAS_GEMM, MATRIX_B);
158     if ((kvars->ldb != NULL) && (vecLen > 1) && (kvars->ldb != kvars->lda)) {
159         kgenPrintf(ctx, "%s /= %u;\n", kvars->ldb, vecLen);
160         done = true;
161     }
162 
163     vecLen = getVecLen(gset, CLBLAS_GEMM, MATRIX_C);
164     if ((kvars->ldc != NULL) && (vecLen > 1) &&
165         (kvars->ldc != kvars->lda) && (kvars->ldc != kvars->ldb)) {
166 
167         kgenPrintf(ctx, "%s /= %u;\n", kvars->ldc, vecLen);
168         done = true;
169     }
170 
171     if (done) {
172         kgenAddBlankLine(ctx);
173     }
174 }
175 
176 void
getPrivateAreaInfo(const BlasGenSettings * gset,BlasFunctionID funcID,MatrixRole mrole,PrivateArea * area)177 getPrivateAreaInfo(
178     const BlasGenSettings *gset,
179     BlasFunctionID funcID,
180     MatrixRole mrole,
181     PrivateArea *area)
182 {
183     const CLBLASKernExtra *kextra = gset->kextra;
184     const SubproblemDim *dim = &gset->subdims[1];
185 
186     area->vecLen = getVecLen(gset, funcID, mrole);
187     getVectorTypeName(kextra->dtype, area->vecLen, &area->typeName, NULL);
188     if (mrole == MATRIX_C) {
189         area->size = (unsigned int)(divRoundUp(dim->x, area->vecLen) * dim->y);
190     }
191     else {
192         size_t h = (mrole == MATRIX_A) ? dim->y : dim->x;
193 
194         area->size = (unsigned int)(h * dim->bwidth / area->vecLen);
195     }
196 }
197 
198 void
declarePrivateArea(struct KgenContext * ctx,const PrivateArea * area,const char * baseName,PrivateStorageType storType)199 declarePrivateArea(
200     struct KgenContext *ctx,
201     const PrivateArea *area,
202     const char *baseName,
203     PrivateStorageType storType)
204 {
205     char tmp[1024];
206     unsigned int i;
207 
208     // TODO: separate case for size equal to 1
209     if (storType == PRIV_STORAGE_ARRAY) {
210         sprintf(tmp, "%s %s[%u];\n", area->typeName, baseName, area->size);
211     }
212     else {
213         char *p;
214 
215         sprintf(tmp, "%s %s0", area->typeName, baseName);
216         p = tmp + strlen(tmp);
217         for (i = 1; i < area->size; i++) {
218             sprintf(p, ", %s%u", baseName, i);
219             p += strlen(p);
220         }
221         strcpy(p, ";\n");
222     }
223 
224     kgenAddStmt(ctx, tmp);
225 }
226 
227 int
defaultTilePostFetch(struct KgenContext * ctx,MatrixRole mrole,void * priv)228 defaultTilePostFetch(
229     struct KgenContext *ctx,
230     MatrixRole mrole,
231     void *priv)
232 {
233     char tmp[1024], cond[128];
234     Kstring src;
235     TilePostFetchPrivate *pfPriv = (TilePostFetchPrivate*)priv;
236     bool distVect = (pfPriv->gset->flags & BGF_DISTINCT_VECLEN);
237     const KernelVarNames *vnames = &pfPriv->gset->varNames;
238     const CLBLASKernExtra *kextra = pfPriv->gset->kextra;
239     const SubproblemDim *dim = &pfPriv->gset->subdims[1];
240     BlasFunctionID funcID = pfPriv->funcID;
241     const Tile* tile;
242     bool partA;
243     unsigned int step;
244     unsigned int i, j;
245     int ret = 0;
246     unsigned int maxJ = 0;
247     unsigned int maxI = 0;
248 
249     if (!isNeedZeroTileTail(funcID, dim, kextra, mrole, distVect)) {
250         return 0;
251     }
252 
253     if (mrole == MATRIX_A) {
254         tile = &pfPriv->gset->tileA;
255         maxJ = tile->nrCols;
256         maxI = tile->nrRows;
257     }
258     else {
259         tile = &pfPriv->gset->tileBX;
260         maxJ = tile->nrRows;
261         maxI = tile->nrCols;
262     }
263 
264     partA = (mrole == MATRIX_A) && tile->trans &&
265             !(pfPriv->gset->flags & BGF_WHOLE_A);
266     step = tileLineSegmentLen(tile);
267     step = (tile->trans ^ (mrole == MATRIX_A)) ? 1 : step;
268 
269     for (j = 0; (j < maxJ) && !ret; j++) {
270         unsigned int k;
271 
272         k = umax(j, (unsigned int)pfPriv->fetchNumA);
273         if (k) {
274             sprintf(tmp, " + %u", k);
275         }
276         else {
277             tmp[0] = '\0';
278         }
279         sprintf(cond, "(%s%s < %s)", vnames->k, tmp, vnames->sizeK);
280 
281         for (i = 0; (i < maxI) && !ret; i += step) {
282             if (mrole != MATRIX_A) {
283                 sprintfTileElement(&src, tile, j, i, step);
284             }
285             else {
286                 sprintfTileElement(&src, tile, i, j, step);
287             }
288             sprintf(tmp, "%s = %s ? %s : 0;\n", src.buf, cond, src.buf);
289             ret = kgenAddStmt(ctx, tmp);
290         }
291     }
292 
293     if (partA) {
294         pfPriv->fetchNumA++;
295     }
296 
297     if ((tile->nrCols * tile->nrRows / tile->vecLen > 1) && !ret) {
298         ret = kgenAddBlankLine(ctx);
299     }
300 
301     return ret;
302 }
303 
304 char
dtypeToBlasPrefix(DataType dtype)305 dtypeToBlasPrefix(DataType dtype)
306 {
307     char c;
308 
309     if (dtype == TYPE_FLOAT) {
310         c = 's';
311     }
312     else {
313         c = dtypeToPrefix(dtype);
314     }
315 
316     return c;
317 }
318 
319 TileMulFlags
kextraToTilemulFlags(BlasFunctionID funcID,KernelExtraFlags kflags)320 kextraToTilemulFlags(BlasFunctionID funcID, KernelExtraFlags kflags)
321 {
322     TileMulFlags mf = TILEMUL_NO_FLAGS;
323 
324     if (isMatrixAccessColMaj(funcID, kflags, MATRIX_A)) {
325         mf |= TILEMUL_TRA;
326     }
327     if (isMatrixConj(kflags, MATRIX_A)) {
328         mf |= TILEMUL_CONJA;
329     }
330     if (!isMatrixAccessColMaj(funcID, kflags, MATRIX_B)) {
331         mf |= TILEMUL_TRB;
332     }
333     if (isMatrixConj(kflags, MATRIX_B)) {
334         mf |= TILEMUL_CONJB;
335     }
336 
337     return mf;
338 }
339 
340 void
getResultGPRsInfo(DataType dtype,const SubproblemDim * dims,unsigned int vecLen,unsigned int * nrRegs,const char ** typeName)341 getResultGPRsInfo(
342     DataType dtype,
343     const SubproblemDim *dims,
344     unsigned int vecLen,
345     unsigned int *nrRegs,
346     const char **typeName)
347 {
348     if (isComplexType(dtype)) {
349         if (nrRegs) {
350             *nrRegs = (unsigned int)(dims->x * dims->y);
351         }
352         if (typeName != NULL) {
353             *typeName = dtypeBuiltinType(dtype);
354         }
355     }
356     else {
357         // handle different vecLen values and fetch vector sizes
358         if (nrRegs) {
359             *nrRegs = (unsigned int)(divRoundUp(dims->x, vecLen) * dims->y);
360         }
361         if (typeName != NULL) {
362             getVectorTypeName(dtype, vecLen, typeName, NULL);
363         }
364     }
365 }
366 
genVectorCPtr(struct KgenContext * pCtx,const BlasGenSettings * pGSet,const char * GPtrName,const char * VCPtrName)367 static void genVectorCPtr( struct KgenContext *pCtx,
368     const BlasGenSettings *pGSet,
369     const char* GPtrName,
370     const char* VCPtrName )
371 {
372     const char *typeName;
373     unsigned int vecLen = 0;
374 
375     vecLen = getVecLen( pGSet, 0, MATRIX_C );
376     vecLen = vecLen > pGSet->tileCY.vecLen ?
377         pGSet->tileCY.vecLen :
378         vecLen;
379 
380     getVectorTypeName( pGSet->kextra->dtype,
381         vecLen,
382         &typeName,
383         NULL );
384 
385     if ( 0 == (pGSet->flags & BGF_LD_IN_VECTORS) ) {
386 
387         vecLen = 1;
388     }
389     // Blas function ID is omitted
390     if ( isComplexType( pGSet->kextra->dtype ) ) {
391         vecLen *= 2;
392     }
393 
394     if ( isDoubleBasedType(pGSet->kextra->dtype) ) {
395 
396         if ( 1 == vecLen ) {
397 
398             kgenPrintf(
399                 pCtx,
400                 "__global %s *%s = %s.d;\n",
401                 typeName,
402                 VCPtrName,
403                 GPtrName);
404         }
405         else {
406 
407             kgenPrintf( pCtx,
408                 "__global %s *%s = %s.d%dv;\n",
409                 typeName,
410                 VCPtrName,
411                 GPtrName,
412                 vecLen);
413         }
414     }
415     else {
416 
417         if ( 1 == vecLen ) {
418 
419             kgenPrintf(
420                 pCtx,
421                 "__global %s *%s = %s.f;\n",
422                 typeName,
423                 VCPtrName,
424                 GPtrName);
425         }
426         else {
427 
428             kgenPrintf( pCtx,
429                 "__global %s *%s = %s.f%dv;\n",
430                 typeName,
431                 VCPtrName,
432                 GPtrName,
433                 vecLen);
434         }
435     }
436 }
437 
438 static void
updateOptimResultGen(struct KgenContext * pCtx,const BlasGenSettings * pGSet,BlasFunctionID funcID,UpdateResultOp op,UpdateResultFlags flags)439 updateOptimResultGen(
440     struct KgenContext *pCtx,
441     const BlasGenSettings *pGSet,
442     BlasFunctionID funcID,
443     UpdateResultOp op,
444     UpdateResultFlags flags)
445 {
446     KernelExtraFlags kflags = pGSet->kextra->flags;
447     Tile tempCTile;
448     Tile fullCTile;
449     unsigned int physVecLenC;
450     DataType dtype;
451     const KernelVarNames *pVNames = NULL;
452     PhysTileIterator physIter;
453     PhysTileIterator blkIter;
454     char cPtrName[] = "pC";
455     const char *typeNameC;
456     bool phyTrans = 0;
457     unsigned int vecLen = 0;
458     unsigned int nBlocks = 0;
459     unsigned int i = 0;
460 
461     Kstring cElem;
462     Kstring tempCElem;
463     Kstring kstrFirst;
464     Kstring kstrSecond;
465     Kstring kstrThird;
466     Kstring expr;
467 
468     //EINVAL
469     if ( NULL == pCtx ||
470         NULL == pGSet ) {
471 
472         return;
473     }
474 
475     dtype = pGSet->kextra->dtype;
476     pVNames = &pGSet->varNames;
477     phyTrans = ( (flags & UPRES_COLUMN_MAJOR ) != 0 );
478 
479     physVecLenC = getVecLen( pGSet, funcID, MATRIX_C );
480     getVectorTypeName( dtype,
481         getVecLen( pGSet,0,MATRIX_C ),
482         &typeNameC,
483         NULL );
484 
485     // declare private C pointer
486     genVectorCPtr( pCtx, pGSet, "uC", "pC" );
487 
488     kgenAddBlankLine( pCtx );
489 
490     // calculate the number of blocks, update should be divided on
491     nBlocks = pGSet->tileCY.nrCols * pGSet->tileCY.nrRows/(
492         pGSet->tileA.nrCols*pGSet->tileA.nrRows +
493         pGSet->tileBX.nrCols*pGSet->tileBX.nrRows );
494 
495     if( pGSet->tileCY.nrCols * pGSet->tileCY.nrRows%(
496         pGSet->tileA.nrCols*pGSet->tileA.nrRows +
497         pGSet->tileBX.nrCols*pGSet->tileBX.nrRows ) ){
498 
499         nBlocks++;
500     }
501 
502     nBlocks = roundUpPow2( (int)nBlocks );
503 
504     // declare the temporary C tile
505     // temporary C tile must have the same transposition as C matrix
506     // for read-write optimization it also has the same vectorization
507     if ( phyTrans ) {
508 
509         if ( nBlocks > pGSet->tileCY.nrCols ) {
510             nBlocks = pGSet->tileCY.nrCols;
511         }
512 
513         initTile( &tempCTile,
514             "tempC",
515             pGSet->tileCY.nrRows,
516             pGSet->tileCY.nrCols/nBlocks,
517             pGSet->tileCY.vecLen,
518             dtype,
519             PRIV_STORAGE_VARIABLE_SET,
520             phyTrans,
521             true );
522 
523         initTile( &fullCTile,
524             "fullC",
525             pGSet->tileCY.nrRows,
526             pGSet->tileCY.nrCols,
527             pGSet->tileCY.vecLen,
528             dtype,
529             PRIV_STORAGE_VARIABLE_SET,
530             phyTrans,
531             true);
532     }
533     else {
534 
535         if ( nBlocks > pGSet->tileCY.nrRows ) {
536             nBlocks = pGSet->tileCY.nrRows;
537         }
538 
539         initTile( &tempCTile,
540             "tempC",
541             pGSet->tileCY.nrRows/nBlocks,
542             pGSet->tileCY.nrCols,
543             pGSet->tileCY.vecLen,
544             dtype,
545             PRIV_STORAGE_VARIABLE_SET,
546             phyTrans,
547             true );
548 
549         initTile( &fullCTile,
550             "fullC",
551             pGSet->tileCY.nrRows,
552             pGSet->tileCY.nrCols,
553             pGSet->tileCY.vecLen,
554             dtype,
555             PRIV_STORAGE_VARIABLE_SET,
556             phyTrans,
557             true);
558     }
559 
560     declareOneTileStorage( pCtx, &tempCTile );
561 
562     // splitting update result on several blocks to prevent
563     // increasing GPR usage
564     for ( i = 0; i < nBlocks; i++ ) {
565 
566         kgenAddBlankLine(pCtx);
567 
568         // fetch ------------------------------------------------------------------
569         vecLen = umin( physVecLenC, pGSet->tileCY.vecLen );
570         vecLen = umin( vecLen, tileLineSegmentLen(&tempCTile) );
571 
572         iterInit( &blkIter, &tempCTile, vecLen, 0 );
573         iterInit( &physIter, &fullCTile, vecLen, 0 );
574 
575         iterSeekPhys( &physIter, blkIter.nrLines * i, blkIter.vec );
576 
577         if (op == UPRES_SUM) {
578             for ( ; 0 == iterIsEnd( &blkIter ); iterIterate( &blkIter ),
579                                                iterIterate( &physIter ) ) {
580 
581                 emptyKstring( &kstrFirst );
582                 emptyKstring( &kstrSecond );
583                 emptyKstring( &kstrThird );
584                 emptyKstring( &cElem );
585                 emptyKstring( &tempCElem );
586 
587                 sprintfTileElement( &tempCElem,
588                     &tempCTile,
589                     blkIter.row,
590                     blkIter.col,
591                     vecLen);
592 
593                 ksprintf( &kstrFirst, "%d", physIter.line );
594                 ksprintf( &kstrSecond, "%s", pVNames->ldc );
595                 ksprintf( &kstrThird, "%d", blkIter.vec );
596 
597                 sprintfFastScalarMad( &expr,
598                     &kstrFirst,
599                     &kstrSecond,
600                     vecLen,//physVecLenC,//scale ldc
601                     &kstrThird);
602 
603                 kgenPrintf( pCtx,
604                     "%s = %s[%s];\n",
605                     tempCElem.buf,
606                     cPtrName,
607                     expr.buf );
608 
609             }
610         }
611 
612         // beta ---------------------------------------------------------------
613         if ( flags & UPRES_WITH_BETA ) {
614 
615             if ( isComplexType(dtype) ||
616                 ( pGSet->tileCY.trans != tempCTile.trans ) ) {
617                 vecLen = 1;
618             }
619             //TODO: for real datatype find longest available veclen can be used
620             //to generate more compact code
621             else {
622                 vecLen = pGSet->tileCY.vecLen;
623             }
624             vecLen = umin( vecLen, tileLineSegmentLen(&tempCTile) );
625 
626             iterInit( &blkIter, &tempCTile, vecLen, 0 );
627 
628             for ( ; 0 == iterIsEnd( &blkIter ); iterIterate( &blkIter ) ) {
629 
630                 sprintfTileElement( &tempCElem,
631                     &tempCTile,
632                     blkIter.row,
633                     blkIter.col,
634                     vecLen);
635 
636                 if ( isComplexType(dtype) ) {
637                     //complex mad
638                     ksprintf( &kstrSecond, "%s", pVNames->beta );
639                     sprintfComplexMulUpdate( &expr,
640                         &tempCElem,
641                         &tempCElem,
642                         &kstrSecond,
643                         NULL,
644                         isDoubleBasedType(dtype),
645                         0,
646                         0,
647                         0 );
648                     kgenPrintf( pCtx, "%s", expr.buf );
649                 }
650                 else {
651                     if ((kflags & KEXTRA_ENABLE_MAD) != 0) {
652                         kgenPrintf( pCtx,
653                             "%s = mad(%s, %s, 0);\n",
654                             tempCElem.buf,
655                             tempCElem.buf,
656                             pVNames->beta);
657                     }
658                     else {
659                         kgenPrintf( pCtx,
660                             "%s = %s * %s;\n",
661                             tempCElem.buf,
662                             tempCElem.buf,
663                             pVNames->beta);
664                     }
665                 }
666             }
667         }
668 
669         // alpha---------------------------------------------------------------
670         if ( (phyTrans == pGSet->tileCY.trans) && (!isComplexType(dtype)) ) {
671 
672             vecLen = pGSet->tileCY.vecLen;
673         }
674         else {
675             vecLen = 1;
676         }
677         vecLen = umin( vecLen, tileLineSegmentLen(&tempCTile) );
678 
679         iterInit( &blkIter, &tempCTile, vecLen, 0 );
680         iterInit( &physIter, &fullCTile, vecLen, 0 );
681 
682         iterSeekPhys( &physIter, blkIter.nrLines * i, blkIter.vec );
683 
684         for ( ; 0 == iterIsEnd( &blkIter ); iterIterate( &blkIter ),
685                                             iterIterate( &physIter) ) {
686 
687             const Kstring *dst;
688 
689             dst = (flags & UPRES_PRIV_DEST) ? &cElem : &tempCElem;
690 
691             sprintfTileElement( &tempCElem,
692                 &tempCTile,
693                 blkIter.row,
694                 blkIter.col,
695                 vecLen);
696 
697             sprintfTileElement( &cElem,
698                 &pGSet->tileCY,
699                 physIter.row,
700                 physIter.col,
701                 vecLen);
702 
703             // complex
704             if ( isComplexType(dtype) ) {
705 
706                 ksprintf( &kstrSecond, "%s", pVNames->alpha );
707 
708                 // upres op: sum or set, if set, third argument
709                 // of complex mad() is zero
710                 sprintfComplexMulUpdate( &expr,
711                     dst,
712                     &cElem,
713                     &kstrSecond,
714                     (op == UPRES_SUM) ? &tempCElem : NULL,
715                     isDoubleBasedType(dtype),
716                     0,
717                     0,
718                     0);
719                 kgenPrintf( pCtx, "%s", expr.buf );
720 
721             }
722             // real
723             else {
724 
725                 // upres op: sum or set, if set, third argument
726                 // of mad() is zero
727                 if ((kflags & KEXTRA_ENABLE_MAD) != 0) {
728                     kgenPrintf( pCtx,
729                         "%s = mad(%s, %s, %s);\n",
730                         dst,
731                         cElem.buf,
732                         pVNames->alpha,
733                         (op == UPRES_SUM) ? tempCElem.buf : "0" );
734                 }
735                 else {
736                     kgenPrintf( pCtx,
737                         "%s = %s * %s + %s;\n",
738                         dst,
739                         cElem.buf,
740                         pVNames->alpha,
741                         (op == UPRES_SUM) ? tempCElem.buf : "0" );
742                 }
743             }
744         }
745 
746         if (flags & UPRES_PRIV_DEST) {
747             return;
748         }
749 
750         // store---------------------------------------------------------------
751         vecLen = umin( physVecLenC, pGSet->tileCY.vecLen );
752         vecLen = umin( vecLen, tileLineSegmentLen( &tempCTile ) );
753 
754         iterInit( &blkIter, &tempCTile, vecLen, 0 );
755         iterInit( &physIter, &fullCTile, vecLen, 0 );
756 
757         iterSeekPhys( &physIter, blkIter.nrLines * i, blkIter.vec );
758 
759         for ( ; 0 == iterIsEnd( &blkIter ); iterIterate( &blkIter ),
760                                             iterIterate( &physIter ) ) {
761 
762             emptyKstring( &kstrFirst );
763             emptyKstring( &kstrSecond );
764             emptyKstring( &kstrThird );
765             emptyKstring( &cElem );
766             emptyKstring( &tempCElem );
767 
768             sprintfTileElement( &tempCElem,
769                 &tempCTile,
770                 blkIter.row,
771                 blkIter.col,
772                 vecLen);
773 
774             ksprintf( &kstrFirst, "%d", physIter.line );
775             ksprintf( &kstrSecond, "%s", pVNames->ldc );
776             ksprintf( &kstrThird, "%d", blkIter.vec );
777 
778             sprintfFastScalarMad( &expr,
779                 &kstrFirst,
780                 &kstrSecond,
781                 vecLen,//physVecLenC,//scale ldc
782                 &kstrThird);
783 
784             kgenPrintf( pCtx,
785                 "%s[%s] = %s;\n",
786                 cPtrName,
787                 expr.buf,
788                 tempCElem.buf );
789 
790         }
791     }
792 
793 }
794 
795 int
genUpdateResultSingle(struct KgenContext * ctx,const char * dst,const char * src,const BlasGenSettings * gset,UpdateResultOp op,UpdateResultFlags flags)796 genUpdateResultSingle(
797     struct KgenContext *ctx,
798     const char *dst,
799     const char *src,
800     const BlasGenSettings *gset,
801     UpdateResultOp op,
802     UpdateResultFlags flags)
803 {
804     char tmp[1024];
805     char *p;
806     const char *opStr;
807     UpdateResultFlags m;
808     int r;
809     bool isComplex = isComplexType(gset->kextra->dtype);
810 
811     // copy destination with respective operator and additional operations
812     if (flags & UPRES_WITH_BETA) {
813         if (isComplex) {
814             sprintf(tmp, "%s = %s * betaR + %s.yx * betaI + ",
815                     dst, dst, dst);
816         }
817         else {
818             sprintf(tmp, "%s = %s * beta + ", dst, dst);
819         }
820     }
821     else {
822         opStr = (op == UPRES_SET) ? "=" : "+=";
823         sprintf(tmp, "%s %s ", dst, opStr);
824     }
825 
826     m = UPRES_WITH_BETA | UPRES_GENERIC;
827     if (isComplex && ((flags & m) == m)) {
828         strcat(tmp, "\n                    ");
829     }
830     p = tmp + strlen(tmp);
831 
832     // multiply source
833     if (flags & UPRES_WITHOUT_ALPHA) {
834         sprintf(p, "%s;\n", src);
835     }
836     else {
837         if (isComplex) {
838             sprintf(p, "%s * alphaR + %s.yx * alphaI;\n", src, src);
839         }
840         else {
841             sprintf(p, "%s * alpha;\n", src);
842         }
843     }
844 
845     r = kgenAddStmt(ctx, tmp);
846 
847     return (r) ? -EOVERFLOW : 0;
848 }
849 
850 static void
updateGenericResultGen(struct KgenContext * ctx,const BlasGenSettings * gset,size_t pitch,UpresVarNames * uvars,UpdateResultOp op,UpdateResultFlags flags,const char * cachedName)851 updateGenericResultGen(
852     struct KgenContext *ctx,
853     const BlasGenSettings *gset,
854     size_t pitch,
855     UpresVarNames* uvars,
856     UpdateResultOp op,
857     UpdateResultFlags flags,
858     const char *cachedName)
859 {
860     char tmp[1024], dst[128], src[128];
861     const char *boundNames[2] = {uvars->nrRows, uvars->nrCols};
862     const char *vecType = NULL;
863     const char *vFieldVectorized;
864     DataType dtype = gset->kextra->dtype;
865     unsigned int wvlen;
866     unsigned int sizes[2];
867     const char*  vfield = dtypeUPtrField(dtype);
868     bool tra = ((flags & UPRES_COLUMN_MAJOR) != 0);
869     bool row = ((flags & UPRES_TAIL_ROW));
870     bool col = ((flags & UPRES_TAIL_COL));
871     bool iwc = ((flags & UPRES_INDEXING_WITH_CONSTANTS) != 0) ||
872                 (gset->tileCY.storType != PRIV_STORAGE_ARRAY);
873     int l0;
874     int l1;
875     bool revert = false;
876 
877     Kstring kstr;
878     int rowId;
879     int colId;
880 
881     sizes[0] = (unsigned int)gset->subdims[1].y;
882     sizes[1] = (unsigned int)gset->subdims[1].x;
883 
884     if (iwc) {
885         const char* l0var =  boundNames[tra];
886         revert =  (tra && col) || (!tra && row);
887 
888         if (revert) {
889             sprintf(tmp, "uC.%s += (%s-1) * %s;\n", vfield, l0var, uvars->ld);
890         }
891         else {
892             sprintf(tmp, "\n");
893         }
894         kgenAddStmt(ctx, tmp);
895 
896     }
897     wvlen = getTmpVecLen(gset, flags, &vecType);
898     if (!iwc) {
899         getVectorTypeName(dtype, wvlen, NULL, &vFieldVectorized);
900         sprintf(tmp, "res.%s = c;\n", vFieldVectorized);
901         kgenAddStmt(ctx, tmp);
902     }
903 
904     if (flags & (UPRES_TAIL_ROW | UPRES_TAIL_COL)) {
905         char offStr[64];
906         char *p = offStr;
907 
908         offStr[0] = '\0';
909         if (flags & UPRES_TAIL_ROW) {
910             sprintf(offStr, " + (%u - %s) * %lu",
911                     sizes[0], uvars->nrRows, pitch);
912             p += strlen(offStr);
913         }
914         if (flags & UPRES_TAIL_COL) {
915             sprintf(p, " + (%u - %s)", sizes[1], uvars->nrCols);
916         }
917         if (iwc) {
918             sprintf(tmp, "res.%s = uC.%s%s;\n", vfield, vfield, offStr);
919             sprintf(tmp, "\n");
920         }
921         else {
922             sprintf(tmp, "res.%s = res.%s%s;\n", vfield, vfield, offStr);
923         }
924         kgenAddStmt(ctx, tmp);
925 
926     }
927     if (iwc) {
928         int l0st = 1; int l0en = sizes[tra];
929         int l1st = 1; int l1en = sizes[1-tra];
930 
931         const char* l0var =  boundNames[tra];
932         const char* l1var = boundNames[1-tra];
933 
934         for (l0 = l0en; l0 >= l0st; l0--) {
935 
936             sprintf(tmp, "if (%s) ",l0var);
937             kgenBeginBranch(ctx, tmp);
938 
939             sprintf(tmp, "switch (%s)", l1var);
940             kgenBeginBranch(ctx, tmp);
941 
942             for (l1 = l1en; l1 >= l1st; l1--) {
943                 sprintf(tmp, "case %d:\n", l1);
944                 kgenAddStmt(ctx, tmp);
945 
946                 if (tra) {
947                     rowId = (row)? (l1en-l1): (l1-l1st);
948                     colId = (col)? (l0-l0st): (l0en-l0);
949                 }
950                 else {
951                     ///////////////////////////
952                     rowId = (row)? (l0-l0st): (l0en-l0);
953                     colId = (col)? (l1en-l1) : (l1-l1st);
954                 }
955 
956                 if ((tra && row) || (!tra && col)) {
957                      sprintf(dst, "uC.%s[(%s+%d) %% %i]",
958                              vfield, l1var, (l1en - l1),  (int)l1en);
959                 }
960                 else {
961                    sprintf(dst, "uC.%s[%d]", vfield, (l1-l1st));
962                 }
963 
964                 sprintfTileElement(&kstr, &gset->tileCY, rowId, colId, wvlen);
965 
966                 if (flags & UPRES_PRIV_DEST) {
967                     genUpdateResultSingle(ctx, kstr.buf, dst, gset, op, flags);
968                 }
969                 else {
970                     genUpdateResultSingle(ctx, dst, kstr.buf, gset, op, flags);
971                 }
972             }
973             kgenEndBranch(ctx, NULL);
974 
975             if (revert) {
976                 sprintf(tmp, "uC.%s -= %s;\n", vfield, uvars->ld);
977             }
978             else {
979                 sprintf(tmp, "uC.%s += %s;\n", vfield, uvars->ld);
980             }
981 
982             kgenAddStmt(ctx, tmp);
983 
984             sprintf(tmp, "%s--;\n", l0var);
985             kgenAddStmt(ctx, tmp);
986             kgenEndBranch(ctx, NULL);
987         }
988 
989     }
990     else {
991         sprintf(tmp, "for (i = 0; i < %s; i++)", boundNames[tra]);
992         kgenBeginBranch(ctx, tmp);
993         sprintf(tmp, "for (j = 0; j < %s; j++)", boundNames[1 - tra]);
994         kgenBeginBranch(ctx, tmp);
995         sprintf(dst, "uC.%s[i * %s + j]", vfield, uvars->ld);
996         if (cachedName) {
997             unsigned int i;
998             char tmpcachedName[80] = " = ";
999             strcat(tmpcachedName, cachedName);
1000             for (i = 3; i < strlen(tmpcachedName); i++) {
1001                 if (strncmp(tmpcachedName+i, "%u", 2) == 0) {
1002                     tmpcachedName[i+1] = 's';
1003                 }
1004             }
1005             sprintf(tmp, tmpcachedName, "i", "[j]");
1006             strcat(dst, tmp);
1007         }
1008         // result (res) can be transposed independently of the matrix C
1009         // If the transposition of "C" and "result" is not consistent
1010         // then change the calculation of the index for "result"
1011         if (gset->tileCY.trans ^ tra) {
1012             sprintf(src, "res.%s[j * %lu + i]", vfield, pitch);
1013         }
1014         else {
1015             sprintf(src, "res.%s[i * %lu + j]", vfield, pitch);
1016         }
1017         if (flags & UPRES_PRIV_DEST) {
1018             genUpdateResultSingle(ctx, src, dst, gset, op, flags);
1019         }
1020         else {
1021             genUpdateResultSingle(ctx, dst, src, gset, op, flags);
1022         }
1023         kgenEndBranch(ctx, NULL);
1024         kgenEndBranch(ctx, NULL);
1025     }
1026 }
1027 
1028 //-----------------------------------------------------------------------------
1029 
1030 int
updateResultGen(struct KgenContext * ctx,const BlasGenSettings * gset,BlasFunctionID funcID,UpdateResultOp op,UpdateResultFlags flags,const UpresVarNames * uvarNames)1031 updateResultGen(
1032     struct KgenContext *ctx,
1033     const BlasGenSettings *gset,
1034     BlasFunctionID funcID,
1035     UpdateResultOp op,
1036     UpdateResultFlags flags,
1037     const UpresVarNames *uvarNames)
1038 {
1039     char tmp[1024];
1040     char *p = tmp;
1041     const char *typeName;
1042     const char *vecType = NULL;
1043     const char *vfield;
1044     const char *suff1;
1045     const char *suff2;
1046     int ret = 0;
1047     unsigned int sizes[2];
1048     bool generic, tra;
1049     unsigned int wvlen;     // length of vectors to copy with
1050     unsigned int uplen;     // length of vectors to update result with
1051     size_t pitch;
1052     char LG;
1053     DataType dtype = gset->kextra->dtype;
1054     unsigned int vecLen;
1055     bool isInlined = (flags & UPRES_INLINE);
1056     UpresVarNames uvars;
1057 
1058     vecLen = (gset->flags & BGF_DISTINCT_VECLEN) ? gset->kextra->vecLenC :
1059                                                    gset->kextra->vecLen;
1060     sizes[0] = (unsigned int)gset->subdims[1].y;
1061     sizes[1] = (unsigned int)gset->subdims[1].x;
1062 
1063     if (isComplexType(dtype)) {
1064         vecLen = 1;
1065     }
1066 
1067     if ((flags & UPRES_WITH_BETA) && (op != UPRES_SUM)) {
1068         return -EINVAL;
1069     }
1070 
1071     tra = ((flags & UPRES_COLUMN_MAJOR) != 0);
1072     generic = ((flags & UPRES_GENERIC) != 0);
1073     typeName = dtypeBuiltinType(dtype);
1074     vfield = dtypeUPtrField(dtype);
1075     pitch = roundUp(sizes[1], vecLen);
1076 
1077     // select write vectorization
1078     wvlen = getTmpVecLen(gset, flags, &vecType);
1079     uplen = (tra ^ gset->tileCY.trans
1080              || (flags & UPRES_NO_VECTORIZATION)) ? 1 : vecLen;
1081 
1082     suff1 = (generic) ? "Generic" : "";
1083     suff2 = (flags & UPRES_PRIV_DEST) ? "Rev" : "";
1084     LG = (flags & UPRES_USE_LDS) ? 'L' : 'G';
1085 
1086     if (!isInlined) {
1087         const char *outTypeName;
1088         const char *memPref = (flags & UPRES_USE_LDS) ? "__local" :
1089                                                            "__global";
1090 
1091         getResultGPRsInfo(dtype, NULL, vecLen, NULL, &outTypeName);
1092 
1093         // define the function
1094         sprintf(tmp, "void\n"
1095                      "updateResult%s%s%c(\n"
1096                      "    %s %s *C,\n"
1097                      "    %s *c,\n"
1098                      "    %s alpha,\n"
1099                      "    uint startRow,\n"
1100                      "    uint startCol,\n"
1101                      "    uint ld",
1102                      suff1, suff2, LG, memPref, typeName,
1103                      outTypeName, typeName);
1104 
1105         p += strlen(p);
1106         if (flags & UPRES_WITH_BETA) {
1107             sprintf(p, ",\n    %s beta", typeName);
1108             p += strlen(p);
1109         }
1110         if (generic) {
1111             sprintf(p, ",\n    uint nrRows,\n"
1112                        "    uint nrCols");
1113         }
1114 
1115         uvars.result = "C";
1116         uvars.ld = "ld";
1117         uvars.startRow = "startRow";
1118         uvars.startCol = "startCol";
1119         uvars.nrRows = "nrRows";
1120         uvars.nrCols = "nrCols";
1121 
1122         strcat(p, ")\n");
1123         kgenDeclareFunction(ctx, tmp);
1124         kgenBeginFuncBody(ctx);
1125     }
1126     else {
1127         memcpy(&uvars, uvarNames, sizeof(uvars));
1128     }
1129 
1130     // declare local variables
1131     sprintf(tmp, "%cPtr uC;\n", LG);
1132     kgenAddStmt(ctx, tmp);
1133     if (generic) {
1134         kgenAddStmt(ctx, "int i, j;\n"
1135                          "PPtr res;\n");
1136     }
1137     else {
1138         /*
1139          * temporary pointer to pass correctly over the
1140          * destination array since destination rows can be
1141          * not aligned on a vector bound
1142          */
1143         if (sizes[1 - tra] % wvlen != 0) {
1144             sprintf(tmp, "%cPtr tmpC;\n", LG);
1145             kgenAddStmt(ctx, tmp);
1146         }
1147         if (wvlen > uplen) {
1148             sprintf(tmp, "%s tmp;\n", vecType);
1149             kgenAddStmt(ctx, tmp);
1150         }
1151     }
1152     if (isComplexType(dtype) && !(flags & UPRES_WITHOUT_ALPHA)) {
1153         declareComplexMultParts(ctx, "alpha", typeName);
1154         if (flags & UPRES_WITH_BETA) {
1155             declareComplexMultParts(ctx, "beta", typeName);
1156         }
1157 
1158     }
1159     kgenAddBlankLine(ctx);
1160 
1161     // LD is scaled
1162     if ( gset->flags & BGF_LD_IN_VECTORS ) {
1163 
1164         vecLen = getVecLen(gset, 0, MATRIX_C);
1165     }
1166     else {
1167 
1168         vecLen = 1;
1169     }
1170 
1171     if (tra) {
1172 
1173         if ( vecLen > 1 ) {
1174 
1175             sprintf(tmp,
1176                 "uC.%s = %s + (%s * %s + %s)/%d;\n",
1177                 vfield,
1178                 uvars.result,
1179                 uvars.startCol,
1180                 uvars.ld,
1181                 uvars.startRow,
1182                 vecLen);
1183         }
1184         else {
1185 
1186             sprintf(tmp,
1187                 "uC.%s = %s + %s * %s + %s;\n",
1188                 vfield,
1189                 uvars.result,
1190                 uvars.startCol,
1191                 uvars.ld,
1192                 uvars.startRow);
1193         }
1194     }
1195     else {
1196 
1197         if ( vecLen > 1 ) {
1198 
1199             sprintf(tmp,
1200                 "uC.%s = %s + (%s * %s + %s)/%d;\n",
1201                 vfield,
1202                 uvars.result,
1203                 uvars.startRow,
1204                 uvars.ld,
1205                 uvars.startCol,
1206                 vecLen);
1207 
1208         }
1209         else {
1210 
1211             sprintf(tmp,
1212                 "uC.%s = %s + %s * %s + %s;\n",
1213                 vfield,
1214                 uvars.result,
1215                 uvars.startRow,
1216                 uvars.ld,
1217                 uvars.startCol);
1218         }
1219     }
1220     kgenAddStmt(ctx, tmp);
1221 
1222     if ((sizes[1 - tra] % wvlen != 0) && !generic) {
1223         kgenAddStmt(ctx, "tmpC = uC;\n");
1224     }
1225     ret = kgenAddBlankLine(ctx);
1226 
1227     if (generic) {
1228         updateGenericResultGen(ctx, gset, pitch, &uvars, op, flags,
1229                                uvarNames ? uvarNames->cachedName : NULL);
1230     }
1231     else {
1232         updateOptimResultGen(ctx,
1233         gset,
1234         funcID,
1235         op,
1236         flags);
1237     }
1238 
1239     if (!isInlined) {
1240         ret = kgenEndFuncBody(ctx);
1241     }
1242 
1243     return (ret) ? -EOVERFLOW : 0;
1244 }
1245 
1246 TailFetch
checkForTailFetches(BlasFunctionID funcID,const SubproblemDim * dim,const CLBLASKernExtra * kextra,MatrixRole mrole,bool distVect,bool lowerTails)1247 checkForTailFetches(
1248     BlasFunctionID funcID,
1249     const SubproblemDim *dim,
1250     const CLBLASKernExtra *kextra,
1251     MatrixRole mrole,
1252     bool distVect,
1253     bool lowerTails)
1254 {
1255     TailFetch ret = FETCH_NO_TAILS;
1256     size_t x;
1257     KernelExtraFlags tailFlag;
1258     unsigned int vecLen;
1259     KernelExtraFlags tailFlagM, tailFlagN, tailFlagK;
1260 
1261     tailFlagM = lowerTails ? KEXTRA_TAILS_M_LOWER : KEXTRA_TAILS_M;
1262     tailFlagN = lowerTails ? KEXTRA_TAILS_N_LOWER : KEXTRA_TAILS_N;
1263     tailFlagK = lowerTails ? KEXTRA_TAILS_K_LOWER : KEXTRA_TAILS_K;
1264 
1265     if (mrole == MATRIX_A) {
1266         x = dim->y;
1267         tailFlag = tailFlagM;
1268         vecLen = (distVect) ? kextra->vecLenA : kextra->vecLen;
1269     }
1270     else {
1271         x = dim->x;
1272         tailFlag = tailFlagN;
1273         vecLen = (distVect) ? kextra->vecLenB : kextra->vecLen;
1274     }
1275 
1276     if (isMatrixAccessColMaj(funcID, kextra->flags, mrole)) {
1277         if ((kextra->flags & tailFlag) && (x != vecLen)) {
1278             ret |= FETCH_TAIL_COL;
1279         }
1280         if (kextra->flags & tailFlagK) {
1281             ret |= FETCH_TAIL_ROW;
1282         }
1283     }
1284     else if (kextra->flags & tailFlagK) {
1285         ret |= FETCH_TAIL_COL;
1286     }
1287 
1288     return ret;
1289 }
1290 
1291 bool
isNeedZeroTileTail(BlasFunctionID funcID,const SubproblemDim * dim,const CLBLASKernExtra * kextra,MatrixRole mrole,bool distVect)1292 isNeedZeroTileTail(
1293     BlasFunctionID funcID,
1294     const SubproblemDim *dim,
1295     const CLBLASKernExtra *kextra,
1296     MatrixRole mrole,
1297     bool distVect)
1298 {
1299     bool trans;
1300     TailFetch tf;
1301 
1302     trans = isMatrixAccessColMaj(funcID, kextra->flags, mrole);
1303     tf = checkForTailFetches(funcID, dim, kextra, mrole, distVect, true);
1304 
1305     return (trans && (tf & FETCH_TAIL_ROW)) ||
1306            (!trans && (tf & FETCH_TAIL_COL));
1307 }
1308 
1309 TailStatus
checkGenAdjustTailCoords(struct KgenContext * ctx,BlasFunctionID funcID,const BlasGenSettings * gset,int * error)1310 checkGenAdjustTailCoords(
1311     struct KgenContext *ctx,
1312     BlasFunctionID funcID,
1313     const BlasGenSettings *gset,
1314     int *error)
1315 {
1316     char tmp[1024];
1317     const SubproblemDim *dim = &gset->subdims[1];
1318     const KernelVarNames *varNames = &gset->varNames;
1319     KernelExtraFlags kflags = gset->kextra->flags;
1320     TailStatus status = 0;
1321     int err = 0;
1322     int n = 0;
1323 
1324     if (!isMatrixAccessColMaj(funcID, kflags, MATRIX_A) &&
1325         (kflags & KEXTRA_TAILS_M_LOWER)) {
1326 
1327         status |= TAIL_A_RAISED;
1328         sprintf(tmp, "if (%s + %lu > %s) {\n"
1329                      "    %s -= %lu - %s %% %lu;\n"
1330                      "}\n",
1331                 varNames->coordA, dim->y, varNames->sizeM,
1332                 varNames->coordA, dim->y, varNames->sizeM,
1333                 dim->y);
1334         if (ctx != NULL) {
1335             err = kgenAddStmt(ctx, tmp);
1336             n++;
1337         }
1338     }
1339 
1340     if (!isMatrixAccessColMaj(funcID, kflags, MATRIX_B) &&
1341         (kflags & KEXTRA_TAILS_N_LOWER) && !err) {
1342 
1343         status |= TAIL_B_RAISED;
1344         sprintf(tmp, "if (%s + %lu > %s) {\n"
1345                      "    %s -= %lu - %s %% %lu;\n"
1346                      "}\n",
1347                 varNames->coordB, dim->x, varNames->sizeN,
1348                 varNames->coordB, dim->x, varNames->sizeN,
1349                 dim->x);
1350         if (ctx != NULL) {
1351             err = kgenAddStmt(ctx, tmp);
1352             n++;
1353         }
1354     }
1355 
1356     if (n && !err) {
1357         err = kgenAddBlankLine(ctx);
1358     }
1359 
1360     if (error != NULL) {
1361         *error = err;
1362     }
1363 
1364     return status;
1365 }
1366 
1367 int
checkGenRestoreTailCoords(struct KgenContext * ctx,const BlasGenSettings * gset,TailStatus status)1368 checkGenRestoreTailCoords(
1369     struct KgenContext *ctx,
1370     const BlasGenSettings *gset,
1371     TailStatus status)
1372 {
1373     char tmp[1024];
1374     const SubproblemDim *dim = &gset->subdims[1];
1375     const KernelVarNames *varNames = &gset->varNames;
1376     int ret = 0;
1377     int n = 0;
1378 
1379     if (status & TAIL_A_RAISED) {
1380         sprintf(tmp, "if ((%s + %lu == %s) && (%s %% %lu)) {\n"
1381                      "    %s += %lu - %s %% %lu;\n"
1382                      "}\n",
1383                 varNames->coordA, dim->y, varNames->sizeM,
1384                 varNames->sizeM, dim->y, varNames->coordA,
1385                 dim->y, varNames->sizeM, dim->y);
1386         ret = kgenAddStmt(ctx, tmp);
1387         n++;
1388     }
1389 
1390     if ((status & TAIL_B_RAISED) && !ret) {
1391 
1392         sprintf(tmp, "if ((%s + %lu == %s) && (%s %% %lu)) {\n"
1393                      "    %s += %lu - %s %% %lu;\n"
1394                      "}\n",
1395                 varNames->coordB, dim->x, varNames->sizeN,
1396                 varNames->sizeN, dim->x, varNames->coordB,
1397                 dim->x, varNames->sizeN, dim->x);
1398         kgenAddStmt(ctx, tmp);
1399         n++;
1400     }
1401 
1402     if (n) {
1403         ret = kgenAddBlankLine(ctx);
1404     }
1405 
1406     return (ret) ? -EOVERFLOW : 0;
1407 }
1408 
1409 UpdateResultFlags
tailStatusToUpresFlags(TailStatus status)1410 tailStatusToUpresFlags(TailStatus status)
1411 {
1412     UpdateResultFlags flags = 0;
1413 
1414     if (status & TAIL_A_RAISED) {
1415         flags |= UPRES_TAIL_ROW;
1416     }
1417     if (status & TAIL_B_RAISED) {
1418         flags |= UPRES_TAIL_COL;
1419     }
1420 
1421     return flags;
1422 }
1423 
1424 int
declareComplexMultParts(struct KgenContext * ctx,const char * baseName,const char * typeName)1425 declareComplexMultParts(
1426     struct KgenContext *ctx,
1427     const char *baseName,
1428     const char *typeName)
1429 {
1430     char tmp[1024];
1431     int r;
1432 
1433     sprintf(tmp, "%s %sR = (%s)(%s.x);\n"
1434                  "%s %sI = (%s)(-%s.y, %s.y);\n",
1435             typeName, baseName, typeName, baseName,
1436             typeName, baseName, typeName, baseName, baseName);
1437     r = kgenAddStmt(ctx, tmp);
1438 
1439     return (r) ? -EOVERFLOW : 0;
1440 }
1441 
1442 void
sprintfFastScalarMad(Kstring * expr,const Kstring * first,const Kstring * second,unsigned int scale,const Kstring * third)1443 sprintfFastScalarMad(
1444     Kstring *expr,
1445     const Kstring *first,
1446     const Kstring *second,
1447     unsigned int scale,
1448     const Kstring *third)
1449 {
1450     unsigned int u1 = 0, u2 = 0, u3 = 0;
1451     bool isNum1, isNum2, isNum3;
1452     int shift;
1453     bool done = false;
1454     const char *thirdStr;
1455     const char *suff3;
1456 
1457     // clear up what are these arguments
1458     if (isKstringEmpty(first)) {
1459         isNum1 = true;
1460     }
1461     else {
1462         isNum1 = !stringToInt(first->buf, &u1);
1463     }
1464 
1465     if (isKstringEmpty(second)) {
1466         isNum2 = true;
1467     }
1468     else {
1469         isNum2 = !stringToInt(second->buf, &u2);
1470     }
1471 
1472     if (!scale) {
1473         scale = 1;
1474     }
1475 
1476     if ((third == NULL) || isKstringEmpty(third)) {
1477         thirdStr = "0";
1478         isNum3 = true;
1479     }
1480     else {
1481         thirdStr = third->buf;
1482         isNum3 = !stringToInt(thirdStr, &u3);
1483     }
1484     suff3 = (isNum3) ? "u" : "";
1485 
1486     // singular case at which only the third component can contribute
1487     if ( (isNum1 && (u1 == 0)) ||
1488          (isNum2 && (u2 /scale == 0))) {
1489 
1490         kstrcpy(expr, thirdStr);
1491         return;
1492     }
1493 
1494     if (isNum1 && isNum2) {
1495         if (isNum3) {
1496             ksprintf(expr, "%u", u1 * u2 / scale + u3);
1497         }
1498         else {
1499             ksprintf(expr, "%u + %s", u1 * u2 / scale, thirdStr);
1500         }
1501         done = true;
1502     }
1503     else if (isNum1) {
1504         /*
1505          * If the third argument is not used, then try to build the expression
1506          * using only shifts if 'scale' and the 'second argument' are both of
1507          * power of 2. Otherwise use mad24.
1508          */
1509         if (isRoundedPow2(u1) && isRoundedPow2(scale)) {
1510             shift = findHighestSetBit(scale) - findHighestSetBit(u1);
1511             if (isNum3 && (u3 == 0)) {
1512                 if (shift < 0) {
1513                     ksprintf(expr, "(%s << %d)", second->buf, -shift);
1514                 }
1515                 else if (shift > 0) {
1516                     ksprintf(expr, "(%s >> %d)", second->buf, shift);
1517                 }
1518                 else {
1519                     kstrcpy(expr, second->buf);
1520                 }
1521             }
1522             else if (shift > 0) {
1523                 ksprintf(expr, "(%s >> %d) + %s",
1524                          second->buf, shift, thirdStr);
1525             }
1526             else if (shift == 0) {
1527                 ksprintf(expr, "%s + %s", second->buf, thirdStr);
1528             }
1529             else {
1530                 ksprintf(expr, "mad24(%uu, %s, %s%s)",
1531                          1u << -shift, second->buf, thirdStr, suff3);
1532             }
1533             done = true;
1534         }
1535     }
1536 
1537     if (!done) {
1538         /*
1539          * Append unsiged suffixes to avoid cases at which one
1540          * operand is signed and the other is unsigned. Typically,
1541          * OpenCL compilers are strict and reject such expressions.
1542          */
1543         if (isNum2) {
1544             if (u2 / scale == 1) {
1545                 if (isNum3 && (u3 == 0)) {
1546                     kstrcpy(expr, first->buf);
1547                 }
1548                 else {
1549                     ksprintf(expr, "%s + %s", first->buf, thirdStr);
1550                 }
1551             }
1552             else {
1553                 ksprintf(expr, "mad24(%s, %uu, %s%s)",
1554                          first->buf, u2 / scale, thirdStr, suff3);
1555             }
1556         }
1557         else {
1558             const char *suff1 = (isNum1) ? "u" : "";
1559             Kstring tmp;
1560             const char *p = NULL;
1561 
1562             if (scale == 1) {
1563                 p = second->buf;
1564             }
1565             else {
1566                 p = tmp.buf;
1567                 if (isRoundedPow2(scale)) {
1568                     shift = findHighestSetBit(scale);
1569                     ksprintf(&tmp, "(%s >> %d)", second->buf, shift);
1570                 }
1571                 else {
1572                     ksprintf(&tmp, "%s / %d", second->buf, scale);
1573                 }
1574             }
1575 
1576             ksprintf(expr, "mad24(%s%s, %s, %s%s)",
1577                      first->buf, suff1, p, thirdStr, suff3);
1578         }
1579     }
1580 }
1581