1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 /*
19  * Related to BLAS memory patterns
20  */
21 
22 #ifndef BLAS_MEMPAT_H_
23 #define BLAS_MEMPAT_H_
24 
25 #include <clBLAS.h>
26 #include <mempat.h>
27 #include <clkern.h>
28 #include <kern_cache.h>
29 
30 /**
31  * @brief Type of internal function implementation
32  */
33 typedef enum clblasImplementation {
34 
35     clblasDefaultGemm,           /**< Default: let the library decide what to use. */
36     clblasLdsBlockGemm,          /**< Use blocked GEMM with LDS optimization. */
37     clblasImageBlockGemm,        /**< Use blocked GEMM with image-based... */
38     clblasBlockGemmWithCaching,  /**< Use blocked GEMM with cache-usage optimization. */
39     clblasSubgroupGemmWithCaching,/**< Use subgroup GEMM with cache-usage optimization. */
40 
41     clblasDefaultTrmm,           /**< Default: let the library decide what to use. */
42     clblasLdsBlockTrmm,          /**< Use blocked TRMM with LDS optimization. */
43     clblasImageBlockTrmm,        /**< Use blocked TRMM with image-based... */
44     clblasBlockTrmmWithCaching,  /**< Use blocked TRMM with cache-usage optimization. */
45     clblasSubgroupTrmmWithCaching,/**< Use subgroup TRMM with cache-usage optimization. */
46 
47     clblasDefaultTrsm,           /**< Default: let the library decide what to use. */
48     clblasLdsBlockTrsm,          /**< Use blocked TRSM with LDS optimization. */
49     clblasImageBlockTrsm,        /**< Use blocked TRSM with image-based... */
50     clblasBlockTrsmWithCaching,  /**< Use blocked TRSM with cache-usage optimization. */
51     clblasBlockTrsmWithoutLds,
52 
53     clblasDefaultSyrk,
54     clblasBlockSyrk,
55     clblasSubgSyrk,
56 
57     clblasDefaultSyr2k,
58     clblasBlockSyr2k,
59     clblasSubgSyr2k
60 
61 } clblasImplementation;
62 
63 /**
64  * @internal
65  * @brief extra information for a memory pattern
66  *        used for BLAS problem solving
67  * @ingroup BLAS_SOLVERIF_SPEC
68  */
69 typedef struct CLBLASMpatExtra {
70     /** memory levels used to store blocks of matrix A */
71     meml_set_t aMset;
72     /** memory levels used to store blocks of matrix B */
73     meml_set_t bMset;
74     CLMemType mobjA;
75     CLMemType mobjB;
76 } CLBLASMpatExtra;
77 
78 /*
79  * init memory patterns for the xGEMM functions
80  *
81  * Returns number of the initialized patterns
82  */
83 unsigned int
84 initGemmMemPatterns(MemoryPattern *mempats);
85 
86 /*
87  * Get index of the specific xGEMM pattern
88  */
89 int
90 getGemmMemPatternIndex(clblasImplementation impl);
91 
92 /*
93  * Get preferred xGEMM pattern
94  */
95 clblasImplementation
96 getGemmPreferredPattern(void);
97 
98 /*
99  * init memory patterns for the xGEMV functions
100  *
101  * Returns number of the initialized patterns
102  */
103 unsigned int
104 initGemvMemPatterns(MemoryPattern *mempats);
105 
106 /*
107  * Get index of the specific xGEMV pattern
108  */
109 int
110 getGemvMemPatternIndex(clblasImplementation impl);
111 
112 /*
113  * init memory patterns for the xSYMV functions
114  *
115  * Returns number of the initialized patterns
116  */
117 unsigned int
118 initSymvMemPatterns(MemoryPattern *mempats);
119 
120 /*
121  * Get index of the specific xSYMV pattern
122  */
123 int
124 getSymvMemPatternIndex(clblasImplementation impl);
125 
126 /*
127  * init memory patterns for the xTRMM functions
128  *
129  * Returns number of the initialized patterns
130  */
131 unsigned int
132 initTrmmMemPatterns(MemoryPattern *mempats);
133 
134 /*
135  * Get index of the specific xTRMM pattern
136  */
137 int
138 getTrmmMemPatternIndex(clblasImplementation impl);
139 
140 /*
141  * Get preferred xTRMM pattern
142  */
143 clblasImplementation
144 getTrmmPreferredPattern(void);
145 
146 /*
147  * init memory patterns for the xTRSM functions
148  *
149  * Returns number of the initialized patterns
150  */
151 unsigned int
152 initTrsmMemPatterns(MemoryPattern *mempats);
153 
154 /*
155  * Get index of the specific xTRSM pattern
156  */
157 int
158 getTrsmMemPatternIndex(clblasImplementation impl);
159 
160 /*
161  * Get preferred xTRSM pattern
162  */
163 clblasImplementation
164 getTrsmPreferredPattern(void);
165 
166 /*
167  * init memory patterns for the xSYR2K functions
168  *
169  * Returns number of the initialized patterns
170  */
171 unsigned int
172 initSyr2kMemPatterns(MemoryPattern *mempats);
173 
174 /*
175  * Get index of the specific xSYR2K pattern
176  */
177 int
178 getSyr2kMemPatternIndex(clblasImplementation impl);
179 
180 /*
181  * init memory patterns for the xSYRK functions
182  *
183  * Returns number of the initialized patterns
184  */
185 unsigned int
186 initSyrkMemPatterns(MemoryPattern *mempats);
187 
188 /*
189  * Get index of the specific xSYRK pattern
190  */
191 int
192 getSyrkMemPatternIndex(clblasImplementation impl);
193 
194 /*
195  * init memory patters for TRMV routine
196  * Returns the number of inited patterns
197  */
198 unsigned int
199 initTrmvMemPatterns(MemoryPattern *mempats);
200 
201 int
202 getTrmvMemPatternIndex(clblasImplementation impl);
203 
204 /*
205  * init memory patterns for TRSV TRTRI routine
206  * Returns the number of inited patterns
207  */
208 unsigned int
209 initTrsvMemPatterns(MemoryPattern *mempats);
210 
211 int
212 getTrsvMemPatternIndex(clblasImplementation impl);
213 
214 unsigned int
215 initTrsvGemvMemPatterns(MemoryPattern *mempats);
216 
217 int
218 getTrsvGemvMemPatternIndex(clblasImplementation impl);
219 
220 unsigned int
221 initSymmMemPatterns(MemoryPattern *mempats);
222 
223 int
224 getSymmMemPatternIndex(clblasImplementation impl);
225 
226 unsigned int
227 initGemmV2MemPatterns(MemoryPattern *mempats);
228 
229 int
230 getGemmV2MemPatternIndex(clblasImplementation impl);
231 
232 unsigned int
233 initGemmV2TailMemPatterns(MemoryPattern *mempats);
234 
235 int
236 getGemmV2TailMemPatternIndex(clblasImplementation impl);
237 
238 /*
239  * init memory patterns for the xSYR functions
240  *
241  * Returns number of the initialized patterns
242  */
243 unsigned int
244 initSyrMemPatterns(MemoryPattern *mempats);
245 
246 /*
247  * Get index of the specific xSYR pattern
248  */
249 int
250 getSyrMemPatternIndex(clblasImplementation impl);
251 
252 /*
253  * init memory patterns for the xSYR2 functions
254  *
255  * Returns number of the initialized patterns
256  */
257 unsigned int
258 initSyr2MemPatterns(MemoryPattern *mempats);
259 
260 /*
261  * Get index of the specific xSYR2 pattern
262  */
263 int
264 getSyr2MemPatternIndex(clblasImplementation impl);
265 
266 
267 /*
268  * init memory patters for GER routine
269  * Returns the number of inited patterns
270  */
271 unsigned int
272 initGerMemPatterns(MemoryPattern *mempats);
273 
274 int
275 getGerMemPatternIndex(clblasImplementation impl);
276 
277 unsigned int
278 initHerMemPatterns(MemoryPattern *mempats);
279 
280 /*
281  * Get index of the specific xSYR pattern
282  */
283 int
284 getHerMemPatternIndex(clblasImplementation impl);
285 
286 /*
287  * init memory patterns for the xHER2 functions
288  *
289  * Returns number of the initialized patterns
290  */
291 unsigned int
292 initHer2MemPatterns(MemoryPattern *mempats);
293 
294 /*
295  * Get index of the specific xHER2 pattern
296  */
297 int
298 getHer2MemPatternIndex(clblasImplementation impl);
299 
300 unsigned int
301 initGbmvMemPatterns(MemoryPattern *mempats);
302 
303 int
304 getGbmvMemPatternIndex(clblasImplementation impl);
305 
306 unsigned int
307 initSwapMemPatterns(MemoryPattern *mempats);
308 
309 int
310 getSwapMemPatternIndex(clblasImplementation impl);
311 
312 unsigned int
313 initScalMemPatterns(MemoryPattern *mempats);
314 
315 int
316 getScalMemPatternIndex(clblasImplementation impl);
317 
318 unsigned int
319 initCopyMemPatterns(MemoryPattern *mempats);
320 
321 int
322 getCopyMemPatternIndex(clblasImplementation impl);
323 
324 unsigned int
325 initDotMemPatterns(MemoryPattern *mempats);
326 
327 int
328 getDotMemPatternIndex(clblasImplementation impl);
329 
330 unsigned int
331 initAxpyMemPatterns(MemoryPattern *mempats);
332 
333 int
334 getAxpyMemPatternIndex(clblasImplementation impl);
335 
336 unsigned int
337 initReductionMemPatterns(MemoryPattern *mempats);
338 
339 int
340 getReductionMemPatternIndex(clblasImplementation impl);
341 
342 unsigned int
343 initRotgMemPatterns(MemoryPattern *mempats);
344 
345 int
346 getRotgMemPatternIndex(clblasImplementation impl);
347 
348 unsigned int
349 initRotmgMemPatterns(MemoryPattern *mempats);
350 
351 int
352 getRotmgMemPatternIndex(clblasImplementation impl);
353 
354 unsigned int
355 initRotmMemPatterns(MemoryPattern *mempats);
356 
357 int
358 getRotmMemPatternIndex(clblasImplementation impl);
359 
360 unsigned int
361 initiAmaxMemPatterns(MemoryPattern *mempats);
362 
363 int
364 getiAmaxMemPatternIndex(clblasImplementation impl);
365 
366 unsigned int
367 initNrm2MemPatterns(MemoryPattern *mempats);
368 
369 int
370 getNrm2MemPatternIndex(clblasImplementation impl);
371 
372 unsigned int
373 initAsumMemPatterns(MemoryPattern *mempats);
374 
375 int
376 getAsumMemPatternIndex(clblasImplementation impl);
377 
378 #endif /* BLAS_MEMPAT_H_ */
379