1 // SPDX-License-Identifier: Apache-2.0
2 /*
3  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * * Redistributions of source code must retain the above copyright
9  *   notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  *   notice, this list of conditions and the following disclaimer in the
12  *   documentation and/or other materials provided with the distribution.
13  * * Neither the name of NVIDIA CORPORATION nor the names of its
14  *   contributors may be used to endorse or promote products derived
15  *   from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
18  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
20  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
25  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28  */
29 
30 /*
31   Extended example for building on-the-fly kernels with C interface.
32   Simple examples demonstrating different ways to load source code
33     and call kernels.
34  */
35 
36 #include "GB_jit_launcher.h"
37 #include "../templates/reduceUnrolled.cu.jit"
38 #include "../templates/sparseDotProduct.cu.jit"
39 #include "../templates/denseDotProduct.cu.jit"
40 #include "../templates/GB_jit_AxB_dot3_phase1.cu.jit"
41 #include "../templates/GB_jit_AxB_dot3_phase2.cu.jit"
42 #include "../templates/GB_jit_AxB_dot3_phase3_dndn.cu.jit"
43 #include "../templates/GB_jit_AxB_dot3_phase3_vsvs.cu.jit"
44 #include "../templates/GB_jit_AxB_dot3_phase3_vssp.cu.jit"
45 #include "../templates/GB_jit_AxB_dot3_phase3_spdn.cu.jit"
46 #include "../templates/GB_jit_AxB_dot3_phase3_mp.cu.jit"
47 #include "../templates/GB_jit_AxB_dot3_phase3_warpix.cu.jit"
48 
49 #include "type_name.hpp"
50 
51 #define JITIFY_PRINT_INSTANTIATION 1
52 #define JITIFY_PRINT_SOURCE 1
53 #define JITIFY_PRINT_LOG 1
54 #define JITIFY_PRINT_PTX 1
55 #define JITIFY_PRINT_LINKER_LOG 1
56 #define JITIFY_PRINT_LAUNCH 1
57 
58 #include "dataFactory.hpp"
59 #include "semiringFactory.hpp"
60 #include "../GB_cuda.h"
61 
62 
63 #if __cplusplus >= 201103L
64 
65 
66 //Kernel jitifiers
67 template<typename T> class reduceFactory ;
68 template<typename T1, typename T2, typename T3> class dotFactory ;
69 template<typename T1, typename T2, typename T3> class spdotFactory ;
70 
71 
72 //AxB_dot3_phase1 kernel launchers
73 template<  typename T_C, typename T_M, typename T_A, typename T_B> class phase1launchFactory ;
74 
75 //AxB_dot3_phase3 kernel launchers
76 
77 template<  typename T_C, typename T_M,
78          typename T_A, typename T_B, typename T_xy, typename T_z> class launchFactory ;
79 
80 
81 const std::vector<std::string> compiler_flags{
82    "-std=c++14",
83    "-remove-unused-globals",
84    "-w",
85    "-D__CUDACC_RTC__",
86    "-I.",
87    "-I..",
88    "-I../../Include",
89    "-I../../Source",
90    "-I../../Source/Template",
91    "-Ilocal_cub/block",
92    "-Itemplates",
93    "-I/usr/local/cuda/include"
94 };
95 
96 const std::vector<std::string> header_names ={};
97 
98 template<  typename T_C, typename T_M, typename T_A, typename T_B>
99 class phase1launchFactory
100 {
101   std::string base_name = "GB_jit_";
102   std::string kernel_name = "AxB_dot3_phase1";
103   std::string template_name = "templates_GB_jit_AxB_dot3_phase1_cu";
104 
105 public:
106 
jitGridBlockLaunch(int gridsz,int blocksz,int64_t * nanobuckets,int64_t * blockBucket,matrix<T_C> * C,matrix<T_M> * M,matrix<T_A> * A,matrix<T_B> * B)107   bool jitGridBlockLaunch(int gridsz, int blocksz,
108                           int64_t *nanobuckets, int64_t *blockBucket,
109                           matrix<T_C> *C, matrix<T_M> *M, matrix<T_A> *A, matrix<T_B> *B)
110      {
111 
112       bool result = false;
113 
114       T_C dumC;
115       T_M dumM;
116       T_A dumA;
117       T_B dumB;
118 
119       dim3 grid(gridsz);
120       dim3 block(blocksz);
121 
122       std::cout<< kernel_name<<
123                       " with types "<<GET_TYPE_NAME(dumC)<<","
124                                     <<GET_TYPE_NAME(dumM)<<","
125                                     <<GET_TYPE_NAME(dumA)<<","
126                                     <<GET_TYPE_NAME(dumB)<<std::endl;
127 
128 
129       jit::launcher( base_name + kernel_name,
130                      jit_template,
131                      header_names,
132                      compiler_flags,
133                      file_callback)
134 
135                    .set_kernel_inst(  base_name + kernel_name ,
136                                     { GET_TYPE_NAME(dumC),
137                                       GET_TYPE_NAME(dumM),
138                                       GET_TYPE_NAME(dumA),
139                                       GET_TYPE_NAME(dumB),
140                                       })
141                    .configure(grid, block)
142                    .launch( nanobuckets, blockBucket, C, M, A, B);
143 
144       checkCudaErrors( cudaDeviceSynchronize() );
145       result= true;
146 
147       return result;
148      }
149 
150 
151 };
152 
153 template<  typename T_C>
154 class phase2launchFactory
155 {
156   std::string base_name = "GB_jit_";
157   std::string kernel_name = "AxB_dot3_phase2";
158   std::string template_name = "templates_GB_jit_AxB_dot3_phase2_cu";
159 
160 public:
161 
jitGridBlockLaunch(int gridsz,int blocksz,int64_t * nanobuckets,int64_t * blockBucket,int64_t * bucketp,int64_t * bucket,matrix<T_C> * C,const int64_t cnz,const int64_t nblocks)162   bool jitGridBlockLaunch(int gridsz, int blocksz,
163                           int64_t *nanobuckets, int64_t *blockBucket,
164                           int64_t *bucketp, int64_t *bucket,
165                           matrix<T_C> *C, const int64_t cnz, const int64_t nblocks )
166      {
167 
168       bool result = false;
169 
170       T_C dumC;
171 
172       dim3 grid(gridsz);
173       dim3 block(blocksz);
174 
175       std::cout<< kernel_name<<" with types " <<GET_TYPE_NAME(dumC)<<std::endl;
176 
177 
178       jit::launcher( base_name + kernel_name,
179                      jit_template,
180                      header_names,
181                      compiler_flags,
182                      file_callback)
183 
184                    .set_kernel_inst(  base_name + kernel_name ,
185                                     { GET_TYPE_NAME(dumC) })
186                    .configure(grid, block)
187                    .launch( nanobuckets, blockBucket, bucketp, bucket, C, cnz);
188 
189       checkCudaErrors( cudaDeviceSynchronize() );
190       result= true;
191 
192       return result;
193      }
194 
195 };
196 
197 template<  typename T_C>
198 class phase2endlaunchFactory
199 {
200   std::string base_name = "GB_jit_";
201   std::string kernel_name = "AxB_dot3_phase2end";
202   std::string template_name = "templates_GB_jit_AxB_dot3_phase2_cu";
203 
204 public:
205 
jitGridBlockLaunch(int gridsz,int blocksz,int64_t * nanobuckets,int64_t * blockBucket,int64_t * bucketp,int64_t * bucket,matrix<T_C> * C,const int64_t cnz)206   bool jitGridBlockLaunch(int gridsz, int blocksz,
207                           int64_t *nanobuckets, int64_t *blockBucket,
208                           int64_t *bucketp, int64_t *bucket,
209                           matrix<T_C> *C, const int64_t cnz)
210      {
211 
212       bool result = false;
213 
214       jit_template = templates_GB_jit_AxB_dot3_phase2_cu;
215 
216       T_C dumC;
217 
218       dim3 grid(gridsz);
219       dim3 block(blocksz);
220 
221       std::cout<< kernel_name<<" with types " <<GET_TYPE_NAME(dumC)<<std::endl;
222 
223 
224       jit::launcher( base_name + kernel_name,
225                      jit_template,
226                      header_names,
227                      compiler_flags,
228                      file_callback)
229 
230                    .set_kernel_inst(  base_name + kernel_name ,
231                                     { GET_TYPE_NAME(dumC) })
232                    .configure(grid, block)
233                    .launch( nanobuckets, blockBucket, bucketp, bucket, C, cnz);
234 
235       checkCudaErrors( cudaDeviceSynchronize() );
236       result= true;
237 
238       return result;
239      }
240 
241 };
242 
243 template<  typename T_C, typename T_M, typename T_A, typename T_B, typename T_XY, typename T_Z>
244 class launchFactory
245 {
246   std::string base_name = "GB_jit_";
247   std::string kernel_name = "AxB_dot3_phase3_";
248   std::string template_name = "___templates_GB_jit_AxB_dot3_phase3_";
249   std::string OpName;
250   std::string SR;
251 
252   GB_callback callback_generator;
253 
254 public:
launchFactory(std::string SemiRing,std::string Optype)255   launchFactory(std::string SemiRing, std::string Optype) {
256       if (SemiRing == "PLUS_TIMES") {
257          std::cout<<"loading PLUS_TIMES semi-ring"<<std::endl;
258          file_callback = semiring_plus_times_callback;
259          //callback_generator.load_string( "mySemiring.h", semiring_string);
260          //file_callback = callback_generator.callback;
261       }
262       else if (SemiRing == "MIN_PLUS") {
263          std::cout<<"loading MIN_PLUS semi-ring"<<std::endl;
264          file_callback = semiring_min_plus_callback;
265       }
266       else if (SemiRing == "MAX_PLUS") {
267          std::cout<<"loading MAX_PLUS semi-ring"<<std::endl;
268          file_callback = semiring_max_plus_callback;
269       }
270       OpName = Optype;
271       SR = SemiRing;
272 
273   }
274 
jitGridBlockLaunch(int gridsz,int blocksz,int64_t start,int64_t end,int64_t * Bucket,matrix<T_C> * C,matrix<T_M> * M,matrix<T_A> * A,matrix<T_B> * B,int sz)275   bool jitGridBlockLaunch(int gridsz, int blocksz,
276                           int64_t start, int64_t end, int64_t *Bucket,
277                           matrix<T_C> *C, matrix<T_M> *M, matrix<T_A> *A, matrix<T_B> *B,
278                           int sz)
279      {
280 
281       bool result = false;
282 
283       T_C dumC;
284       T_M dumM;
285       T_A dumA;
286       T_B dumB;
287       T_XY dumXY;
288       T_Z dumZ;
289 
290       dim3 grid(gridsz);
291       dim3 block(blocksz);
292 
293       std::cout<< kernel_name<<SR<<OpName<<
294                       " with types "<<GET_TYPE_NAME(dumC)<<","
295                                     <<GET_TYPE_NAME(dumM)<<","
296                                     <<GET_TYPE_NAME(dumA)<<","
297                                     <<GET_TYPE_NAME(dumB)<<","
298                                     <<GET_TYPE_NAME(dumXY)<<","
299                                     <<GET_TYPE_NAME(dumZ)<<std::endl;
300 
301       const char*  jit_template;
302       if (OpName == "dndn") {
303          jit_template = ___templates_GB_jit_AxB_dot3_phase3_dndn_cu;
304       }
305       if (OpName == "vsvs") {
306          jit_template = ___templates_GB_jit_AxB_dot3_phase3_vsvs_cu;
307       }
308       if (OpName == "vssp") {
309          jit_template = ___templates_GB_jit_AxB_dot3_phase3_vssp_cu;
310       }
311       if (OpName == "spdn") {
312          jit_template = ___templates_GB_jit_AxB_dot3_phase3_spdn_cu;
313       }
314       if (OpName == "mp") {
315          jit_template = ___templates_GB_jit_AxB_dot3_phase3_mp_cu;
316       }
317       if (OpName == "warp") {
318          jit_template = ___templates_GB_jit_AxB_dot3_phase3_warpix_cu;
319       }
320 
321       jit::launcher( base_name + SR + OpName+ GET_TYPE_NAME(dumZ),
322                      jit_template,
323                      header_names,
324                      compiler_flags,
325                      file_callback)
326 
327                    .set_kernel_inst(  kernel_name+OpName,
328                                     { GET_TYPE_NAME(dumC),
329                                       GET_TYPE_NAME(dumA),
330                                       GET_TYPE_NAME(dumB),
331                                       GET_TYPE_NAME(dumXY),
332                                       GET_TYPE_NAME(dumXY),
333                                       GET_TYPE_NAME(dumZ)
334                                       })
335                    .configure(grid, block)
336                    .launch( start, end, Bucket,
337                             C, M, A, B, sz);
338 
339       checkCudaErrors( cudaDeviceSynchronize() );
340       result= true;
341 
342       return result;
343      }
344 
345 };
346 
347 template<typename T1, typename T2, typename T3>
348 class spdotFactory
349 {
350   std::string base_name = "GBjit_spDot_";
351 public:
spdotFactory()352   spdotFactory() {
353   }
354 
jitGridBlockLaunch(int gridsz,int blocksz,unsigned int xn,unsigned int * xi,T1 * x,unsigned int yn,unsigned int * yi,T2 * y,T3 * output,std::string OpName)355   bool jitGridBlockLaunch(int gridsz, int blocksz, unsigned int xn, unsigned int *xi, T1* x,
356                                                    unsigned int yn, unsigned int *yi, T2* y,
357                                                         T3* output, std::string OpName)
358   {
359 
360       bool result = false;
361       if (OpName == "PLUS_TIMES") {
362          file_callback = &semiring_plus_times_callback;
363       }
364       else if (OpName == "MIN_PLUS") {
365          file_callback = &semiring_min_plus_callback;
366       }
367 
368       T1 dum1;
369       T2 dum2;
370       T3 dum3;
371 
372       dim3 grid(gridsz);
373       dim3 block(blocksz);
374 
375       jit::launcher( base_name + OpName,
376                      ___templates_sparseDotProduct_cu,
377                      header_names,
378                      compiler_flags,
379                      file_callback)
380 
381                    .set_kernel_inst("sparseDotProduct",
382                                     { GET_TYPE_NAME(dum1),
383                                       GET_TYPE_NAME(dum2),
384                                       GET_TYPE_NAME(dum3)})
385                    .configure(grid, block)
386                    .launch(xn, xi, x, yn, yi, y, output);
387 
388 
389       checkCudaErrors( cudaDeviceSynchronize() );
390       result= true;
391 
392       return result;
393   }
394 
395 };
396 
397 template<typename T1, typename T2, typename T3>
398 class dotFactory
399 {
400   std::string base_name = "GBjit_dnDot_";
401 public:
dotFactory()402   dotFactory() {
403   }
404 
405 
jitGridBlockLaunch(int gridsz,int blocksz,T1 * x,T2 * y,T3 * output,unsigned int N,std::string OpName)406   bool jitGridBlockLaunch(int gridsz, int blocksz, T1* x, T2* y, T3* output, unsigned int N, std::string OpName)
407   {
408 
409       bool result = false;
410       if (OpName == "PLUS_TIMES") {
411          file_callback = &semiring_plus_times_callback;
412       }
413       else if (OpName == "MIN_PLUS") {
414          file_callback = &semiring_min_plus_callback;
415       }
416 
417       T1 dum1;
418       T2 dum2;
419       T3 dum3;
420 
421       dim3 grid(gridsz);
422       dim3 block(blocksz);
423 
424       jit::launcher( base_name + OpName,
425                      ___templates_denseDotProduct_cu,
426                      header_names,
427                      compiler_flags,
428                      file_callback)
429 
430                    .set_kernel_inst("denseDotProduct",
431                                     { GET_TYPE_NAME(dum1),
432                                       GET_TYPE_NAME(dum2),
433                                       GET_TYPE_NAME(dum3)})
434                    .configure(grid, block)
435                    .launch(x, y, output, N);
436 
437       checkCudaErrors( cudaDeviceSynchronize() );
438       result= true;
439 
440       return result;
441   }
442 
443 };
444 
445 template<typename T>
446 class reduceFactory
447 {
448   std::string base_name = "GBjit_reduce_";
449 
450 public:
reduceFactory()451   reduceFactory() {
452   }
453 
jitGridBlockLaunch(int gridsz,int blocksz,T * indata,T * output,unsigned int N,std::string OpName)454   bool jitGridBlockLaunch(int gridsz, int blocksz,
455                           T* indata, T* output, unsigned int N,
456                           std::string OpName)
457   {
458       dim3 grid(gridsz);
459       dim3 block(blocksz);
460       bool result = false;
461       T dummy;
462 
463       std::cout<<" indata type ="<< GET_TYPE_NAME(dummy)<<std::endl;
464 
465       if (OpName == "PLUS") {
466          file_callback = &file_callback_plus;
467       }
468       else if (OpName == "MIN") {
469          file_callback = &file_callback_min;
470       }
471       else if (OpName == "MAX") {
472          file_callback = &file_callback_max;
473       }
474 
475 
476       jit::launcher( base_name + OpName,
477                      ___templates_reduceUnrolled_cu,
478                      header_names,
479                      compiler_flags,
480                      file_callback)
481                    .set_kernel_inst("reduceUnrolled",
482                                     { GET_TYPE_NAME(dummy) })
483                    .configure(grid, block)
484                    .launch( indata, output, N);
485 
486       checkCudaErrors( cudaDeviceSynchronize() );
487 
488       result= true;
489 
490 
491       return result;
492   }
493 
494 };
495 
496 #endif  // C++11
497 
498