1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 #include <stdio.h>
19
20 #include "../blas_kgen.h"
21 #include "trsm_kgen_legacy.h"
22
23 void
genUpdateIntermTrsmResult(struct KgenContext * ctx,const BlasGenSettings * gset,const char * optFuncName,const char * genericFuncName,bool withMhitCond)24 genUpdateIntermTrsmResult(
25 struct KgenContext *ctx,
26 const BlasGenSettings *gset,
27 const char *optFuncName,
28 const char *genericFuncName,
29 bool withMhitCond)
30 {
31 char tmp[1024];
32 const char *coordY, *coordX;
33 char *revAlp, *alp;
34 DataType dtype = gset->kextra->dtype;
35 KernelExtraFlags kflags = gset->kextra->flags;
36 const SubproblemDim *dim = &gset->subdims[1];
37 const KernelVarNames *kvarNames = &gset->varNames;
38
39 if (isComplexType(dtype)) {
40 if (dtype == TYPE_COMPLEX_FLOAT) {
41 revAlp = "div((float2)(-1.f, 0), alpha)";
42 alp = "(float2)(1.f, 0)";
43 }
44 else {
45 revAlp = "div((double2)(-1., 0), alpha)";
46 alp = "(double2)(1., 0)";
47 }
48 }
49 else {
50 revAlp = "-1. / alpha";
51 alp = "1.";
52 }
53
54 coordY = kvarNames->coordA;
55 coordX = kvarNames->coordB;
56
57 if (!(kflags & (KEXTRA_TAILS_M | KEXTRA_TAILS_N))) {
58 sprintf(tmp, "%s(B, c, %s, %s, %s, ldb, %s);\n",
59 optFuncName, alp, coordY, coordX, revAlp);
60 kgenAddStmt(ctx, tmp);
61 }
62 else {
63 if (withMhitCond) {
64 sprintf(tmp, "if ((%s < %s) && (%s < %s))",
65 coordY, kvarNames->sizeM, coordX, kvarNames->sizeN);
66 kgenBeginBranch(ctx, tmp);
67 }
68 else {
69 /* for x, y variables scope */
70 kgenBeginBranch(ctx, NULL);
71 }
72
73 sprintf(tmp, "uint y = min(%luu, %s - (uint)%s);\n"
74 "uint x = min(%luu, %s - (uint)%s);\n"
75 "if ((y == %luu) && (x == %luu)) {\n"
76 " %s(B, c, %s, %s, %s, ldb, %s);\n"
77 "}\n"
78 "else {\n"
79 " %s(B, c, %s, %s, %s, ldb, %s, y, x);\n"
80 "}\n",
81 dim->y, kvarNames->sizeM, coordY,
82 dim->x, kvarNames->sizeN, coordX,
83 dim->y, dim->x,
84 optFuncName, alp, coordY, coordX, revAlp,
85 genericFuncName, alp, coordY, coordX, revAlp);
86
87 kgenAddStmt(ctx, tmp);
88
89 kgenEndBranch(ctx, NULL);
90 }
91 }
92
93 void
genHeapTrsmResultToLDS(struct KgenContext * ctx,const BlasGenSettings * gset,const char * funcName,const char * dstName)94 genHeapTrsmResultToLDS(
95 struct KgenContext *ctx,
96 const BlasGenSettings *gset,
97 const char *funcName,
98 const char *dstName)
99 {
100 char tmp[1024];
101 char *alp;
102 unsigned int l1Pans;
103 DataType dtype = gset->kextra->dtype;
104 const SubproblemDim *dims = gset->subdims;
105
106 if(isComplexType(dtype)) {
107 if (dtype == TYPE_COMPLEX_FLOAT) {
108 alp = "(float2)(1.f, 0)";
109 }
110 else {
111 alp = "(double2)(1., 0)";
112 }
113 }
114 else {
115 alp = "1.";
116 }
117
118 l1Pans = (unsigned int)dims[0].x / (unsigned int)dims[1].x;
119 sprintf(tmp, "%s(%s, c, %s, (lid / %u * %lu), (lid %% %u * %lu), %lu);\n",
120 funcName, dstName, alp, l1Pans, dims[1].y, l1Pans, dims[1].x,
121 dims[0].bwidth);
122 kgenAddStmt(ctx, tmp);
123 }
124
125 void
genInvertingBlockFunc(struct KgenContext * ctx,size_t pitch,DataType dtype,KernelExtraFlags kflags)126 genInvertingBlockFunc(
127 struct KgenContext *ctx,
128 size_t pitch,
129 DataType dtype,
130 KernelExtraFlags kflags)
131 {
132 char tmp[1024];
133 const char *ctype;
134 ctype = dtypeBuiltinType(dtype);
135
136 sprintf(tmp, "void\ninvert(__local %s *src, __local %s *dst, int lid, "
137 "int lastRow)\n", ctype, ctype);
138 kgenDeclareFunction(ctx, tmp);
139 kgenBeginFuncBody(ctx);
140 kgenAddStmt(ctx, "int i, k;\n");
141
142 if (isComplexType(dtype)) {
143 sprintf(tmp, "dst[lid * %lu + lid].x = 1.f;\n", pitch);
144 }
145 else {
146 sprintf(tmp, "dst[lid * %lu + lid] = 1.f;\n", pitch);
147 }
148 kgenAddStmt(ctx, tmp);
149
150 if (isMatrixUpper(kflags)) {
151 sprintf(tmp, "for (i = lastRow - 1; i >= 0; i--)");
152 }
153 else {
154 sprintf(tmp, "for (i = 0; i < lastRow; i++)");
155 }
156 kgenBeginBranch(ctx, tmp);
157
158 if (isComplexType(dtype)) {
159 sprintf(tmp, "dst[i * %lu + lid] = div(dst[i * %lu + lid], "
160 "src[i * %lu + i]);\n", pitch, pitch, pitch);
161 }
162 else {
163 sprintf(tmp, "dst[i * %lu + lid] = dst[i * %lu + lid] / "
164 "src[i * %lu + i];\n", pitch, pitch, pitch);
165 }
166 kgenAddStmt(ctx, tmp);
167
168 if (isMatrixUpper(kflags)) {
169 sprintf(tmp, "for (k = 0; k < i; k++)");
170 }
171 else {
172 sprintf(tmp, "for (k = i + 1; k < %lu; k++)", pitch);
173 }
174 kgenBeginBranch(ctx, tmp);
175 if (isComplexType(dtype)) {
176 sprintf(tmp, "dst[k * %lu + lid] = dst[k * %lu + lid] - "
177 "mul(src[k * %lu + i], dst[i * %lu + lid]);\n",
178 pitch, pitch, pitch, pitch);
179 }
180 else {
181 sprintf(tmp, "dst[k * %lu + lid] = dst[k * %lu + lid] - "
182 "dst[i * %lu + lid] * src[k * %lu + i];\n",
183 pitch, pitch, pitch, pitch);
184 }
185 kgenAddStmt(ctx, tmp);
186 kgenEndBranch(ctx, NULL);
187 kgenEndBranch(ctx, NULL);
188 kgenEndFuncBody(ctx);
189 }
190
191