1 // SPDX-License-Identifier: Apache-2.0 2 /* 3 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions 7 * are met: 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * * Neither the name of NVIDIA CORPORATION nor the names of its 14 * contributors may be used to endorse or promote products derived 15 * from this software without specific prior written permission. 16 * 17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 18 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 21 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 22 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 23 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 25 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 */ 29 30 /* 31 Extended example for building on-the-fly kernels with C interface. 32 Simple examples demonstrating different ways to load source code 33 and call kernels. 34 */ 35 36 #include "GB_jit_launcher.h" 37 #include "../templates/reduceUnrolled.cu.jit" 38 #include "../templates/sparseDotProduct.cu.jit" 39 #include "../templates/denseDotProduct.cu.jit" 40 #include "../templates/GB_jit_AxB_dot3_phase1.cu.jit" 41 #include "../templates/GB_jit_AxB_dot3_phase2.cu.jit" 42 #include "../templates/GB_jit_AxB_dot3_phase3_dndn.cu.jit" 43 #include "../templates/GB_jit_AxB_dot3_phase3_vsvs.cu.jit" 44 #include "../templates/GB_jit_AxB_dot3_phase3_vssp.cu.jit" 45 #include "../templates/GB_jit_AxB_dot3_phase3_spdn.cu.jit" 46 #include "../templates/GB_jit_AxB_dot3_phase3_mp.cu.jit" 47 #include "../templates/GB_jit_AxB_dot3_phase3_warpix.cu.jit" 48 49 #include "type_name.hpp" 50 51 #define JITIFY_PRINT_INSTANTIATION 1 52 #define JITIFY_PRINT_SOURCE 1 53 #define JITIFY_PRINT_LOG 1 54 #define JITIFY_PRINT_PTX 1 55 #define JITIFY_PRINT_LINKER_LOG 1 56 #define JITIFY_PRINT_LAUNCH 1 57 58 #include "dataFactory.hpp" 59 #include "semiringFactory.hpp" 60 #include "../GB_cuda.h" 61 62 63 #if __cplusplus >= 201103L 64 65 66 //Kernel jitifiers 67 template<typename T> class reduceFactory ; 68 template<typename T1, typename T2, typename T3> class dotFactory ; 69 template<typename T1, typename T2, typename T3> class spdotFactory ; 70 71 72 //AxB_dot3_phase1 kernel launchers 73 template< typename T_C, typename T_M, typename T_A, typename T_B> class phase1launchFactory ; 74 75 //AxB_dot3_phase3 kernel launchers 76 77 template< typename T_C, typename T_M, 78 typename T_A, typename T_B, typename T_xy, typename T_z> class launchFactory ; 79 80 81 const std::vector<std::string> compiler_flags{ 82 "-std=c++14", 83 "-remove-unused-globals", 84 "-w", 85 "-D__CUDACC_RTC__", 86 "-I.", 87 "-I..", 88 "-I../../Include", 89 "-I../../Source", 90 "-I../../Source/Template", 91 "-Ilocal_cub/block", 92 "-Itemplates", 93 "-I/usr/local/cuda/include" 94 }; 95 96 const std::vector<std::string> header_names ={}; 97 98 template< typename T_C, typename T_M, typename T_A, typename T_B> 99 class phase1launchFactory 100 { 101 std::string base_name = "GB_jit_"; 102 std::string kernel_name = "AxB_dot3_phase1"; 103 std::string template_name = "templates_GB_jit_AxB_dot3_phase1_cu"; 104 105 public: 106 jitGridBlockLaunch(int gridsz,int blocksz,int64_t * nanobuckets,int64_t * blockBucket,matrix<T_C> * C,matrix<T_M> * M,matrix<T_A> * A,matrix<T_B> * B)107 bool jitGridBlockLaunch(int gridsz, int blocksz, 108 int64_t *nanobuckets, int64_t *blockBucket, 109 matrix<T_C> *C, matrix<T_M> *M, matrix<T_A> *A, matrix<T_B> *B) 110 { 111 112 bool result = false; 113 114 T_C dumC; 115 T_M dumM; 116 T_A dumA; 117 T_B dumB; 118 119 dim3 grid(gridsz); 120 dim3 block(blocksz); 121 122 std::cout<< kernel_name<< 123 " with types "<<GET_TYPE_NAME(dumC)<<"," 124 <<GET_TYPE_NAME(dumM)<<"," 125 <<GET_TYPE_NAME(dumA)<<"," 126 <<GET_TYPE_NAME(dumB)<<std::endl; 127 128 129 jit::launcher( base_name + kernel_name, 130 jit_template, 131 header_names, 132 compiler_flags, 133 file_callback) 134 135 .set_kernel_inst( base_name + kernel_name , 136 { GET_TYPE_NAME(dumC), 137 GET_TYPE_NAME(dumM), 138 GET_TYPE_NAME(dumA), 139 GET_TYPE_NAME(dumB), 140 }) 141 .configure(grid, block) 142 .launch( nanobuckets, blockBucket, C, M, A, B); 143 144 checkCudaErrors( cudaDeviceSynchronize() ); 145 result= true; 146 147 return result; 148 } 149 150 151 }; 152 153 template< typename T_C> 154 class phase2launchFactory 155 { 156 std::string base_name = "GB_jit_"; 157 std::string kernel_name = "AxB_dot3_phase2"; 158 std::string template_name = "templates_GB_jit_AxB_dot3_phase2_cu"; 159 160 public: 161 jitGridBlockLaunch(int gridsz,int blocksz,int64_t * nanobuckets,int64_t * blockBucket,int64_t * bucketp,int64_t * bucket,matrix<T_C> * C,const int64_t cnz,const int64_t nblocks)162 bool jitGridBlockLaunch(int gridsz, int blocksz, 163 int64_t *nanobuckets, int64_t *blockBucket, 164 int64_t *bucketp, int64_t *bucket, 165 matrix<T_C> *C, const int64_t cnz, const int64_t nblocks ) 166 { 167 168 bool result = false; 169 170 T_C dumC; 171 172 dim3 grid(gridsz); 173 dim3 block(blocksz); 174 175 std::cout<< kernel_name<<" with types " <<GET_TYPE_NAME(dumC)<<std::endl; 176 177 178 jit::launcher( base_name + kernel_name, 179 jit_template, 180 header_names, 181 compiler_flags, 182 file_callback) 183 184 .set_kernel_inst( base_name + kernel_name , 185 { GET_TYPE_NAME(dumC) }) 186 .configure(grid, block) 187 .launch( nanobuckets, blockBucket, bucketp, bucket, C, cnz); 188 189 checkCudaErrors( cudaDeviceSynchronize() ); 190 result= true; 191 192 return result; 193 } 194 195 }; 196 197 template< typename T_C> 198 class phase2endlaunchFactory 199 { 200 std::string base_name = "GB_jit_"; 201 std::string kernel_name = "AxB_dot3_phase2end"; 202 std::string template_name = "templates_GB_jit_AxB_dot3_phase2_cu"; 203 204 public: 205 jitGridBlockLaunch(int gridsz,int blocksz,int64_t * nanobuckets,int64_t * blockBucket,int64_t * bucketp,int64_t * bucket,matrix<T_C> * C,const int64_t cnz)206 bool jitGridBlockLaunch(int gridsz, int blocksz, 207 int64_t *nanobuckets, int64_t *blockBucket, 208 int64_t *bucketp, int64_t *bucket, 209 matrix<T_C> *C, const int64_t cnz) 210 { 211 212 bool result = false; 213 214 jit_template = templates_GB_jit_AxB_dot3_phase2_cu; 215 216 T_C dumC; 217 218 dim3 grid(gridsz); 219 dim3 block(blocksz); 220 221 std::cout<< kernel_name<<" with types " <<GET_TYPE_NAME(dumC)<<std::endl; 222 223 224 jit::launcher( base_name + kernel_name, 225 jit_template, 226 header_names, 227 compiler_flags, 228 file_callback) 229 230 .set_kernel_inst( base_name + kernel_name , 231 { GET_TYPE_NAME(dumC) }) 232 .configure(grid, block) 233 .launch( nanobuckets, blockBucket, bucketp, bucket, C, cnz); 234 235 checkCudaErrors( cudaDeviceSynchronize() ); 236 result= true; 237 238 return result; 239 } 240 241 }; 242 243 template< typename T_C, typename T_M, typename T_A, typename T_B, typename T_XY, typename T_Z> 244 class launchFactory 245 { 246 std::string base_name = "GB_jit_"; 247 std::string kernel_name = "AxB_dot3_phase3_"; 248 std::string template_name = "___templates_GB_jit_AxB_dot3_phase3_"; 249 std::string OpName; 250 std::string SR; 251 252 GB_callback callback_generator; 253 254 public: launchFactory(std::string SemiRing,std::string Optype)255 launchFactory(std::string SemiRing, std::string Optype) { 256 if (SemiRing == "PLUS_TIMES") { 257 std::cout<<"loading PLUS_TIMES semi-ring"<<std::endl; 258 file_callback = semiring_plus_times_callback; 259 //callback_generator.load_string( "mySemiring.h", semiring_string); 260 //file_callback = callback_generator.callback; 261 } 262 else if (SemiRing == "MIN_PLUS") { 263 std::cout<<"loading MIN_PLUS semi-ring"<<std::endl; 264 file_callback = semiring_min_plus_callback; 265 } 266 else if (SemiRing == "MAX_PLUS") { 267 std::cout<<"loading MAX_PLUS semi-ring"<<std::endl; 268 file_callback = semiring_max_plus_callback; 269 } 270 OpName = Optype; 271 SR = SemiRing; 272 273 } 274 jitGridBlockLaunch(int gridsz,int blocksz,int64_t start,int64_t end,int64_t * Bucket,matrix<T_C> * C,matrix<T_M> * M,matrix<T_A> * A,matrix<T_B> * B,int sz)275 bool jitGridBlockLaunch(int gridsz, int blocksz, 276 int64_t start, int64_t end, int64_t *Bucket, 277 matrix<T_C> *C, matrix<T_M> *M, matrix<T_A> *A, matrix<T_B> *B, 278 int sz) 279 { 280 281 bool result = false; 282 283 T_C dumC; 284 T_M dumM; 285 T_A dumA; 286 T_B dumB; 287 T_XY dumXY; 288 T_Z dumZ; 289 290 dim3 grid(gridsz); 291 dim3 block(blocksz); 292 293 std::cout<< kernel_name<<SR<<OpName<< 294 " with types "<<GET_TYPE_NAME(dumC)<<"," 295 <<GET_TYPE_NAME(dumM)<<"," 296 <<GET_TYPE_NAME(dumA)<<"," 297 <<GET_TYPE_NAME(dumB)<<"," 298 <<GET_TYPE_NAME(dumXY)<<"," 299 <<GET_TYPE_NAME(dumZ)<<std::endl; 300 301 const char* jit_template; 302 if (OpName == "dndn") { 303 jit_template = ___templates_GB_jit_AxB_dot3_phase3_dndn_cu; 304 } 305 if (OpName == "vsvs") { 306 jit_template = ___templates_GB_jit_AxB_dot3_phase3_vsvs_cu; 307 } 308 if (OpName == "vssp") { 309 jit_template = ___templates_GB_jit_AxB_dot3_phase3_vssp_cu; 310 } 311 if (OpName == "spdn") { 312 jit_template = ___templates_GB_jit_AxB_dot3_phase3_spdn_cu; 313 } 314 if (OpName == "mp") { 315 jit_template = ___templates_GB_jit_AxB_dot3_phase3_mp_cu; 316 } 317 if (OpName == "warp") { 318 jit_template = ___templates_GB_jit_AxB_dot3_phase3_warpix_cu; 319 } 320 321 jit::launcher( base_name + SR + OpName+ GET_TYPE_NAME(dumZ), 322 jit_template, 323 header_names, 324 compiler_flags, 325 file_callback) 326 327 .set_kernel_inst( kernel_name+OpName, 328 { GET_TYPE_NAME(dumC), 329 GET_TYPE_NAME(dumA), 330 GET_TYPE_NAME(dumB), 331 GET_TYPE_NAME(dumXY), 332 GET_TYPE_NAME(dumXY), 333 GET_TYPE_NAME(dumZ) 334 }) 335 .configure(grid, block) 336 .launch( start, end, Bucket, 337 C, M, A, B, sz); 338 339 checkCudaErrors( cudaDeviceSynchronize() ); 340 result= true; 341 342 return result; 343 } 344 345 }; 346 347 template<typename T1, typename T2, typename T3> 348 class spdotFactory 349 { 350 std::string base_name = "GBjit_spDot_"; 351 public: spdotFactory()352 spdotFactory() { 353 } 354 jitGridBlockLaunch(int gridsz,int blocksz,unsigned int xn,unsigned int * xi,T1 * x,unsigned int yn,unsigned int * yi,T2 * y,T3 * output,std::string OpName)355 bool jitGridBlockLaunch(int gridsz, int blocksz, unsigned int xn, unsigned int *xi, T1* x, 356 unsigned int yn, unsigned int *yi, T2* y, 357 T3* output, std::string OpName) 358 { 359 360 bool result = false; 361 if (OpName == "PLUS_TIMES") { 362 file_callback = &semiring_plus_times_callback; 363 } 364 else if (OpName == "MIN_PLUS") { 365 file_callback = &semiring_min_plus_callback; 366 } 367 368 T1 dum1; 369 T2 dum2; 370 T3 dum3; 371 372 dim3 grid(gridsz); 373 dim3 block(blocksz); 374 375 jit::launcher( base_name + OpName, 376 ___templates_sparseDotProduct_cu, 377 header_names, 378 compiler_flags, 379 file_callback) 380 381 .set_kernel_inst("sparseDotProduct", 382 { GET_TYPE_NAME(dum1), 383 GET_TYPE_NAME(dum2), 384 GET_TYPE_NAME(dum3)}) 385 .configure(grid, block) 386 .launch(xn, xi, x, yn, yi, y, output); 387 388 389 checkCudaErrors( cudaDeviceSynchronize() ); 390 result= true; 391 392 return result; 393 } 394 395 }; 396 397 template<typename T1, typename T2, typename T3> 398 class dotFactory 399 { 400 std::string base_name = "GBjit_dnDot_"; 401 public: dotFactory()402 dotFactory() { 403 } 404 405 jitGridBlockLaunch(int gridsz,int blocksz,T1 * x,T2 * y,T3 * output,unsigned int N,std::string OpName)406 bool jitGridBlockLaunch(int gridsz, int blocksz, T1* x, T2* y, T3* output, unsigned int N, std::string OpName) 407 { 408 409 bool result = false; 410 if (OpName == "PLUS_TIMES") { 411 file_callback = &semiring_plus_times_callback; 412 } 413 else if (OpName == "MIN_PLUS") { 414 file_callback = &semiring_min_plus_callback; 415 } 416 417 T1 dum1; 418 T2 dum2; 419 T3 dum3; 420 421 dim3 grid(gridsz); 422 dim3 block(blocksz); 423 424 jit::launcher( base_name + OpName, 425 ___templates_denseDotProduct_cu, 426 header_names, 427 compiler_flags, 428 file_callback) 429 430 .set_kernel_inst("denseDotProduct", 431 { GET_TYPE_NAME(dum1), 432 GET_TYPE_NAME(dum2), 433 GET_TYPE_NAME(dum3)}) 434 .configure(grid, block) 435 .launch(x, y, output, N); 436 437 checkCudaErrors( cudaDeviceSynchronize() ); 438 result= true; 439 440 return result; 441 } 442 443 }; 444 445 template<typename T> 446 class reduceFactory 447 { 448 std::string base_name = "GBjit_reduce_"; 449 450 public: reduceFactory()451 reduceFactory() { 452 } 453 jitGridBlockLaunch(int gridsz,int blocksz,T * indata,T * output,unsigned int N,std::string OpName)454 bool jitGridBlockLaunch(int gridsz, int blocksz, 455 T* indata, T* output, unsigned int N, 456 std::string OpName) 457 { 458 dim3 grid(gridsz); 459 dim3 block(blocksz); 460 bool result = false; 461 T dummy; 462 463 std::cout<<" indata type ="<< GET_TYPE_NAME(dummy)<<std::endl; 464 465 if (OpName == "PLUS") { 466 file_callback = &file_callback_plus; 467 } 468 else if (OpName == "MIN") { 469 file_callback = &file_callback_min; 470 } 471 else if (OpName == "MAX") { 472 file_callback = &file_callback_max; 473 } 474 475 476 jit::launcher( base_name + OpName, 477 ___templates_reduceUnrolled_cu, 478 header_names, 479 compiler_flags, 480 file_callback) 481 .set_kernel_inst("reduceUnrolled", 482 { GET_TYPE_NAME(dummy) }) 483 .configure(grid, block) 484 .launch( indata, output, N); 485 486 checkCudaErrors( cudaDeviceSynchronize() ); 487 488 result= true; 489 490 491 return result; 492 } 493 494 }; 495 496 #endif // C++11 497 498