1 /* ************************************************************************ 2 * Copyright 2013 Advanced Micro Devices, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * ************************************************************************/ 16 17 18 /* 19 * Related to BLAS memory patterns 20 */ 21 22 #ifndef BLAS_MEMPAT_H_ 23 #define BLAS_MEMPAT_H_ 24 25 #include <clBLAS.h> 26 #include <mempat.h> 27 #include <clkern.h> 28 #include <kern_cache.h> 29 30 /** 31 * @brief Type of internal function implementation 32 */ 33 typedef enum clblasImplementation { 34 35 clblasDefaultGemm, /**< Default: let the library decide what to use. */ 36 clblasLdsBlockGemm, /**< Use blocked GEMM with LDS optimization. */ 37 clblasImageBlockGemm, /**< Use blocked GEMM with image-based... */ 38 clblasBlockGemmWithCaching, /**< Use blocked GEMM with cache-usage optimization. */ 39 clblasSubgroupGemmWithCaching,/**< Use subgroup GEMM with cache-usage optimization. */ 40 41 clblasDefaultTrmm, /**< Default: let the library decide what to use. */ 42 clblasLdsBlockTrmm, /**< Use blocked TRMM with LDS optimization. */ 43 clblasImageBlockTrmm, /**< Use blocked TRMM with image-based... */ 44 clblasBlockTrmmWithCaching, /**< Use blocked TRMM with cache-usage optimization. */ 45 clblasSubgroupTrmmWithCaching,/**< Use subgroup TRMM with cache-usage optimization. */ 46 47 clblasDefaultTrsm, /**< Default: let the library decide what to use. */ 48 clblasLdsBlockTrsm, /**< Use blocked TRSM with LDS optimization. */ 49 clblasImageBlockTrsm, /**< Use blocked TRSM with image-based... */ 50 clblasBlockTrsmWithCaching, /**< Use blocked TRSM with cache-usage optimization. */ 51 clblasBlockTrsmWithoutLds, 52 53 clblasDefaultSyrk, 54 clblasBlockSyrk, 55 clblasSubgSyrk, 56 57 clblasDefaultSyr2k, 58 clblasBlockSyr2k, 59 clblasSubgSyr2k 60 61 } clblasImplementation; 62 63 /** 64 * @internal 65 * @brief extra information for a memory pattern 66 * used for BLAS problem solving 67 * @ingroup BLAS_SOLVERIF_SPEC 68 */ 69 typedef struct CLBLASMpatExtra { 70 /** memory levels used to store blocks of matrix A */ 71 meml_set_t aMset; 72 /** memory levels used to store blocks of matrix B */ 73 meml_set_t bMset; 74 CLMemType mobjA; 75 CLMemType mobjB; 76 } CLBLASMpatExtra; 77 78 /* 79 * init memory patterns for the xGEMM functions 80 * 81 * Returns number of the initialized patterns 82 */ 83 unsigned int 84 initGemmMemPatterns(MemoryPattern *mempats); 85 86 /* 87 * Get index of the specific xGEMM pattern 88 */ 89 int 90 getGemmMemPatternIndex(clblasImplementation impl); 91 92 /* 93 * Get preferred xGEMM pattern 94 */ 95 clblasImplementation 96 getGemmPreferredPattern(void); 97 98 /* 99 * init memory patterns for the xGEMV functions 100 * 101 * Returns number of the initialized patterns 102 */ 103 unsigned int 104 initGemvMemPatterns(MemoryPattern *mempats); 105 106 /* 107 * Get index of the specific xGEMV pattern 108 */ 109 int 110 getGemvMemPatternIndex(clblasImplementation impl); 111 112 /* 113 * init memory patterns for the xSYMV functions 114 * 115 * Returns number of the initialized patterns 116 */ 117 unsigned int 118 initSymvMemPatterns(MemoryPattern *mempats); 119 120 /* 121 * Get index of the specific xSYMV pattern 122 */ 123 int 124 getSymvMemPatternIndex(clblasImplementation impl); 125 126 /* 127 * init memory patterns for the xTRMM functions 128 * 129 * Returns number of the initialized patterns 130 */ 131 unsigned int 132 initTrmmMemPatterns(MemoryPattern *mempats); 133 134 /* 135 * Get index of the specific xTRMM pattern 136 */ 137 int 138 getTrmmMemPatternIndex(clblasImplementation impl); 139 140 /* 141 * Get preferred xTRMM pattern 142 */ 143 clblasImplementation 144 getTrmmPreferredPattern(void); 145 146 /* 147 * init memory patterns for the xTRSM functions 148 * 149 * Returns number of the initialized patterns 150 */ 151 unsigned int 152 initTrsmMemPatterns(MemoryPattern *mempats); 153 154 /* 155 * Get index of the specific xTRSM pattern 156 */ 157 int 158 getTrsmMemPatternIndex(clblasImplementation impl); 159 160 /* 161 * Get preferred xTRSM pattern 162 */ 163 clblasImplementation 164 getTrsmPreferredPattern(void); 165 166 /* 167 * init memory patterns for the xSYR2K functions 168 * 169 * Returns number of the initialized patterns 170 */ 171 unsigned int 172 initSyr2kMemPatterns(MemoryPattern *mempats); 173 174 /* 175 * Get index of the specific xSYR2K pattern 176 */ 177 int 178 getSyr2kMemPatternIndex(clblasImplementation impl); 179 180 /* 181 * init memory patterns for the xSYRK functions 182 * 183 * Returns number of the initialized patterns 184 */ 185 unsigned int 186 initSyrkMemPatterns(MemoryPattern *mempats); 187 188 /* 189 * Get index of the specific xSYRK pattern 190 */ 191 int 192 getSyrkMemPatternIndex(clblasImplementation impl); 193 194 /* 195 * init memory patters for TRMV routine 196 * Returns the number of inited patterns 197 */ 198 unsigned int 199 initTrmvMemPatterns(MemoryPattern *mempats); 200 201 int 202 getTrmvMemPatternIndex(clblasImplementation impl); 203 204 /* 205 * init memory patterns for TRSV TRTRI routine 206 * Returns the number of inited patterns 207 */ 208 unsigned int 209 initTrsvMemPatterns(MemoryPattern *mempats); 210 211 int 212 getTrsvMemPatternIndex(clblasImplementation impl); 213 214 unsigned int 215 initTrsvGemvMemPatterns(MemoryPattern *mempats); 216 217 int 218 getTrsvGemvMemPatternIndex(clblasImplementation impl); 219 220 unsigned int 221 initSymmMemPatterns(MemoryPattern *mempats); 222 223 int 224 getSymmMemPatternIndex(clblasImplementation impl); 225 226 unsigned int 227 initGemmV2MemPatterns(MemoryPattern *mempats); 228 229 int 230 getGemmV2MemPatternIndex(clblasImplementation impl); 231 232 unsigned int 233 initGemmV2TailMemPatterns(MemoryPattern *mempats); 234 235 int 236 getGemmV2TailMemPatternIndex(clblasImplementation impl); 237 238 /* 239 * init memory patterns for the xSYR functions 240 * 241 * Returns number of the initialized patterns 242 */ 243 unsigned int 244 initSyrMemPatterns(MemoryPattern *mempats); 245 246 /* 247 * Get index of the specific xSYR pattern 248 */ 249 int 250 getSyrMemPatternIndex(clblasImplementation impl); 251 252 /* 253 * init memory patterns for the xSYR2 functions 254 * 255 * Returns number of the initialized patterns 256 */ 257 unsigned int 258 initSyr2MemPatterns(MemoryPattern *mempats); 259 260 /* 261 * Get index of the specific xSYR2 pattern 262 */ 263 int 264 getSyr2MemPatternIndex(clblasImplementation impl); 265 266 267 /* 268 * init memory patters for GER routine 269 * Returns the number of inited patterns 270 */ 271 unsigned int 272 initGerMemPatterns(MemoryPattern *mempats); 273 274 int 275 getGerMemPatternIndex(clblasImplementation impl); 276 277 unsigned int 278 initHerMemPatterns(MemoryPattern *mempats); 279 280 /* 281 * Get index of the specific xSYR pattern 282 */ 283 int 284 getHerMemPatternIndex(clblasImplementation impl); 285 286 /* 287 * init memory patterns for the xHER2 functions 288 * 289 * Returns number of the initialized patterns 290 */ 291 unsigned int 292 initHer2MemPatterns(MemoryPattern *mempats); 293 294 /* 295 * Get index of the specific xHER2 pattern 296 */ 297 int 298 getHer2MemPatternIndex(clblasImplementation impl); 299 300 unsigned int 301 initGbmvMemPatterns(MemoryPattern *mempats); 302 303 int 304 getGbmvMemPatternIndex(clblasImplementation impl); 305 306 unsigned int 307 initSwapMemPatterns(MemoryPattern *mempats); 308 309 int 310 getSwapMemPatternIndex(clblasImplementation impl); 311 312 unsigned int 313 initScalMemPatterns(MemoryPattern *mempats); 314 315 int 316 getScalMemPatternIndex(clblasImplementation impl); 317 318 unsigned int 319 initCopyMemPatterns(MemoryPattern *mempats); 320 321 int 322 getCopyMemPatternIndex(clblasImplementation impl); 323 324 unsigned int 325 initDotMemPatterns(MemoryPattern *mempats); 326 327 int 328 getDotMemPatternIndex(clblasImplementation impl); 329 330 unsigned int 331 initAxpyMemPatterns(MemoryPattern *mempats); 332 333 int 334 getAxpyMemPatternIndex(clblasImplementation impl); 335 336 unsigned int 337 initReductionMemPatterns(MemoryPattern *mempats); 338 339 int 340 getReductionMemPatternIndex(clblasImplementation impl); 341 342 unsigned int 343 initRotgMemPatterns(MemoryPattern *mempats); 344 345 int 346 getRotgMemPatternIndex(clblasImplementation impl); 347 348 unsigned int 349 initRotmgMemPatterns(MemoryPattern *mempats); 350 351 int 352 getRotmgMemPatternIndex(clblasImplementation impl); 353 354 unsigned int 355 initRotmMemPatterns(MemoryPattern *mempats); 356 357 int 358 getRotmMemPatternIndex(clblasImplementation impl); 359 360 unsigned int 361 initiAmaxMemPatterns(MemoryPattern *mempats); 362 363 int 364 getiAmaxMemPatternIndex(clblasImplementation impl); 365 366 unsigned int 367 initNrm2MemPatterns(MemoryPattern *mempats); 368 369 int 370 getNrm2MemPatternIndex(clblasImplementation impl); 371 372 unsigned int 373 initAsumMemPatterns(MemoryPattern *mempats); 374 375 int 376 getAsumMemPatternIndex(clblasImplementation impl); 377 378 #endif /* BLAS_MEMPAT_H_ */ 379