1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #include <stdio.h>
19 
20 #include "../blas_kgen.h"
21 #include "trsm_kgen_legacy.h"
22 
23 void
genUpdateIntermTrsmResult(struct KgenContext * ctx,const BlasGenSettings * gset,const char * optFuncName,const char * genericFuncName,bool withMhitCond)24 genUpdateIntermTrsmResult(
25     struct KgenContext *ctx,
26     const BlasGenSettings *gset,
27     const char *optFuncName,
28     const char *genericFuncName,
29     bool withMhitCond)
30 {
31     char tmp[1024];
32     const char *coordY, *coordX;
33     char *revAlp, *alp;
34     DataType dtype = gset->kextra->dtype;
35     KernelExtraFlags kflags = gset->kextra->flags;
36     const SubproblemDim *dim = &gset->subdims[1];
37     const KernelVarNames *kvarNames = &gset->varNames;
38 
39     if (isComplexType(dtype)) {
40         if (dtype == TYPE_COMPLEX_FLOAT) {
41             revAlp = "div((float2)(-1.f, 0), alpha)";
42             alp = "(float2)(1.f, 0)";
43         }
44         else {
45             revAlp = "div((double2)(-1., 0), alpha)";
46             alp = "(double2)(1., 0)";
47         }
48     }
49     else {
50         revAlp = "-1. / alpha";
51         alp = "1.";
52     }
53 
54     coordY = kvarNames->coordA;
55     coordX = kvarNames->coordB;
56 
57     if (!(kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N))) {
58         sprintf(tmp, "%s(B, c, %s, %s, %s, ldb, %s);\n",
59                 optFuncName, alp, coordY, coordX, revAlp);
60         kgenAddStmt(ctx, tmp);
61     }
62     else {
63         if (withMhitCond) {
64             sprintf(tmp, "if ((%s < %s) && (%s < %s))",
65                     coordY, kvarNames->sizeM, coordX, kvarNames->sizeN);
66             kgenBeginBranch(ctx, tmp);
67         }
68         else {
69             /* for x, y variables scope */
70             kgenBeginBranch(ctx, NULL);
71         }
72 
73         sprintf(tmp, "uint y = min(%luu, %s - (uint)%s);\n"
74                      "uint x = min(%luu, %s - (uint)%s);\n"
75                      "if ((y == %luu) && (x == %luu)) {\n"
76                      "    %s(B, c, %s, %s, %s, ldb, %s);\n"
77                      "}\n"
78                      "else {\n"
79                      "    %s(B, c, %s, %s, %s, ldb, %s, y, x);\n"
80                      "}\n",
81                 dim->y, kvarNames->sizeM, coordY,
82                 dim->x, kvarNames->sizeN, coordX,
83                 dim->y, dim->x,
84                 optFuncName, alp, coordY, coordX, revAlp,
85                 genericFuncName, alp, coordY, coordX, revAlp);
86 
87         kgenAddStmt(ctx, tmp);
88 
89         kgenEndBranch(ctx, NULL);
90     }
91 }
92 
93 void
genHeapTrsmResultToLDS(struct KgenContext * ctx,const BlasGenSettings * gset,const char * funcName,const char * dstName)94 genHeapTrsmResultToLDS(
95     struct KgenContext *ctx,
96     const BlasGenSettings *gset,
97     const char *funcName,
98     const char *dstName)
99 {
100     char tmp[1024];
101     char *alp;
102     unsigned int l1Pans;
103     DataType dtype = gset->kextra->dtype;
104     const SubproblemDim *dims = gset->subdims;
105 
106     if(isComplexType(dtype)) {
107         if (dtype == TYPE_COMPLEX_FLOAT) {
108             alp = "(float2)(1.f, 0)";
109         }
110         else {
111             alp = "(double2)(1., 0)";
112         }
113     }
114     else {
115         alp = "1.";
116     }
117 
118     l1Pans = (unsigned int)dims[0].x / (unsigned int)dims[1].x;
119     sprintf(tmp, "%s(%s, c, %s, (lid / %u * %lu), (lid %% %u * %lu), %lu);\n",
120             funcName, dstName, alp, l1Pans, dims[1].y, l1Pans, dims[1].x,
121             dims[0].bwidth);
122     kgenAddStmt(ctx, tmp);
123 }
124 
125 void
genInvertingBlockFunc(struct KgenContext * ctx,size_t pitch,DataType dtype,KernelExtraFlags kflags)126 genInvertingBlockFunc(
127     struct KgenContext *ctx,
128     size_t pitch,
129     DataType dtype,
130     KernelExtraFlags kflags)
131 {
132     char tmp[1024];
133     const char *ctype;
134     ctype = dtypeBuiltinType(dtype);
135 
136     sprintf(tmp, "void\ninvert(__local %s *src, __local %s *dst, int lid, "
137                               "int lastRow)\n", ctype, ctype);
138     kgenDeclareFunction(ctx, tmp);
139     kgenBeginFuncBody(ctx);
140     kgenAddStmt(ctx, "int i, k;\n");
141 
142     if (isComplexType(dtype)) {
143         sprintf(tmp, "dst[lid * %lu + lid].x = 1.f;\n", pitch);
144     }
145     else {
146         sprintf(tmp, "dst[lid * %lu + lid] = 1.f;\n", pitch);
147     }
148     kgenAddStmt(ctx, tmp);
149 
150     if (isMatrixUpper(kflags)) {
151         sprintf(tmp, "for (i = lastRow - 1; i >= 0; i--)");
152     }
153     else {
154         sprintf(tmp, "for (i = 0; i < lastRow; i++)");
155     }
156     kgenBeginBranch(ctx, tmp);
157 
158     if (isComplexType(dtype)) {
159         sprintf(tmp, "dst[i * %lu + lid] = div(dst[i * %lu + lid], "
160                      "src[i * %lu + i]);\n", pitch, pitch, pitch);
161     }
162     else {
163         sprintf(tmp, "dst[i * %lu + lid] = dst[i * %lu + lid] / "
164                      "src[i * %lu + i];\n", pitch, pitch, pitch);
165     }
166     kgenAddStmt(ctx, tmp);
167 
168     if (isMatrixUpper(kflags)) {
169         sprintf(tmp, "for (k = 0; k < i; k++)");
170     }
171     else {
172         sprintf(tmp, "for (k = i + 1; k < %lu; k++)", pitch);
173     }
174     kgenBeginBranch(ctx, tmp);
175     if (isComplexType(dtype)) {
176         sprintf(tmp, "dst[k * %lu + lid] = dst[k * %lu + lid] - "
177                      "mul(src[k * %lu + i], dst[i * %lu + lid]);\n",
178                 pitch, pitch, pitch, pitch);
179     }
180     else {
181         sprintf(tmp, "dst[k * %lu + lid] = dst[k * %lu + lid] - "
182                       "dst[i * %lu + lid] * src[k * %lu + i];\n",
183                 pitch, pitch, pitch, pitch);
184     }
185     kgenAddStmt(ctx, tmp);
186     kgenEndBranch(ctx, NULL);
187     kgenEndBranch(ctx, NULL);
188     kgenEndFuncBody(ctx);
189 }
190 
191