1 // SPDX-License-Identifier: Apache-2.0
2 /*
3  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * * Redistributions of source code must retain the above copyright
9  *   notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  *   notice, this list of conditions and the following disclaimer in the
12  *   documentation and/or other materials provided with the distribution.
13  * * Neither the name of NVIDIA CORPORATION nor the names of its
14  *   contributors may be used to endorse or promote products derived
15  *   from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
18  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
20  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
25  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28  */
29 
30 /*
31   -----------
32   Jitify 0.9
33   -----------
34   A C++ library for easy integration of CUDA runtime compilation into
35   existing codes.
36 
37   --------------
38   How to compile
39   --------------
40   Compiler dependencies: <jitify.hpp>, -std=c++11
41   Linker dependencies:   dl cuda nvrtc
42 
43   --------------------------------------
44   Embedding source files into executable
45   --------------------------------------
46   g++  ... -ldl -rdynamic -DJITIFY_ENABLE_EMBEDDED_FILES=1
47   -Wl,-b,binary,my_kernel.cu,include/my_header.cuh,-b,default nvcc ... -ldl
48   -Xcompiler "-rdynamic
49   -Wl\,-b\,binary\,my_kernel.cu\,include/my_header.cuh\,-b\,default"
50   JITIFY_INCLUDE_EMBEDDED_FILE(my_kernel_cu);
51   JITIFY_INCLUDE_EMBEDDED_FILE(include_my_header_cuh);
52 
53   ----
54   TODO
55   ----
56   Extract valid compile options and pass the rest to cuModuleLoadDataEx
57   See if can have stringified headers automatically looked-up
58     by having stringify add them to a (static) global map.
59     The global map can be updated by creating a static class instance
60       whose constructor performs the registration.
61     Can then remove all headers from JitCache constructor in example code
62   See other TODOs in code
63 */
64 
65 /*! \file jitify.hpp
66  *  \brief The Jitify library header
67  */
68 
69 /*! \mainpage Jitify - A C++ library that simplifies the use of NVRTC
70  *  \p Use class jitify::JitCache to manage and launch JIT-compiled CUDA
71  *    kernels.
72  *
73  *  \p Use namespace jitify::reflection to reflect types and values into
74  *    code-strings.
75  *
76  *  \p Use JITIFY_INCLUDE_EMBEDDED_FILE() to declare files that have been
77  *  embedded into the executable using the GCC linker.
78  *
79  *  \p Use jitify::parallel_for and JITIFY_LAMBDA() to generate and launch
80  *  simple kernels.
81  */
82 
83 #pragma once
84 
85 #ifndef JITIFY_THREAD_SAFE
86 #define JITIFY_THREAD_SAFE 1
87 #endif
88 
89 #if JITIFY_ENABLE_EMBEDDED_FILES
90 #include <dlfcn.h>
91 #endif
92 #include <stdint.h>
93 #include <algorithm>
94 #include <cctype>
95 #include <cstring>  // For strtok_r etc.
96 #include <deque>
97 #include <fstream>
98 #include <iomanip>
99 #include <iostream>
100 #include <map>
101 #include <memory>
102 #include <sstream>
103 #include <stdexcept>
104 #include <string>
105 #include <typeinfo>
106 #include <unordered_map>
107 #include <unordered_set>
108 #include <vector>
109 #if JITIFY_THREAD_SAFE
110 #include <mutex>
111 #endif
112 
113 #include <cuda.h>
114 #include <cuda_runtime_api.h>  // For dim3, cudaStream_t
115 #if CUDA_VERSION >= 8000
116 #define NVRTC_GET_TYPE_NAME 1
117 #endif
118 #include <nvrtc.h>
119 
120 // For use by get_current_executable_path().
121 #ifdef __linux__
122 #include <linux/limits.h>  // For PATH_MAX
123 
124 #include <cstdlib>  // For realpath
125 #define JITIFY_PATH_MAX PATH_MAX
126 #elif defined(_WIN32) || defined(_WIN64)
127 #include <windows.h>
128 #define JITIFY_PATH_MAX MAX_PATH
129 #else
130 #error "Unsupported platform"
131 #endif
132 
133 #ifdef _MSC_VER       // MSVC compiler
134 #include <dbghelp.h>  // For UnDecorateSymbolName
135 #else
136 #include <cxxabi.h>  // For abi::__cxa_demangle
137 #endif
138 
139 #if defined(_WIN32) || defined(_WIN64)
140 // WAR for strtok_r being called strtok_s on Windows
141 #pragma push_macro("strtok_r")
142 #undef strtok_r
143 #define strtok_r strtok_s
144 // WAR for min and max possibly being macros defined by windows.h
145 #pragma push_macro("min")
146 #pragma push_macro("max")
147 #undef min
148 #undef max
149 #endif
150 
151 #ifndef JITIFY_PRINT_LOG
152 #define JITIFY_PRINT_LOG 1
153 #endif
154 
155 #if JITIFY_PRINT_ALL
156 #define JITIFY_PRINT_INSTANTIATION 1
157 #define JITIFY_PRINT_SOURCE 1
158 #define JITIFY_PRINT_LOG 1
159 #define JITIFY_PRINT_PTX 1
160 #define JITIFY_PRINT_LINKER_LOG 1
161 #define JITIFY_PRINT_LAUNCH 1
162 #define JITIFY_PRINT_HEADER_PATHS 1
163 #endif
164 
165 #if JITIFY_ENABLE_EMBEDDED_FILES
166 #define JITIFY_FORCE_UNDEFINED_SYMBOL(x) void* x##_forced = (void*)&x
167 /*! Include a source file that has been embedded into the executable using the
168  *    GCC linker.
169  * \param name The name of the source file (<b>not</b> as a string), which must
170  * be sanitized by replacing non-alpha-numeric characters with underscores.
171  * E.g., \code{.cpp}JITIFY_INCLUDE_EMBEDDED_FILE(my_header_h)\endcode will
172  * include the embedded file "my_header.h".
173  * \note Files declared with this macro can be referenced using
174  * their original (unsanitized) filenames when creating a \p
175  * jitify::Program instance.
176  */
177 #define JITIFY_INCLUDE_EMBEDDED_FILE(name)                                \
178   extern "C" uint8_t _jitify_binary_##name##_start[] asm("_binary_" #name \
179                                                          "_start");       \
180   extern "C" uint8_t _jitify_binary_##name##_end[] asm("_binary_" #name   \
181                                                        "_end");           \
182   JITIFY_FORCE_UNDEFINED_SYMBOL(_jitify_binary_##name##_start);           \
183   JITIFY_FORCE_UNDEFINED_SYMBOL(_jitify_binary_##name##_end)
184 #endif  // JITIFY_ENABLE_EMBEDDED_FILES
185 
186 /*! Jitify library namespace
187  */
188 namespace jitify {
189 
190 /*! Source-file load callback.
191  *
192  *  \param filename The name of the requested source file.
193  *  \param tmp_stream A temporary stream that can be used to hold source code.
194  *  \return A pointer to an input stream containing the source code, or NULL
195  *  to defer loading of the file to Jitify's file-loading mechanisms.
196  */
197 typedef std::istream* (*file_callback_type)(std::string filename,
198                                             std::iostream& tmp_stream);
199 
200 // Exclude from Doxygen
201 //! \cond
202 
203 class JitCache;
204 
205 // Simple cache using LRU discard policy
206 template <typename KeyType, typename ValueType>
207 class ObjectCache {
208  public:
209   typedef KeyType key_type;
210   typedef ValueType value_type;
211 
212  private:
213   typedef std::map<key_type, value_type> object_map;
214   typedef std::deque<key_type> key_rank;
215   typedef typename key_rank::iterator rank_iterator;
216   object_map _objects;
217   key_rank _ranked_keys;
218   size_t _capacity;
219 
discard_old(size_t n=0)220   inline void discard_old(size_t n = 0) {
221     if (n > _capacity) {
222       throw std::runtime_error("Insufficient capacity in cache");
223     }
224     while (_objects.size() > _capacity - n) {
225       key_type discard_key = _ranked_keys.back();
226       _ranked_keys.pop_back();
227       _objects.erase(discard_key);
228     }
229   }
230 
231  public:
ObjectCache(size_t capacity=8)232   inline ObjectCache(size_t capacity = 8) : _capacity(capacity) {}
resize(size_t capacity)233   inline void resize(size_t capacity) {
234     _capacity = capacity;
235     this->discard_old();
236   }
contains(const key_type & k) const237   inline bool contains(const key_type& k) const {
238     return (bool)_objects.count(k);
239   }
touch(const key_type & k)240   inline void touch(const key_type& k) {
241     if (!this->contains(k)) {
242       throw std::runtime_error("Key not found in cache");
243     }
244     rank_iterator rank = std::find(_ranked_keys.begin(), _ranked_keys.end(), k);
245     if (rank != _ranked_keys.begin()) {
246       // Move key to front of ranks
247       _ranked_keys.erase(rank);
248       _ranked_keys.push_front(k);
249     }
250   }
get(const key_type & k)251   inline value_type& get(const key_type& k) {
252     if (!this->contains(k)) {
253       throw std::runtime_error("Key not found in cache");
254     }
255     this->touch(k);
256     return _objects[k];
257   }
insert(const key_type & k,const value_type & v=value_type ())258   inline value_type& insert(const key_type& k,
259                             const value_type& v = value_type()) {
260     this->discard_old(1);
261     _ranked_keys.push_front(k);
262     return _objects.insert(std::make_pair(k, v)).first->second;
263   }
264   template <typename... Args>
emplace(const key_type & k,Args &&...args)265   inline value_type& emplace(const key_type& k, Args&&... args) {
266     this->discard_old(1);
267     // Note: Use of piecewise_construct allows non-movable non-copyable types
268     auto iter = _objects
269                     .emplace(std::piecewise_construct, std::forward_as_tuple(k),
270                              std::forward_as_tuple(args...))
271                     .first;
272     _ranked_keys.push_front(iter->first);
273     return iter->second;
274   }
275 };
276 
277 namespace detail {
278 
279 // Convenience wrapper for std::vector that provides handy constructors
280 template <typename T>
281 class vector : public std::vector<T> {
282   typedef std::vector<T> super_type;
283 
284  public:
vector()285   vector() : super_type() {}
vector(size_t n)286   vector(size_t n) : super_type(n) {}  // Note: Not explicit, allows =0
vector(std::vector<T> const & vals)287   vector(std::vector<T> const& vals) : super_type(vals) {}
288   template <int N>
vector(T const (& vals)[N])289   vector(T const (&vals)[N]) : super_type(vals, vals + N) {}
vector(std::vector<T> && vals)290   vector(std::vector<T>&& vals) : super_type(vals) {}
vector(std::initializer_list<T> vals)291   vector(std::initializer_list<T> vals) : super_type(vals) {}
292 };
293 
294 // Helper functions for parsing/manipulating source code
295 
replace_characters(std::string str,std::string const & oldchars,char newchar)296 inline std::string replace_characters(std::string str,
297                                       std::string const& oldchars,
298                                       char newchar) {
299   size_t i = str.find_first_of(oldchars);
300   while (i != std::string::npos) {
301     str[i] = newchar;
302     i = str.find_first_of(oldchars, i + 1);
303   }
304   return str;
305 }
sanitize_filename(std::string name)306 inline std::string sanitize_filename(std::string name) {
307   return replace_characters(name, "/\\.-: ?%*|\"<>", '_');
308 }
309 
310 #if JITIFY_ENABLE_EMBEDDED_FILES
311 class EmbeddedData {
312   void* _app;
313   EmbeddedData(EmbeddedData const&);
314   EmbeddedData& operator=(EmbeddedData const&);
315 
316  public:
EmbeddedData()317   EmbeddedData() {
318     _app = dlopen(NULL, RTLD_LAZY);
319     if (!_app) {
320       throw std::runtime_error(std::string("dlopen failed: ") + dlerror());
321     }
322     dlerror();  // Clear any existing error
323   }
~EmbeddedData()324   ~EmbeddedData() {
325     if (_app) {
326       dlclose(_app);
327     }
328   }
operator [](std::string key) const329   const uint8_t* operator[](std::string key) const {
330     key = sanitize_filename(key);
331     key = "_binary_" + key;
332     uint8_t const* data = (uint8_t const*)dlsym(_app, key.c_str());
333     if (!data) {
334       throw std::runtime_error(std::string("dlsym failed: ") + dlerror());
335     }
336     return data;
337   }
begin(std::string key) const338   const uint8_t* begin(std::string key) const {
339     return (*this)[key + "_start"];
340   }
end(std::string key) const341   const uint8_t* end(std::string key) const { return (*this)[key + "_end"]; }
342 };
343 #endif  // JITIFY_ENABLE_EMBEDDED_FILES
344 
is_tokenchar(char c)345 inline bool is_tokenchar(char c) {
346   return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
347          (c >= '0' && c <= '9') || c == '_';
348 }
replace_token(std::string src,std::string token,std::string replacement)349 inline std::string replace_token(std::string src, std::string token,
350                                  std::string replacement) {
351   size_t i = src.find(token);
352   while (i != std::string::npos) {
353     if (i == 0 || i == src.size() - token.size() ||
354         (!is_tokenchar(src[i - 1]) && !is_tokenchar(src[i + token.size()]))) {
355       src.replace(i, token.size(), replacement);
356       i += replacement.size();
357     } else {
358       i += token.size();
359     }
360     i = src.find(token, i);
361   }
362   return src;
363 }
path_base(std::string p)364 inline std::string path_base(std::string p) {
365   // "/usr/local/myfile.dat" -> "/usr/local"
366   // "foo/bar"  -> "foo"
367   // "foo/bar/" -> "foo/bar"
368 #if defined _WIN32 || defined _WIN64
369   char sep = '\\';
370 #else
371   char sep = '/';
372 #endif
373   size_t i = p.find_last_of(sep);
374   if (i != std::string::npos) {
375     return p.substr(0, i);
376   } else {
377     return "";
378   }
379 }
path_join(std::string p1,std::string p2)380 inline std::string path_join(std::string p1, std::string p2) {
381 #ifdef _WIN32
382   char sep = '\\';
383 #else
384   char sep = '/';
385 #endif
386   if (p1.size() && p2.size() && p2[0] == sep) {
387     throw std::invalid_argument("Cannot join to absolute path");
388   }
389   if (p1.size() && p1[p1.size() - 1] != sep) {
390     p1 += sep;
391   }
392   return p1 + p2;
393 }
394 // Elides "/." and "/.." tokens from path.
path_simplify(const std::string & path)395 inline std::string path_simplify(const std::string& path) {
396   std::vector<std::string> dirs;
397   std::string cur_dir;
398   bool after_slash = false;
399   for (int i = 0; i < (int)path.size(); ++i) {
400     if (path[i] == '/') {
401       if (after_slash) continue;  // Ignore repeat slashes
402       after_slash = true;
403       if (cur_dir == ".." && !dirs.empty() && dirs.back() != "..") {
404         if (dirs.size() == 1 && dirs.front().empty()) {
405           throw std::runtime_error(
406               "Invalid path: back-traversals exceed depth of absolute path");
407         }
408         dirs.pop_back();
409       } else if (cur_dir != ".") {  // Ignore /./
410         dirs.push_back(cur_dir);
411       }
412       cur_dir.clear();
413     } else {
414       after_slash = false;
415       cur_dir.push_back(path[i]);
416     }
417   }
418   if (!after_slash) {
419     dirs.push_back(cur_dir);
420   }
421   std::stringstream ss;
422   for (int i = 0; i < (int)dirs.size() - 1; ++i) {
423     ss << dirs[i] << "/";
424   }
425   if (!dirs.empty()) ss << dirs.back();
426   if (after_slash) ss << "/";
427   return ss.str();
428 }
hash_larson64(const char * s,unsigned long long seed=0)429 inline unsigned long long hash_larson64(const char* s,
430                                         unsigned long long seed = 0) {
431   unsigned long long hash = seed;
432   while (*s) {
433     hash = hash * 101 + *s++;
434   }
435   return hash;
436 }
437 
hash_combine(uint64_t a,uint64_t b)438 inline uint64_t hash_combine(uint64_t a, uint64_t b) {
439   // Note: The magic number comes from the golden ratio
440   return a ^ (0x9E3779B97F4A7C17ull + b + (b >> 2) + (a << 6));
441 }
442 
extract_include_info_from_compile_error(std::string log,std::string & name,std::string & parent,int & line_num)443 inline bool extract_include_info_from_compile_error(std::string log,
444                                                     std::string& name,
445                                                     std::string& parent,
446                                                     int& line_num) {
447   static const std::vector<std::string> pattern = {
448       "could not open source file \"", "cannot open source file \""};
449 
450   for (auto& p : pattern) {
451     size_t beg = log.find(p);
452     if (beg != std::string::npos) {
453       beg += p.size();
454       size_t end = log.find("\"", beg);
455       name = log.substr(beg, end - beg);
456 
457       size_t line_beg = log.rfind("\n", beg);
458       if (line_beg == std::string::npos) {
459         line_beg = 0;
460       } else {
461         line_beg += 1;
462       }
463 
464       size_t split = log.find("(", line_beg);
465       parent = log.substr(line_beg, split - line_beg);
466       line_num =
467           atoi(log.substr(split + 1, log.find(")", split + 1) - (split + 1))
468                    .c_str());
469 
470       return true;
471     }
472   }
473 
474   return false;
475 }
476 
is_include_directive_with_quotes(const std::string & source,int line_num)477 inline bool is_include_directive_with_quotes(const std::string& source,
478                                              int line_num) {
479   // TODO: Check each find() for failure.
480   size_t beg = 0;
481   for (int i = 1; i < line_num; ++i) {
482     beg = source.find("\n", beg) + 1;
483   }
484   beg = source.find("include", beg) + 7;
485   beg = source.find_first_of("\"<", beg);
486   return source[beg] == '"';
487 }
488 
comment_out_code_line(int line_num,std::string source)489 inline std::string comment_out_code_line(int line_num, std::string source) {
490   size_t beg = 0;
491   for (int i = 1; i < line_num; ++i) {
492     beg = source.find("\n", beg) + 1;
493   }
494   return (source.substr(0, beg) + "//" + source.substr(beg));
495 }
496 
print_with_line_numbers(std::string const & source)497 inline void print_with_line_numbers(std::string const& source) {
498   int linenum = 1;
499   std::stringstream source_ss(source);
500   for (std::string line; std::getline(source_ss, line); ++linenum) {
501     std::cout << std::setfill(' ') << std::setw(3) << linenum << " " << line
502               << std::endl;
503   }
504 }
505 
print_compile_log(std::string program_name,std::string const & log)506 inline void print_compile_log(std::string program_name,
507                               std::string const& log) {
508   std::cout << "---------------------------------------------------"
509             << std::endl;
510   std::cout << "--- JIT compile log for " << program_name << " ---"
511             << std::endl;
512   std::cout << "---------------------------------------------------"
513             << std::endl;
514   std::cout << log << std::endl;
515   std::cout << "---------------------------------------------------"
516             << std::endl;
517 }
518 
split_string(std::string str,long maxsplit=-1,std::string delims=" \\t")519 inline std::vector<std::string> split_string(std::string str,
520                                              long maxsplit = -1,
521                                              std::string delims = " \t") {
522   std::vector<std::string> results;
523   if (maxsplit == 0) {
524     results.push_back(str);
525     return results;
526   }
527   // Note: +1 to include NULL-terminator
528   std::vector<char> v_str(str.c_str(), str.c_str() + (str.size() + 1));
529   char* c_str = v_str.data();
530   char* saveptr = c_str;
531   char* token = nullptr;
532   for (long i = 0; i != maxsplit; ++i) {
533     token = ::strtok_r(c_str, delims.c_str(), &saveptr);
534     c_str = 0;
535     if (!token) {
536       return results;
537     }
538     results.push_back(token);
539   }
540   // Check if there's a final piece
541   token += ::strlen(token) + 1;
542   if (token - v_str.data() < (ptrdiff_t)str.size()) {
543     // Find the start of the final piece
544     token += ::strspn(token, delims.c_str());
545     if (*token) {
546       results.push_back(token);
547     }
548   }
549   return results;
550 }
551 
552 static const std::map<std::string, std::string>& get_jitsafe_headers_map();
553 
load_source(std::string filename,std::map<std::string,std::string> & sources,std::string current_dir="",std::vector<std::string> include_paths=std::vector<std::string> (),file_callback_type file_callback=0,std::map<std::string,std::string> * fullpaths=nullptr,bool search_current_dir=true)554 inline bool load_source(
555     std::string filename, std::map<std::string, std::string>& sources,
556     std::string current_dir = "",
557     std::vector<std::string> include_paths = std::vector<std::string>(),
558     file_callback_type file_callback = 0,
559     std::map<std::string, std::string>* fullpaths = nullptr,
560     bool search_current_dir = true) {
561   std::istream* source_stream = 0;
562   std::stringstream string_stream;
563   std::ifstream file_stream;
564   // First detect direct source-code string ("my_program\nprogram_code...")
565   size_t newline_pos = filename.find("\n");
566   if (newline_pos != std::string::npos) {
567     std::string source = filename.substr(newline_pos + 1);
568     filename = filename.substr(0, newline_pos);
569     string_stream << source;
570     source_stream = &string_stream;
571   }
572   if (sources.count(filename)) {
573     // Already got this one
574     return true;
575   }
576   if (!source_stream) {
577     std::string fullpath = path_join(current_dir, filename);
578     // Try loading from callback
579     if (!file_callback ||
580         !(source_stream = file_callback(fullpath, string_stream))) {
581 #if JITIFY_ENABLE_EMBEDDED_FILES
582       // Try loading as embedded file
583       EmbeddedData embedded;
584       std::string source;
585       try {
586         source.assign(embedded.begin(fullpath), embedded.end(fullpath));
587         string_stream << source;
588         source_stream = &string_stream;
589       } catch (std::runtime_error const&)
590 #endif  // JITIFY_ENABLE_EMBEDDED_FILES
591       {
592         // Try loading from filesystem
593         bool found_file = false;
594         if (search_current_dir) {
595           file_stream.open(fullpath.c_str());
596           if (file_stream) {
597             source_stream = &file_stream;
598             found_file = true;
599           }
600         }
601         // Search include directories
602         if (!found_file) {
603           for (int i = 0; i < (int)include_paths.size(); ++i) {
604             fullpath = path_join(include_paths[i], filename);
605             file_stream.open(fullpath.c_str());
606             if (file_stream) {
607               source_stream = &file_stream;
608               found_file = true;
609               break;
610             }
611           }
612           if (!found_file) {
613             // Try loading from builtin headers
614             fullpath = path_join("__jitify_builtin", filename);
615             auto it = get_jitsafe_headers_map().find(filename);
616             if (it != get_jitsafe_headers_map().end()) {
617               string_stream << it->second;
618               source_stream = &string_stream;
619             } else {
620               return false;
621             }
622           }
623         }
624       }
625     }
626     if (fullpaths) {
627       // Record the full file path corresponding to this include name.
628       (*fullpaths)[filename] = path_simplify(fullpath);
629     }
630   }
631   sources[filename] = std::string();
632   std::string& source = sources[filename];
633   std::string line;
634   size_t linenum = 0;
635   unsigned long long hash = 0;
636   bool pragma_once = false;
637   bool remove_next_blank_line = false;
638   while (std::getline(*source_stream, line)) {
639     ++linenum;
640 
641     // HACK WAR for static variables not allowed on the device (unless
642     // __shared__)
643     // TODO: This breaks static member variables
644     // line = replace_token(line, "static const", "/*static*/ const");
645 
646     // TODO: Need to watch out for /* */ comments too
647     std::string cleanline =
648         line.substr(0, line.find("//"));  // Strip line comments
649     // if( cleanline.back() == "\r" ) { // Remove Windows line ending
650     //	cleanline = cleanline.substr(0, cleanline.size()-1);
651     //}
652     // TODO: Should trim whitespace before checking .empty()
653     if (cleanline.empty() && remove_next_blank_line) {
654       remove_next_blank_line = false;
655       continue;
656     }
657     // Maintain a file hash for use in #pragma once WAR
658     hash = hash_larson64(line.c_str(), hash);
659     if (cleanline.find("#pragma once") != std::string::npos) {
660       pragma_once = true;
661       // Note: This is an attempt to recover the original line numbering,
662       //         which otherwise gets off-by-one due to the include guard.
663       remove_next_blank_line = true;
664       // line = "//" + line; // Comment out the #pragma once line
665       continue;
666     }
667 
668     // HACK WAR for Thrust using "#define FOO #pragma bar"
669     size_t pragma_beg = cleanline.find("#pragma ");
670     if (pragma_beg != std::string::npos) {
671       std::string line_after_pragma = line.substr(pragma_beg);
672       std::vector<std::string> pragma_split =
673           split_string(line_after_pragma, 2);
674       line =
675           (line.substr(0, pragma_beg) + "_Pragma(\"" + pragma_split[1] + "\")");
676       if (pragma_split.size() == 3) {
677         line += " " + pragma_split[2];
678       }
679     }
680 
681     source += line + "\n";
682   }
683   // HACK TESTING (WAR for cub)
684   // source = "#define cudaDeviceSynchronize() cudaSuccess\n" + source;
685   ////source = "cudaError_t cudaDeviceSynchronize() { return cudaSuccess; }\n" +
686   /// source;
687 
688   // WAR for #pragma once causing problems when there are multiple inclusions
689   //   of the same header from different paths.
690   if (pragma_once) {
691     std::stringstream ss;
692     ss << std::uppercase << std::hex << std::setw(8) << std::setfill('0')
693        << hash;
694     std::string include_guard_name = "_JITIFY_INCLUDE_GUARD_" + ss.str() + "\n";
695     std::string include_guard_header;
696     include_guard_header += "#ifndef " + include_guard_name;
697     include_guard_header += "#define " + include_guard_name;
698     std::string include_guard_footer;
699     include_guard_footer += "#endif // " + include_guard_name;
700     source = include_guard_header + source + "\n" + include_guard_footer;
701   }
702   // return filename;
703   return true;
704 }
705 
706 }  // namespace detail
707 
708 //! \endcond
709 
710 /*! Jitify reflection utilities namespace
711  */
712 namespace reflection {
713 
714 //  Provides type and value reflection via a function 'reflect':
715 //    reflect<Type>()   -> "Type"
716 //    reflect(value)    -> "(T)value"
717 //    reflect<VAL>()    -> "VAL"
718 //    reflect<Type,VAL> -> "VAL"
719 //    reflect_template<float,NonType<int,7>,char>() -> "<float,7,char>"
720 //    reflect_template({"float", "7", "char"}) -> "<float,7,char>"
721 
722 /*! A wrapper class for non-type template parameters.
723  */
724 template <typename T, T VALUE_>
725 struct NonType {
726   constexpr static T VALUE = VALUE_;
727 };
728 
729 // Forward declaration
730 template <typename T>
731 inline std::string reflect(T const& value);
732 
733 //! \cond
734 
735 namespace detail {
736 
737 template <typename T>
value_string(const T & x)738 inline std::string value_string(const T& x) {
739   std::stringstream ss;
740   ss << x;
741   return ss.str();
742 }
743 // WAR for non-printable characters
744 template <>
value_string(const char & x)745 inline std::string value_string<char>(const char& x) {
746   std::stringstream ss;
747   ss << (int)x;
748   return ss.str();
749 }
750 template <>
value_string(const signed char & x)751 inline std::string value_string<signed char>(const signed char& x) {
752   std::stringstream ss;
753   ss << (int)x;
754   return ss.str();
755 }
756 template <>
value_string(const unsigned char & x)757 inline std::string value_string<unsigned char>(const unsigned char& x) {
758   std::stringstream ss;
759   ss << (int)x;
760   return ss.str();
761 }
762 template <>
value_string(const wchar_t & x)763 inline std::string value_string<wchar_t>(const wchar_t& x) {
764   std::stringstream ss;
765   ss << (long)x;
766   return ss.str();
767 }
768 // Specialisation for bool true/false literals
769 template <>
value_string(const bool & x)770 inline std::string value_string<bool>(const bool& x) {
771   return x ? "true" : "false";
772 }
773 
774 // Removes all tokens that start with double underscores.
strip_double_underscore_tokens(char * s)775 inline void strip_double_underscore_tokens(char* s) {
776   using jitify::detail::is_tokenchar;
777   char* w = s;
778   do {
779     if (*s == '_' && *(s + 1) == '_') {
780       while (is_tokenchar(*++s))
781         ;
782     }
783   } while ((*w++ = *s++));
784 }
785 
786 //#if CUDA_VERSION < 8000
787 #ifdef _MSC_VER  // MSVC compiler
demangle_cuda_symbol(const char * mangled_name)788 inline std::string demangle_cuda_symbol(const char* mangled_name) {
789   // We don't have a way to demangle CUDA symbol names under MSVC.
790   return mangled_name;
791 }
demangle_native_type(const std::type_info & typeinfo)792 inline std::string demangle_native_type(const std::type_info& typeinfo) {
793   // Get the decorated name and skip over the leading '.'.
794   const char* decorated_name = typeinfo.raw_name() + 1;
795   char undecorated_name[4096];
796   if (UnDecorateSymbolName(
797           decorated_name, undecorated_name,
798           sizeof(undecorated_name) / sizeof(*undecorated_name),
799           UNDNAME_NO_ARGUMENTS |          // Treat input as a type name
800               UNDNAME_NAME_ONLY           // No "class" and "struct" prefixes
801           /*UNDNAME_NO_MS_KEYWORDS*/)) {  // No "__cdecl", "__ptr64" etc.
802     // WAR for UNDNAME_NO_MS_KEYWORDS messing up function types.
803     strip_double_underscore_tokens(undecorated_name);
804     return undecorated_name;
805   }
806   throw std::runtime_error("UnDecorateSymbolName failed");
807 }
808 #else   // not MSVC
demangle_cuda_symbol(const char * mangled_name)809 inline std::string demangle_cuda_symbol(const char* mangled_name) {
810   size_t bufsize = 0;
811   char* buf = nullptr;
812   std::string demangled_name;
813   int status;
814   auto demangled_ptr = std::unique_ptr<char, decltype(free)*>(
815       abi::__cxa_demangle(mangled_name, buf, &bufsize, &status), free);
816   if (status == 0) {
817     demangled_name = demangled_ptr.get();  // all worked as expected
818   } else if (status == -2) {
819     demangled_name = mangled_name;  // we interpret this as plain C name
820   } else if (status == -1) {
821     throw std::runtime_error(
822         std::string("memory allocation failure in __cxa_demangle"));
823   } else if (status == -3) {
824     throw std::runtime_error(std::string("invalid argument to __cxa_demangle"));
825   }
826   return demangled_name;
827 }
demangle_native_type(const std::type_info & typeinfo)828 inline std::string demangle_native_type(const std::type_info& typeinfo) {
829   return demangle_cuda_symbol(typeinfo.name());
830 }
831 #endif  // not MSVC
832 //#endif // CUDA_VERSION < 8000
833 
834 template <typename>
835 class JitifyTypeNameWrapper_ {};
836 
837 template <typename T>
838 struct type_reflection {
namejitify::reflection::detail::type_reflection839   inline static std::string name() {
840     //#if CUDA_VERSION < 8000
841     // TODO: Use nvrtcGetTypeName once it has the same behavior as this.
842     // WAR for typeid discarding cv qualifiers on value-types
843     // Wrap type in dummy template class to preserve cv-qualifiers, then strip
844     // off the wrapper from the resulting string.
845     std::string wrapped_name =
846         demangle_native_type(typeid(JitifyTypeNameWrapper_<T>));
847     // Note: The reflected name of this class also has namespace prefixes.
848     const std::string wrapper_class_name = "JitifyTypeNameWrapper_<";
849     size_t start = wrapped_name.find(wrapper_class_name);
850     if (start == std::string::npos) {
851       throw std::runtime_error("Type reflection failed: " + wrapped_name);
852     }
853     start += wrapper_class_name.size();
854     std::string name =
855         wrapped_name.substr(start, wrapped_name.size() - (start + 1));
856     return name;
857     //#else
858     //         std::string ret;
859     //         nvrtcResult status = nvrtcGetTypeName<T>(&ret);
860     //         if( status != NVRTC_SUCCESS ) {
861     //                 throw std::runtime_error(std::string("nvrtcGetTypeName
862     // failed:
863     //")+ nvrtcGetErrorString(status));
864     //         }
865     //         return ret;
866     //#endif
867   }
868 };  // namespace detail
869 template <typename T, T VALUE>
870 struct type_reflection<NonType<T, VALUE> > {
namejitify::reflection::detail::type_reflection871   inline static std::string name() {
872     return jitify::reflection::reflect(VALUE);
873   }
874 };
875 
876 }  // namespace detail
877 
878 //! \endcond
879 
880 /*! Create an Instance object that contains a const reference to the
881  *  value.  We use this to wrap abstract objects from which we want to extract
882  *  their type at runtime (e.g., derived type).  This is used to facilitate
883  *  templating on derived type when all we know at compile time is abstract
884  * type.
885  */
886 template <typename T>
887 struct Instance {
888   const T& value;
Instancejitify::reflection::Instance889   Instance(const T& value) : value(value) {}
890 };
891 
892 /*! Create an Instance object from which we can extract the value's run-time
893  * type.
894  *  \param value The const value to be captured.
895  */
896 template <typename T>
instance_of(T const & value)897 inline Instance<T const> instance_of(T const& value) {
898   return Instance<T const>(value);
899 }
900 
901 /*! A wrapper used for representing types as values.
902  */
903 template <typename T>
904 struct Type {};
905 
906 // Type reflection
907 // E.g., reflect<float>() -> "float"
908 // Note: This strips trailing const and volatile qualifiers
909 /*! Generate a code-string for a type.
910  *  \code{.cpp}reflect<float>() --> "float"\endcode
911  */
912 template <typename T>
reflect()913 inline std::string reflect() {
914   return detail::type_reflection<T>::name();
915 }
916 // Value reflection
917 // E.g., reflect(3.14f) -> "(float)3.14"
918 /*! Generate a code-string for a value.
919  *  \code{.cpp}reflect(3.14f) --> "(float)3.14"\endcode
920  */
921 template <typename T>
reflect(T const & value)922 inline std::string reflect(T const& value) {
923   return "(" + reflect<T>() + ")" + detail::value_string(value);
924 }
925 // Non-type template arg reflection (implicit conversion to int64_t)
926 // E.g., reflect<7>() -> "(int64_t)7"
927 /*! Generate a code-string for an integer non-type template argument.
928  *  \code{.cpp}reflect<7>() --> "(int64_t)7"\endcode
929  */
930 template <int64_t N>
reflect()931 inline std::string reflect() {
932   return reflect<NonType<int64_t, N> >();
933 }
934 // Non-type template arg reflection (explicit type)
935 // E.g., reflect<int,7>() -> "(int)7"
936 /*! Generate a code-string for a generic non-type template argument.
937  *  \code{.cpp} reflect<int,7>() --> "(int)7" \endcode
938  */
939 template <typename T, T N>
reflect()940 inline std::string reflect() {
941   return reflect<NonType<T, N> >();
942 }
943 // Type reflection via value
944 // E.g., reflect(Type<float>()) -> "float"
945 /*! Generate a code-string for a type wrapped as a Type instance.
946  *  \code{.cpp}reflect(Type<float>()) --> "float"\endcode
947  */
948 template <typename T>
reflect(jitify::reflection::Type<T>)949 inline std::string reflect(jitify::reflection::Type<T>) {
950   return reflect<T>();
951 }
952 
953 /*! Generate a code-string for a type wrapped as an Instance instance.
954  *  \code{.cpp}reflect(Instance<float>(3.1f)) --> "float"\endcode
955  *  or more simply when passed to a instance_of helper
956  *  \code{.cpp}reflect(instance_of(3.1f)) --> "float"\endcodei
957  *  This is specifically for the case where we want to extract the run-time
958  * type, e.g., derived type, of an object pointer.
959  */
960 template <typename T>
reflect(jitify::reflection::Instance<T> & value)961 inline std::string reflect(jitify::reflection::Instance<T>& value) {
962   return detail::demangle_native_type(typeid(value.value));
963 }
964 
965 // Type from value
966 // E.g., type_of(3.14f) -> Type<float>()
967 /*! Create a Type object representing a value's type.
968  *  \param value The value whose type is to be captured.
969  */
970 template <typename T>
type_of(T & value)971 inline Type<T> type_of(T& value) {
972   return Type<T>();
973 }
974 /*! Create a Type object representing a value's type.
975  *  \param value The const value whose type is to be captured.
976  */
977 template <typename T>
type_of(T const & value)978 inline Type<T const> type_of(T const& value) {
979   return Type<T const>();
980 }
981 
982 // Multiple value reflections one call, returning list of strings
983 template <typename... Args>
reflect_all(Args...args)984 inline std::vector<std::string> reflect_all(Args... args) {
985   return {reflect(args)...};
986 }
987 
reflect_list(jitify::detail::vector<std::string> const & args,std::string opener="",std::string closer="")988 inline std::string reflect_list(jitify::detail::vector<std::string> const& args,
989                                 std::string opener = "",
990                                 std::string closer = "") {
991   std::stringstream ss;
992   ss << opener;
993   for (int i = 0; i < (int)args.size(); ++i) {
994     if (i > 0) ss << ",";
995     ss << args[i];
996   }
997   ss << closer;
998   return ss.str();
999 }
1000 
1001 // Template instantiation reflection
1002 // inline std::string reflect_template(std::vector<std::string> const& args) {
reflect_template(jitify::detail::vector<std::string> const & args)1003 inline std::string reflect_template(
1004     jitify::detail::vector<std::string> const& args) {
1005   // Note: The space in " >" is a WAR to avoid '>>' appearing
1006   return reflect_list(args, "<", " >");
1007 }
1008 // TODO: See if can make this evaluate completely at compile-time
1009 template <typename... Ts>
reflect_template()1010 inline std::string reflect_template() {
1011   return reflect_template({reflect<Ts>()...});
1012   // return reflect_template<sizeof...(Ts)>({reflect<Ts>()...});
1013 }
1014 
1015 }  // namespace reflection
1016 
1017 //! \cond
1018 
1019 namespace detail {
1020 
1021 // Demangles nested variable names using the PTX name mangling scheme
1022 // (which follows the Itanium64 ABI). E.g., _ZN1a3Foo2bcE -> a::Foo::bc.
demangle_ptx_variable_name(const char * name)1023 inline std::string demangle_ptx_variable_name(const char* name) {
1024   std::stringstream ss;
1025   const char* c = name;
1026   if (*c++ != '_' || *c++ != 'Z') return name;  // Non-mangled name
1027   if (*c++ != 'N') return "";  // Not a nested name, unsupported
1028   while (true) {
1029     // Parse identifier length.
1030     int n = 0;
1031     while (std::isdigit(*c)) {
1032       n = n * 10 + (*c - '0');
1033       c++;
1034     }
1035     if (!n) return "";  // Invalid or unsupported mangled name
1036     // Parse identifier.
1037     const char* c0 = c;
1038     while (n-- && *c) c++;
1039     if (!*c) return "";  // Mangled name is truncated
1040     std::string id(c0, c);
1041     // Identifiers starting with "_GLOBAL" are anonymous namespaces.
1042     ss << (id.substr(0, 7) == "_GLOBAL" ? "(anonymous namespace)" : id);
1043     // Nested name specifiers end with 'E'.
1044     if (*c == 'E') break;
1045     // There are more identifiers to come, add join token.
1046     ss << "::";
1047   }
1048   return ss.str();
1049 }
1050 
get_current_executable_path()1051 static const char* get_current_executable_path() {
1052   static const char* path = []() -> const char* {
1053     static char buffer[JITIFY_PATH_MAX] = {};
1054 #ifdef __linux__
1055     if (!::realpath("/proc/self/exe", buffer)) return nullptr;
1056 #elif defined(_WIN32) || defined(_WIN64)
1057     if (!GetModuleFileNameA(nullptr, buffer, JITIFY_PATH_MAX)) return nullptr;
1058 #endif
1059     return buffer;
1060   }();
1061   return path;
1062 }
1063 
endswith(const std::string & str,const std::string & suffix)1064 inline bool endswith(const std::string& str, const std::string& suffix) {
1065   return str.size() >= suffix.size() &&
1066          str.substr(str.size() - suffix.size()) == suffix;
1067 }
1068 
1069 // Infers the JIT input type from the filename suffix. If no known suffix is
1070 // present, the filename is assumed to refer to a library, and the associated
1071 // suffix (and possibly prefix) is automatically added to the filename.
get_cuda_jit_input_type(std::string * filename)1072 inline CUjitInputType get_cuda_jit_input_type(std::string* filename) {
1073   if (endswith(*filename, ".ptx")) {
1074     return CU_JIT_INPUT_PTX;
1075   } else if (endswith(*filename, ".cubin")) {
1076     return CU_JIT_INPUT_CUBIN;
1077   } else if (endswith(*filename, ".fatbin")) {
1078     return CU_JIT_INPUT_FATBINARY;
1079   } else if (endswith(*filename,
1080 #if defined _WIN32 || defined _WIN64
1081                       ".obj"
1082 #else  // Linux
1083                       ".o"
1084 #endif
1085                       )) {
1086     return CU_JIT_INPUT_OBJECT;
1087   } else {  // Assume library
1088 #if defined _WIN32 || defined _WIN64
1089     if (!endswith(*filename, ".lib")) {
1090       *filename += ".lib";
1091     }
1092 #else  // Linux
1093     if (!endswith(*filename, ".a")) {
1094       *filename = "lib" + *filename + ".a";
1095     }
1096 #endif
1097     return CU_JIT_INPUT_LIBRARY;
1098   }
1099 }
1100 
1101 class CUDAKernel {
1102   std::vector<std::string> _link_files;
1103   std::vector<std::string> _link_paths;
1104   CUlinkState _link_state;
1105   CUmodule _module;
1106   CUfunction _kernel;
1107   std::string _func_name;
1108   std::string _ptx;
1109   std::map<std::string, std::string> _global_map;
1110   std::vector<CUjit_option> _opts;
1111   std::vector<void*> _optvals;
1112 #ifdef JITIFY_PRINT_LINKER_LOG
1113   static const unsigned int _log_size = 8192;
1114   char _error_log[_log_size];
1115   char _info_log[_log_size];
1116 #endif
1117 
cuda_safe_call(CUresult res) const1118   inline void cuda_safe_call(CUresult res) const {
1119     if (res != CUDA_SUCCESS) {
1120       const char* msg;
1121       cuGetErrorName(res, &msg);
1122       throw std::runtime_error(msg);
1123     }
1124   }
create_module(std::vector<std::string> link_files,std::vector<std::string> link_paths)1125   inline void create_module(std::vector<std::string> link_files,
1126                             std::vector<std::string> link_paths) {
1127     CUresult result;
1128 #ifndef JITIFY_PRINT_LINKER_LOG
1129     // WAR since linker log does not seem to be constructed using a single call
1130     // to cuModuleLoadDataEx.
1131     if (link_files.empty()) {
1132       result =
1133           cuModuleLoadDataEx(&_module, _ptx.c_str(), (unsigned)_opts.size(),
1134                              _opts.data(), _optvals.data());
1135     } else
1136 #endif
1137     {
1138       cuda_safe_call(cuLinkCreate((unsigned)_opts.size(), _opts.data(),
1139                                   _optvals.data(), &_link_state));
1140       cuda_safe_call(cuLinkAddData(_link_state, CU_JIT_INPUT_PTX,
1141                                    (void*)_ptx.c_str(), _ptx.size(),
1142                                    "jitified_source.ptx", 0, 0, 0));
1143       for (int i = 0; i < (int)link_files.size(); ++i) {
1144         std::string link_file = link_files[i];
1145         CUjitInputType jit_input_type;
1146         if (link_file == ".") {
1147           // Special case for linking to current executable.
1148           link_file = get_current_executable_path();
1149           jit_input_type = CU_JIT_INPUT_OBJECT;
1150         } else {
1151           // Infer based on filename.
1152           jit_input_type = get_cuda_jit_input_type(&link_file);
1153         }
1154         CUresult result = cuLinkAddFile(_link_state, jit_input_type,
1155                                         link_file.c_str(), 0, 0, 0);
1156         int path_num = 0;
1157         while (result == CUDA_ERROR_FILE_NOT_FOUND &&
1158                path_num < (int)link_paths.size()) {
1159           std::string filename = path_join(link_paths[path_num++], link_file);
1160           result = cuLinkAddFile(_link_state, jit_input_type, filename.c_str(),
1161                                  0, 0, 0);
1162         }
1163 #if JITIFY_PRINT_LINKER_LOG
1164         if (result == CUDA_ERROR_FILE_NOT_FOUND) {
1165           std::cerr << "Linker error: Device library not found: " << link_file
1166                     << std::endl;
1167         } else if (result != CUDA_SUCCESS) {
1168           std::cerr << "Linker error: Failed to add file: " << link_file
1169                     << std::endl;
1170           std::cerr << _error_log << std::endl;
1171         }
1172 #endif
1173         cuda_safe_call(result);
1174       }
1175       size_t cubin_size;
1176       void* cubin;
1177       result = cuLinkComplete(_link_state, &cubin, &cubin_size);
1178       if (result == CUDA_SUCCESS) {
1179         result = cuModuleLoadData(&_module, cubin);
1180       }
1181     }
1182 #ifdef JITIFY_PRINT_LINKER_LOG
1183     std::cout << "---------------------------------------" << std::endl;
1184     std::cout << "--- Linker for "
1185               << reflection::detail::demangle_cuda_symbol(_func_name.c_str())
1186               << " ---" << std::endl;
1187     std::cout << "---------------------------------------" << std::endl;
1188     std::cout << _info_log << std::endl;
1189     std::cout << std::endl;
1190     std::cout << _error_log << std::endl;
1191     std::cout << "---------------------------------------" << std::endl;
1192 #endif
1193     cuda_safe_call(result);
1194     // Allow _func_name to be empty to support cases where we want to generate
1195     // PTX containing extern symbol definitions but no kernels.
1196     if (!_func_name.empty()) {
1197       cuda_safe_call(
1198           cuModuleGetFunction(&_kernel, _module, _func_name.c_str()));
1199     }
1200   }
destroy_module()1201   inline void destroy_module() {
1202     if (_link_state) {
1203       cuda_safe_call(cuLinkDestroy(_link_state));
1204     }
1205     _link_state = 0;
1206     if (_module) {
1207       cuModuleUnload(_module);
1208     }
1209     _module = 0;
1210   }
1211 
1212   // create a map of __constant__ and __device__ variables in the ptx file
1213   // mapping demangled to mangled name
create_global_variable_map()1214   inline void create_global_variable_map() {
1215     size_t pos = 0;
1216     while (pos < _ptx.size()) {
1217       pos = std::min(_ptx.find(".const .align", pos),
1218                      _ptx.find(".global .align", pos));
1219       if (pos == std::string::npos) break;
1220       size_t end = _ptx.find_first_of(";=", pos);
1221       if (_ptx[end] == '=') --end;
1222       std::string line = _ptx.substr(pos, end - pos);
1223       pos = end;
1224       size_t symbol_start = line.find_last_of(" ") + 1;
1225       size_t symbol_end = line.find_last_of("[");
1226       std::string entry = line.substr(symbol_start, symbol_end - symbol_start);
1227       std::string key = detail::demangle_ptx_variable_name(entry.c_str());
1228       // Skip unsupported mangled names. E.g., a static variable defined inside
1229       // a function (such variables are not directly addressable from outside
1230       // the function, so skipping them is the correct behavior).
1231       if (key == "") continue;
1232       _global_map[key] = entry;
1233     }
1234   }
1235 
set_linker_log()1236   inline void set_linker_log() {
1237 #ifdef JITIFY_PRINT_LINKER_LOG
1238     _opts.push_back(CU_JIT_INFO_LOG_BUFFER);
1239     _optvals.push_back((void*)_info_log);
1240     _opts.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES);
1241     _optvals.push_back((void*)(long)_log_size);
1242     _opts.push_back(CU_JIT_ERROR_LOG_BUFFER);
1243     _optvals.push_back((void*)_error_log);
1244     _opts.push_back(CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES);
1245     _optvals.push_back((void*)(long)_log_size);
1246     _opts.push_back(CU_JIT_LOG_VERBOSE);
1247     _optvals.push_back((void*)1);
1248 #endif
1249   }
1250 
1251  public:
CUDAKernel()1252   inline CUDAKernel() : _link_state(0), _module(0), _kernel(0) {}
1253   inline CUDAKernel(const CUDAKernel& other) = delete;
1254   inline CUDAKernel& operator=(const CUDAKernel& other) = delete;
1255   inline CUDAKernel(CUDAKernel&& other) = delete;
1256   inline CUDAKernel& operator=(CUDAKernel&& other) = delete;
CUDAKernel(const char * func_name,const char * ptx,std::vector<std::string> link_files,std::vector<std::string> link_paths,unsigned int nopts=0,CUjit_option * opts=0,void ** optvals=0)1257   inline CUDAKernel(const char* func_name, const char* ptx,
1258                     std::vector<std::string> link_files,
1259                     std::vector<std::string> link_paths, unsigned int nopts = 0,
1260                     CUjit_option* opts = 0, void** optvals = 0)
1261       : _link_files(link_files),
1262         _link_paths(link_paths),
1263         _link_state(0),
1264         _module(0),
1265         _kernel(0),
1266         _func_name(func_name),
1267         _ptx(ptx),
1268         _opts(opts, opts + nopts),
1269         _optvals(optvals, optvals + nopts) {
1270     this->set_linker_log();
1271     this->create_module(link_files, link_paths);
1272     this->create_global_variable_map();
1273   }
1274 
set(const char * func_name,const char * ptx,std::vector<std::string> link_files,std::vector<std::string> link_paths,unsigned int nopts=0,CUjit_option * opts=0,void ** optvals=0)1275   inline CUDAKernel& set(const char* func_name, const char* ptx,
1276                          std::vector<std::string> link_files,
1277                          std::vector<std::string> link_paths,
1278                          unsigned int nopts = 0, CUjit_option* opts = 0,
1279                          void** optvals = 0) {
1280     this->destroy_module();
1281     _func_name = func_name;
1282     _ptx = ptx;
1283     _link_files = link_files;
1284     _link_paths = link_paths;
1285     _opts.assign(opts, opts + nopts);
1286     _optvals.assign(optvals, optvals + nopts);
1287     this->set_linker_log();
1288     this->create_module(link_files, link_paths);
1289     this->create_global_variable_map();
1290     return *this;
1291   }
~CUDAKernel()1292   inline ~CUDAKernel() { this->destroy_module(); }
operator CUfunction() const1293   inline operator CUfunction() const { return _kernel; }
1294 
launch(dim3 grid,dim3 block,unsigned int smem,CUstream stream,std::vector<void * > arg_ptrs) const1295   inline CUresult launch(dim3 grid, dim3 block, unsigned int smem,
1296                          CUstream stream, std::vector<void*> arg_ptrs) const {
1297     return cuLaunchKernel(_kernel, grid.x, grid.y, grid.z, block.x, block.y,
1298                           block.z, smem, stream, arg_ptrs.data(), NULL);
1299   }
1300 
get_global_ptr(const char * name,size_t * size=nullptr) const1301   inline CUdeviceptr get_global_ptr(const char* name,
1302                                     size_t* size = nullptr) const {
1303     CUdeviceptr global_ptr = 0;
1304     auto global = _global_map.find(name);
1305     if (global != _global_map.end()) {
1306       cuda_safe_call(cuModuleGetGlobal(&global_ptr, size, _module,
1307                                        global->second.c_str()));
1308     } else {
1309       throw std::runtime_error(std::string("failed to look up global ") + name);
1310     }
1311     return global_ptr;
1312   }
1313 
1314   template <typename T>
get_global_data(const char * name,T * data,size_t count,CUstream stream=0) const1315   inline CUresult get_global_data(const char* name, T* data, size_t count,
1316                                   CUstream stream = 0) const {
1317     size_t size_bytes;
1318     CUdeviceptr ptr = get_global_ptr(name, &size_bytes);
1319     size_t given_size_bytes = count * sizeof(T);
1320     if (given_size_bytes != size_bytes) {
1321       throw std::runtime_error(
1322           std::string("Value for global variable ") + name +
1323           " has wrong size: got " + std::to_string(given_size_bytes) +
1324           " bytes, expected " + std::to_string(size_bytes));
1325     }
1326     return cuMemcpyDtoH(data, ptr, size_bytes);
1327   }
1328 
1329   template <typename T>
set_global_data(const char * name,const T * data,size_t count,CUstream stream=0) const1330   inline CUresult set_global_data(const char* name, const T* data, size_t count,
1331                                   CUstream stream = 0) const {
1332     size_t size_bytes;
1333     CUdeviceptr ptr = get_global_ptr(name, &size_bytes);
1334     size_t given_size_bytes = count * sizeof(T);
1335     if (given_size_bytes != size_bytes) {
1336       throw std::runtime_error(
1337           std::string("Value for global variable ") + name +
1338           " has wrong size: got " + std::to_string(given_size_bytes) +
1339           " bytes, expected " + std::to_string(size_bytes));
1340     }
1341     return cuMemcpyHtoD(ptr, data, size_bytes);
1342   }
1343 
function_name() const1344   const std::string& function_name() const { return _func_name; }
ptx() const1345   const std::string& ptx() const { return _ptx; }
link_files() const1346   const std::vector<std::string>& link_files() const { return _link_files; }
link_paths() const1347   const std::vector<std::string>& link_paths() const { return _link_paths; }
1348 };
1349 
1350 static const char* jitsafe_header_preinclude_h = R"(
1351 //// WAR for Thrust (which appears to have forgotten to include this in result_of_adaptable_function.h
1352 //#include <type_traits>
1353 
1354 //// WAR for Thrust (which appear to have forgotten to include this in error_code.h)
1355 //#include <string>
1356 
1357 // WAR for Thrust (which only supports gnuc, clang or msvc)
1358 #define __GNUC__ 4
1359 
1360 // WAR for generics/shfl.h
1361 #define THRUST_STATIC_ASSERT(x)
1362 
1363 // WAR for CUB
1364 #ifdef __host__
1365 #undef __host__
1366 #endif
1367 #define __host__
1368 
1369 // WAR to allow exceptions to be parsed
1370 #define try
1371 #define catch(...)
1372 )";
1373 
1374 
1375 static const char* jitsafe_header_float_h = R"(
1376 #pragma once
1377 
1378 #define FLT_RADIX       2
1379 #define FLT_MANT_DIG    24
1380 #define DBL_MANT_DIG    53
1381 #define FLT_DIG         6
1382 #define DBL_DIG         15
1383 #define FLT_MIN_EXP     -125
1384 #define DBL_MIN_EXP     -1021
1385 #define FLT_MIN_10_EXP  -37
1386 #define DBL_MIN_10_EXP  -307
1387 #define FLT_MAX_EXP     128
1388 #define DBL_MAX_EXP     1024
1389 #define FLT_MAX_10_EXP  38
1390 #define DBL_MAX_10_EXP  308
1391 #define FLT_MAX         3.4028234e38f
1392 #define DBL_MAX         1.7976931348623157e308
1393 #define FLT_EPSILON     1.19209289e-7f
1394 #define DBL_EPSILON     2.220440492503130e-16
1395 #define FLT_MIN         1.1754943e-38f;
1396 #define DBL_MIN         2.2250738585072013e-308
1397 #define FLT_ROUNDS      1
1398 #if defined __cplusplus && __cplusplus >= 201103L
1399 #define FLT_EVAL_METHOD 0
1400 #define DECIMAL_DIG     21
1401 #endif
1402 )";
1403 
1404 static const char* jitsafe_header_limits_h = R"(
1405 #pragma once
1406 
1407 #if defined _WIN32 || defined _WIN64
1408  #define __WORDSIZE 32
1409 #else
1410  #if defined __x86_64__ && !defined __ILP32__
1411   #define __WORDSIZE 64
1412  #else
1413   #define __WORDSIZE 32
1414  #endif
1415 #endif
1416 #define MB_LEN_MAX  16
1417 #define CHAR_BIT    8
1418 #define SCHAR_MIN   (-128)
1419 #define SCHAR_MAX   127
1420 #define UCHAR_MAX   255
1421 enum {
1422   _JITIFY_CHAR_IS_UNSIGNED = (char)-1 >= 0,
1423   CHAR_MIN = _JITIFY_CHAR_IS_UNSIGNED ? 0 : SCHAR_MIN,
1424   CHAR_MAX = _JITIFY_CHAR_IS_UNSIGNED ? UCHAR_MAX : SCHAR_MAX,
1425 };
1426 #define SHRT_MIN    (-32768)
1427 #define SHRT_MAX    32767
1428 #define USHRT_MAX   65535
1429 #define INT_MIN     (-INT_MAX - 1)
1430 #define INT_MAX     2147483647
1431 #define UINT_MAX    4294967295U
1432 #if __WORDSIZE == 64
1433  # define LONG_MAX  9223372036854775807L
1434 #else
1435  # define LONG_MAX  2147483647L
1436 #endif
1437 #define LONG_MIN    (-LONG_MAX - 1L)
1438 #if __WORDSIZE == 64
1439  #define ULONG_MAX  18446744073709551615UL
1440 #else
1441  #define ULONG_MAX  4294967295UL
1442 #endif
1443 #define LLONG_MAX  9223372036854775807LL
1444 #define LLONG_MIN  (-LLONG_MAX - 1LL)
1445 #define ULLONG_MAX 18446744073709551615ULL
1446 )";
1447 
1448 static const char* jitsafe_header_iterator = R"(
1449 #pragma once
1450 
1451 namespace __jitify_iterator_ns {
1452 struct output_iterator_tag {};
1453 struct input_iterator_tag {};
1454 struct forward_iterator_tag {};
1455 struct bidirectional_iterator_tag {};
1456 struct random_access_iterator_tag {};
1457 template<class Iterator>
1458 struct iterator_traits {
1459   typedef typename Iterator::iterator_category iterator_category;
1460   typedef typename Iterator::value_type        value_type;
1461   typedef typename Iterator::difference_type   difference_type;
1462   typedef typename Iterator::pointer           pointer;
1463   typedef typename Iterator::reference         reference;
1464 };
1465 template<class T>
1466 struct iterator_traits<T*> {
1467   typedef random_access_iterator_tag iterator_category;
1468   typedef T                          value_type;
1469   typedef ptrdiff_t                  difference_type;
1470   typedef T*                         pointer;
1471   typedef T&                         reference;
1472 };
1473 template<class T>
1474 struct iterator_traits<T const*> {
1475   typedef random_access_iterator_tag iterator_category;
1476   typedef T                          value_type;
1477   typedef ptrdiff_t                  difference_type;
1478   typedef T const*                   pointer;
1479   typedef T const&                   reference;
1480 };
1481 } // namespace __jitify_iterator_ns
1482 namespace std { using namespace __jitify_iterator_ns; }
1483 using namespace __jitify_iterator_ns;
1484 )";
1485 
1486 // TODO: This is incomplete; need floating point limits
1487 //   Joe Eaton: added IEEE float and double types, none of the smaller types
1488 //              using type specific structs since we can't template on floats.
1489 static const char* jitsafe_header_limits = R"(
1490 #pragma once
1491 #include <climits>
1492 #include <cfloat>
1493 // TODO: epsilon(), infinity(), etc
1494 namespace __jitify_detail {
1495 #if __cplusplus >= 201103L
1496 #define JITIFY_CXX11_CONSTEXPR constexpr
1497 #define JITIFY_CXX11_NOEXCEPT noexcept
1498 #else
1499 #define JITIFY_CXX11_CONSTEXPR
1500 #define JITIFY_CXX11_NOEXCEPT
1501 #endif
1502 
1503 struct FloatLimits {
1504 #if __cplusplus >= 201103L
1505    static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1506           float lowest() JITIFY_CXX11_NOEXCEPT {   return -FLT_MAX;}
1507    static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1508           float min() JITIFY_CXX11_NOEXCEPT {      return FLT_MIN; }
1509    static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1510           float max() JITIFY_CXX11_NOEXCEPT {      return FLT_MAX; }
1511 #endif  // __cplusplus >= 201103L
1512    enum {
1513    is_specialized    = true,
1514    is_signed         = true,
1515    is_integer        = false,
1516    is_exact          = false,
1517    has_infinity      = true,
1518    has_quiet_NaN     = true,
1519    has_signaling_NaN = true,
1520    has_denorm        = 1,
1521    has_denorm_loss   = true,
1522    round_style       = 1,
1523    is_iec559         = true,
1524    is_bounded        = true,
1525    is_modulo         = false,
1526    digits            = 24,
1527    digits10          = 6,
1528    max_digits10      = 9,
1529    radix             = 2,
1530    min_exponent      = -125,
1531    min_exponent10    = -37,
1532    max_exponent      = 128,
1533    max_exponent10    = 38,
1534    tinyness_before   = false,
1535    traps             = false
1536    };
1537 };
1538 struct DoubleLimits {
1539 #if __cplusplus >= 201103L
1540    static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1541           double lowest() noexcept { return -DBL_MAX; }
1542    static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1543           double min() noexcept { return DBL_MIN; }
1544    static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1545           double max() noexcept { return DBL_MAX; }
1546 #endif  // __cplusplus >= 201103L
1547    enum {
1548    is_specialized    = true,
1549    is_signed         = true,
1550    is_integer        = false,
1551    is_exact          = false,
1552    has_infinity      = true,
1553    has_quiet_NaN     = true,
1554    has_signaling_NaN = true,
1555    has_denorm        = 1,
1556    has_denorm_loss   = true,
1557    round_style       = 1,
1558    is_iec559         = true,
1559    is_bounded        = true,
1560    is_modulo         = false,
1561    digits            = 53,
1562    digits10          = 15,
1563    max_digits10      = 17,
1564    radix             = 2,
1565    min_exponent      = -1021,
1566    min_exponent10    = -307,
1567    max_exponent      = 1024,
1568    max_exponent10    = 308,
1569    tinyness_before   = false,
1570    traps             = false
1571    };
1572 };
1573 template<class T, T Min, T Max, int Digits=-1>
1574 struct IntegerLimits {
1575 	static inline __host__ __device__ T min() { return Min; }
1576 	static inline __host__ __device__ T max() { return Max; }
1577 #if __cplusplus >= 201103L
1578 	static constexpr inline __host__ __device__ T lowest() noexcept {
1579 		return Min;
1580 	}
1581 #endif  // __cplusplus >= 201103L
1582 	enum {
1583        is_specialized = true,
1584        digits = (Digits == -1) ? (int)(sizeof(T)*8 - (Min != 0)) : Digits,
1585        digits10   = (digits * 30103) / 100000,
1586        is_signed  = ((T)(-1)<0),
1587        is_integer = true,
1588        is_exact   = true,
1589        radix      = 2,
1590        is_bounded = true,
1591        is_modulo  = false
1592 	};
1593 };
1594 } // namespace __jitify_detail
1595 namespace std { using namespace __jitify_detail; }
1596 namespace __jitify_limits_ns {
1597 template<typename T> struct numeric_limits {
1598     enum { is_specialized = false };
1599 };
1600 template<> struct numeric_limits<bool>               : public
1601 __jitify_detail::IntegerLimits<bool,              false,    true,1> {};
1602 template<> struct numeric_limits<char>               : public
1603 __jitify_detail::IntegerLimits<char,              CHAR_MIN, CHAR_MAX>
1604 {};
1605 template<> struct numeric_limits<signed char>        : public
1606 __jitify_detail::IntegerLimits<signed char,       SCHAR_MIN,SCHAR_MAX>
1607 {};
1608 template<> struct numeric_limits<unsigned char>      : public
1609 __jitify_detail::IntegerLimits<unsigned char,     0,        UCHAR_MAX>
1610 {};
1611 template<> struct numeric_limits<wchar_t>            : public
1612 __jitify_detail::IntegerLimits<wchar_t,           INT_MIN,  INT_MAX> {};
1613 template<> struct numeric_limits<short>              : public
1614 __jitify_detail::IntegerLimits<short,             SHRT_MIN, SHRT_MAX>
1615 {};
1616 template<> struct numeric_limits<unsigned short>     : public
1617 __jitify_detail::IntegerLimits<unsigned short,    0,        USHRT_MAX>
1618 {};
1619 template<> struct numeric_limits<int>                : public
1620 __jitify_detail::IntegerLimits<int,               INT_MIN,  INT_MAX> {};
1621 template<> struct numeric_limits<unsigned int>       : public
1622 __jitify_detail::IntegerLimits<unsigned int,      0,        UINT_MAX>
1623 {};
1624 template<> struct numeric_limits<long>               : public
1625 __jitify_detail::IntegerLimits<long,              LONG_MIN, LONG_MAX>
1626 {};
1627 template<> struct numeric_limits<unsigned long>      : public
1628 __jitify_detail::IntegerLimits<unsigned long,     0,        ULONG_MAX>
1629 {};
1630 template<> struct numeric_limits<long long>          : public
1631 __jitify_detail::IntegerLimits<long long,         LLONG_MIN,LLONG_MAX>
1632 {};
1633 template<> struct numeric_limits<unsigned long long> : public
1634 __jitify_detail::IntegerLimits<unsigned long long,0,        ULLONG_MAX>
1635 {};
1636 //template<typename T> struct numeric_limits { static const bool
1637 //is_signed = ((T)(-1)<0); };
1638 template<> struct numeric_limits<float>              : public
1639 __jitify_detail::FloatLimits
1640 {};
1641 template<> struct numeric_limits<double>             : public
1642 __jitify_detail::DoubleLimits
1643 {};
1644 } // namespace __jitify_limits_ns
1645 namespace std { using namespace __jitify_limits_ns; }
1646 using namespace __jitify_limits_ns;
1647 )";
1648 
1649 // TODO: This is highly incomplete
1650 static const char* jitsafe_header_type_traits = R"(
1651     #pragma once
1652     #if __cplusplus >= 201103L
1653     namespace __jitify_type_traits_ns {
1654 
1655     template<bool B, class T = void> struct enable_if {};
1656     template<class T>                struct enable_if<true, T> { typedef T type; };
1657     #if __cplusplus >= 201402L
1658     template< bool B, class T = void > using enable_if_t = typename enable_if<B,T>::type;
1659     #endif
1660 
1661     struct true_type  {
1662       enum { value = true };
1663       operator bool() const { return true; }
1664     };
1665     struct false_type {
1666       enum { value = false };
1667       operator bool() const { return false; }
1668     };
1669 
1670     template<typename T> struct is_floating_point    : false_type {};
1671     template<> struct is_floating_point<float>       :  true_type {};
1672     template<> struct is_floating_point<double>      :  true_type {};
1673     template<> struct is_floating_point<long double> :  true_type {};
1674 
1675     template<class T> struct is_integral              : false_type {};
1676     template<> struct is_integral<bool>               :  true_type {};
1677     template<> struct is_integral<char>               :  true_type {};
1678     template<> struct is_integral<signed char>        :  true_type {};
1679     template<> struct is_integral<unsigned char>      :  true_type {};
1680     template<> struct is_integral<short>              :  true_type {};
1681     template<> struct is_integral<unsigned short>     :  true_type {};
1682     template<> struct is_integral<int>                :  true_type {};
1683     template<> struct is_integral<unsigned int>       :  true_type {};
1684     template<> struct is_integral<long>               :  true_type {};
1685     template<> struct is_integral<unsigned long>      :  true_type {};
1686     template<> struct is_integral<long long>          :  true_type {};
1687     template<> struct is_integral<unsigned long long> :  true_type {};
1688 
1689     template<typename T> struct is_signed    : false_type {};
1690     template<> struct is_signed<float>       :  true_type {};
1691     template<> struct is_signed<double>      :  true_type {};
1692     template<> struct is_signed<long double> :  true_type {};
1693     template<> struct is_signed<signed char> :  true_type {};
1694     template<> struct is_signed<short>       :  true_type {};
1695     template<> struct is_signed<int>         :  true_type {};
1696     template<> struct is_signed<long>        :  true_type {};
1697     template<> struct is_signed<long long>   :  true_type {};
1698 
1699     template<typename T> struct is_unsigned             : false_type {};
1700     template<> struct is_unsigned<unsigned char>      :  true_type {};
1701     template<> struct is_unsigned<unsigned short>     :  true_type {};
1702     template<> struct is_unsigned<unsigned int>       :  true_type {};
1703     template<> struct is_unsigned<unsigned long>      :  true_type {};
1704     template<> struct is_unsigned<unsigned long long> :  true_type {};
1705 
1706     template<typename T, typename U> struct is_same      : false_type {};
1707     template<typename T>             struct is_same<T,T> :  true_type {};
1708 
1709     template<class T> struct is_array : false_type {};
1710     template<class T> struct is_array<T[]> : true_type {};
1711     template<class T, size_t N> struct is_array<T[N]> : true_type {};
1712 
1713     //partial implementation only of is_function
1714     template<class> struct is_function : false_type { };
1715     template<class Ret, class... Args> struct is_function<Ret(Args...)> : true_type {}; //regular
1716     template<class Ret, class... Args> struct is_function<Ret(Args......)> : true_type {}; // variadic
1717 
1718     template<class> struct result_of;
1719     template<class F, typename... Args>
1720     struct result_of<F(Args...)> {
1721     // TODO: This is a hack; a proper implem is quite complicated.
1722     typedef typename F::result_type type;
1723     };
1724 
1725     template <class T> struct remove_reference { typedef T type; };
1726     template <class T> struct remove_reference<T&> { typedef T type; };
1727     template <class T> struct remove_reference<T&&> { typedef T type; };
1728     #if __cplusplus >= 201402L
1729     template< class T > using remove_reference_t = typename remove_reference<T>::type;
1730     #endif
1731 
1732     template<class T> struct remove_extent { typedef T type; };
1733     template<class T> struct remove_extent<T[]> { typedef T type; };
1734     template<class T, size_t N> struct remove_extent<T[N]> { typedef T type; };
1735     #if __cplusplus >= 201402L
1736     template< class T > using remove_extent_t = typename remove_extent<T>::type;
1737     #endif
1738 
1739     template< class T > struct remove_const          { typedef T type; };
1740     template< class T > struct remove_const<const T> { typedef T type; };
1741     template< class T > struct remove_volatile             { typedef T type; };
1742     template< class T > struct remove_volatile<volatile T> { typedef T type; };
1743     template< class T > struct remove_cv { typedef typename remove_volatile<typename remove_const<T>::type>::type type; };
1744     #if __cplusplus >= 201402L
1745     template< class T > using remove_cv_t       = typename remove_cv<T>::type;
1746     template< class T > using remove_const_t    = typename remove_const<T>::type;
1747     template< class T > using remove_volatile_t = typename remove_volatile<T>::type;
1748     #endif
1749 
1750     template<bool B, class T, class F> struct conditional { typedef T type; };
1751     template<class T, class F> struct conditional<false, T, F> { typedef F type; };
1752     #if __cplusplus >= 201402L
1753     template< bool B, class T, class F > using conditional_t = typename conditional<B,T,F>::type;
1754     #endif
1755 
1756     namespace __jitify_detail {
1757     template< class T, bool is_function_type = false > struct add_pointer { using type = typename remove_reference<T>::type*; };
1758     template< class T > struct add_pointer<T, true> { using type = T; };
1759     template< class T, class... Args > struct add_pointer<T(Args...), true> { using type = T(*)(Args...); };
1760     template< class T, class... Args > struct add_pointer<T(Args..., ...), true> { using type = T(*)(Args..., ...); };
1761     }
1762     template< class T > struct add_pointer : __jitify_detail::add_pointer<T, is_function<T>::value> {};
1763     #if __cplusplus >= 201402L
1764     template< class T > using add_pointer_t = typename add_pointer<T>::type;
1765     #endif
1766 
1767     template< class T > struct decay {
1768     private:
1769       typedef typename remove_reference<T>::type U;
1770     public:
1771       typedef typename conditional<is_array<U>::value, typename remove_extent<U>::type*,
1772         typename conditional<is_function<U>::value,typename add_pointer<U>::type,typename remove_cv<U>::type
1773         >::type>::type type;
1774     };
1775     #if __cplusplus >= 201402L
1776     template< class T > using decay_t = typename decay<T>::type;
1777     #endif
1778 
1779     } // namespace __jtiify_type_traits_ns
1780     namespace std { using namespace __jitify_type_traits_ns; }
1781     using namespace __jitify_type_traits_ns;
1782     #endif // c++11
1783 )";
1784 
1785 // TODO: INT_FAST8_MAX et al. and a few other misc constants
1786 static const char* jitsafe_header_stdint_h =
1787     "#pragma once\n"
1788     "#include <climits>\n"
1789     "namespace __jitify_stdint_ns {\n"
1790     "typedef signed char      int8_t;\n"
1791     "typedef signed short     int16_t;\n"
1792     "typedef signed int       int32_t;\n"
1793     "typedef signed long long int64_t;\n"
1794     "typedef signed char      int_fast8_t;\n"
1795     "typedef signed short     int_fast16_t;\n"
1796     "typedef signed int       int_fast32_t;\n"
1797     "typedef signed long long int_fast64_t;\n"
1798     "typedef signed char      int_least8_t;\n"
1799     "typedef signed short     int_least16_t;\n"
1800     "typedef signed int       int_least32_t;\n"
1801     "typedef signed long long int_least64_t;\n"
1802     "typedef signed long long intmax_t;\n"
1803     "typedef signed long      intptr_t; //optional\n"
1804     "typedef unsigned char      uint8_t;\n"
1805     "typedef unsigned short     uint16_t;\n"
1806     "typedef unsigned int       uint32_t;\n"
1807     "typedef unsigned long long uint64_t;\n"
1808     "typedef unsigned char      uint_fast8_t;\n"
1809     "typedef unsigned short     uint_fast16_t;\n"
1810     "typedef unsigned int       uint_fast32_t;\n"
1811     "typedef unsigned long long uint_fast64_t;\n"
1812     "typedef unsigned char      uint_least8_t;\n"
1813     "typedef unsigned short     uint_least16_t;\n"
1814     "typedef unsigned int       uint_least32_t;\n"
1815     "typedef unsigned long long uint_least64_t;\n"
1816     "typedef unsigned long long uintmax_t;\n"
1817     "typedef unsigned long      uintptr_t; //optional\n"
1818     "#define INT8_MIN    SCHAR_MIN\n"
1819     "#define INT16_MIN   SHRT_MIN\n"
1820     "#define INT32_MIN   INT_MIN\n"
1821     "#define INT64_MIN   LLONG_MIN\n"
1822     "#define INT8_MAX    SCHAR_MAX\n"
1823     "#define INT16_MAX   SHRT_MAX\n"
1824     "#define INT32_MAX   INT_MAX\n"
1825     "#define INT64_MAX   LLONG_MAX\n"
1826     "#define UINT8_MAX   UCHAR_MAX\n"
1827     "#define UINT16_MAX  USHRT_MAX\n"
1828     "#define UINT32_MAX  UINT_MAX\n"
1829     "#define UINT64_MAX  ULLONG_MAX\n"
1830     "#define INTPTR_MIN  LONG_MIN\n"
1831     "#define INTMAX_MIN  LLONG_MIN\n"
1832     "#define INTPTR_MAX  LONG_MAX\n"
1833     "#define INTMAX_MAX  LLONG_MAX\n"
1834     "#define UINTPTR_MAX ULONG_MAX\n"
1835     "#define UINTMAX_MAX ULLONG_MAX\n"
1836     "#define PTRDIFF_MIN INTPTR_MIN\n"
1837     "#define PTRDIFF_MAX INTPTR_MAX\n"
1838     "#define SIZE_MAX    UINT64_MAX\n"
1839     "} // namespace __jitify_stdint_ns\n"
1840     "namespace std { using namespace __jitify_stdint_ns; }\n"
1841     "using namespace __jitify_stdint_ns;\n";
1842 
1843 // TODO: offsetof
1844 static const char* jitsafe_header_stddef_h =
1845     "#pragma once\n"
1846     "#include <climits>\n"
1847     "namespace __jitify_stddef_ns {\n"
1848     "#if __cplusplus >= 201103L\n"
1849     "typedef decltype(nullptr) nullptr_t;\n"
1850     "#if defined(_MSC_VER)\n"
1851     "  typedef double max_align_t;\n"
1852     "#elif defined(__APPLE__)\n"
1853     "  typedef long double max_align_t;\n"
1854     "#else\n"
1855     "  // Define max_align_t to match the GCC definition.\n"
1856     "  typedef struct {\n"
1857     "    long long __jitify_max_align_nonce1\n"
1858     "        __attribute__((__aligned__(__alignof__(long long))));\n"
1859     "    long double __jitify_max_align_nonce2\n"
1860     "        __attribute__((__aligned__(__alignof__(long double))));\n"
1861     "  } max_align_t;\n"
1862     "#endif\n"
1863     "#endif  // __cplusplus >= 201103L\n"
1864     "#if __cplusplus >= 201703L\n"
1865     "enum class byte : unsigned char {};\n"
1866     "#endif  // __cplusplus >= 201703L\n"
1867     "} // namespace __jitify_stddef_ns\n"
1868     "namespace std {\n"
1869     "  // NVRTC provides built-in definitions of ::size_t and ::ptrdiff_t.\n"
1870     "  using ::size_t;\n"
1871     "  using ::ptrdiff_t;\n"
1872     "  using namespace __jitify_stddef_ns;\n"
1873     "} // namespace std\n"
1874     "using namespace __jitify_stddef_ns;\n";
1875 
1876 static const char* jitsafe_header_stdlib_h =
1877     "#pragma once\n"
1878     "#include <stddef.h>\n";
1879 static const char* jitsafe_header_stdio_h =
1880     "#pragma once\n"
1881     "#include <stddef.h>\n"
1882     "#define FILE int\n"
1883     "int fflush ( FILE * stream );\n"
1884     "int fprintf ( FILE * stream, const char * format, ... );\n";
1885 
1886 static const char* jitsafe_header_string_h =
1887     "#pragma once\n"
1888     "char* strcpy ( char * destination, const char * source );\n"
1889     "int strcmp ( const char * str1, const char * str2 );\n"
1890     "char* strerror( int errnum );\n";
1891 
1892 static const char* jitsafe_header_cstring =
1893     "#pragma once\n"
1894     "\n"
1895     "namespace __jitify_cstring_ns {\n"
1896     "char* strcpy ( char * destination, const char * source );\n"
1897     "int strcmp ( const char * str1, const char * str2 );\n"
1898     "char* strerror( int errnum );\n"
1899     "} // namespace __jitify_cstring_ns\n"
1900     "namespace std { using namespace __jitify_cstring_ns; }\n"
1901     "using namespace __jitify_cstring_ns;\n";
1902 
1903 // HACK TESTING (WAR for cub)
1904 static const char* jitsafe_header_iostream =
1905     "#pragma once\n"
1906     "#include <ostream>\n"
1907     "#include <istream>\n";
1908 // HACK TESTING (WAR for Thrust)
1909 static const char* jitsafe_header_ostream =
1910     "#pragma once\n"
1911     "\n"
1912     "namespace __jitify_ostream_ns {\n"
1913     "template<class CharT,class Traits=void>\n"  // = std::char_traits<CharT>
1914                                                  // >\n"
1915     "struct basic_ostream {\n"
1916     "};\n"
1917     "typedef basic_ostream<char> ostream;\n"
1918     "ostream& endl(ostream& os);\n"
1919     "ostream& operator<<( ostream&, ostream& (*f)( ostream& ) );\n"
1920     "template< class CharT, class Traits > basic_ostream<CharT, Traits>& endl( "
1921     "basic_ostream<CharT, Traits>& os );\n"
1922     "template< class CharT, class Traits > basic_ostream<CharT, Traits>& "
1923     "operator<<( basic_ostream<CharT,Traits>& os, const char* c );\n"
1924     "#if __cplusplus >= 201103L\n"
1925     "template< class CharT, class Traits, class T > basic_ostream<CharT, "
1926     "Traits>& operator<<( basic_ostream<CharT,Traits>&& os, const T& value );\n"
1927     "#endif  // __cplusplus >= 201103L\n"
1928     "} // namespace __jitify_ostream_ns\n"
1929     "namespace std { using namespace __jitify_ostream_ns; }\n"
1930     "using namespace __jitify_ostream_ns;\n";
1931 
1932 static const char* jitsafe_header_istream =
1933     "#pragma once\n"
1934     "\n"
1935     "namespace __jitify_istream_ns {\n"
1936     "template<class CharT,class Traits=void>\n"  // = std::char_traits<CharT>
1937                                                  // >\n"
1938     "struct basic_istream {\n"
1939     "};\n"
1940     "typedef basic_istream<char> istream;\n"
1941     "} // namespace __jitify_istream_ns\n"
1942     "namespace std { using namespace __jitify_istream_ns; }\n"
1943     "using namespace __jitify_istream_ns;\n";
1944 
1945 static const char* jitsafe_header_sstream =
1946     "#pragma once\n"
1947     "#include <ostream>\n"
1948     "#include <istream>\n";
1949 
1950 static const char* jitsafe_header_utility =
1951     "#pragma once\n"
1952     "namespace __jitify_utility_ns {\n"
1953     "template<class T1, class T2>\n"
1954     "struct pair {\n"
1955     "	T1 first;\n"
1956     "	T2 second;\n"
1957     "	inline pair() {}\n"
1958     "	inline pair(T1 const& first_, T2 const& second_)\n"
1959     "		: first(first_), second(second_) {}\n"
1960     "	// TODO: Standard includes many more constructors...\n"
1961     "	// TODO: Comparison operators\n"
1962     "};\n"
1963     "template<class T1, class T2>\n"
1964     "pair<T1,T2> make_pair(T1 const& first, T2 const& second) {\n"
1965     "	return pair<T1,T2>(first, second);\n"
1966     "}\n"
1967     "} // namespace __jitify_utility_ns\n"
1968     "namespace std { using namespace __jitify_utility_ns; }\n"
1969     "using namespace __jitify_utility_ns;\n";
1970 
1971 // TODO: incomplete
1972 static const char* jitsafe_header_vector =
1973     "#pragma once\n"
1974     "namespace __jitify_vector_ns {\n"
1975     "template<class T, class Allocator=void>\n"  // = std::allocator> \n"
1976     "struct vector {\n"
1977     "};\n"
1978     "} // namespace __jitify_vector_ns\n"
1979     "namespace std { using namespace __jitify_vector_ns; }\n"
1980     "using namespace __jitify_vector_ns;\n";
1981 
1982 // TODO: incomplete
1983 static const char* jitsafe_header_string =
1984     "#pragma once\n"
1985     "namespace __jitify_string_ns {\n"
1986     "template<class CharT,class Traits=void,class Allocator=void>\n"
1987     "struct basic_string {\n"
1988     "basic_string();\n"
1989     "basic_string( const CharT* s );\n"  //, const Allocator& alloc =
1990                                          // Allocator() );\n"
1991     "const CharT* c_str() const;\n"
1992     "bool empty() const;\n"
1993     "void operator+=(const char *);\n"
1994     "void operator+=(const basic_string &);\n"
1995     "};\n"
1996     "typedef basic_string<char> string;\n"
1997     "} // namespace __jitify_string_ns\n"
1998     "namespace std { using namespace __jitify_string_ns; }\n"
1999     "using namespace __jitify_string_ns;\n";
2000 
2001 // TODO: incomplete
2002 static const char* jitsafe_header_stdexcept =
2003     "#pragma once\n"
2004     "namespace __jitify_stdexcept_ns {\n"
2005     "struct runtime_error {\n"
2006     "explicit runtime_error( const std::string& what_arg );"
2007     "explicit runtime_error( const char* what_arg );"
2008     "virtual const char* what() const;\n"
2009     "};\n"
2010     "} // namespace __jitify_stdexcept_ns\n"
2011     "namespace std { using namespace __jitify_stdexcept_ns; }\n"
2012     "using namespace __jitify_stdexcept_ns;\n";
2013 
2014 // TODO: incomplete
2015 static const char* jitsafe_header_complex =
2016     "#pragma once\n"
2017     "namespace __jitify_complex_ns {\n"
2018     "template<typename T>\n"
2019     "class complex {\n"
2020     "	T _real;\n"
2021     "	T _imag;\n"
2022     "public:\n"
2023     "	complex() : _real(0), _imag(0) {}\n"
2024     "	complex(T const& real, T const& imag)\n"
2025     "		: _real(real), _imag(imag) {}\n"
2026     "	complex(T const& real)\n"
2027     "               : _real(real), _imag(static_cast<T>(0)) {}\n"
2028     "	T const& real() const { return _real; }\n"
2029     "	T&       real()       { return _real; }\n"
2030     "	void real(const T &r) { _real = r; }\n"
2031     "	T const& imag() const { return _imag; }\n"
2032     "	T&       imag()       { return _imag; }\n"
2033     "	void imag(const T &i) { _imag = i; }\n"
2034     "       complex<T>& operator+=(const complex<T> z)\n"
2035     "         { _real += z.real(); _imag += z.imag(); return *this; }\n"
2036     "};\n"
2037     "template<typename T>\n"
2038     "complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs)\n"
2039     "  { return complex<T>(lhs.real()*rhs.real()-lhs.imag()*rhs.imag(),\n"
2040     "                      lhs.real()*rhs.imag()+lhs.imag()*rhs.real()); }\n"
2041     "template<typename T>\n"
2042     "complex<T> operator*(const complex<T>& lhs, const T & rhs)\n"
2043     "  { return complexs<T>(lhs.real()*rhs,lhs.imag()*rhs); }\n"
2044     "template<typename T>\n"
2045     "complex<T> operator*(const T& lhs, const complex<T>& rhs)\n"
2046     "  { return complexs<T>(rhs.real()*lhs,rhs.imag()*lhs); }\n"
2047     "} // namespace __jitify_complex_ns\n"
2048     "namespace std { using namespace __jitify_complex_ns; }\n"
2049     "using namespace __jitify_complex_ns;\n";
2050 
2051 // TODO: This is incomplete (missing binary and integer funcs, macros,
2052 // constants, types)
2053 static const char* jitsafe_header_math =
2054     "#pragma once\n"
2055     "namespace __jitify_math_ns {\n"
2056     "#if __cplusplus >= 201103L\n"
2057     "#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\\n"
2058     "	inline double      f(double x)         { return ::f(x); } \\\n"
2059     "	inline float       f##f(float x)       { return ::f(x); } \\\n"
2060     "	/*inline long double f##l(long double x) { return ::f(x); }*/ \\\n"
2061     "	inline float       f(float x)          { return ::f(x); } \\\n"
2062     "	/*inline long double f(long double x)    { return ::f(x); }*/\n"
2063     "#else\n"
2064     "#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\\n"
2065     "	inline double      f(double x)         { return ::f(x); } \\\n"
2066     "	inline float       f##f(float x)       { return ::f(x); } \\\n"
2067     "	/*inline long double f##l(long double x) { return ::f(x); }*/\n"
2068     "#endif\n"
2069     "DEFINE_MATH_UNARY_FUNC_WRAPPER(cos)\n"
2070     "DEFINE_MATH_UNARY_FUNC_WRAPPER(sin)\n"
2071     "DEFINE_MATH_UNARY_FUNC_WRAPPER(tan)\n"
2072     "DEFINE_MATH_UNARY_FUNC_WRAPPER(acos)\n"
2073     "DEFINE_MATH_UNARY_FUNC_WRAPPER(asin)\n"
2074     "DEFINE_MATH_UNARY_FUNC_WRAPPER(atan)\n"
2075     "template<typename T> inline T atan2(T y, T x) { return ::atan2(y, x); }\n"
2076     "DEFINE_MATH_UNARY_FUNC_WRAPPER(cosh)\n"
2077     "DEFINE_MATH_UNARY_FUNC_WRAPPER(sinh)\n"
2078     "DEFINE_MATH_UNARY_FUNC_WRAPPER(tanh)\n"
2079     "DEFINE_MATH_UNARY_FUNC_WRAPPER(exp)\n"
2080     "template<typename T> inline T frexp(T x, int* exp) { return ::frexp(x, "
2081     "exp); }\n"
2082     "template<typename T> inline T ldexp(T x, int  exp) { return ::ldexp(x, "
2083     "exp); }\n"
2084     "DEFINE_MATH_UNARY_FUNC_WRAPPER(log)\n"
2085     "DEFINE_MATH_UNARY_FUNC_WRAPPER(log10)\n"
2086     "template<typename T> inline T modf(T x, T* intpart) { return ::modf(x, "
2087     "intpart); }\n"
2088     "template<typename T> inline T pow(T x, T y) { return ::pow(x, y); }\n"
2089     "DEFINE_MATH_UNARY_FUNC_WRAPPER(sqrt)\n"
2090     "DEFINE_MATH_UNARY_FUNC_WRAPPER(ceil)\n"
2091     "DEFINE_MATH_UNARY_FUNC_WRAPPER(floor)\n"
2092     "template<typename T> inline T fmod(T n, T d) { return ::fmod(n, d); }\n"
2093     "DEFINE_MATH_UNARY_FUNC_WRAPPER(fabs)\n"
2094     "template<typename T> inline T abs(T x) { return ::abs(x); }\n"
2095     "#if __cplusplus >= 201103L\n"
2096     "DEFINE_MATH_UNARY_FUNC_WRAPPER(acosh)\n"
2097     "DEFINE_MATH_UNARY_FUNC_WRAPPER(asinh)\n"
2098     "DEFINE_MATH_UNARY_FUNC_WRAPPER(atanh)\n"
2099     "DEFINE_MATH_UNARY_FUNC_WRAPPER(exp2)\n"
2100     "DEFINE_MATH_UNARY_FUNC_WRAPPER(expm1)\n"
2101     "template<typename T> inline int ilogb(T x) { return ::ilogb(x); }\n"
2102     "DEFINE_MATH_UNARY_FUNC_WRAPPER(log1p)\n"
2103     "DEFINE_MATH_UNARY_FUNC_WRAPPER(log2)\n"
2104     "DEFINE_MATH_UNARY_FUNC_WRAPPER(logb)\n"
2105     "template<typename T> inline T scalbn (T x, int n)  { return ::scalbn(x, "
2106     "n); }\n"
2107     "template<typename T> inline T scalbln(T x, long n) { return ::scalbn(x, "
2108     "n); }\n"
2109     "DEFINE_MATH_UNARY_FUNC_WRAPPER(cbrt)\n"
2110     "template<typename T> inline T hypot(T x, T y) { return ::hypot(x, y); }\n"
2111     "DEFINE_MATH_UNARY_FUNC_WRAPPER(erf)\n"
2112     "DEFINE_MATH_UNARY_FUNC_WRAPPER(erfc)\n"
2113     "DEFINE_MATH_UNARY_FUNC_WRAPPER(tgamma)\n"
2114     "DEFINE_MATH_UNARY_FUNC_WRAPPER(lgamma)\n"
2115     "DEFINE_MATH_UNARY_FUNC_WRAPPER(trunc)\n"
2116     "DEFINE_MATH_UNARY_FUNC_WRAPPER(round)\n"
2117     "template<typename T> inline long lround(T x) { return ::lround(x); }\n"
2118     "template<typename T> inline long long llround(T x) { return ::llround(x); "
2119     "}\n"
2120     "DEFINE_MATH_UNARY_FUNC_WRAPPER(rint)\n"
2121     "template<typename T> inline long lrint(T x) { return ::lrint(x); }\n"
2122     "template<typename T> inline long long llrint(T x) { return ::llrint(x); "
2123     "}\n"
2124     "DEFINE_MATH_UNARY_FUNC_WRAPPER(nearbyint)\n"
2125     // TODO: remainder, remquo, copysign, nan, nextafter, nexttoward, fdim,
2126     // fmax, fmin, fma
2127     "#endif\n"
2128     "#undef DEFINE_MATH_UNARY_FUNC_WRAPPER\n"
2129     "} // namespace __jitify_math_ns\n"
2130     "namespace std { using namespace __jitify_math_ns; }\n"
2131     "#define M_PI 3.14159265358979323846\n"
2132     // Note: Global namespace already includes CUDA math funcs
2133     "//using namespace __jitify_math_ns;\n";
2134 
2135 static const char* jitsafe_header_memory_h = R"(
2136     #pragma once
2137     #include <string.h>
2138  )";
2139 
2140 // TODO: incomplete
2141 static const char* jitsafe_header_mutex = R"(
2142     #pragma once
2143     #if __cplusplus >= 201103L
2144     namespace __jitify_mutex_ns {
2145     class mutex {
2146     public:
2147     void lock();
2148     bool try_lock();
2149     void unlock();
2150     };
2151     } // namespace __jitify_mutex_ns
2152     namespace std { using namespace __jitify_mutex_ns; }
2153     using namespace __jitify_mutex_ns;
2154     #endif
2155  )";
2156 
2157 static const char* jitsafe_header_algorithm = R"(
2158     #pragma once
2159     #if __cplusplus >= 201103L
2160     namespace __jitify_algorithm_ns {
2161 
2162     #if __cplusplus == 201103L
2163     #define JITIFY_CXX14_CONSTEXPR
2164     #else
2165     #define JITIFY_CXX14_CONSTEXPR constexpr
2166     #endif
2167 
2168     template<class T> JITIFY_CXX14_CONSTEXPR const T& max(const T& a, const T& b)
2169     {
2170       return (b > a) ? b : a;
2171     }
2172     template<class T> JITIFY_CXX14_CONSTEXPR const T& min(const T& a, const T& b)
2173     {
2174       return (b < a) ? b : a;
2175     }
2176 
2177     } // namespace __jitify_algorithm_ns
2178     namespace std { using namespace __jitify_algorithm_ns; }
2179     using namespace __jitify_algorithm_ns;
2180     #endif
2181  )";
2182 
2183 static const char* jitsafe_header_time_h = R"(
2184     #pragma once
2185     #define NULL 0
2186     #define CLOCKS_PER_SEC 1000000
2187     namespace __jitify_time_ns {
2188     typedef long time_t;
2189     struct tm {
2190       int tm_sec;
2191       int tm_min;
2192       int tm_hour;
2193       int tm_mday;
2194       int tm_mon;
2195       int tm_year;
2196       int tm_wday;
2197       int tm_yday;
2198       int tm_isdst;
2199     };
2200     #if __cplusplus >= 201703L
2201     struct timespec {
2202       time_t tv_sec;
2203       long tv_nsec;
2204     };
2205     #endif
2206     }  // namespace __jitify_time_ns
2207     namespace std {
2208       // NVRTC provides built-in definitions of ::size_t and ::clock_t.
2209       using ::size_t;
2210       using ::clock_t;
2211       using namespace __jitify_time_ns;
2212     }
2213     using namespace __jitify_time_ns;
2214  )";
2215 
2216 // WAR: These need to be pre-included as a workaround for NVRTC implicitly using
2217 // /usr/include as an include path. The other built-in headers will be included
2218 // lazily as needed.
2219 static const char* preinclude_jitsafe_header_names[] = {
2220     "jitify_preinclude.h",
2221     "limits.h",
2222     "math.h",
2223     "memory.h",
2224     "stdint.h",
2225     "stdlib.h",
2226     "stdio.h",
2227     "string.h",
2228     "time.h",
2229 };
2230 
2231 template <class T, int N>
array_size(T (&)[N])2232 int array_size(T (&)[N]) {
2233   return N;
2234 }
2235 const int preinclude_jitsafe_headers_count =
2236     array_size(preinclude_jitsafe_header_names);
2237 
get_jitsafe_headers_map()2238 static const std::map<std::string, std::string>& get_jitsafe_headers_map() {
2239   static const std::map<std::string, std::string> jitsafe_headers_map = {
2240       {"jitify_preinclude.h", jitsafe_header_preinclude_h},
2241       {"float.h", jitsafe_header_float_h},
2242       {"cfloat", jitsafe_header_float_h},
2243       {"limits.h", jitsafe_header_limits_h},
2244       {"climits", jitsafe_header_limits_h},
2245       {"stdint.h", jitsafe_header_stdint_h},
2246       {"cstdint", jitsafe_header_stdint_h},
2247       {"stddef.h", jitsafe_header_stddef_h},
2248       {"cstddef", jitsafe_header_stddef_h},
2249       {"stdlib.h", jitsafe_header_stdlib_h},
2250       {"cstdlib", jitsafe_header_stdlib_h},
2251       {"stdio.h", jitsafe_header_stdio_h},
2252       {"cstdio", jitsafe_header_stdio_h},
2253       {"string.h", jitsafe_header_string_h},
2254       {"cstring", jitsafe_header_cstring},
2255       {"iterator", jitsafe_header_iterator},
2256       {"limits", jitsafe_header_limits},
2257       {"type_traits", jitsafe_header_type_traits},
2258       {"utility", jitsafe_header_utility},
2259       {"math.h", jitsafe_header_math},
2260       {"cmath", jitsafe_header_math},
2261       {"memory.h", jitsafe_header_memory_h},
2262       {"complex", jitsafe_header_complex},
2263       {"iostream", jitsafe_header_iostream},
2264       {"ostream", jitsafe_header_ostream},
2265       {"istream", jitsafe_header_istream},
2266       {"sstream", jitsafe_header_sstream},
2267       {"vector", jitsafe_header_vector},
2268       {"string", jitsafe_header_string},
2269       {"stdexcept", jitsafe_header_stdexcept},
2270       {"mutex", jitsafe_header_mutex},
2271       {"algorithm", jitsafe_header_algorithm},
2272       {"time.h", jitsafe_header_time_h},
2273       {"ctime", jitsafe_header_time_h},
2274   };
2275   return jitsafe_headers_map;
2276 }
2277 
add_options_from_env(std::vector<std::string> & options)2278 inline void add_options_from_env(std::vector<std::string>& options) {
2279   // Add options from environment variable
2280   const char* env_options = std::getenv("JITIFY_OPTIONS");
2281   if (env_options) {
2282     std::stringstream ss;
2283     ss << env_options;
2284     std::string opt;
2285     while (!(ss >> opt).fail()) {
2286       options.push_back(opt);
2287     }
2288   }
2289   // Add options from JITIFY_OPTIONS macro
2290 #ifdef JITIFY_OPTIONS
2291 #define JITIFY_TOSTRING_IMPL(x) #x
2292 #define JITIFY_TOSTRING(x) JITIFY_TOSTRING_IMPL(x)
2293   std::stringstream ss;
2294   ss << JITIFY_TOSTRING(JITIFY_OPTIONS);
2295   std::string opt;
2296   while (!(ss >> opt).fail()) {
2297     options.push_back(opt);
2298   }
2299 #undef JITIFY_TOSTRING
2300 #undef JITIFY_TOSTRING_IMPL
2301 #endif  // JITIFY_OPTIONS
2302 }
2303 
detect_and_add_cuda_arch(std::vector<std::string> & options)2304 inline void detect_and_add_cuda_arch(std::vector<std::string>& options) {
2305   for (int i = 0; i < (int)options.size(); ++i) {
2306     // Note that this will also match the middle of "--gpu-architecture".
2307     if (options[i].find("-arch") != std::string::npos) {
2308       // Arch already specified in options
2309       return;
2310     }
2311   }
2312   // Use the compute capability of the current device
2313   // TODO: Check these API calls for errors
2314   cudaError_t status;
2315   int device;
2316   status = cudaGetDevice(&device);
2317   if (status != cudaSuccess) {
2318     throw std::runtime_error(
2319         std::string(
2320             "Failed to detect GPU architecture: cudaGetDevice failed: ") +
2321         cudaGetErrorString(status));
2322   }
2323   int cc_major;
2324   cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device);
2325   int cc_minor;
2326   cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, device);
2327   int cc = cc_major * 10 + cc_minor;
2328   // Note: We must limit the architecture to the max supported by the current
2329   //         version of NVRTC, otherwise newer hardware will cause errors
2330   //         on older versions of CUDA.
2331   // TODO: It would be better to detect this somehow, rather than hard-coding it
2332 
2333   // Tegra chips do not have forwards compatibility so we need to special case
2334   // them.
2335   bool is_tegra = ((cc_major == 3 && cc_minor == 2) ||  // Logan
2336                    (cc_major == 5 && cc_minor == 3) ||  // Erista
2337                    (cc_major == 6 && cc_minor == 2) ||  // Parker
2338                    (cc_major == 7 && cc_minor == 2));   // Xavier
2339   if (!is_tegra) {
2340     // ensure that future CUDA versions just work (even if suboptimal)
2341     const int cuda_major = std::min(10, CUDA_VERSION / 1000);
2342     // clang-format off
2343     switch (cuda_major) {
2344       case 10: cc = std::min(cc, 75); break; // Turing
2345       case  9: cc = std::min(cc, 70); break; // Volta
2346       case  8: cc = std::min(cc, 61); break; // Pascal
2347       case  7: cc = std::min(cc, 52); break; // Maxwell
2348       default:
2349         throw std::runtime_error("Unexpected CUDA major version " +
2350                                  std::to_string(cuda_major));
2351     }
2352     // clang-format on
2353   }
2354 
2355   std::stringstream ss;
2356   ss << cc;
2357   options.push_back("-arch=compute_" + ss.str());
2358 }
2359 
detect_and_add_cxx11_flag(std::vector<std::string> & options)2360 inline void detect_and_add_cxx11_flag(std::vector<std::string>& options) {
2361   // Reverse loop so we can erase on the fly.
2362   for (int i = (int)options.size() - 1; i >= 0; --i) {
2363     if (options[i].find("-std=c++98") != std::string::npos) {
2364       // NVRTC doesn't support specifying c++98 explicitly, so we remove it.
2365       options.erase(options.begin() + i);
2366       return;
2367     } else if (options[i].find("-std") != std::string::npos) {
2368       // Some other standard was explicitly specified, don't change anything.
2369       return;
2370     }
2371   }
2372   // Jitify must be compiled with C++11 support, so we default to enabling it
2373   // for the JIT-compiled code too.
2374   options.push_back("-std=c++11");
2375 }
2376 
split_compiler_and_linker_options(std::vector<std::string> options,std::vector<std::string> * compiler_options,std::vector<std::string> * linker_files,std::vector<std::string> * linker_paths)2377 inline void split_compiler_and_linker_options(
2378     std::vector<std::string> options,
2379     std::vector<std::string>* compiler_options,
2380     std::vector<std::string>* linker_files,
2381     std::vector<std::string>* linker_paths) {
2382   for (int i = 0; i < (int)options.size(); ++i) {
2383     std::string opt = options[i];
2384     std::string flag = opt.substr(0, 2);
2385     std::string value = opt.substr(2);
2386     if (flag == "-l") {
2387       linker_files->push_back(value);
2388     } else if (flag == "-L") {
2389       linker_paths->push_back(value);
2390     } else {
2391       compiler_options->push_back(opt);
2392     }
2393   }
2394 }
2395 
pop_remove_unused_globals_flag(std::vector<std::string> * options)2396 inline bool pop_remove_unused_globals_flag(std::vector<std::string>* options) {
2397   auto it = std::remove_if(
2398       options->begin(), options->end(), [](const std::string& opt) {
2399         return opt.find("-remove-unused-globals") != std::string::npos;
2400       });
2401   if (it != options->end()) {
2402     options->resize(it - options->begin());
2403     return true;
2404   }
2405   return false;
2406 }
2407 
ptx_parse_decl_name(const std::string & line)2408 inline std::string ptx_parse_decl_name(const std::string& line) {
2409   size_t name_end = line.find_first_of("[;");
2410   if (name_end == std::string::npos) {
2411     throw std::runtime_error(
2412         "Failed to parse .global/.const declaration in PTX: expected a "
2413         "semicolon");
2414   }
2415   size_t name_start_minus1 = line.find_last_of(" \t", name_end);
2416   if (name_start_minus1 == std::string::npos) {
2417     throw std::runtime_error(
2418         "Failed to parse .global/.const declaration in PTX: expected "
2419         "whitespace");
2420   }
2421   size_t name_start = name_start_minus1 + 1;
2422   std::string name = line.substr(name_start, name_end - name_start);
2423   return name;
2424 }
2425 
ptx_remove_unused_globals(std::string * ptx)2426 inline void ptx_remove_unused_globals(std::string* ptx) {
2427   std::istringstream iss(*ptx);
2428   std::vector<std::string> lines;
2429   std::unordered_map<size_t, std::string> line_num_to_global_name;
2430   std::unordered_set<std::string> name_set;
2431   for (std::string line; std::getline(iss, line);) {
2432     size_t line_num = lines.size();
2433     lines.push_back(line);
2434     auto terms = split_string(line);
2435     if (terms.size() <= 1) continue;  // Ignore lines with no arguments
2436     if (terms[0].substr(0, 2) == "//") continue;  // Ignore comment lines
2437     if (terms[0].substr(0, 7) == ".global" ||
2438         terms[0].substr(0, 6) == ".const") {
2439       line_num_to_global_name.emplace(line_num, ptx_parse_decl_name(line));
2440       continue;
2441     }
2442     if (terms[0][0] == '.') continue;  // Ignore .version, .reg, .param etc.
2443     // Note: The first term will always be an instruction name; starting at 1
2444     // also allows unchecked inspection of the previous term.
2445     for (int i = 1; i < (int)terms.size(); ++i) {
2446       if (terms[i].substr(0, 2) == "//") break;  // Ignore comments
2447       // Note: The characters '.' and '%' are not treated as delimiters.
2448       const char* token_delims = " \t()[]{},;+-*/~&|^?:=!<>\"'\\";
2449       for (auto token : split_string(terms[i], -1, token_delims)) {
2450         if (  // Ignore non-names
2451             !(std::isalpha(token[0]) || token[0] == '_' || token[0] == '$') ||
2452             token.find('.') != std::string::npos ||
2453             // Ignore variable/parameter declarations
2454             terms[i - 1][0] == '.' ||
2455             // Ignore branch instructions
2456             (token == "bra" && terms[i - 1][0] == '@') ||
2457             // Ignore branch labels
2458             (token.substr(0, 2) == "BB" &&
2459              terms[i - 1].substr(0, 3) == "bra")) {
2460           continue;
2461         }
2462         name_set.insert(token);
2463       }
2464     }
2465   }
2466   std::ostringstream oss;
2467   for (size_t line_num = 0; line_num < lines.size(); ++line_num) {
2468     auto it = line_num_to_global_name.find(line_num);
2469     if (it != line_num_to_global_name.end()) {
2470       const std::string& name = it->second;
2471       if (!name_set.count(name)) {
2472         continue;  // Remove unused .global declaration.
2473       }
2474     }
2475     oss << lines[line_num] << '\n';
2476   }
2477   *ptx = oss.str();
2478 }
2479 
compile_kernel(std::string program_name,std::map<std::string,std::string> sources,std::vector<std::string> options,std::string instantiation="",std::string * log=0,std::string * ptx=0,std::string * mangled_instantiation=0)2480 inline nvrtcResult compile_kernel(std::string program_name,
2481                                   std::map<std::string, std::string> sources,
2482                                   std::vector<std::string> options,
2483                                   std::string instantiation = "",
2484                                   std::string* log = 0, std::string* ptx = 0,
2485                                   std::string* mangled_instantiation = 0) {
2486   std::string program_source = sources[program_name];
2487   // Build arrays of header names and sources
2488   std::vector<const char*> header_names_c;
2489   std::vector<const char*> header_sources_c;
2490   int num_headers = (int)(sources.size() - 1);
2491   header_names_c.reserve(num_headers);
2492   header_sources_c.reserve(num_headers);
2493   typedef std::map<std::string, std::string> source_map;
2494   for (source_map::const_iterator iter = sources.begin(); iter != sources.end();
2495        ++iter) {
2496     std::string const& name = iter->first;
2497     std::string const& code = iter->second;
2498     if (name == program_name) {
2499       continue;
2500     }
2501     header_names_c.push_back(name.c_str());
2502     header_sources_c.push_back(code.c_str());
2503   }
2504 
2505   // TODO: This WAR is expected to be unnecessary as of CUDA > 10.2.
2506   bool should_remove_unused_globals =
2507       detail::pop_remove_unused_globals_flag(&options);
2508 
2509   std::vector<const char*> options_c(options.size() + 2);
2510   options_c[0] = "--device-as-default-execution-space";
2511   options_c[1] = "--pre-include=jitify_preinclude.h";
2512   for (int i = 0; i < (int)options.size(); ++i) {
2513     options_c[i + 2] = options[i].c_str();
2514   }
2515 
2516 #if CUDA_VERSION < 8000
2517   std::string inst_dummy;
2518   if (!instantiation.empty()) {
2519     // WAR for no nvrtcAddNameExpression before CUDA 8.0
2520     // Force template instantiation by adding dummy reference to kernel
2521     inst_dummy = "__jitify_instantiation";
2522     program_source +=
2523         "\nvoid* " + inst_dummy + " = (void*)" + instantiation + ";\n";
2524   }
2525 #endif
2526 
2527 #define CHECK_NVRTC(call)       \
2528   do {                          \
2529     nvrtcResult ret = call;     \
2530     if (ret != NVRTC_SUCCESS) { \
2531       return ret;               \
2532     }                           \
2533   } while (0)
2534 
2535   nvrtcProgram nvrtc_program;
2536   CHECK_NVRTC(nvrtcCreateProgram(
2537       &nvrtc_program, program_source.c_str(), program_name.c_str(), num_headers,
2538       header_sources_c.data(), header_names_c.data()));
2539 
2540 #if CUDA_VERSION >= 8000
2541   if (!instantiation.empty()) {
2542     CHECK_NVRTC(nvrtcAddNameExpression(nvrtc_program, instantiation.c_str()));
2543   }
2544 #endif
2545 
2546   nvrtcResult ret = nvrtcCompileProgram(nvrtc_program, (int)options_c.size(),
2547                                         options_c.data());
2548   if (log) {
2549     size_t logsize;
2550     CHECK_NVRTC(nvrtcGetProgramLogSize(nvrtc_program, &logsize));
2551     std::vector<char> vlog(logsize, 0);
2552     CHECK_NVRTC(nvrtcGetProgramLog(nvrtc_program, vlog.data()));
2553     log->assign(vlog.data(), logsize);
2554   }
2555   if (ret != NVRTC_SUCCESS) {
2556     return ret;
2557   }
2558 
2559   if (ptx) {
2560     size_t ptxsize;
2561     CHECK_NVRTC(nvrtcGetPTXSize(nvrtc_program, &ptxsize));
2562     std::vector<char> vptx(ptxsize);
2563     CHECK_NVRTC(nvrtcGetPTX(nvrtc_program, vptx.data()));
2564     ptx->assign(vptx.data(), ptxsize);
2565     if (should_remove_unused_globals) {
2566       detail::ptx_remove_unused_globals(ptx);
2567     }
2568   }
2569 
2570   if (!instantiation.empty() && mangled_instantiation) {
2571 #if CUDA_VERSION >= 8000
2572     const char* mangled_instantiation_cstr;
2573     // Note: The returned string pointer becomes invalid after
2574     //         nvrtcDestroyProgram has been called, so we save it.
2575     CHECK_NVRTC(nvrtcGetLoweredName(nvrtc_program, instantiation.c_str(),
2576                                     &mangled_instantiation_cstr));
2577     *mangled_instantiation = mangled_instantiation_cstr;
2578 #else
2579     // Extract mangled kernel template instantiation from PTX
2580     inst_dummy += " = ";  // Note: This must match how the PTX is generated
2581     int mi_beg = ptx->find(inst_dummy) + inst_dummy.size();
2582     int mi_end = ptx->find(";", mi_beg);
2583     *mangled_instantiation = ptx->substr(mi_beg, mi_end - mi_beg);
2584 #endif
2585   }
2586 
2587   CHECK_NVRTC(nvrtcDestroyProgram(&nvrtc_program));
2588 #undef CHECK_NVRTC
2589   return NVRTC_SUCCESS;
2590 }
2591 
load_program(std::string const & cuda_source,std::vector<std::string> const & headers,file_callback_type file_callback,std::vector<std::string> * include_paths,std::map<std::string,std::string> * program_sources,std::vector<std::string> * program_options,std::string * program_name)2592 inline void load_program(std::string const& cuda_source,
2593                          std::vector<std::string> const& headers,
2594                          file_callback_type file_callback,
2595                          std::vector<std::string>* include_paths,
2596                          std::map<std::string, std::string>* program_sources,
2597                          std::vector<std::string>* program_options,
2598                          std::string* program_name) {
2599   // Extract include paths from compile options
2600   std::vector<std::string>::iterator iter = program_options->begin();
2601   while (iter != program_options->end()) {
2602     std::string const& opt = *iter;
2603     if (opt.substr(0, 2) == "-I") {
2604       include_paths->push_back(opt.substr(2));
2605       iter = program_options->erase(iter);
2606     } else {
2607       ++iter;
2608     }
2609   }
2610 
2611   // Load program source
2612   if (!detail::load_source(cuda_source, *program_sources, "", *include_paths,
2613                            file_callback)) {
2614     throw std::runtime_error("Source not found: " + cuda_source);
2615   }
2616   *program_name = program_sources->begin()->first;
2617 
2618   // Maps header include names to their full file paths.
2619   std::map<std::string, std::string> header_fullpaths;
2620 
2621   // Load header sources
2622   for (std::string const& header : headers) {
2623     if (!detail::load_source(header, *program_sources, "", *include_paths,
2624                              file_callback, &header_fullpaths)) {
2625       // **TODO: Deal with source not found
2626       throw std::runtime_error("Source not found: " + header);
2627     }
2628   }
2629 
2630 #if JITIFY_PRINT_SOURCE
2631   std::string& program_source = (*program_sources)[*program_name];
2632   std::cout << "---------------------------------------" << std::endl;
2633   std::cout << "--- Source of " << *program_name << " ---" << std::endl;
2634   std::cout << "---------------------------------------" << std::endl;
2635   detail::print_with_line_numbers(program_source);
2636   std::cout << "---------------------------------------" << std::endl;
2637 #endif
2638 
2639   std::vector<std::string> compiler_options, linker_files, linker_paths;
2640   detail::split_compiler_and_linker_options(*program_options, &compiler_options,
2641                                             &linker_files, &linker_paths);
2642 
2643   // If no arch is specified at this point we use whatever the current
2644   // context is. This ensures we pick up the correct internal headers
2645   // for arch-dependent compilation, e.g., some intrinsics are only
2646   // present for specific architectures.
2647   detail::detect_and_add_cuda_arch(compiler_options);
2648   detail::detect_and_add_cxx11_flag(compiler_options);
2649 
2650   // Iteratively try to compile the sources, and use the resulting errors to
2651   // identify missing headers.
2652   std::string log;
2653   nvrtcResult ret;
2654   while ((ret = detail::compile_kernel(*program_name, *program_sources,
2655                                        compiler_options, "", &log)) ==
2656          NVRTC_ERROR_COMPILATION) {
2657     std::string include_name;
2658     std::string include_parent;
2659     int line_num = 0;
2660     if (!detail::extract_include_info_from_compile_error(
2661             log, include_name, include_parent, line_num)) {
2662 #if JITIFY_PRINT_LOG
2663       detail::print_compile_log(*program_name, log);
2664 #endif
2665       // There was a non include-related compilation error
2666       // TODO: How to handle error?
2667       throw std::runtime_error("Runtime compilation failed");
2668     }
2669 
2670     bool is_included_with_quotes = false;
2671     if (program_sources->count(include_parent)) {
2672       const std::string& parent_source = (*program_sources)[include_parent];
2673       is_included_with_quotes =
2674           is_include_directive_with_quotes(parent_source, line_num);
2675     }
2676 
2677     // Try to load the new header
2678     // Note: This fullpath lookup is needed because the compiler error
2679     // messages have the include name of the header instead of its full path.
2680     std::string include_parent_fullpath = header_fullpaths[include_parent];
2681     std::string include_path = detail::path_base(include_parent_fullpath);
2682     if (detail::load_source(include_name, *program_sources, include_path,
2683                             *include_paths, file_callback, &header_fullpaths,
2684                             is_included_with_quotes)) {
2685 #if JITIFY_PRINT_HEADER_PATHS
2686       std::cout << "Found #include " << include_name << " from "
2687                 << include_parent << ":" << line_num << " ["
2688                 << include_parent_fullpath << "]"
2689                 << " at:\n  " << header_fullpaths[include_name] << std::endl;
2690 #endif
2691     } else {  // Failed to find header file.
2692       // Comment-out the include line and print a warning
2693       if (!program_sources->count(include_parent)) {
2694         // ***TODO: Unless there's another mechanism (e.g., potentially
2695         //            the parent path vs. filename problem), getting
2696         //            here means include_parent was found automatically
2697         //            in a system include path.
2698         //            We need a WAR to zap it from *its parent*.
2699 
2700         typedef std::map<std::string, std::string> source_map;
2701         for (source_map::const_iterator it = program_sources->begin();
2702              it != program_sources->end(); ++it) {
2703           std::cout << "  " << it->first << std::endl;
2704         }
2705         throw std::out_of_range(include_parent +
2706                                 " not in loaded sources!"
2707                                 " This may be due to a header being loaded by"
2708                                 " NVRTC without Jitify's knowledge.");
2709       }
2710       std::string& parent_source = (*program_sources)[include_parent];
2711       parent_source = detail::comment_out_code_line(line_num, parent_source);
2712 #if JITIFY_PRINT_LOG
2713       std::cout << include_parent << "(" << line_num
2714                 << "): warning: " << include_name << ": [jitify] File not found"
2715                 << std::endl;
2716 #endif
2717     }
2718   }
2719   if (ret != NVRTC_SUCCESS) {
2720 #if JITIFY_PRINT_LOG
2721     if (ret == NVRTC_ERROR_INVALID_OPTION) {
2722       std::cout << "Compiler options: ";
2723       for (int i = 0; i < (int)compiler_options.size(); ++i) {
2724         std::cout << compiler_options[i] << " ";
2725       }
2726       std::cout << std::endl;
2727     }
2728 #endif
2729     throw std::runtime_error(std::string("NVRTC error: ") +
2730                              nvrtcGetErrorString(ret));
2731   }
2732 }
2733 
instantiate_kernel(std::string const & program_name,std::map<std::string,std::string> const & program_sources,std::string const & instantiation,std::vector<std::string> const & options,std::string * log,std::string * ptx,std::string * mangled_instantiation,std::vector<std::string> * linker_files,std::vector<std::string> * linker_paths)2734 inline void instantiate_kernel(
2735     std::string const& program_name,
2736     std::map<std::string, std::string> const& program_sources,
2737     std::string const& instantiation, std::vector<std::string> const& options,
2738     std::string* log, std::string* ptx, std::string* mangled_instantiation,
2739     std::vector<std::string>* linker_files,
2740     std::vector<std::string>* linker_paths) {
2741   std::vector<std::string> compiler_options;
2742   detail::split_compiler_and_linker_options(options, &compiler_options,
2743                                             linker_files, linker_paths);
2744 
2745   nvrtcResult ret =
2746       detail::compile_kernel(program_name, program_sources, compiler_options,
2747                              instantiation, log, ptx, mangled_instantiation);
2748 #if JITIFY_PRINT_LOG
2749   if (log->size() > 1) {
2750     detail::print_compile_log(program_name, *log);
2751   }
2752 #endif
2753   if (ret != NVRTC_SUCCESS) {
2754     throw std::runtime_error(std::string("NVRTC error: ") +
2755                              nvrtcGetErrorString(ret));
2756   }
2757 
2758 #if JITIFY_PRINT_PTX
2759   std::cout << "---------------------------------------" << std::endl;
2760   std::cout << *mangled_instantiation << std::endl;
2761   std::cout << "---------------------------------------" << std::endl;
2762   std::cout << "--- PTX for " << mangled_instantiation << " in " << program_name
2763             << " ---" << std::endl;
2764   std::cout << "---------------------------------------" << std::endl;
2765   std::cout << *ptx << std::endl;
2766   std::cout << "---------------------------------------" << std::endl;
2767 #endif
2768 }
2769 
get_1d_max_occupancy(CUfunction func,CUoccupancyB2DSize smem_callback,unsigned int * smem,int max_block_size,unsigned int flags,int * grid,int * block)2770 inline void get_1d_max_occupancy(CUfunction func,
2771                                  CUoccupancyB2DSize smem_callback,
2772                                  unsigned int* smem, int max_block_size,
2773                                  unsigned int flags, int* grid, int* block) {
2774   if (!func) {
2775     throw std::runtime_error(
2776         "Kernel pointer is NULL; you may need to define JITIFY_THREAD_SAFE "
2777         "1");
2778   }
2779   CUresult res = cuOccupancyMaxPotentialBlockSizeWithFlags(
2780       grid, block, func, smem_callback, *smem, max_block_size, flags);
2781   if (res != CUDA_SUCCESS) {
2782     const char* msg;
2783     cuGetErrorName(res, &msg);
2784     throw std::runtime_error(msg);
2785   }
2786   if (smem_callback) {
2787     *smem = (unsigned int)smem_callback(*block);
2788   }
2789 }
2790 
2791 }  // namespace detail
2792 
2793 //! \endcond
2794 
2795 class KernelInstantiation;
2796 class Kernel;
2797 class Program;
2798 class JitCache;
2799 
2800 struct ProgramConfig {
2801   std::vector<std::string> options;
2802   std::vector<std::string> include_paths;
2803   std::string name;
2804   typedef std::map<std::string, std::string> source_map;
2805   source_map sources;
2806 };
2807 
2808 class JitCache_impl {
2809   friend class Program_impl;
2810   friend class KernelInstantiation_impl;
2811   friend class KernelLauncher_impl;
2812   typedef uint64_t key_type;
2813   jitify::ObjectCache<key_type, detail::CUDAKernel> _kernel_cache;
2814   jitify::ObjectCache<key_type, ProgramConfig> _program_config_cache;
2815   std::vector<std::string> _options;
2816 #if JITIFY_THREAD_SAFE
2817   std::mutex _kernel_cache_mutex;
2818   std::mutex _program_cache_mutex;
2819 #endif
2820  public:
JitCache_impl(size_t cache_size)2821   inline JitCache_impl(size_t cache_size)
2822       : _kernel_cache(cache_size), _program_config_cache(cache_size) {
2823     detail::add_options_from_env(_options);
2824 
2825     // Bootstrap the cuda context to avoid errors
2826     cudaFree(0);
2827   }
2828 };
2829 
2830 class Program_impl {
2831   // A friendly class
2832   friend class Kernel_impl;
2833   friend class KernelLauncher_impl;
2834   friend class KernelInstantiation_impl;
2835   // TODO: This can become invalid if JitCache is destroyed before the
2836   //         Program object is. However, this can't happen if JitCache
2837   //           instances are static.
2838   JitCache_impl& _cache;
2839   uint64_t _hash;
2840   ProgramConfig* _config;
2841   void load_sources(std::string source, std::vector<std::string> headers,
2842                     std::vector<std::string> options,
2843                     file_callback_type file_callback);
2844 
2845  public:
2846   inline Program_impl(JitCache_impl& cache, std::string source,
2847                       jitify::detail::vector<std::string> headers = 0,
2848                       jitify::detail::vector<std::string> options = 0,
2849                       file_callback_type file_callback = 0);
2850   inline Program_impl(Program_impl const&) = default;
2851   inline Program_impl(Program_impl&&) = default;
options() const2852   inline std::vector<std::string> const& options() const {
2853     return _config->options;
2854   }
name() const2855   inline std::string const& name() const { return _config->name; }
sources() const2856   inline ProgramConfig::source_map const& sources() const {
2857     return _config->sources;
2858   }
include_paths() const2859   inline std::vector<std::string> const& include_paths() const {
2860     return _config->include_paths;
2861   }
2862 };
2863 
2864 class Kernel_impl {
2865   friend class KernelLauncher_impl;
2866   friend class KernelInstantiation_impl;
2867   Program_impl _program;
2868   std::string _name;
2869   std::vector<std::string> _options;
2870   uint64_t _hash;
2871 
2872  public:
2873   inline Kernel_impl(Program_impl const& program, std::string name,
2874                      jitify::detail::vector<std::string> options = 0);
2875   inline Kernel_impl(Kernel_impl const&) = default;
2876   inline Kernel_impl(Kernel_impl&&) = default;
2877 };
2878 
2879 class KernelInstantiation_impl {
2880   friend class KernelLauncher_impl;
2881   Kernel_impl _kernel;
2882   uint64_t _hash;
2883   std::string _template_inst;
2884   std::vector<std::string> _options;
2885   detail::CUDAKernel* _cuda_kernel;
2886   inline void print() const;
2887   void build_kernel();
2888 
2889  public:
2890   inline KernelInstantiation_impl(
2891       Kernel_impl const& kernel, std::vector<std::string> const& template_args);
2892   inline KernelInstantiation_impl(KernelInstantiation_impl const&) = default;
2893   inline KernelInstantiation_impl(KernelInstantiation_impl&&) = default;
cuda_kernel() const2894   detail::CUDAKernel const& cuda_kernel() const { return *_cuda_kernel; }
2895 };
2896 
2897 class KernelLauncher_impl {
2898   KernelInstantiation_impl _kernel_inst;
2899   dim3 _grid;
2900   dim3 _block;
2901   unsigned int _smem;
2902   cudaStream_t _stream;
2903 
2904  public:
KernelLauncher_impl(KernelInstantiation_impl const & kernel_inst,dim3 grid,dim3 block,unsigned int smem=0,cudaStream_t stream=0)2905   inline KernelLauncher_impl(KernelInstantiation_impl const& kernel_inst,
2906                              dim3 grid, dim3 block, unsigned int smem = 0,
2907                              cudaStream_t stream = 0)
2908       : _kernel_inst(kernel_inst),
2909         _grid(grid),
2910         _block(block),
2911         _smem(smem),
2912         _stream(stream) {}
2913   inline KernelLauncher_impl(KernelLauncher_impl const&) = default;
2914   inline KernelLauncher_impl(KernelLauncher_impl&&) = default;
2915   inline CUresult launch(
2916       jitify::detail::vector<void*> arg_ptrs,
2917       jitify::detail::vector<std::string> arg_types = 0) const;
2918 };
2919 
2920 /*! An object representing a configured and instantiated kernel ready
2921  *    for launching.
2922  */
2923 class KernelLauncher {
2924   std::unique_ptr<KernelLauncher_impl const> _impl;
2925 
2926  public:
2927   inline KernelLauncher(KernelInstantiation const& kernel_inst, dim3 grid,
2928                         dim3 block, unsigned int smem = 0,
2929                         cudaStream_t stream = 0);
2930 
2931   // Note: It's important that there is no implicit conversion required
2932   //         for arg_ptrs, because otherwise the parameter pack version
2933   //         below gets called instead (probably resulting in a segfault).
2934   /*! Launch the kernel.
2935    *
2936    *  \param arg_ptrs  A vector of pointers to each function argument for the
2937    *    kernel.
2938    *  \param arg_types A vector of function argument types represented
2939    *    as code-strings. This parameter is optional and is only used to print
2940    *    out the function signature.
2941    */
launch(std::vector<void * > arg_ptrs=std::vector<void * > (),jitify::detail::vector<std::string> arg_types=0) const2942   inline CUresult launch(
2943       std::vector<void*> arg_ptrs = std::vector<void*>(),
2944       jitify::detail::vector<std::string> arg_types = 0) const {
2945     return _impl->launch(arg_ptrs, arg_types);
2946   }
2947   // Regular function call syntax
2948   /*! Launch the kernel.
2949    *
2950    *  \see launch
2951    */
2952   template <typename... ArgTypes>
operator ()(ArgTypes...args) const2953   inline CUresult operator()(ArgTypes... args) const {
2954     return this->launch(args...);
2955   }
2956   /*! Launch the kernel.
2957    *
2958    *  \param args Function arguments for the kernel.
2959    */
2960   template <typename... ArgTypes>
launch(ArgTypes...args) const2961   inline CUresult launch(ArgTypes... args) const {
2962     return this->launch(std::vector<void*>({(void*)&args...}),
2963                         {reflection::reflect<ArgTypes>()...});
2964   }
2965 };
2966 
2967 /*! An object representing a kernel instantiation made up of a Kernel and
2968  *    template arguments.
2969  */
2970 class KernelInstantiation {
2971   friend class KernelLauncher;
2972   std::unique_ptr<KernelInstantiation_impl const> _impl;
2973 
2974  public:
2975   inline KernelInstantiation(Kernel const& kernel,
2976                              std::vector<std::string> const& template_args);
2977 
2978   /*! Implicit conversion to the underlying CUfunction object.
2979    *
2980    * \note This allows use of CUDA APIs like
2981    *   cuOccupancyMaxActiveBlocksPerMultiprocessor.
2982    */
operator CUfunction() const2983   inline operator CUfunction() const { return _impl->cuda_kernel(); }
2984 
2985   /*! Configure the kernel launch.
2986    *
2987    *  \see configure
2988    */
operator ()(dim3 grid,dim3 block,unsigned int smem=0,cudaStream_t stream=0) const2989   inline KernelLauncher operator()(dim3 grid, dim3 block, unsigned int smem = 0,
2990                                    cudaStream_t stream = 0) const {
2991     return this->configure(grid, block, smem, stream);
2992   }
2993   /*! Configure the kernel launch.
2994    *
2995    *  \param grid   The thread grid dimensions for the launch.
2996    *  \param block  The thread block dimensions for the launch.
2997    *  \param smem   The amount of shared memory to dynamically allocate, in
2998    * bytes.
2999    *  \param stream The CUDA stream to launch the kernel in.
3000    */
configure(dim3 grid,dim3 block,unsigned int smem=0,cudaStream_t stream=0) const3001   inline KernelLauncher configure(dim3 grid, dim3 block, unsigned int smem = 0,
3002                                   cudaStream_t stream = 0) const {
3003     return KernelLauncher(*this, grid, block, smem, stream);
3004   }
3005   /*! Configure the kernel launch with a 1-dimensional block and grid chosen
3006    *  automatically to maximise occupancy.
3007    *
3008    * \param max_block_size  The upper limit on the block size, or 0 for no
3009    * limit.
3010    * \param smem  The amount of shared memory to dynamically allocate, in bytes.
3011    * \param smem_callback  A function returning smem for a given block size (overrides \p smem).
3012    * \param stream The CUDA stream to launch the kernel in.
3013    * \param flags The flags to pass to cuOccupancyMaxPotentialBlockSizeWithFlags.
3014    */
configure_1d_max_occupancy(int max_block_size=0,unsigned int smem=0,CUoccupancyB2DSize smem_callback=0,cudaStream_t stream=0,unsigned int flags=0) const3015   inline KernelLauncher configure_1d_max_occupancy(
3016       int max_block_size = 0, unsigned int smem = 0,
3017       CUoccupancyB2DSize smem_callback = 0, cudaStream_t stream = 0,
3018       unsigned int flags = 0) const {
3019     int grid;
3020     int block;
3021     CUfunction func = _impl->cuda_kernel();
3022     detail::get_1d_max_occupancy(func, smem_callback, &smem, max_block_size,
3023                                  flags, &grid, &block);
3024     return this->configure(grid, block, smem, stream);
3025   }
3026 
3027   /*
3028    * \deprecated Use \p get_global_ptr instead.
3029    */
get_constant_ptr(const char * name,size_t * size=nullptr) const3030   inline CUdeviceptr get_constant_ptr(const char* name,
3031                                       size_t* size = nullptr) const {
3032     return get_global_ptr(name, size);
3033   }
3034 
3035   /*
3036    * Get a device pointer to a global __constant__ or __device__ variable using
3037    * its un-mangled name. If provided, *size is set to the size of the variable
3038    * in bytes.
3039    */
get_global_ptr(const char * name,size_t * size=nullptr) const3040   inline CUdeviceptr get_global_ptr(const char* name,
3041                                     size_t* size = nullptr) const {
3042     return _impl->cuda_kernel().get_global_ptr(name, size);
3043   }
3044 
3045   /*
3046    * Copy data from a global __constant__ or __device__ array to the host using
3047    * its un-mangled name.
3048    */
3049   template <typename T>
get_global_array(const char * name,T * data,size_t count,CUstream stream=0) const3050   inline CUresult get_global_array(const char* name, T* data, size_t count,
3051                                    CUstream stream = 0) const {
3052     return _impl->cuda_kernel().get_global_data(name, data, count, stream);
3053   }
3054 
3055   /*
3056    * Copy a value from a global __constant__ or __device__ variable to the host
3057    * using its un-mangled name.
3058    */
3059   template <typename T>
get_global_value(const char * name,T * value,CUstream stream=0) const3060   inline CUresult get_global_value(const char* name, T* value,
3061                                    CUstream stream = 0) const {
3062     return get_global_array(name, value, 1, stream);
3063   }
3064 
3065   /*
3066    * Copy data from the host to a global __constant__ or __device__ array using
3067    * its un-mangled name.
3068    */
3069   template <typename T>
set_global_array(const char * name,const T * data,size_t count,CUstream stream=0) const3070   inline CUresult set_global_array(const char* name, const T* data,
3071                                    size_t count, CUstream stream = 0) const {
3072     return _impl->cuda_kernel().set_global_data(name, data, count, stream);
3073   }
3074 
3075   /*
3076    * Copy a value from the host to a global __constant__ or __device__ variable
3077    * using its un-mangled name.
3078    */
3079   template <typename T>
set_global_value(const char * name,const T & value,CUstream stream=0) const3080   inline CUresult set_global_value(const char* name, const T& value,
3081                                    CUstream stream = 0) const {
3082     return set_global_array(name, &value, 1, stream);
3083   }
3084 
mangled_name() const3085   const std::string& mangled_name() const {
3086     return _impl->cuda_kernel().function_name();
3087   }
3088 
ptx() const3089   const std::string& ptx() const { return _impl->cuda_kernel().ptx(); }
3090 
link_files() const3091   const std::vector<std::string>& link_files() const {
3092     return _impl->cuda_kernel().link_files();
3093   }
3094 
link_paths() const3095   const std::vector<std::string>& link_paths() const {
3096     return _impl->cuda_kernel().link_paths();
3097   }
3098 };
3099 
3100 /*! An object representing a kernel made up of a Program, a name and options.
3101  */
3102 class Kernel {
3103   friend class KernelInstantiation;
3104   std::unique_ptr<Kernel_impl const> _impl;
3105 
3106  public:
3107   Kernel(Program const& program, std::string name,
3108          jitify::detail::vector<std::string> options = 0);
3109 
3110   /*! Instantiate the kernel.
3111    *
3112    *  \param template_args A vector of template arguments represented as
3113    *    code-strings. These can be generated using
3114    *    \code{.cpp}jitify::reflection::reflect<type>()\endcode or
3115    *    \code{.cpp}jitify::reflection::reflect(value)\endcode
3116    *
3117    *  \note Template type deduction is not possible, so all types must be
3118    *    explicitly specified.
3119    */
3120   // inline KernelInstantiation instantiate(std::vector<std::string> const&
3121   // template_args) const {
instantiate(std::vector<std::string> const & template_args=std::vector<std::string> ()) const3122   inline KernelInstantiation instantiate(
3123       std::vector<std::string> const& template_args =
3124           std::vector<std::string>()) const {
3125     return KernelInstantiation(*this, template_args);
3126   }
3127 
3128   // Regular template instantiation syntax (note limited flexibility)
3129   /*! Instantiate the kernel.
3130    *
3131    *  \note The template arguments specified on this function are
3132    *    used to instantiate the kernel. Non-type template arguments must
3133    *    be wrapped with
3134    *    \code{.cpp}jitify::reflection::NonType<type,value>\endcode
3135    *
3136    *  \note Template type deduction is not possible, so all types must be
3137    *    explicitly specified.
3138    */
3139   template <typename... TemplateArgs>
instantiate() const3140   inline KernelInstantiation instantiate() const {
3141     return this->instantiate(
3142         std::vector<std::string>({reflection::reflect<TemplateArgs>()...}));
3143   }
3144   // Template-like instantiation syntax
3145   //   E.g., instantiate(myvar,Type<MyType>())(grid,block)
3146   /*! Instantiate the kernel.
3147    *
3148    *  \param targs The template arguments for the kernel, represented as
3149    *    values. Types must be wrapped with
3150    *    \code{.cpp}jitify::reflection::Type<type>()\endcode or
3151    *    \code{.cpp}jitify::reflection::type_of(value)\endcode
3152    *
3153    *  \note Template type deduction is not possible, so all types must be
3154    *    explicitly specified.
3155    */
3156   template <typename... TemplateArgs>
instantiate(TemplateArgs...targs) const3157   inline KernelInstantiation instantiate(TemplateArgs... targs) const {
3158     return this->instantiate(
3159         std::vector<std::string>({reflection::reflect(targs)...}));
3160   }
3161 };
3162 
3163 /*! An object representing a program made up of source code, headers
3164  *    and options.
3165  */
3166 class Program {
3167   friend class Kernel;
3168   std::unique_ptr<Program_impl const> _impl;
3169 
3170  public:
3171   Program(JitCache& cache, std::string source,
3172           jitify::detail::vector<std::string> headers = 0,
3173           jitify::detail::vector<std::string> options = 0,
3174           file_callback_type file_callback = 0);
3175 
3176   /*! Select a kernel.
3177    *
3178    * \param name The name of the kernel (unmangled and without
3179    * template arguments).
3180    * \param options A vector of options to be passed to the NVRTC
3181    * compiler when compiling this kernel.
3182    */
kernel(std::string name,jitify::detail::vector<std::string> options=0) const3183   inline Kernel kernel(std::string name,
3184                        jitify::detail::vector<std::string> options = 0) const {
3185     return Kernel(*this, name, options);
3186   }
3187   /*! Select a kernel.
3188    *
3189    *  \see kernel
3190    */
operator ()(std::string name,jitify::detail::vector<std::string> options=0) const3191   inline Kernel operator()(
3192       std::string name, jitify::detail::vector<std::string> options = 0) const {
3193     return this->kernel(name, options);
3194   }
3195 };
3196 
3197 /*! An object that manages a cache of JIT-compiled CUDA kernels.
3198  *
3199  */
3200 class JitCache {
3201   friend class Program;
3202   std::unique_ptr<JitCache_impl> _impl;
3203 
3204  public:
3205   /*! JitCache constructor.
3206    *  \param cache_size The number of kernels to hold in the cache
3207    *    before overwriting the least-recently-used ones.
3208    */
3209   enum { DEFAULT_CACHE_SIZE = 128 };
JitCache(size_t cache_size=DEFAULT_CACHE_SIZE)3210   JitCache(size_t cache_size = DEFAULT_CACHE_SIZE)
3211       : _impl(new JitCache_impl(cache_size)) {}
3212 
3213   /*! Create a program.
3214    *
3215    *  \param source A string containing either the source filename or
3216    *    the source itself; in the latter case, the first line must be
3217    *    the name of the program.
3218    *  \param headers A vector of strings representing the source of
3219    *    each header file required by the program. Each entry can be
3220    *    either the header filename or the header source itself; in
3221    *    the latter case, the first line must be the name of the header
3222    *    (i.e., the name by which the header is #included).
3223    *  \param options A vector of options to be passed to the
3224    *    NVRTC compiler. Include paths specified with \p -I
3225    *    are added to the search paths used by Jitify. The environment
3226    *    variable JITIFY_OPTIONS can also be used to define additional
3227    *    options.
3228    *  \param file_callback A pointer to a callback function that is
3229    *    invoked whenever a source file needs to be loaded. Inside this
3230    *    function, the user can either load/specify the source themselves
3231    *    or defer to Jitify's file-loading mechanisms.
3232    *  \note Program or header source files referenced by filename are
3233    *  looked-up using the following mechanisms (in this order):
3234    *  \note 1) By calling file_callback.
3235    *  \note 2) By looking for the file embedded in the executable via the GCC
3236    * linker.
3237    *  \note 3) By looking for the file in the filesystem.
3238    *
3239    *  \note Jitify recursively scans all source files for \p #include
3240    *  directives and automatically adds them to the set of headers needed
3241    *  by the program.
3242    *  If a \p #include directive references a header that cannot be found,
3243    *  the directive is automatically removed from the source code to prevent
3244    *  immediate compilation failure. This may result in compilation errors
3245    *  if the header was required by the program.
3246    *
3247    *  \note Jitify automatically includes NVRTC-safe versions of some
3248    *  standard library headers.
3249    */
program(std::string source,jitify::detail::vector<std::string> headers=0,jitify::detail::vector<std::string> options=0,file_callback_type file_callback=0)3250   inline Program program(std::string source,
3251                          jitify::detail::vector<std::string> headers = 0,
3252                          jitify::detail::vector<std::string> options = 0,
3253                          file_callback_type file_callback = 0) {
3254     return Program(*this, source, headers, options, file_callback);
3255   }
3256 };
3257 
Program(JitCache & cache,std::string source,jitify::detail::vector<std::string> headers,jitify::detail::vector<std::string> options,file_callback_type file_callback)3258 inline Program::Program(JitCache& cache, std::string source,
3259                         jitify::detail::vector<std::string> headers,
3260                         jitify::detail::vector<std::string> options,
3261                         file_callback_type file_callback)
3262     : _impl(new Program_impl(*cache._impl, source, headers, options,
3263                              file_callback)) {}
3264 
Kernel(Program const & program,std::string name,jitify::detail::vector<std::string> options)3265 inline Kernel::Kernel(Program const& program, std::string name,
3266                       jitify::detail::vector<std::string> options)
3267     : _impl(new Kernel_impl(*program._impl, name, options)) {}
3268 
KernelInstantiation(Kernel const & kernel,std::vector<std::string> const & template_args)3269 inline KernelInstantiation::KernelInstantiation(
3270     Kernel const& kernel, std::vector<std::string> const& template_args)
3271     : _impl(new KernelInstantiation_impl(*kernel._impl, template_args)) {}
3272 
KernelLauncher(KernelInstantiation const & kernel_inst,dim3 grid,dim3 block,unsigned int smem,cudaStream_t stream)3273 inline KernelLauncher::KernelLauncher(KernelInstantiation const& kernel_inst,
3274                                       dim3 grid, dim3 block, unsigned int smem,
3275                                       cudaStream_t stream)
3276     : _impl(new KernelLauncher_impl(*kernel_inst._impl, grid, block, smem,
3277                                     stream)) {}
3278 
operator <<(std::ostream & stream,dim3 d)3279 inline std::ostream& operator<<(std::ostream& stream, dim3 d) {
3280   if (d.y == 1 && d.z == 1) {
3281     stream << d.x;
3282   } else {
3283     stream << "(" << d.x << "," << d.y << "," << d.z << ")";
3284   }
3285   return stream;
3286 }
3287 
launch(jitify::detail::vector<void * > arg_ptrs,jitify::detail::vector<std::string> arg_types) const3288 inline CUresult KernelLauncher_impl::launch(
3289     jitify::detail::vector<void*> arg_ptrs,
3290     jitify::detail::vector<std::string> arg_types) const {
3291 #if JITIFY_PRINT_LAUNCH
3292   Kernel_impl const& kernel = _kernel_inst._kernel;
3293   std::string arg_types_string =
3294       (arg_types.empty() ? "..." : reflection::reflect_list(arg_types));
3295   std::cout << "Launching " << kernel._name << _kernel_inst._template_inst
3296             << "<<<" << _grid << "," << _block << "," << _smem << "," << _stream
3297             << ">>>"
3298             << "(" << arg_types_string << ")" << std::endl;
3299 #endif
3300   if (!_kernel_inst._cuda_kernel) {
3301     throw std::runtime_error(
3302         "Kernel pointer is NULL; you may need to define JITIFY_THREAD_SAFE 1");
3303   }
3304   return _kernel_inst._cuda_kernel->launch(_grid, _block, _smem, _stream,
3305                                            arg_ptrs);
3306 }
3307 
KernelInstantiation_impl(Kernel_impl const & kernel,std::vector<std::string> const & template_args)3308 inline KernelInstantiation_impl::KernelInstantiation_impl(
3309     Kernel_impl const& kernel, std::vector<std::string> const& template_args)
3310     : _kernel(kernel), _options(kernel._options) {
3311   _template_inst =
3312       (template_args.empty() ? ""
3313                              : reflection::reflect_template(template_args));
3314   using detail::hash_combine;
3315   using detail::hash_larson64;
3316   _hash = _kernel._hash;
3317   _hash = hash_combine(_hash, hash_larson64(_template_inst.c_str()));
3318   JitCache_impl& cache = _kernel._program._cache;
3319   uint64_t cache_key = _hash;
3320 #if JITIFY_THREAD_SAFE
3321   std::lock_guard<std::mutex> lock(cache._kernel_cache_mutex);
3322 #endif
3323   if (cache._kernel_cache.contains(cache_key)) {
3324 #if JITIFY_PRINT_INSTANTIATION
3325     std::cout << "Found ";
3326     this->print();
3327 #endif
3328     _cuda_kernel = &cache._kernel_cache.get(cache_key);
3329   } else {
3330 #if JITIFY_PRINT_INSTANTIATION
3331     std::cout << "Building ";
3332     this->print();
3333 #endif
3334     _cuda_kernel = &cache._kernel_cache.emplace(cache_key);
3335     this->build_kernel();
3336   }
3337 }
3338 
print() const3339 inline void KernelInstantiation_impl::print() const {
3340   std::string options_string = reflection::reflect_list(_options);
3341   std::cout << _kernel._name << _template_inst << " [" << options_string << "]"
3342             << std::endl;
3343 }
3344 
build_kernel()3345 inline void KernelInstantiation_impl::build_kernel() {
3346   Program_impl const& program = _kernel._program;
3347 
3348   std::string instantiation = _kernel._name + _template_inst;
3349 
3350   std::string log, ptx, mangled_instantiation;
3351   std::vector<std::string> linker_files, linker_paths;
3352   detail::instantiate_kernel(program.name(), program.sources(), instantiation,
3353                              _options, &log, &ptx, &mangled_instantiation,
3354                              &linker_files, &linker_paths);
3355 
3356   _cuda_kernel->set(mangled_instantiation.c_str(), ptx.c_str(), linker_files,
3357                     linker_paths);
3358 }
3359 
Kernel_impl(Program_impl const & program,std::string name,jitify::detail::vector<std::string> options)3360 Kernel_impl::Kernel_impl(Program_impl const& program, std::string name,
3361                          jitify::detail::vector<std::string> options)
3362     : _program(program), _name(name), _options(options) {
3363   // Merge options from parent
3364   _options.insert(_options.end(), _program.options().begin(),
3365                   _program.options().end());
3366   detail::detect_and_add_cuda_arch(_options);
3367   detail::detect_and_add_cxx11_flag(_options);
3368   std::string options_string = reflection::reflect_list(_options);
3369   using detail::hash_combine;
3370   using detail::hash_larson64;
3371   _hash = _program._hash;
3372   _hash = hash_combine(_hash, hash_larson64(_name.c_str()));
3373   _hash = hash_combine(_hash, hash_larson64(options_string.c_str()));
3374 }
3375 
Program_impl(JitCache_impl & cache,std::string source,jitify::detail::vector<std::string> headers,jitify::detail::vector<std::string> options,file_callback_type file_callback)3376 Program_impl::Program_impl(JitCache_impl& cache, std::string source,
3377                            jitify::detail::vector<std::string> headers,
3378                            jitify::detail::vector<std::string> options,
3379                            file_callback_type file_callback)
3380     : _cache(cache) {
3381   // Compute hash of source, headers and options
3382   std::string options_string = reflection::reflect_list(options);
3383   using detail::hash_combine;
3384   using detail::hash_larson64;
3385   _hash = hash_combine(hash_larson64(source.c_str()),
3386                        hash_larson64(options_string.c_str()));
3387   for (size_t i = 0; i < headers.size(); ++i) {
3388     _hash = hash_combine(_hash, hash_larson64(headers[i].c_str()));
3389   }
3390   _hash = hash_combine(_hash, (uint64_t)file_callback);
3391   // Add pre-include built-in JIT-safe headers
3392   for (int i = 0; i < detail::preinclude_jitsafe_headers_count; ++i) {
3393     const char* hdr_name = detail::preinclude_jitsafe_header_names[i];
3394     const std::string& hdr_source =
3395         detail::get_jitsafe_headers_map().at(hdr_name);
3396     headers.push_back(std::string(hdr_name) + "\n" + hdr_source);
3397   }
3398   // Merge options from parent
3399   options.insert(options.end(), _cache._options.begin(), _cache._options.end());
3400   // Load sources
3401 #if JITIFY_THREAD_SAFE
3402   std::lock_guard<std::mutex> lock(cache._program_cache_mutex);
3403 #endif
3404   if (!cache._program_config_cache.contains(_hash)) {
3405     _config = &cache._program_config_cache.insert(_hash);
3406     this->load_sources(source, headers, options, file_callback);
3407   } else {
3408     _config = &cache._program_config_cache.get(_hash);
3409   }
3410 }
3411 
load_sources(std::string source,std::vector<std::string> headers,std::vector<std::string> options,file_callback_type file_callback)3412 inline void Program_impl::load_sources(std::string source,
3413                                        std::vector<std::string> headers,
3414                                        std::vector<std::string> options,
3415                                        file_callback_type file_callback) {
3416   _config->options = options;
3417   detail::load_program(source, headers, file_callback, &_config->include_paths,
3418                        &_config->sources, &_config->options, &_config->name);
3419 }
3420 
3421 enum Location { HOST, DEVICE };
3422 
3423 /*! Specifies location and parameters for execution of an algorithm.
3424  *  \param stream        The CUDA stream on which to execute.
3425  *  \param headers       A vector of headers to include in the code.
3426  *  \param options       Options to pass to the NVRTC compiler.
3427  *  \param file_callback See jitify::Program.
3428  *  \param block_size    The size of the CUDA thread block with which to
3429  * execute.
3430  *  \param cache_size    The number of kernels to store in the cache
3431  * before overwriting the least-recently-used ones.
3432  */
3433 struct ExecutionPolicy {
3434   /*! Location (HOST or DEVICE) on which to execute.*/
3435   Location location;
3436   /*! List of headers to include when compiling the algorithm.*/
3437   std::vector<std::string> headers;
3438   /*! List of compiler options.*/
3439   std::vector<std::string> options;
3440   /*! Optional callback for loading source files.*/
3441   file_callback_type file_callback;
3442   /*! CUDA stream on which to execute.*/
3443   cudaStream_t stream;
3444   /*! CUDA device on which to execute.*/
3445   int device;
3446   /*! CUDA block size with which to execute.*/
3447   int block_size;
3448   /*! The number of instantiations to store in the cache before overwriting
3449    *  the least-recently-used ones.*/
3450   size_t cache_size;
ExecutionPolicyjitify::ExecutionPolicy3451   ExecutionPolicy(Location location_ = DEVICE,
3452                   jitify::detail::vector<std::string> headers_ = 0,
3453                   jitify::detail::vector<std::string> options_ = 0,
3454                   file_callback_type file_callback_ = 0,
3455                   cudaStream_t stream_ = 0, int device_ = 0,
3456                   int block_size_ = 256,
3457                   size_t cache_size_ = JitCache::DEFAULT_CACHE_SIZE)
3458       : location(location_),
3459         headers(headers_),
3460         options(options_),
3461         file_callback(file_callback_),
3462         stream(stream_),
3463         device(device_),
3464         block_size(block_size_),
3465         cache_size(cache_size_) {}
3466 };
3467 
3468 template <class Func>
3469 class Lambda;
3470 
3471 /*! An object that captures a set of variables for use in a parallel_for
3472  *    expression. See JITIFY_CAPTURE().
3473  */
3474 class Capture {
3475  public:
3476   std::vector<std::string> _arg_decls;
3477   std::vector<void*> _arg_ptrs;
3478 
3479  public:
3480   template <typename... Args>
Capture(std::vector<std::string> arg_names,Args const &...args)3481   inline Capture(std::vector<std::string> arg_names, Args const&... args)
3482       : _arg_ptrs{(void*)&args...} {
3483     std::vector<std::string> arg_types = {reflection::reflect<Args>()...};
3484     _arg_decls.resize(arg_names.size());
3485     for (int i = 0; i < (int)arg_names.size(); ++i) {
3486       _arg_decls[i] = arg_types[i] + " " + arg_names[i];
3487     }
3488   }
3489 };
3490 
3491 /*! An object that captures the instantiated Lambda function for use
3492     in a parallel_for expression and the function string for NVRTC
3493     compilation
3494  */
3495 template <class Func>
3496 class Lambda {
3497  public:
3498   Capture _capture;
3499   std::string _func_string;
3500   Func _func;
3501 
3502  public:
Lambda(Capture const & capture,std::string func_string,Func func)3503   inline Lambda(Capture const& capture, std::string func_string, Func func)
3504       : _capture(capture), _func_string(func_string), _func(func) {}
3505 };
3506 
3507 template <typename T>
make_Lambda(Capture const & capture,std::string func,T lambda)3508 inline Lambda<T> make_Lambda(Capture const& capture, std::string func,
3509                              T lambda) {
3510   return Lambda<T>(capture, func, lambda);
3511 }
3512 
3513 #define JITIFY_CAPTURE(...)                                            \
3514   jitify::Capture(jitify::detail::split_string(#__VA_ARGS__, -1, ","), \
3515                   __VA_ARGS__)
3516 
3517 #define JITIFY_MAKE_LAMBDA(capture, x, ...)               \
3518   jitify::make_Lambda(capture, std::string(#__VA_ARGS__), \
3519                       [x](int i) { __VA_ARGS__; })
3520 
3521 #define JITIFY_ARGS(...) __VA_ARGS__
3522 
3523 #define JITIFY_LAMBDA_(x, ...) \
3524   JITIFY_MAKE_LAMBDA(JITIFY_CAPTURE(x), JITIFY_ARGS(x), __VA_ARGS__)
3525 
3526 // macro sequence to strip surrounding brackets
3527 #define JITIFY_STRIP_PARENS(X) X
3528 #define JITIFY_PASS_PARAMETERS(X) JITIFY_STRIP_PARENS(JITIFY_ARGS X)
3529 
3530 /*! Creates a Lambda object with captured variables and a function
3531  *    definition.
3532  *  \param capture A bracket-enclosed list of variables to capture.
3533  *  \param ...     The function definition.
3534  *
3535  *  \code{.cpp}
3536  *  float* capture_me;
3537  *  int    capture_me_too;
3538  *  auto my_lambda = JITIFY_LAMBDA( (capture_me, capture_me_too),
3539  *                                  capture_me[i] = i*capture_me_too );
3540  *  \endcode
3541  */
3542 #define JITIFY_LAMBDA(capture, ...)                            \
3543   JITIFY_LAMBDA_(JITIFY_ARGS(JITIFY_PASS_PARAMETERS(capture)), \
3544                  JITIFY_ARGS(__VA_ARGS__))
3545 
3546 // TODO: Try to implement for_each that accepts iterators instead of indices
3547 //       Add compile guard for NOCUDA compilation
3548 /*! Call a function for a range of indices
3549  *
3550  *  \param policy Determines the location and device parameters for
3551  *  execution of the parallel_for.
3552  *  \param begin  The starting index.
3553  *  \param end    The ending index.
3554  *  \param lambda A Lambda object created using the JITIFY_LAMBDA() macro.
3555  *
3556  *  \code{.cpp}
3557  *  char const* in;
3558  *  float*      out;
3559  *  parallel_for(0, 100, JITIFY_LAMBDA( (in, out), {char x = in[i]; out[i] =
3560  * x*x; } ); \endcode
3561  */
3562 template <typename IndexType, class Func>
parallel_for(ExecutionPolicy policy,IndexType begin,IndexType end,Lambda<Func> const & lambda)3563 CUresult parallel_for(ExecutionPolicy policy, IndexType begin, IndexType end,
3564                       Lambda<Func> const& lambda) {
3565   using namespace jitify;
3566 
3567   if (policy.location == HOST) {
3568 #ifdef _OPENMP
3569 #pragma omp parallel for
3570 #endif
3571     for (IndexType i = begin; i < end; i++) {
3572       lambda._func(i);
3573     }
3574     return CUDA_SUCCESS;  // FIXME - replace with non-CUDA enum type?
3575   }
3576 
3577   thread_local static JitCache kernel_cache(policy.cache_size);
3578 
3579   std::vector<std::string> arg_decls;
3580   arg_decls.push_back("I begin, I end");
3581   arg_decls.insert(arg_decls.end(), lambda._capture._arg_decls.begin(),
3582                    lambda._capture._arg_decls.end());
3583 
3584   std::stringstream source_ss;
3585   source_ss << "parallel_for_program\n";
3586   for (auto const& header : policy.headers) {
3587     std::string header_name = header.substr(0, header.find("\n"));
3588     source_ss << "#include <" << header_name << ">\n";
3589   }
3590   source_ss << "template<typename I>\n"
3591                "__global__\n"
3592                "void parallel_for_kernel("
3593             << reflection::reflect_list(arg_decls)
3594             << ") {\n"
3595                "	I i0 = threadIdx.x + blockDim.x*blockIdx.x;\n"
3596                "	for( I i=i0+begin; i<end; i+=blockDim.x*gridDim.x ) {\n"
3597                "	"
3598             << "\t" << lambda._func_string << ";\n"
3599             << "	}\n"
3600                "}\n";
3601 
3602   Program program = kernel_cache.program(source_ss.str(), policy.headers,
3603                                          policy.options, policy.file_callback);
3604 
3605   std::vector<void*> arg_ptrs;
3606   arg_ptrs.push_back(&begin);
3607   arg_ptrs.push_back(&end);
3608   arg_ptrs.insert(arg_ptrs.end(), lambda._capture._arg_ptrs.begin(),
3609                   lambda._capture._arg_ptrs.end());
3610 
3611   size_t n = end - begin;
3612   dim3 block(policy.block_size);
3613   dim3 grid((unsigned int)std::min((n - 1) / block.x + 1, size_t(65535)));
3614   cudaSetDevice(policy.device);
3615   return program.kernel("parallel_for_kernel")
3616       .instantiate<IndexType>()
3617       .configure(grid, block, 0, policy.stream)
3618       .launch(arg_ptrs);
3619 }
3620 
3621 namespace experimental {
3622 
3623 using jitify::file_callback_type;
3624 
3625 namespace serialization {
3626 
3627 namespace detail {
3628 
3629 // This should be incremented whenever the serialization format changes in any
3630 // incompatible way.
3631 static constexpr const size_t kSerializationVersion = 1;
3632 
serialize(std::ostream & stream,size_t u)3633 inline void serialize(std::ostream& stream, size_t u) {
3634   uint64_t u64 = u;
3635   stream.write(reinterpret_cast<char*>(&u64), sizeof(u64));
3636 }
3637 
deserialize(std::istream & stream,size_t * size)3638 inline bool deserialize(std::istream& stream, size_t* size) {
3639   uint64_t u64;
3640   stream.read(reinterpret_cast<char*>(&u64), sizeof(u64));
3641   *size = u64;
3642   return stream.good();
3643 }
3644 
serialize(std::ostream & stream,std::string const & s)3645 inline void serialize(std::ostream& stream, std::string const& s) {
3646   serialize(stream, s.size());
3647   stream.write(s.data(), s.size());
3648 }
3649 
deserialize(std::istream & stream,std::string * s)3650 inline bool deserialize(std::istream& stream, std::string* s) {
3651   size_t size;
3652   if (!deserialize(stream, &size)) return false;
3653   s->resize(size);
3654   if (s->size()) {
3655     stream.read(&(*s)[0], s->size());
3656   }
3657   return stream.good();
3658 }
3659 
serialize(std::ostream & stream,std::vector<std::string> const & v)3660 inline void serialize(std::ostream& stream, std::vector<std::string> const& v) {
3661   serialize(stream, v.size());
3662   for (auto const& s : v) {
3663     serialize(stream, s);
3664   }
3665 }
3666 
deserialize(std::istream & stream,std::vector<std::string> * v)3667 inline bool deserialize(std::istream& stream, std::vector<std::string>* v) {
3668   size_t size;
3669   if (!deserialize(stream, &size)) return false;
3670   v->resize(size);
3671   for (auto& s : *v) {
3672     if (!deserialize(stream, &s)) return false;
3673   }
3674   return true;
3675 }
3676 
serialize(std::ostream & stream,std::map<std::string,std::string> const & m)3677 inline void serialize(std::ostream& stream,
3678                       std::map<std::string, std::string> const& m) {
3679   serialize(stream, m.size());
3680   for (auto const& kv : m) {
3681     serialize(stream, kv.first);
3682     serialize(stream, kv.second);
3683   }
3684 }
3685 
deserialize(std::istream & stream,std::map<std::string,std::string> * m)3686 inline bool deserialize(std::istream& stream,
3687                         std::map<std::string, std::string>* m) {
3688   size_t size;
3689   if (!deserialize(stream, &size)) return false;
3690   for (size_t i = 0; i < size; ++i) {
3691     std::string key;
3692     if (!deserialize(stream, &key)) return false;
3693     if (!deserialize(stream, &(*m)[key])) return false;
3694   }
3695   return true;
3696 }
3697 
3698 template <typename T, typename... Rest>
serialize(std::ostream & stream,T const & value,Rest...rest)3699 inline void serialize(std::ostream& stream, T const& value, Rest... rest) {
3700   serialize(stream, value);
3701   serialize(stream, rest...);
3702 }
3703 
3704 template <typename T, typename... Rest>
deserialize(std::istream & stream,T * value,Rest...rest)3705 inline bool deserialize(std::istream& stream, T* value, Rest... rest) {
3706   if (!deserialize(stream, value)) return false;
3707   return deserialize(stream, rest...);
3708 }
3709 
serialize_magic_number(std::ostream & stream)3710 inline void serialize_magic_number(std::ostream& stream) {
3711   stream.write("JTFY", 4);
3712   serialize(stream, kSerializationVersion);
3713 }
3714 
deserialize_magic_number(std::istream & stream)3715 inline bool deserialize_magic_number(std::istream& stream) {
3716   char magic_number[4] = {0, 0, 0, 0};
3717   stream.read(&magic_number[0], 4);
3718   if (!(magic_number[0] == 'J' && magic_number[1] == 'T' &&
3719         magic_number[2] == 'F' && magic_number[3] == 'Y')) {
3720     return false;
3721   }
3722   size_t serialization_version;
3723   if (!deserialize(stream, &serialization_version)) return false;
3724   return serialization_version == kSerializationVersion;
3725 }
3726 
3727 }  // namespace detail
3728 
3729 template <typename... Values>
serialize(Values const &...values)3730 inline std::string serialize(Values const&... values) {
3731   std::ostringstream ss(std::stringstream::out | std::stringstream::binary);
3732   detail::serialize_magic_number(ss);
3733   detail::serialize(ss, values...);
3734   return ss.str();
3735 }
3736 
3737 template <typename... Values>
deserialize(std::string const & serialized,Values * ...values)3738 inline bool deserialize(std::string const& serialized, Values*... values) {
3739   std::istringstream ss(serialized,
3740                         std::stringstream::in | std::stringstream::binary);
3741   if (!detail::deserialize_magic_number(ss)) return false;
3742   return detail::deserialize(ss, values...);
3743 }
3744 
3745 }  // namespace serialization
3746 
3747 class Program;
3748 class Kernel;
3749 class KernelInstantiation;
3750 class KernelLauncher;
3751 
3752 /*! An object representing a program made up of source code, headers
3753  *    and options.
3754  */
3755 class Program {
3756  private:
3757   friend class KernelInstantiation;
3758   std::string _name;
3759   std::vector<std::string> _options;
3760   std::map<std::string, std::string> _sources;
3761 
3762   // Private constructor used by deserialize()
Program()3763   Program() {}
3764 
3765  public:
3766   /*! Create a program.
3767    *
3768    *  \param source A string containing either the source filename or
3769    *    the source itself; in the latter case, the first line must be
3770    *    the name of the program.
3771    *  \param headers A vector of strings representing the source of
3772    *    each header file required by the program. Each entry can be
3773    *    either the header filename or the header source itself; in
3774    *    the latter case, the first line must be the name of the header
3775    *    (i.e., the name by which the header is #included).
3776    *  \param options A vector of options to be passed to the
3777    *    NVRTC compiler. Include paths specified with \p -I
3778    *    are added to the search paths used by Jitify. The environment
3779    *    variable JITIFY_OPTIONS can also be used to define additional
3780    *    options.
3781    *  \param file_callback A pointer to a callback function that is
3782    *    invoked whenever a source file needs to be loaded. Inside this
3783    *    function, the user can either load/specify the source themselves
3784    *    or defer to Jitify's file-loading mechanisms.
3785    *  \note Program or header source files referenced by filename are
3786    *  looked-up using the following mechanisms (in this order):
3787    *  \note 1) By calling file_callback.
3788    *  \note 2) By looking for the file embedded in the executable via the GCC
3789    * linker.
3790    *  \note 3) By looking for the file in the filesystem.
3791    *
3792    *  \note Jitify recursively scans all source files for \p #include
3793    *  directives and automatically adds them to the set of headers needed
3794    *  by the program.
3795    *  If a \p #include directive references a header that cannot be found,
3796    *  the directive is automatically removed from the source code to prevent
3797    *  immediate compilation failure. This may result in compilation errors
3798    *  if the header was required by the program.
3799    *
3800    *  \note Jitify automatically includes NVRTC-safe versions of some
3801    *  standard library headers.
3802    */
Program(std::string const & cuda_source,std::vector<std::string> const & given_headers={},std::vector<std::string> const & given_options={},file_callback_type file_callback=nullptr)3803   Program(std::string const& cuda_source,
3804           std::vector<std::string> const& given_headers = {},
3805           std::vector<std::string> const& given_options = {},
3806           file_callback_type file_callback = nullptr) {
3807     // Add pre-include built-in JIT-safe headers
3808     std::vector<std::string> headers = given_headers;
3809     for (int i = 0; i < detail::preinclude_jitsafe_headers_count; ++i) {
3810       const char* hdr_name = detail::preinclude_jitsafe_header_names[i];
3811       const std::string& hdr_source =
3812           detail::get_jitsafe_headers_map().at(hdr_name);
3813       headers.push_back(std::string(hdr_name) + "\n" + hdr_source);
3814     }
3815 
3816     _options = given_options;
3817     detail::add_options_from_env(_options);
3818     std::vector<std::string> include_paths;
3819     detail::load_program(cuda_source, headers, file_callback, &include_paths,
3820                          &_sources, &_options, &_name);
3821   }
3822 
3823   /*! Restore a serialized program.
3824    *
3825    * \param serialized_program The serialized program to restore.
3826    *
3827    * \see serialize
3828    */
deserialize(std::string const & serialized_program)3829   static Program deserialize(std::string const& serialized_program) {
3830     Program program;
3831     if (!serialization::deserialize(serialized_program, &program._name,
3832                                     &program._options, &program._sources)) {
3833       throw std::runtime_error("Failed to deserialize program");
3834     }
3835     return program;
3836   }
3837 
3838   /*! Save the program.
3839    *
3840    * \see deserialize
3841    */
serialize() const3842   std::string serialize() const {
3843     // Note: Must update kSerializationVersion if this is changed.
3844     return serialization::serialize(_name, _options, _sources);
3845   };
3846 
3847   /*! Select a kernel.
3848    *
3849    * \param name The name of the kernel (unmangled and without
3850    * template arguments).
3851    * \param options A vector of options to be passed to the NVRTC
3852    * compiler when compiling this kernel.
3853    */
3854   Kernel kernel(std::string const& name,
3855                 std::vector<std::string> const& options = {}) const;
3856 };
3857 
3858 class Kernel {
3859   friend class KernelInstantiation;
3860   Program const* _program;
3861   std::string _name;
3862   std::vector<std::string> _options;
3863 
3864  public:
Kernel(Program const * program,std::string const & name,std::vector<std::string> const & options={})3865   Kernel(Program const* program, std::string const& name,
3866          std::vector<std::string> const& options = {})
3867       : _program(program), _name(name), _options(options) {}
3868 
3869   /*! Instantiate the kernel.
3870    *
3871    *  \param template_args A vector of template arguments represented as
3872    *    code-strings. These can be generated using
3873    *    \code{.cpp}jitify::reflection::reflect<type>()\endcode or
3874    *    \code{.cpp}jitify::reflection::reflect(value)\endcode
3875    *
3876    *  \note Template type deduction is not possible, so all types must be
3877    *    explicitly specified.
3878    */
3879   KernelInstantiation instantiate(
3880       std::vector<std::string> const& template_args =
3881           std::vector<std::string>()) const;
3882 
3883   // Regular template instantiation syntax (note limited flexibility)
3884   /*! Instantiate the kernel.
3885    *
3886    *  \note The template arguments specified on this function are
3887    *    used to instantiate the kernel. Non-type template arguments must
3888    *    be wrapped with
3889    *    \code{.cpp}jitify::reflection::NonType<type,value>\endcode
3890    *
3891    *  \note Template type deduction is not possible, so all types must be
3892    *    explicitly specified.
3893    */
3894   template <typename... TemplateArgs>
3895   KernelInstantiation instantiate() const;
3896 
3897   // Template-like instantiation syntax
3898   //   E.g., instantiate(myvar,Type<MyType>())(grid,block)
3899   /*! Instantiate the kernel.
3900    *
3901    *  \param targs The template arguments for the kernel, represented as
3902    *    values. Types must be wrapped with
3903    *    \code{.cpp}jitify::reflection::Type<type>()\endcode or
3904    *    \code{.cpp}jitify::reflection::type_of(value)\endcode
3905    *
3906    *  \note Template type deduction is not possible, so all types must be
3907    *    explicitly specified.
3908    */
3909   template <typename... TemplateArgs>
3910   KernelInstantiation instantiate(TemplateArgs... targs) const;
3911 };
3912 
3913 class KernelInstantiation {
3914   friend class KernelLauncher;
3915   std::unique_ptr<detail::CUDAKernel> _cuda_kernel;
3916 
3917   // Private constructor used by deserialize()
KernelInstantiation(std::string const & func_name,std::string const & ptx,std::vector<std::string> const & link_files,std::vector<std::string> const & link_paths)3918   KernelInstantiation(std::string const& func_name, std::string const& ptx,
3919                       std::vector<std::string> const& link_files,
3920                       std::vector<std::string> const& link_paths)
3921       : _cuda_kernel(new detail::CUDAKernel(func_name.c_str(), ptx.c_str(),
3922                                             link_files, link_paths)) {}
3923 
3924  public:
KernelInstantiation(Kernel const & kernel,std::vector<std::string> const & template_args)3925   KernelInstantiation(Kernel const& kernel,
3926                       std::vector<std::string> const& template_args) {
3927     Program const* program = kernel._program;
3928 
3929     std::string template_inst =
3930         (template_args.empty() ? ""
3931                                : reflection::reflect_template(template_args));
3932     std::string instantiation = kernel._name + template_inst;
3933 
3934     std::vector<std::string> options;
3935     options.insert(options.begin(), program->_options.begin(),
3936                    program->_options.end());
3937     options.insert(options.begin(), kernel._options.begin(),
3938                    kernel._options.end());
3939     detail::detect_and_add_cuda_arch(options);
3940     detail::detect_and_add_cxx11_flag(options);
3941 
3942     std::string log, ptx, mangled_instantiation;
3943     std::vector<std::string> linker_files, linker_paths;
3944     detail::instantiate_kernel(program->_name, program->_sources, instantiation,
3945                                options, &log, &ptx, &mangled_instantiation,
3946                                &linker_files, &linker_paths);
3947 
3948     _cuda_kernel.reset(new detail::CUDAKernel(mangled_instantiation.c_str(),
3949                                               ptx.c_str(), linker_files,
3950                                               linker_paths));
3951   }
3952 
3953   /*! Implicit conversion to the underlying CUfunction object.
3954    *
3955    * \note This allows use of CUDA APIs like
3956    *   cuOccupancyMaxActiveBlocksPerMultiprocessor.
3957    */
operator CUfunction() const3958   operator CUfunction() const { return *_cuda_kernel; }
3959 
3960   /*! Restore a serialized kernel instantiation.
3961    *
3962    * \param serialized_kernel_inst The serialized kernel instantiation to
3963    * restore.
3964    *
3965    * \see serialize
3966    */
deserialize(std::string const & serialized_kernel_inst)3967   static KernelInstantiation deserialize(
3968       std::string const& serialized_kernel_inst) {
3969     std::string func_name, ptx;
3970     std::vector<std::string> link_files, link_paths;
3971     if (!serialization::deserialize(serialized_kernel_inst, &func_name, &ptx,
3972                                     &link_files, &link_paths)) {
3973       throw std::runtime_error("Failed to deserialize kernel instantiation");
3974     }
3975     return KernelInstantiation(func_name, ptx, link_files, link_paths);
3976   }
3977 
3978   /*! Save the program.
3979    *
3980    * \see deserialize
3981    */
serialize() const3982   std::string serialize() const {
3983     // Note: Must update kSerializationVersion if this is changed.
3984     return serialization::serialize(
3985         _cuda_kernel->function_name(), _cuda_kernel->ptx(),
3986         _cuda_kernel->link_files(), _cuda_kernel->link_paths());
3987   }
3988 
3989   /*! Configure the kernel launch.
3990    *
3991    *  \param grid   The thread grid dimensions for the launch.
3992    *  \param block  The thread block dimensions for the launch.
3993    *  \param smem   The amount of shared memory to dynamically allocate, in
3994    * bytes.
3995    *  \param stream The CUDA stream to launch the kernel in.
3996    */
3997   KernelLauncher configure(dim3 grid, dim3 block, unsigned int smem = 0,
3998                            cudaStream_t stream = 0) const;
3999 
4000   /*! Configure the kernel launch with a 1-dimensional block and grid chosen
4001    *  automatically to maximise occupancy.
4002    *
4003    * \param max_block_size  The upper limit on the block size, or 0 for no
4004    * limit.
4005    * \param smem  The amount of shared memory to dynamically allocate, in bytes.
4006    * \param smem_callback  A function returning smem for a given block size
4007    * (overrides \p smem).
4008    * \param stream The CUDA stream to launch the kernel in.
4009    * \param flags The flags to pass to
4010    * cuOccupancyMaxPotentialBlockSizeWithFlags.
4011    */
4012   KernelLauncher configure_1d_max_occupancy(
4013       int max_block_size = 0, unsigned int smem = 0,
4014       CUoccupancyB2DSize smem_callback = 0, cudaStream_t stream = 0,
4015       unsigned int flags = 0) const;
4016 
4017   /*
4018    * \deprecated Use \p get_global_ptr instead.
4019    */
get_constant_ptr(const char * name,size_t * size=nullptr) const4020   CUdeviceptr get_constant_ptr(const char* name, size_t* size = nullptr) const {
4021     return get_global_ptr(name, size);
4022   }
4023 
4024   /*
4025    * Get a device pointer to a global __constant__ or __device__ variable using
4026    * its un-mangled name. If provided, *size is set to the size of the variable
4027    * in bytes.
4028    */
get_global_ptr(const char * name,size_t * size=nullptr) const4029   CUdeviceptr get_global_ptr(const char* name, size_t* size = nullptr) const {
4030     return _cuda_kernel->get_global_ptr(name, size);
4031   }
4032 
4033   /*
4034    * Copy data from a global __constant__ or __device__ array to the host using
4035    * its un-mangled name.
4036    */
4037   template <typename T>
get_global_array(const char * name,T * data,size_t count,CUstream stream=0) const4038   CUresult get_global_array(const char* name, T* data, size_t count,
4039                             CUstream stream = 0) const {
4040     return _cuda_kernel->get_global_data(name, data, count, stream);
4041   }
4042 
4043   /*
4044    * Copy a value from a global __constant__ or __device__ variable to the host
4045    * using its un-mangled name.
4046    */
4047   template <typename T>
get_global_value(const char * name,T * value,CUstream stream=0) const4048   CUresult get_global_value(const char* name, T* value,
4049                             CUstream stream = 0) const {
4050     return get_global_array(name, value, 1, stream);
4051   }
4052 
4053   /*
4054    * Copy data from the host to a global __constant__ or __device__ array using
4055    * its un-mangled name.
4056    */
4057   template <typename T>
set_global_array(const char * name,const T * data,size_t count,CUstream stream=0) const4058   CUresult set_global_array(const char* name, const T* data, size_t count,
4059                             CUstream stream = 0) const {
4060     return _cuda_kernel->set_global_data(name, data, count, stream);
4061   }
4062 
4063   /*
4064    * Copy a value from the host to a global __constant__ or __device__ variable
4065    * using its un-mangled name.
4066    */
4067   template <typename T>
set_global_value(const char * name,const T & value,CUstream stream=0) const4068   CUresult set_global_value(const char* name, const T& value,
4069                             CUstream stream = 0) const {
4070     return set_global_array(name, &value, 1, stream);
4071   }
4072 
mangled_name() const4073   const std::string& mangled_name() const {
4074     return _cuda_kernel->function_name();
4075   }
4076 
ptx() const4077   const std::string& ptx() const { return _cuda_kernel->ptx(); }
4078 
link_files() const4079   const std::vector<std::string>& link_files() const {
4080     return _cuda_kernel->link_files();
4081   }
4082 
link_paths() const4083   const std::vector<std::string>& link_paths() const {
4084     return _cuda_kernel->link_paths();
4085   }
4086 };
4087 
4088 class KernelLauncher {
4089   KernelInstantiation const* _kernel_inst;
4090   dim3 _grid;
4091   dim3 _block;
4092   unsigned int _smem;
4093   cudaStream_t _stream;
4094 
4095  public:
KernelLauncher(KernelInstantiation const * kernel_inst,dim3 grid,dim3 block,unsigned int smem=0,cudaStream_t stream=0)4096   KernelLauncher(KernelInstantiation const* kernel_inst, dim3 grid, dim3 block,
4097                  unsigned int smem = 0, cudaStream_t stream = 0)
4098       : _kernel_inst(kernel_inst),
4099         _grid(grid),
4100         _block(block),
4101         _smem(smem),
4102         _stream(stream) {}
4103 
4104   // Note: It's important that there is no implicit conversion required
4105   //         for arg_ptrs, because otherwise the parameter pack version
4106   //         below gets called instead (probably resulting in a segfault).
4107   /*! Launch the kernel.
4108    *
4109    *  \param arg_ptrs  A vector of pointers to each function argument for the
4110    *    kernel.
4111    *  \param arg_types A vector of function argument types represented
4112    *    as code-strings. This parameter is optional and is only used to print
4113    *    out the function signature.
4114    */
launch(std::vector<void * > arg_ptrs={},std::vector<std::string> arg_types={}) const4115   CUresult launch(std::vector<void*> arg_ptrs = {},
4116                   std::vector<std::string> arg_types = {}) const {
4117 #if JITIFY_PRINT_LAUNCH
4118     std::string arg_types_string =
4119         (arg_types.empty() ? "..." : reflection::reflect_list(arg_types));
4120     std::cout << "Launching " << _kernel_inst->_cuda_kernel->function_name()
4121               << "<<<" << _grid << "," << _block << "," << _smem << ","
4122               << _stream << ">>>"
4123               << "(" << arg_types_string << ")" << std::endl;
4124 #endif
4125     return _kernel_inst->_cuda_kernel->launch(_grid, _block, _smem, _stream,
4126                                               arg_ptrs);
4127   }
4128 
4129   /*! Launch the kernel.
4130    *
4131    *  \param args Function arguments for the kernel.
4132    */
4133   template <typename... ArgTypes>
launch(ArgTypes...args) const4134   CUresult launch(ArgTypes... args) const {
4135     return this->launch(std::vector<void*>({(void*)&args...}),
4136                         {reflection::reflect<ArgTypes>()...});
4137   }
4138 };
4139 
kernel(std::string const & name,std::vector<std::string> const & options) const4140 inline Kernel Program::kernel(std::string const& name,
4141                               std::vector<std::string> const& options) const {
4142   return Kernel(this, name, options);
4143 }
4144 
instantiate(std::vector<std::string> const & template_args) const4145 inline KernelInstantiation Kernel::instantiate(
4146     std::vector<std::string> const& template_args) const {
4147   return KernelInstantiation(*this, template_args);
4148 }
4149 
4150 template <typename... TemplateArgs>
instantiate() const4151 inline KernelInstantiation Kernel::instantiate() const {
4152   return this->instantiate(
4153       std::vector<std::string>({reflection::reflect<TemplateArgs>()...}));
4154 }
4155 
4156 template <typename... TemplateArgs>
instantiate(TemplateArgs...targs) const4157 inline KernelInstantiation Kernel::instantiate(TemplateArgs... targs) const {
4158   return this->instantiate(
4159       std::vector<std::string>({reflection::reflect(targs)...}));
4160 }
4161 
configure(dim3 grid,dim3 block,unsigned int smem,cudaStream_t stream) const4162 inline KernelLauncher KernelInstantiation::configure(
4163     dim3 grid, dim3 block, unsigned int smem, cudaStream_t stream) const {
4164   return KernelLauncher(this, grid, block, smem, stream);
4165 }
4166 
configure_1d_max_occupancy(int max_block_size,unsigned int smem,CUoccupancyB2DSize smem_callback,cudaStream_t stream,unsigned int flags) const4167 inline KernelLauncher KernelInstantiation::configure_1d_max_occupancy(
4168     int max_block_size, unsigned int smem, CUoccupancyB2DSize smem_callback,
4169     cudaStream_t stream, unsigned int flags) const {
4170   int grid;
4171   int block;
4172   CUfunction func = *_cuda_kernel;
4173   detail::get_1d_max_occupancy(func, smem_callback, &smem, max_block_size,
4174                                flags, &grid, &block);
4175   return this->configure(grid, block, smem, stream);
4176 }
4177 
4178 }  // namespace experimental
4179 
4180 }  // namespace jitify
4181 
4182 #if defined(_WIN32) || defined(_WIN64)
4183 #pragma pop_macro("max")
4184 #pragma pop_macro("min")
4185 #pragma pop_macro("strtok_r")
4186 #endif
4187