1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 /*
19  * Generators initialization
20  */
21 
22 #include <blas_mempat.h>
23 
24 #include "clblas-internal.h"
25 #include "init.h"
26 
27 unsigned int
initGemmMemPatterns(MemoryPattern * mempats)28 initGemmMemPatterns(MemoryPattern *mempats)
29 {
30     initGemmLdsPattern(&mempats[0]);
31     initGemmImgPattern(&mempats[1]);
32 	InitGEMMCachedBlockPattern(&mempats[2]);
33 	InitGEMMCachedSubgroupPattern(&mempats[3]);
34     return 4;
35 }
36 
37 int
getGemmMemPatternIndex(clblasImplementation impl)38 getGemmMemPatternIndex(clblasImplementation impl)
39 {
40     switch (impl) {
41 	case clblasLdsBlockGemm:				return  0;
42     case clblasImageBlockGemm:			return  1;
43     case clblasBlockGemmWithCaching:		return  2;
44     case clblasSubgroupGemmWithCaching:	return	3;
45 	default:								return -1;
46     }
47 }
48 
49 clblasImplementation
getGemmPreferredPattern(void)50 getGemmPreferredPattern(void)
51 {
52     switch (clblasSolvers[CLBLAS_GEMM].defaultPattern) {
53     case 0:  return clblasLdsBlockGemm;
54     case 1:  return clblasImageBlockGemm;
55     case 2:  return clblasBlockGemmWithCaching;
56     case 3:  return clblasSubgroupGemmWithCaching;
57     default: return clblasDefaultGemm;
58     }
59 }
60 
61 unsigned int
initGemvMemPatterns(MemoryPattern * mempats)62 initGemvMemPatterns(MemoryPattern *mempats)
63 {
64     initGemvPattern(mempats);
65 
66     return 1;
67 }
68 
69 int
getGemvMemPatternIndex(clblasImplementation impl)70 getGemvMemPatternIndex(clblasImplementation impl)
71 {
72     switch (impl) {
73     default:    return -1;
74     }
75 }
76 
77 unsigned int
initSymvMemPatterns(MemoryPattern * mempats)78 initSymvMemPatterns(MemoryPattern *mempats)
79 {
80     initSymvPattern(mempats);
81 
82     return 1;
83 }
84 
85 int
getSymvMemPatternIndex(clblasImplementation impl)86 getSymvMemPatternIndex(clblasImplementation impl)
87 {
88     switch (impl) {
89     default:    return -1;
90     }
91 }
92 
93 unsigned int
initTrmmMemPatterns(MemoryPattern * mempats)94 initTrmmMemPatterns(MemoryPattern *mempats)
95 {
96     initTrmmLdsPattern(mempats);
97     initTrmmImgPattern(&mempats[1]);
98     initTrmmCachedBlockPattern(&mempats[2]);
99     initTrmmCachedSubgroupPattern(&mempats[3]);
100 
101     return 4;
102 }
103 
104 int
getTrmmMemPatternIndex(clblasImplementation impl)105 getTrmmMemPatternIndex(clblasImplementation impl)
106 {
107     switch (impl) {
108 
109         case clblasLdsBlockTrmm:             return  0;
110         case clblasImageBlockTrmm:           return  1;
111         case clblasBlockTrmmWithCaching:     return  2;
112         case clblasSubgroupTrmmWithCaching:  return 3;
113 
114         default: return -1;
115     }
116 }
117 
118 clblasImplementation
getTrmmPreferredPattern(void)119 getTrmmPreferredPattern(void)
120 {
121     switch (clblasSolvers[CLBLAS_TRMM].defaultPattern) {
122 
123         case 0: return clblasLdsBlockTrmm;
124         case 1: return clblasImageBlockTrmm;
125         case 2: return clblasBlockTrmmWithCaching;
126         case 3: return clblasSubgroupTrmmWithCaching;
127 
128         default: return clblasDefaultTrmm;
129     }
130 }
131 
132 unsigned int
initTrsmMemPatterns(MemoryPattern * mempats)133 initTrsmMemPatterns(MemoryPattern *mempats)
134 {
135     initTrsmLdsPattern(mempats);
136     initTrsmImgPattern(&mempats[1]);
137     initTrsmLdsLessCachedPattern(&mempats[2]);
138     initTrsmCachedPattern(&mempats[3]);
139 
140     return 4;
141 }
142 
143 int
getTrsmMemPatternIndex(clblasImplementation impl)144 getTrsmMemPatternIndex(clblasImplementation impl)
145 {
146     switch (impl) {
147     case clblasLdsBlockTrsm:         return  0;
148     case clblasImageBlockTrsm:       return  1;
149     case clblasBlockTrsmWithoutLds:  return  2;
150     case clblasBlockTrsmWithCaching: return  3;
151     default:                            return -1;
152     }
153 }
154 
155 clblasImplementation
getTrsmPreferredPattern(void)156 getTrsmPreferredPattern(void)
157 {
158     switch (clblasSolvers[CLBLAS_TRSM].defaultPattern) {
159     case 0:  return clblasLdsBlockTrsm;
160     case 1:  return clblasImageBlockTrsm;
161     case 2:  return clblasBlockTrsmWithoutLds;
162     case 3:  return clblasBlockTrsmWithCaching;
163     default: return clblasDefaultTrsm;
164     }
165 }
166 
167 unsigned int
initSyrkMemPatterns(MemoryPattern * mempats)168 initSyrkMemPatterns(MemoryPattern *mempats)
169 {
170     initSyrkBlockPattern(&mempats[0]);
171     initSyrkSubgPattern(&mempats[1]);
172 
173     return 2;
174 }
175 
176 clblasImplementation
getSyrkPreferredPattern(void)177 getSyrkPreferredPattern(void)
178 {
179     switch (clblasSolvers[CLBLAS_SYRK].defaultPattern) {
180 
181     case 0:  return clblasBlockSyrk;
182     case 1:  return clblasSubgSyrk;
183     default: return clblasDefaultSyrk;
184 
185     }
186 }
187 
188 int
getSyrkMemPatternIndex(clblasImplementation impl)189 getSyrkMemPatternIndex(clblasImplementation impl)
190 {
191     switch (impl) {
192 
193     case clblasBlockSyrk: return 0;
194     case clblasSubgSyrk: return 1;
195     default:    return -1;
196 
197     }
198 }
199 
200 unsigned int
initSyr2kMemPatterns(MemoryPattern * mempats)201 initSyr2kMemPatterns(MemoryPattern *mempats)
202 {
203     initSyr2kBlockPattern(&mempats[0]);
204     initSyr2kSubgPattern(&mempats[1]);
205 
206     return 2;
207 }
208 
209 clblasImplementation
getSyr2kPreferredPattern(void)210 getSyr2kPreferredPattern(void)
211 {
212     switch (clblasSolvers[CLBLAS_SYR2K].defaultPattern) {
213 
214     case 0:  return clblasBlockSyr2k;
215     case 1:  return clblasSubgSyr2k;
216     default: return clblasDefaultSyr2k;
217 
218     }
219 }
220 
221 int
getSyr2kMemPatternIndex(clblasImplementation impl)222 getSyr2kMemPatternIndex(clblasImplementation impl)
223 {
224     switch (impl) {
225 
226     case clblasBlockSyr2k: return 0;
227     case clblasSubgSyr2k: return 1;
228     default:    return -1;
229 
230     }
231 }
232 
233 unsigned int
initTrmvMemPatterns(MemoryPattern * mempats)234 initTrmvMemPatterns(MemoryPattern *mempats)
235 {
236 	initTrmvRegisterPattern(&mempats[0]);
237 	return 1;
238 }
239 
240 int
getTrmvMemPatternIndex(clblasImplementation impl)241 getTrmvMemPatternIndex(clblasImplementation impl)
242 {
243 	switch(impl) {
244 	default: return -1;
245 	}
246 }
247 
248 unsigned int
initTrsvMemPatterns(MemoryPattern * mempats)249 initTrsvMemPatterns(MemoryPattern *mempats)
250 {
251 	initTrsvDefaultPattern(&mempats[0]);
252 	return 1;
253 }
254 
255 int
getTrsvMemPatternIndex(clblasImplementation impl)256 getTrsvMemPatternIndex(clblasImplementation impl)
257 {
258 	switch(impl) {
259 	default: return -1;
260 	}
261 }
262 
263 unsigned int
initSyrMemPatterns(MemoryPattern * mempats)264 initSyrMemPatterns(MemoryPattern *mempats)
265 {
266     initSyrDefaultPattern(&mempats[0]);
267     return 1;
268 }
269 
270 int
getSyrMemPatternIndex(clblasImplementation impl)271 getSyrMemPatternIndex(clblasImplementation impl)
272 {
273     switch(impl) {
274     default: return -1;
275     }
276 }
277 
278 unsigned int
initSyr2MemPatterns(MemoryPattern * mempats)279 initSyr2MemPatterns(MemoryPattern *mempats)
280 {
281 	initSyr2DefaultPattern(&mempats[0]);
282 	return 1;
283 }
284 
285 int
getSyr2MemPatternIndex(clblasImplementation impl)286 getSyr2MemPatternIndex(clblasImplementation impl)
287 {
288     switch(impl) {
289     default: return -1;
290     }
291 }
292 
293 unsigned int
initTrsvGemvMemPatterns(MemoryPattern * mempats)294 initTrsvGemvMemPatterns(MemoryPattern *mempats)
295 {
296 	initTrsvGemvDefaultPattern(&mempats[0]);
297 	return 1;
298 }
299 
300 int
getTrsvGemvMemPatternIndex(clblasImplementation impl)301 getTrsvGemvMemPatternIndex(clblasImplementation impl)
302 {
303 	switch(impl) {
304 	default: return -1;
305 	}
306 }
307 
308 unsigned int
initSymmMemPatterns(MemoryPattern * mempats)309 initSymmMemPatterns(MemoryPattern *mempats)
310 {
311 	initSymmDefaultPattern(&mempats[0]);
312 	return 1;
313 }
314 
315 
316 int
getSymmMemPatternIndex(clblasImplementation impl)317 getSymmMemPatternIndex(clblasImplementation impl)
318 {
319 	switch(impl) {
320 	default: return -1;
321 	}
322 }
323 
324 unsigned int
initGemmV2MemPatterns(MemoryPattern * mempats)325 initGemmV2MemPatterns(MemoryPattern *mempats)
326 {
327 	initGemmV2CachedPattern(mempats);
328 	return 1;
329 }
330 
331 int
getGemmV2MemPatternIndex(clblasImplementation impl)332 getGemmV2MemPatternIndex(clblasImplementation impl)
333 {
334 	switch(impl) {
335 		default: return -1;
336 	}
337 }
338 
339 unsigned int
initGemmV2TailMemPatterns(MemoryPattern * mempats)340 initGemmV2TailMemPatterns(MemoryPattern *mempats)
341 {
342 	initGemmV2TailCachedPattern(mempats);
343 	return 1;
344 }
345 
346 int
getGemmV2TailMemPatternIndex(clblasImplementation impl)347 getGemmV2TailMemPatternIndex(clblasImplementation impl)
348 {
349 	switch(impl) {
350 		default: return -1;
351 	}
352 }
353 
354 unsigned int
initGerMemPatterns(MemoryPattern * mempats)355 initGerMemPatterns(MemoryPattern *mempats)
356 {
357 	initGerRegisterPattern(&mempats[0]);
358 	return 1;
359 }
360 
361 int
getGerMemPatternIndex(clblasImplementation impl)362 getGerMemPatternIndex(clblasImplementation impl)
363 {
364 	switch(impl) {
365 	default: return -1;
366 	}
367 }
368 
369 unsigned int
initHerMemPatterns(MemoryPattern * mempats)370 initHerMemPatterns(MemoryPattern *mempats)
371 {
372     initHerDefaultPattern(&mempats[0]);
373     return 1;
374 }
375 
376 int
getHerMemPatternIndex(clblasImplementation impl)377 getHerMemPatternIndex(clblasImplementation impl)
378 {
379     switch(impl) {
380     default: return -1;
381     }
382 }
383 
384 unsigned int
initHer2MemPatterns(MemoryPattern * mempats)385 initHer2MemPatterns(MemoryPattern *mempats)
386 {
387 	initHer2DefaultPattern(&mempats[0]);
388 	return 1;
389 }
390 
391 int
getHer2MemPatternIndex(clblasImplementation impl)392 getHer2MemPatternIndex(clblasImplementation impl)
393 {
394     switch(impl) {
395     default: return -1;
396     }
397 }
398 
399 unsigned int
initGbmvMemPatterns(MemoryPattern * mempats)400 initGbmvMemPatterns(MemoryPattern *mempats)
401 {
402 	initGbmvRegisterPattern(&mempats[0]);
403 	return 1;
404 }
405 
406 int
getGbmvMemPatternIndex(clblasImplementation impl)407 getGbmvMemPatternIndex(clblasImplementation impl)
408 {
409 	switch(impl) {
410 	default: return -1;
411 	}
412 }
413 
414 unsigned int
initSwapMemPatterns(MemoryPattern * mempats)415 initSwapMemPatterns(MemoryPattern *mempats)
416 {
417     initSwapRegisterPattern(&mempats[0]);
418     return 1;
419 }
420 
421 int
getSwapMemPatternIndex(clblasImplementation impl)422 getSwapMemPatternIndex(clblasImplementation impl)
423 {
424     switch(impl) {
425     default: return -1;
426     }
427 }
428 
429 unsigned int
initScalMemPatterns(MemoryPattern * mempats)430 initScalMemPatterns(MemoryPattern *mempats)
431 {
432     initScalRegisterPattern(&mempats[0]);
433     return 1;
434 }
435 
436 
437 int
getScalMemPatternIndex(clblasImplementation impl)438 getScalMemPatternIndex(clblasImplementation impl)
439 {
440     switch(impl) {
441     default: return -1;
442     }
443 }
444 
445 unsigned int
initCopyMemPatterns(MemoryPattern * mempats)446 initCopyMemPatterns(MemoryPattern *mempats)
447 {
448     initCopyRegisterPattern(&mempats[0]);
449     return 1;
450 }
451 
452 int
getCopyMemPatternIndex(clblasImplementation impl)453 getCopyMemPatternIndex(clblasImplementation impl)
454 {
455     switch(impl) {
456     default: return -1;
457     }
458 }
459 
460 unsigned int
initAxpyMemPatterns(MemoryPattern * mempats)461 initAxpyMemPatterns(MemoryPattern *mempats)
462 {
463     initAxpyRegisterPattern(&mempats[0]);
464     return 1;
465 }
466 
467 int
getAxpyMemPatternIndex(clblasImplementation impl)468 getAxpyMemPatternIndex(clblasImplementation impl)
469 {
470     switch(impl) {
471     default: return -1;
472     }
473 }
474 
475 unsigned int
initDotMemPatterns(MemoryPattern * mempats)476 initDotMemPatterns(MemoryPattern *mempats)
477 {
478     initDotRegisterPattern(&mempats[0]);
479     return 1;
480 }
481 
482 int
getDotMemPatternIndex(clblasImplementation impl)483 getDotMemPatternIndex(clblasImplementation impl)
484 {
485     switch(impl) {
486     default: return -1;
487     }
488 }
489 
490 unsigned int
initReductionMemPatterns(MemoryPattern * mempats)491 initReductionMemPatterns(MemoryPattern *mempats)
492 {
493     initReductionRegisterPattern(&mempats[0]);
494     return 1;
495 }
496 
497 int
getReductionMemPatternIndex(clblasImplementation impl)498 getReductionMemPatternIndex(clblasImplementation impl)
499 {
500     switch(impl) {
501     default: return -1;
502     }
503 }
504 
505 unsigned int
initRotgMemPatterns(MemoryPattern * mempats)506 initRotgMemPatterns(MemoryPattern *mempats)
507 {
508     initRotgRegisterPattern(&mempats[0]);
509     return 1;
510 }
511 
512 int
getRotgMemPatternIndex(clblasImplementation impl)513 getRotgMemPatternIndex(clblasImplementation impl)
514 {
515     switch(impl) {
516     default: return -1;
517     }
518 }
519 
520 unsigned int
initRotmgMemPatterns(MemoryPattern * mempats)521 initRotmgMemPatterns(MemoryPattern *mempats)
522 {
523     initRotmgRegisterPattern(&mempats[0]);
524     return 1;
525 }
526 
527 int
getRotmgMemPatternIndex(clblasImplementation impl)528 getRotmgMemPatternIndex(clblasImplementation impl)
529 {
530     switch(impl) {
531     default: return -1;
532     }
533 }
534 
535 unsigned int
initRotmMemPatterns(MemoryPattern * mempats)536 initRotmMemPatterns(MemoryPattern *mempats)
537 {
538     initRotmRegisterPattern(&mempats[0]);
539     return 1;
540 }
541 
542 int
getRotmMemPatternIndex(clblasImplementation impl)543 getRotmMemPatternIndex(clblasImplementation impl)
544 {
545     switch(impl) {
546     default: return -1;
547     }
548 }
549 
550 unsigned int
initiAmaxMemPatterns(MemoryPattern * mempats)551 initiAmaxMemPatterns(MemoryPattern *mempats)
552 {
553     initiAmaxRegisterPattern(&mempats[0]);
554     return 1;
555 }
556 
557 int
getiAmaxMemPatternIndex(clblasImplementation impl)558 getiAmaxMemPatternIndex(clblasImplementation impl)
559 {
560     switch(impl) {
561     default: return -1;
562     }
563 }
564 
565 unsigned int
initNrm2MemPatterns(MemoryPattern * mempats)566 initNrm2MemPatterns(MemoryPattern *mempats)
567 {
568     initNrm2RegisterPattern(&mempats[0]);
569     return 1;
570 }
571 
572 int
getNrm2MemPatternIndex(clblasImplementation impl)573 getNrm2MemPatternIndex(clblasImplementation impl)
574 {
575     switch(impl) {
576     default: return -1;
577     }
578 }
579 
580 unsigned int
initAsumMemPatterns(MemoryPattern * mempats)581 initAsumMemPatterns(MemoryPattern *mempats)
582 {
583     initAsumRegisterPattern(&mempats[0]);
584     return 1;
585 }
586 
587 int
getAsumMemPatternIndex(clblasImplementation impl)588 getAsumMemPatternIndex(clblasImplementation impl)
589 {
590     switch(impl) {
591     default: return -1;
592     }
593 }
594