1 /* ************************************************************************ 2 * Copyright 2013 Advanced Micro Devices, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * ************************************************************************/ 16 17 18 /* 19 * Generators initialization 20 */ 21 22 #include <blas_mempat.h> 23 24 #include "clblas-internal.h" 25 #include "init.h" 26 27 unsigned int initGemmMemPatterns(MemoryPattern * mempats)28initGemmMemPatterns(MemoryPattern *mempats) 29 { 30 initGemmLdsPattern(&mempats[0]); 31 initGemmImgPattern(&mempats[1]); 32 InitGEMMCachedBlockPattern(&mempats[2]); 33 InitGEMMCachedSubgroupPattern(&mempats[3]); 34 return 4; 35 } 36 37 int getGemmMemPatternIndex(clblasImplementation impl)38getGemmMemPatternIndex(clblasImplementation impl) 39 { 40 switch (impl) { 41 case clblasLdsBlockGemm: return 0; 42 case clblasImageBlockGemm: return 1; 43 case clblasBlockGemmWithCaching: return 2; 44 case clblasSubgroupGemmWithCaching: return 3; 45 default: return -1; 46 } 47 } 48 49 clblasImplementation getGemmPreferredPattern(void)50getGemmPreferredPattern(void) 51 { 52 switch (clblasSolvers[CLBLAS_GEMM].defaultPattern) { 53 case 0: return clblasLdsBlockGemm; 54 case 1: return clblasImageBlockGemm; 55 case 2: return clblasBlockGemmWithCaching; 56 case 3: return clblasSubgroupGemmWithCaching; 57 default: return clblasDefaultGemm; 58 } 59 } 60 61 unsigned int initGemvMemPatterns(MemoryPattern * mempats)62initGemvMemPatterns(MemoryPattern *mempats) 63 { 64 initGemvPattern(mempats); 65 66 return 1; 67 } 68 69 int getGemvMemPatternIndex(clblasImplementation impl)70getGemvMemPatternIndex(clblasImplementation impl) 71 { 72 switch (impl) { 73 default: return -1; 74 } 75 } 76 77 unsigned int initSymvMemPatterns(MemoryPattern * mempats)78initSymvMemPatterns(MemoryPattern *mempats) 79 { 80 initSymvPattern(mempats); 81 82 return 1; 83 } 84 85 int getSymvMemPatternIndex(clblasImplementation impl)86getSymvMemPatternIndex(clblasImplementation impl) 87 { 88 switch (impl) { 89 default: return -1; 90 } 91 } 92 93 unsigned int initTrmmMemPatterns(MemoryPattern * mempats)94initTrmmMemPatterns(MemoryPattern *mempats) 95 { 96 initTrmmLdsPattern(mempats); 97 initTrmmImgPattern(&mempats[1]); 98 initTrmmCachedBlockPattern(&mempats[2]); 99 initTrmmCachedSubgroupPattern(&mempats[3]); 100 101 return 4; 102 } 103 104 int getTrmmMemPatternIndex(clblasImplementation impl)105getTrmmMemPatternIndex(clblasImplementation impl) 106 { 107 switch (impl) { 108 109 case clblasLdsBlockTrmm: return 0; 110 case clblasImageBlockTrmm: return 1; 111 case clblasBlockTrmmWithCaching: return 2; 112 case clblasSubgroupTrmmWithCaching: return 3; 113 114 default: return -1; 115 } 116 } 117 118 clblasImplementation getTrmmPreferredPattern(void)119getTrmmPreferredPattern(void) 120 { 121 switch (clblasSolvers[CLBLAS_TRMM].defaultPattern) { 122 123 case 0: return clblasLdsBlockTrmm; 124 case 1: return clblasImageBlockTrmm; 125 case 2: return clblasBlockTrmmWithCaching; 126 case 3: return clblasSubgroupTrmmWithCaching; 127 128 default: return clblasDefaultTrmm; 129 } 130 } 131 132 unsigned int initTrsmMemPatterns(MemoryPattern * mempats)133initTrsmMemPatterns(MemoryPattern *mempats) 134 { 135 initTrsmLdsPattern(mempats); 136 initTrsmImgPattern(&mempats[1]); 137 initTrsmLdsLessCachedPattern(&mempats[2]); 138 initTrsmCachedPattern(&mempats[3]); 139 140 return 4; 141 } 142 143 int getTrsmMemPatternIndex(clblasImplementation impl)144getTrsmMemPatternIndex(clblasImplementation impl) 145 { 146 switch (impl) { 147 case clblasLdsBlockTrsm: return 0; 148 case clblasImageBlockTrsm: return 1; 149 case clblasBlockTrsmWithoutLds: return 2; 150 case clblasBlockTrsmWithCaching: return 3; 151 default: return -1; 152 } 153 } 154 155 clblasImplementation getTrsmPreferredPattern(void)156getTrsmPreferredPattern(void) 157 { 158 switch (clblasSolvers[CLBLAS_TRSM].defaultPattern) { 159 case 0: return clblasLdsBlockTrsm; 160 case 1: return clblasImageBlockTrsm; 161 case 2: return clblasBlockTrsmWithoutLds; 162 case 3: return clblasBlockTrsmWithCaching; 163 default: return clblasDefaultTrsm; 164 } 165 } 166 167 unsigned int initSyrkMemPatterns(MemoryPattern * mempats)168initSyrkMemPatterns(MemoryPattern *mempats) 169 { 170 initSyrkBlockPattern(&mempats[0]); 171 initSyrkSubgPattern(&mempats[1]); 172 173 return 2; 174 } 175 176 clblasImplementation getSyrkPreferredPattern(void)177getSyrkPreferredPattern(void) 178 { 179 switch (clblasSolvers[CLBLAS_SYRK].defaultPattern) { 180 181 case 0: return clblasBlockSyrk; 182 case 1: return clblasSubgSyrk; 183 default: return clblasDefaultSyrk; 184 185 } 186 } 187 188 int getSyrkMemPatternIndex(clblasImplementation impl)189getSyrkMemPatternIndex(clblasImplementation impl) 190 { 191 switch (impl) { 192 193 case clblasBlockSyrk: return 0; 194 case clblasSubgSyrk: return 1; 195 default: return -1; 196 197 } 198 } 199 200 unsigned int initSyr2kMemPatterns(MemoryPattern * mempats)201initSyr2kMemPatterns(MemoryPattern *mempats) 202 { 203 initSyr2kBlockPattern(&mempats[0]); 204 initSyr2kSubgPattern(&mempats[1]); 205 206 return 2; 207 } 208 209 clblasImplementation getSyr2kPreferredPattern(void)210getSyr2kPreferredPattern(void) 211 { 212 switch (clblasSolvers[CLBLAS_SYR2K].defaultPattern) { 213 214 case 0: return clblasBlockSyr2k; 215 case 1: return clblasSubgSyr2k; 216 default: return clblasDefaultSyr2k; 217 218 } 219 } 220 221 int getSyr2kMemPatternIndex(clblasImplementation impl)222getSyr2kMemPatternIndex(clblasImplementation impl) 223 { 224 switch (impl) { 225 226 case clblasBlockSyr2k: return 0; 227 case clblasSubgSyr2k: return 1; 228 default: return -1; 229 230 } 231 } 232 233 unsigned int initTrmvMemPatterns(MemoryPattern * mempats)234initTrmvMemPatterns(MemoryPattern *mempats) 235 { 236 initTrmvRegisterPattern(&mempats[0]); 237 return 1; 238 } 239 240 int getTrmvMemPatternIndex(clblasImplementation impl)241getTrmvMemPatternIndex(clblasImplementation impl) 242 { 243 switch(impl) { 244 default: return -1; 245 } 246 } 247 248 unsigned int initTrsvMemPatterns(MemoryPattern * mempats)249initTrsvMemPatterns(MemoryPattern *mempats) 250 { 251 initTrsvDefaultPattern(&mempats[0]); 252 return 1; 253 } 254 255 int getTrsvMemPatternIndex(clblasImplementation impl)256getTrsvMemPatternIndex(clblasImplementation impl) 257 { 258 switch(impl) { 259 default: return -1; 260 } 261 } 262 263 unsigned int initSyrMemPatterns(MemoryPattern * mempats)264initSyrMemPatterns(MemoryPattern *mempats) 265 { 266 initSyrDefaultPattern(&mempats[0]); 267 return 1; 268 } 269 270 int getSyrMemPatternIndex(clblasImplementation impl)271getSyrMemPatternIndex(clblasImplementation impl) 272 { 273 switch(impl) { 274 default: return -1; 275 } 276 } 277 278 unsigned int initSyr2MemPatterns(MemoryPattern * mempats)279initSyr2MemPatterns(MemoryPattern *mempats) 280 { 281 initSyr2DefaultPattern(&mempats[0]); 282 return 1; 283 } 284 285 int getSyr2MemPatternIndex(clblasImplementation impl)286getSyr2MemPatternIndex(clblasImplementation impl) 287 { 288 switch(impl) { 289 default: return -1; 290 } 291 } 292 293 unsigned int initTrsvGemvMemPatterns(MemoryPattern * mempats)294initTrsvGemvMemPatterns(MemoryPattern *mempats) 295 { 296 initTrsvGemvDefaultPattern(&mempats[0]); 297 return 1; 298 } 299 300 int getTrsvGemvMemPatternIndex(clblasImplementation impl)301getTrsvGemvMemPatternIndex(clblasImplementation impl) 302 { 303 switch(impl) { 304 default: return -1; 305 } 306 } 307 308 unsigned int initSymmMemPatterns(MemoryPattern * mempats)309initSymmMemPatterns(MemoryPattern *mempats) 310 { 311 initSymmDefaultPattern(&mempats[0]); 312 return 1; 313 } 314 315 316 int getSymmMemPatternIndex(clblasImplementation impl)317getSymmMemPatternIndex(clblasImplementation impl) 318 { 319 switch(impl) { 320 default: return -1; 321 } 322 } 323 324 unsigned int initGemmV2MemPatterns(MemoryPattern * mempats)325initGemmV2MemPatterns(MemoryPattern *mempats) 326 { 327 initGemmV2CachedPattern(mempats); 328 return 1; 329 } 330 331 int getGemmV2MemPatternIndex(clblasImplementation impl)332getGemmV2MemPatternIndex(clblasImplementation impl) 333 { 334 switch(impl) { 335 default: return -1; 336 } 337 } 338 339 unsigned int initGemmV2TailMemPatterns(MemoryPattern * mempats)340initGemmV2TailMemPatterns(MemoryPattern *mempats) 341 { 342 initGemmV2TailCachedPattern(mempats); 343 return 1; 344 } 345 346 int getGemmV2TailMemPatternIndex(clblasImplementation impl)347getGemmV2TailMemPatternIndex(clblasImplementation impl) 348 { 349 switch(impl) { 350 default: return -1; 351 } 352 } 353 354 unsigned int initGerMemPatterns(MemoryPattern * mempats)355initGerMemPatterns(MemoryPattern *mempats) 356 { 357 initGerRegisterPattern(&mempats[0]); 358 return 1; 359 } 360 361 int getGerMemPatternIndex(clblasImplementation impl)362getGerMemPatternIndex(clblasImplementation impl) 363 { 364 switch(impl) { 365 default: return -1; 366 } 367 } 368 369 unsigned int initHerMemPatterns(MemoryPattern * mempats)370initHerMemPatterns(MemoryPattern *mempats) 371 { 372 initHerDefaultPattern(&mempats[0]); 373 return 1; 374 } 375 376 int getHerMemPatternIndex(clblasImplementation impl)377getHerMemPatternIndex(clblasImplementation impl) 378 { 379 switch(impl) { 380 default: return -1; 381 } 382 } 383 384 unsigned int initHer2MemPatterns(MemoryPattern * mempats)385initHer2MemPatterns(MemoryPattern *mempats) 386 { 387 initHer2DefaultPattern(&mempats[0]); 388 return 1; 389 } 390 391 int getHer2MemPatternIndex(clblasImplementation impl)392getHer2MemPatternIndex(clblasImplementation impl) 393 { 394 switch(impl) { 395 default: return -1; 396 } 397 } 398 399 unsigned int initGbmvMemPatterns(MemoryPattern * mempats)400initGbmvMemPatterns(MemoryPattern *mempats) 401 { 402 initGbmvRegisterPattern(&mempats[0]); 403 return 1; 404 } 405 406 int getGbmvMemPatternIndex(clblasImplementation impl)407getGbmvMemPatternIndex(clblasImplementation impl) 408 { 409 switch(impl) { 410 default: return -1; 411 } 412 } 413 414 unsigned int initSwapMemPatterns(MemoryPattern * mempats)415initSwapMemPatterns(MemoryPattern *mempats) 416 { 417 initSwapRegisterPattern(&mempats[0]); 418 return 1; 419 } 420 421 int getSwapMemPatternIndex(clblasImplementation impl)422getSwapMemPatternIndex(clblasImplementation impl) 423 { 424 switch(impl) { 425 default: return -1; 426 } 427 } 428 429 unsigned int initScalMemPatterns(MemoryPattern * mempats)430initScalMemPatterns(MemoryPattern *mempats) 431 { 432 initScalRegisterPattern(&mempats[0]); 433 return 1; 434 } 435 436 437 int getScalMemPatternIndex(clblasImplementation impl)438getScalMemPatternIndex(clblasImplementation impl) 439 { 440 switch(impl) { 441 default: return -1; 442 } 443 } 444 445 unsigned int initCopyMemPatterns(MemoryPattern * mempats)446initCopyMemPatterns(MemoryPattern *mempats) 447 { 448 initCopyRegisterPattern(&mempats[0]); 449 return 1; 450 } 451 452 int getCopyMemPatternIndex(clblasImplementation impl)453getCopyMemPatternIndex(clblasImplementation impl) 454 { 455 switch(impl) { 456 default: return -1; 457 } 458 } 459 460 unsigned int initAxpyMemPatterns(MemoryPattern * mempats)461initAxpyMemPatterns(MemoryPattern *mempats) 462 { 463 initAxpyRegisterPattern(&mempats[0]); 464 return 1; 465 } 466 467 int getAxpyMemPatternIndex(clblasImplementation impl)468getAxpyMemPatternIndex(clblasImplementation impl) 469 { 470 switch(impl) { 471 default: return -1; 472 } 473 } 474 475 unsigned int initDotMemPatterns(MemoryPattern * mempats)476initDotMemPatterns(MemoryPattern *mempats) 477 { 478 initDotRegisterPattern(&mempats[0]); 479 return 1; 480 } 481 482 int getDotMemPatternIndex(clblasImplementation impl)483getDotMemPatternIndex(clblasImplementation impl) 484 { 485 switch(impl) { 486 default: return -1; 487 } 488 } 489 490 unsigned int initReductionMemPatterns(MemoryPattern * mempats)491initReductionMemPatterns(MemoryPattern *mempats) 492 { 493 initReductionRegisterPattern(&mempats[0]); 494 return 1; 495 } 496 497 int getReductionMemPatternIndex(clblasImplementation impl)498getReductionMemPatternIndex(clblasImplementation impl) 499 { 500 switch(impl) { 501 default: return -1; 502 } 503 } 504 505 unsigned int initRotgMemPatterns(MemoryPattern * mempats)506initRotgMemPatterns(MemoryPattern *mempats) 507 { 508 initRotgRegisterPattern(&mempats[0]); 509 return 1; 510 } 511 512 int getRotgMemPatternIndex(clblasImplementation impl)513getRotgMemPatternIndex(clblasImplementation impl) 514 { 515 switch(impl) { 516 default: return -1; 517 } 518 } 519 520 unsigned int initRotmgMemPatterns(MemoryPattern * mempats)521initRotmgMemPatterns(MemoryPattern *mempats) 522 { 523 initRotmgRegisterPattern(&mempats[0]); 524 return 1; 525 } 526 527 int getRotmgMemPatternIndex(clblasImplementation impl)528getRotmgMemPatternIndex(clblasImplementation impl) 529 { 530 switch(impl) { 531 default: return -1; 532 } 533 } 534 535 unsigned int initRotmMemPatterns(MemoryPattern * mempats)536initRotmMemPatterns(MemoryPattern *mempats) 537 { 538 initRotmRegisterPattern(&mempats[0]); 539 return 1; 540 } 541 542 int getRotmMemPatternIndex(clblasImplementation impl)543getRotmMemPatternIndex(clblasImplementation impl) 544 { 545 switch(impl) { 546 default: return -1; 547 } 548 } 549 550 unsigned int initiAmaxMemPatterns(MemoryPattern * mempats)551initiAmaxMemPatterns(MemoryPattern *mempats) 552 { 553 initiAmaxRegisterPattern(&mempats[0]); 554 return 1; 555 } 556 557 int getiAmaxMemPatternIndex(clblasImplementation impl)558getiAmaxMemPatternIndex(clblasImplementation impl) 559 { 560 switch(impl) { 561 default: return -1; 562 } 563 } 564 565 unsigned int initNrm2MemPatterns(MemoryPattern * mempats)566initNrm2MemPatterns(MemoryPattern *mempats) 567 { 568 initNrm2RegisterPattern(&mempats[0]); 569 return 1; 570 } 571 572 int getNrm2MemPatternIndex(clblasImplementation impl)573getNrm2MemPatternIndex(clblasImplementation impl) 574 { 575 switch(impl) { 576 default: return -1; 577 } 578 } 579 580 unsigned int initAsumMemPatterns(MemoryPattern * mempats)581initAsumMemPatterns(MemoryPattern *mempats) 582 { 583 initAsumRegisterPattern(&mempats[0]); 584 return 1; 585 } 586 587 int getAsumMemPatternIndex(clblasImplementation impl)588getAsumMemPatternIndex(clblasImplementation impl) 589 { 590 switch(impl) { 591 default: return -1; 592 } 593 } 594