1 // SPDX-License-Identifier: Apache-2.0
2 /*
3 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * * Redistributions of source code must retain the above copyright
9 * notice, this list of conditions and the following disclaimer.
10 * * Redistributions in binary form must reproduce the above copyright
11 * notice, this list of conditions and the following disclaimer in the
12 * documentation and/or other materials provided with the distribution.
13 * * Neither the name of NVIDIA CORPORATION nor the names of its
14 * contributors may be used to endorse or promote products derived
15 * from this software without specific prior written permission.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
18 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
20 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
25 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29
30 /*
31 -----------
32 Jitify 0.9
33 -----------
34 A C++ library for easy integration of CUDA runtime compilation into
35 existing codes.
36
37 --------------
38 How to compile
39 --------------
40 Compiler dependencies: <jitify.hpp>, -std=c++11
41 Linker dependencies: dl cuda nvrtc
42
43 --------------------------------------
44 Embedding source files into executable
45 --------------------------------------
46 g++ ... -ldl -rdynamic -DJITIFY_ENABLE_EMBEDDED_FILES=1
47 -Wl,-b,binary,my_kernel.cu,include/my_header.cuh,-b,default nvcc ... -ldl
48 -Xcompiler "-rdynamic
49 -Wl\,-b\,binary\,my_kernel.cu\,include/my_header.cuh\,-b\,default"
50 JITIFY_INCLUDE_EMBEDDED_FILE(my_kernel_cu);
51 JITIFY_INCLUDE_EMBEDDED_FILE(include_my_header_cuh);
52
53 ----
54 TODO
55 ----
56 Extract valid compile options and pass the rest to cuModuleLoadDataEx
57 See if can have stringified headers automatically looked-up
58 by having stringify add them to a (static) global map.
59 The global map can be updated by creating a static class instance
60 whose constructor performs the registration.
61 Can then remove all headers from JitCache constructor in example code
62 See other TODOs in code
63 */
64
65 /*! \file jitify.hpp
66 * \brief The Jitify library header
67 */
68
69 /*! \mainpage Jitify - A C++ library that simplifies the use of NVRTC
70 * \p Use class jitify::JitCache to manage and launch JIT-compiled CUDA
71 * kernels.
72 *
73 * \p Use namespace jitify::reflection to reflect types and values into
74 * code-strings.
75 *
76 * \p Use JITIFY_INCLUDE_EMBEDDED_FILE() to declare files that have been
77 * embedded into the executable using the GCC linker.
78 *
79 * \p Use jitify::parallel_for and JITIFY_LAMBDA() to generate and launch
80 * simple kernels.
81 */
82
83 #pragma once
84
85 #ifndef JITIFY_THREAD_SAFE
86 #define JITIFY_THREAD_SAFE 1
87 #endif
88
89 #if JITIFY_ENABLE_EMBEDDED_FILES
90 #include <dlfcn.h>
91 #endif
92 #include <stdint.h>
93 #include <algorithm>
94 #include <cctype>
95 #include <cstring> // For strtok_r etc.
96 #include <deque>
97 #include <fstream>
98 #include <iomanip>
99 #include <iostream>
100 #include <map>
101 #include <memory>
102 #include <sstream>
103 #include <stdexcept>
104 #include <string>
105 #include <typeinfo>
106 #include <unordered_map>
107 #include <unordered_set>
108 #include <vector>
109 #if JITIFY_THREAD_SAFE
110 #include <mutex>
111 #endif
112
113 #include <cuda.h>
114 #include <cuda_runtime_api.h> // For dim3, cudaStream_t
115 #if CUDA_VERSION >= 8000
116 #define NVRTC_GET_TYPE_NAME 1
117 #endif
118 #include <nvrtc.h>
119
120 // For use by get_current_executable_path().
121 #ifdef __linux__
122 #include <linux/limits.h> // For PATH_MAX
123
124 #include <cstdlib> // For realpath
125 #define JITIFY_PATH_MAX PATH_MAX
126 #elif defined(_WIN32) || defined(_WIN64)
127 #include <windows.h>
128 #define JITIFY_PATH_MAX MAX_PATH
129 #else
130 #error "Unsupported platform"
131 #endif
132
133 #ifdef _MSC_VER // MSVC compiler
134 #include <dbghelp.h> // For UnDecorateSymbolName
135 #else
136 #include <cxxabi.h> // For abi::__cxa_demangle
137 #endif
138
139 #if defined(_WIN32) || defined(_WIN64)
140 // WAR for strtok_r being called strtok_s on Windows
141 #pragma push_macro("strtok_r")
142 #undef strtok_r
143 #define strtok_r strtok_s
144 // WAR for min and max possibly being macros defined by windows.h
145 #pragma push_macro("min")
146 #pragma push_macro("max")
147 #undef min
148 #undef max
149 #endif
150
151 #ifndef JITIFY_PRINT_LOG
152 #define JITIFY_PRINT_LOG 1
153 #endif
154
155 #if JITIFY_PRINT_ALL
156 #define JITIFY_PRINT_INSTANTIATION 1
157 #define JITIFY_PRINT_SOURCE 1
158 #define JITIFY_PRINT_LOG 1
159 #define JITIFY_PRINT_PTX 1
160 #define JITIFY_PRINT_LINKER_LOG 1
161 #define JITIFY_PRINT_LAUNCH 1
162 #define JITIFY_PRINT_HEADER_PATHS 1
163 #endif
164
165 #if JITIFY_ENABLE_EMBEDDED_FILES
166 #define JITIFY_FORCE_UNDEFINED_SYMBOL(x) void* x##_forced = (void*)&x
167 /*! Include a source file that has been embedded into the executable using the
168 * GCC linker.
169 * \param name The name of the source file (<b>not</b> as a string), which must
170 * be sanitized by replacing non-alpha-numeric characters with underscores.
171 * E.g., \code{.cpp}JITIFY_INCLUDE_EMBEDDED_FILE(my_header_h)\endcode will
172 * include the embedded file "my_header.h".
173 * \note Files declared with this macro can be referenced using
174 * their original (unsanitized) filenames when creating a \p
175 * jitify::Program instance.
176 */
177 #define JITIFY_INCLUDE_EMBEDDED_FILE(name) \
178 extern "C" uint8_t _jitify_binary_##name##_start[] asm("_binary_" #name \
179 "_start"); \
180 extern "C" uint8_t _jitify_binary_##name##_end[] asm("_binary_" #name \
181 "_end"); \
182 JITIFY_FORCE_UNDEFINED_SYMBOL(_jitify_binary_##name##_start); \
183 JITIFY_FORCE_UNDEFINED_SYMBOL(_jitify_binary_##name##_end)
184 #endif // JITIFY_ENABLE_EMBEDDED_FILES
185
186 /*! Jitify library namespace
187 */
188 namespace jitify {
189
190 /*! Source-file load callback.
191 *
192 * \param filename The name of the requested source file.
193 * \param tmp_stream A temporary stream that can be used to hold source code.
194 * \return A pointer to an input stream containing the source code, or NULL
195 * to defer loading of the file to Jitify's file-loading mechanisms.
196 */
197 typedef std::istream* (*file_callback_type)(std::string filename,
198 std::iostream& tmp_stream);
199
200 // Exclude from Doxygen
201 //! \cond
202
203 class JitCache;
204
205 // Simple cache using LRU discard policy
206 template <typename KeyType, typename ValueType>
207 class ObjectCache {
208 public:
209 typedef KeyType key_type;
210 typedef ValueType value_type;
211
212 private:
213 typedef std::map<key_type, value_type> object_map;
214 typedef std::deque<key_type> key_rank;
215 typedef typename key_rank::iterator rank_iterator;
216 object_map _objects;
217 key_rank _ranked_keys;
218 size_t _capacity;
219
discard_old(size_t n=0)220 inline void discard_old(size_t n = 0) {
221 if (n > _capacity) {
222 throw std::runtime_error("Insufficient capacity in cache");
223 }
224 while (_objects.size() > _capacity - n) {
225 key_type discard_key = _ranked_keys.back();
226 _ranked_keys.pop_back();
227 _objects.erase(discard_key);
228 }
229 }
230
231 public:
ObjectCache(size_t capacity=8)232 inline ObjectCache(size_t capacity = 8) : _capacity(capacity) {}
resize(size_t capacity)233 inline void resize(size_t capacity) {
234 _capacity = capacity;
235 this->discard_old();
236 }
contains(const key_type & k) const237 inline bool contains(const key_type& k) const {
238 return (bool)_objects.count(k);
239 }
touch(const key_type & k)240 inline void touch(const key_type& k) {
241 if (!this->contains(k)) {
242 throw std::runtime_error("Key not found in cache");
243 }
244 rank_iterator rank = std::find(_ranked_keys.begin(), _ranked_keys.end(), k);
245 if (rank != _ranked_keys.begin()) {
246 // Move key to front of ranks
247 _ranked_keys.erase(rank);
248 _ranked_keys.push_front(k);
249 }
250 }
get(const key_type & k)251 inline value_type& get(const key_type& k) {
252 if (!this->contains(k)) {
253 throw std::runtime_error("Key not found in cache");
254 }
255 this->touch(k);
256 return _objects[k];
257 }
insert(const key_type & k,const value_type & v=value_type ())258 inline value_type& insert(const key_type& k,
259 const value_type& v = value_type()) {
260 this->discard_old(1);
261 _ranked_keys.push_front(k);
262 return _objects.insert(std::make_pair(k, v)).first->second;
263 }
264 template <typename... Args>
emplace(const key_type & k,Args &&...args)265 inline value_type& emplace(const key_type& k, Args&&... args) {
266 this->discard_old(1);
267 // Note: Use of piecewise_construct allows non-movable non-copyable types
268 auto iter = _objects
269 .emplace(std::piecewise_construct, std::forward_as_tuple(k),
270 std::forward_as_tuple(args...))
271 .first;
272 _ranked_keys.push_front(iter->first);
273 return iter->second;
274 }
275 };
276
277 namespace detail {
278
279 // Convenience wrapper for std::vector that provides handy constructors
280 template <typename T>
281 class vector : public std::vector<T> {
282 typedef std::vector<T> super_type;
283
284 public:
vector()285 vector() : super_type() {}
vector(size_t n)286 vector(size_t n) : super_type(n) {} // Note: Not explicit, allows =0
vector(std::vector<T> const & vals)287 vector(std::vector<T> const& vals) : super_type(vals) {}
288 template <int N>
vector(T const (& vals)[N])289 vector(T const (&vals)[N]) : super_type(vals, vals + N) {}
vector(std::vector<T> && vals)290 vector(std::vector<T>&& vals) : super_type(vals) {}
vector(std::initializer_list<T> vals)291 vector(std::initializer_list<T> vals) : super_type(vals) {}
292 };
293
294 // Helper functions for parsing/manipulating source code
295
replace_characters(std::string str,std::string const & oldchars,char newchar)296 inline std::string replace_characters(std::string str,
297 std::string const& oldchars,
298 char newchar) {
299 size_t i = str.find_first_of(oldchars);
300 while (i != std::string::npos) {
301 str[i] = newchar;
302 i = str.find_first_of(oldchars, i + 1);
303 }
304 return str;
305 }
sanitize_filename(std::string name)306 inline std::string sanitize_filename(std::string name) {
307 return replace_characters(name, "/\\.-: ?%*|\"<>", '_');
308 }
309
310 #if JITIFY_ENABLE_EMBEDDED_FILES
311 class EmbeddedData {
312 void* _app;
313 EmbeddedData(EmbeddedData const&);
314 EmbeddedData& operator=(EmbeddedData const&);
315
316 public:
EmbeddedData()317 EmbeddedData() {
318 _app = dlopen(NULL, RTLD_LAZY);
319 if (!_app) {
320 throw std::runtime_error(std::string("dlopen failed: ") + dlerror());
321 }
322 dlerror(); // Clear any existing error
323 }
~EmbeddedData()324 ~EmbeddedData() {
325 if (_app) {
326 dlclose(_app);
327 }
328 }
operator [](std::string key) const329 const uint8_t* operator[](std::string key) const {
330 key = sanitize_filename(key);
331 key = "_binary_" + key;
332 uint8_t const* data = (uint8_t const*)dlsym(_app, key.c_str());
333 if (!data) {
334 throw std::runtime_error(std::string("dlsym failed: ") + dlerror());
335 }
336 return data;
337 }
begin(std::string key) const338 const uint8_t* begin(std::string key) const {
339 return (*this)[key + "_start"];
340 }
end(std::string key) const341 const uint8_t* end(std::string key) const { return (*this)[key + "_end"]; }
342 };
343 #endif // JITIFY_ENABLE_EMBEDDED_FILES
344
is_tokenchar(char c)345 inline bool is_tokenchar(char c) {
346 return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
347 (c >= '0' && c <= '9') || c == '_';
348 }
replace_token(std::string src,std::string token,std::string replacement)349 inline std::string replace_token(std::string src, std::string token,
350 std::string replacement) {
351 size_t i = src.find(token);
352 while (i != std::string::npos) {
353 if (i == 0 || i == src.size() - token.size() ||
354 (!is_tokenchar(src[i - 1]) && !is_tokenchar(src[i + token.size()]))) {
355 src.replace(i, token.size(), replacement);
356 i += replacement.size();
357 } else {
358 i += token.size();
359 }
360 i = src.find(token, i);
361 }
362 return src;
363 }
path_base(std::string p)364 inline std::string path_base(std::string p) {
365 // "/usr/local/myfile.dat" -> "/usr/local"
366 // "foo/bar" -> "foo"
367 // "foo/bar/" -> "foo/bar"
368 #if defined _WIN32 || defined _WIN64
369 char sep = '\\';
370 #else
371 char sep = '/';
372 #endif
373 size_t i = p.find_last_of(sep);
374 if (i != std::string::npos) {
375 return p.substr(0, i);
376 } else {
377 return "";
378 }
379 }
path_join(std::string p1,std::string p2)380 inline std::string path_join(std::string p1, std::string p2) {
381 #ifdef _WIN32
382 char sep = '\\';
383 #else
384 char sep = '/';
385 #endif
386 if (p1.size() && p2.size() && p2[0] == sep) {
387 throw std::invalid_argument("Cannot join to absolute path");
388 }
389 if (p1.size() && p1[p1.size() - 1] != sep) {
390 p1 += sep;
391 }
392 return p1 + p2;
393 }
394 // Elides "/." and "/.." tokens from path.
path_simplify(const std::string & path)395 inline std::string path_simplify(const std::string& path) {
396 std::vector<std::string> dirs;
397 std::string cur_dir;
398 bool after_slash = false;
399 for (int i = 0; i < (int)path.size(); ++i) {
400 if (path[i] == '/') {
401 if (after_slash) continue; // Ignore repeat slashes
402 after_slash = true;
403 if (cur_dir == ".." && !dirs.empty() && dirs.back() != "..") {
404 if (dirs.size() == 1 && dirs.front().empty()) {
405 throw std::runtime_error(
406 "Invalid path: back-traversals exceed depth of absolute path");
407 }
408 dirs.pop_back();
409 } else if (cur_dir != ".") { // Ignore /./
410 dirs.push_back(cur_dir);
411 }
412 cur_dir.clear();
413 } else {
414 after_slash = false;
415 cur_dir.push_back(path[i]);
416 }
417 }
418 if (!after_slash) {
419 dirs.push_back(cur_dir);
420 }
421 std::stringstream ss;
422 for (int i = 0; i < (int)dirs.size() - 1; ++i) {
423 ss << dirs[i] << "/";
424 }
425 if (!dirs.empty()) ss << dirs.back();
426 if (after_slash) ss << "/";
427 return ss.str();
428 }
hash_larson64(const char * s,unsigned long long seed=0)429 inline unsigned long long hash_larson64(const char* s,
430 unsigned long long seed = 0) {
431 unsigned long long hash = seed;
432 while (*s) {
433 hash = hash * 101 + *s++;
434 }
435 return hash;
436 }
437
hash_combine(uint64_t a,uint64_t b)438 inline uint64_t hash_combine(uint64_t a, uint64_t b) {
439 // Note: The magic number comes from the golden ratio
440 return a ^ (0x9E3779B97F4A7C17ull + b + (b >> 2) + (a << 6));
441 }
442
extract_include_info_from_compile_error(std::string log,std::string & name,std::string & parent,int & line_num)443 inline bool extract_include_info_from_compile_error(std::string log,
444 std::string& name,
445 std::string& parent,
446 int& line_num) {
447 static const std::vector<std::string> pattern = {
448 "could not open source file \"", "cannot open source file \""};
449
450 for (auto& p : pattern) {
451 size_t beg = log.find(p);
452 if (beg != std::string::npos) {
453 beg += p.size();
454 size_t end = log.find("\"", beg);
455 name = log.substr(beg, end - beg);
456
457 size_t line_beg = log.rfind("\n", beg);
458 if (line_beg == std::string::npos) {
459 line_beg = 0;
460 } else {
461 line_beg += 1;
462 }
463
464 size_t split = log.find("(", line_beg);
465 parent = log.substr(line_beg, split - line_beg);
466 line_num =
467 atoi(log.substr(split + 1, log.find(")", split + 1) - (split + 1))
468 .c_str());
469
470 return true;
471 }
472 }
473
474 return false;
475 }
476
is_include_directive_with_quotes(const std::string & source,int line_num)477 inline bool is_include_directive_with_quotes(const std::string& source,
478 int line_num) {
479 // TODO: Check each find() for failure.
480 size_t beg = 0;
481 for (int i = 1; i < line_num; ++i) {
482 beg = source.find("\n", beg) + 1;
483 }
484 beg = source.find("include", beg) + 7;
485 beg = source.find_first_of("\"<", beg);
486 return source[beg] == '"';
487 }
488
comment_out_code_line(int line_num,std::string source)489 inline std::string comment_out_code_line(int line_num, std::string source) {
490 size_t beg = 0;
491 for (int i = 1; i < line_num; ++i) {
492 beg = source.find("\n", beg) + 1;
493 }
494 return (source.substr(0, beg) + "//" + source.substr(beg));
495 }
496
print_with_line_numbers(std::string const & source)497 inline void print_with_line_numbers(std::string const& source) {
498 int linenum = 1;
499 std::stringstream source_ss(source);
500 for (std::string line; std::getline(source_ss, line); ++linenum) {
501 std::cout << std::setfill(' ') << std::setw(3) << linenum << " " << line
502 << std::endl;
503 }
504 }
505
print_compile_log(std::string program_name,std::string const & log)506 inline void print_compile_log(std::string program_name,
507 std::string const& log) {
508 std::cout << "---------------------------------------------------"
509 << std::endl;
510 std::cout << "--- JIT compile log for " << program_name << " ---"
511 << std::endl;
512 std::cout << "---------------------------------------------------"
513 << std::endl;
514 std::cout << log << std::endl;
515 std::cout << "---------------------------------------------------"
516 << std::endl;
517 }
518
split_string(std::string str,long maxsplit=-1,std::string delims=" \\t")519 inline std::vector<std::string> split_string(std::string str,
520 long maxsplit = -1,
521 std::string delims = " \t") {
522 std::vector<std::string> results;
523 if (maxsplit == 0) {
524 results.push_back(str);
525 return results;
526 }
527 // Note: +1 to include NULL-terminator
528 std::vector<char> v_str(str.c_str(), str.c_str() + (str.size() + 1));
529 char* c_str = v_str.data();
530 char* saveptr = c_str;
531 char* token = nullptr;
532 for (long i = 0; i != maxsplit; ++i) {
533 token = ::strtok_r(c_str, delims.c_str(), &saveptr);
534 c_str = 0;
535 if (!token) {
536 return results;
537 }
538 results.push_back(token);
539 }
540 // Check if there's a final piece
541 token += ::strlen(token) + 1;
542 if (token - v_str.data() < (ptrdiff_t)str.size()) {
543 // Find the start of the final piece
544 token += ::strspn(token, delims.c_str());
545 if (*token) {
546 results.push_back(token);
547 }
548 }
549 return results;
550 }
551
552 static const std::map<std::string, std::string>& get_jitsafe_headers_map();
553
load_source(std::string filename,std::map<std::string,std::string> & sources,std::string current_dir="",std::vector<std::string> include_paths=std::vector<std::string> (),file_callback_type file_callback=0,std::map<std::string,std::string> * fullpaths=nullptr,bool search_current_dir=true)554 inline bool load_source(
555 std::string filename, std::map<std::string, std::string>& sources,
556 std::string current_dir = "",
557 std::vector<std::string> include_paths = std::vector<std::string>(),
558 file_callback_type file_callback = 0,
559 std::map<std::string, std::string>* fullpaths = nullptr,
560 bool search_current_dir = true) {
561 std::istream* source_stream = 0;
562 std::stringstream string_stream;
563 std::ifstream file_stream;
564 // First detect direct source-code string ("my_program\nprogram_code...")
565 size_t newline_pos = filename.find("\n");
566 if (newline_pos != std::string::npos) {
567 std::string source = filename.substr(newline_pos + 1);
568 filename = filename.substr(0, newline_pos);
569 string_stream << source;
570 source_stream = &string_stream;
571 }
572 if (sources.count(filename)) {
573 // Already got this one
574 return true;
575 }
576 if (!source_stream) {
577 std::string fullpath = path_join(current_dir, filename);
578 // Try loading from callback
579 if (!file_callback ||
580 !(source_stream = file_callback(fullpath, string_stream))) {
581 #if JITIFY_ENABLE_EMBEDDED_FILES
582 // Try loading as embedded file
583 EmbeddedData embedded;
584 std::string source;
585 try {
586 source.assign(embedded.begin(fullpath), embedded.end(fullpath));
587 string_stream << source;
588 source_stream = &string_stream;
589 } catch (std::runtime_error const&)
590 #endif // JITIFY_ENABLE_EMBEDDED_FILES
591 {
592 // Try loading from filesystem
593 bool found_file = false;
594 if (search_current_dir) {
595 file_stream.open(fullpath.c_str());
596 if (file_stream) {
597 source_stream = &file_stream;
598 found_file = true;
599 }
600 }
601 // Search include directories
602 if (!found_file) {
603 for (int i = 0; i < (int)include_paths.size(); ++i) {
604 fullpath = path_join(include_paths[i], filename);
605 file_stream.open(fullpath.c_str());
606 if (file_stream) {
607 source_stream = &file_stream;
608 found_file = true;
609 break;
610 }
611 }
612 if (!found_file) {
613 // Try loading from builtin headers
614 fullpath = path_join("__jitify_builtin", filename);
615 auto it = get_jitsafe_headers_map().find(filename);
616 if (it != get_jitsafe_headers_map().end()) {
617 string_stream << it->second;
618 source_stream = &string_stream;
619 } else {
620 return false;
621 }
622 }
623 }
624 }
625 }
626 if (fullpaths) {
627 // Record the full file path corresponding to this include name.
628 (*fullpaths)[filename] = path_simplify(fullpath);
629 }
630 }
631 sources[filename] = std::string();
632 std::string& source = sources[filename];
633 std::string line;
634 size_t linenum = 0;
635 unsigned long long hash = 0;
636 bool pragma_once = false;
637 bool remove_next_blank_line = false;
638 while (std::getline(*source_stream, line)) {
639 ++linenum;
640
641 // HACK WAR for static variables not allowed on the device (unless
642 // __shared__)
643 // TODO: This breaks static member variables
644 // line = replace_token(line, "static const", "/*static*/ const");
645
646 // TODO: Need to watch out for /* */ comments too
647 std::string cleanline =
648 line.substr(0, line.find("//")); // Strip line comments
649 // if( cleanline.back() == "\r" ) { // Remove Windows line ending
650 // cleanline = cleanline.substr(0, cleanline.size()-1);
651 //}
652 // TODO: Should trim whitespace before checking .empty()
653 if (cleanline.empty() && remove_next_blank_line) {
654 remove_next_blank_line = false;
655 continue;
656 }
657 // Maintain a file hash for use in #pragma once WAR
658 hash = hash_larson64(line.c_str(), hash);
659 if (cleanline.find("#pragma once") != std::string::npos) {
660 pragma_once = true;
661 // Note: This is an attempt to recover the original line numbering,
662 // which otherwise gets off-by-one due to the include guard.
663 remove_next_blank_line = true;
664 // line = "//" + line; // Comment out the #pragma once line
665 continue;
666 }
667
668 // HACK WAR for Thrust using "#define FOO #pragma bar"
669 size_t pragma_beg = cleanline.find("#pragma ");
670 if (pragma_beg != std::string::npos) {
671 std::string line_after_pragma = line.substr(pragma_beg);
672 std::vector<std::string> pragma_split =
673 split_string(line_after_pragma, 2);
674 line =
675 (line.substr(0, pragma_beg) + "_Pragma(\"" + pragma_split[1] + "\")");
676 if (pragma_split.size() == 3) {
677 line += " " + pragma_split[2];
678 }
679 }
680
681 source += line + "\n";
682 }
683 // HACK TESTING (WAR for cub)
684 // source = "#define cudaDeviceSynchronize() cudaSuccess\n" + source;
685 ////source = "cudaError_t cudaDeviceSynchronize() { return cudaSuccess; }\n" +
686 /// source;
687
688 // WAR for #pragma once causing problems when there are multiple inclusions
689 // of the same header from different paths.
690 if (pragma_once) {
691 std::stringstream ss;
692 ss << std::uppercase << std::hex << std::setw(8) << std::setfill('0')
693 << hash;
694 std::string include_guard_name = "_JITIFY_INCLUDE_GUARD_" + ss.str() + "\n";
695 std::string include_guard_header;
696 include_guard_header += "#ifndef " + include_guard_name;
697 include_guard_header += "#define " + include_guard_name;
698 std::string include_guard_footer;
699 include_guard_footer += "#endif // " + include_guard_name;
700 source = include_guard_header + source + "\n" + include_guard_footer;
701 }
702 // return filename;
703 return true;
704 }
705
706 } // namespace detail
707
708 //! \endcond
709
710 /*! Jitify reflection utilities namespace
711 */
712 namespace reflection {
713
714 // Provides type and value reflection via a function 'reflect':
715 // reflect<Type>() -> "Type"
716 // reflect(value) -> "(T)value"
717 // reflect<VAL>() -> "VAL"
718 // reflect<Type,VAL> -> "VAL"
719 // reflect_template<float,NonType<int,7>,char>() -> "<float,7,char>"
720 // reflect_template({"float", "7", "char"}) -> "<float,7,char>"
721
722 /*! A wrapper class for non-type template parameters.
723 */
724 template <typename T, T VALUE_>
725 struct NonType {
726 constexpr static T VALUE = VALUE_;
727 };
728
729 // Forward declaration
730 template <typename T>
731 inline std::string reflect(T const& value);
732
733 //! \cond
734
735 namespace detail {
736
737 template <typename T>
value_string(const T & x)738 inline std::string value_string(const T& x) {
739 std::stringstream ss;
740 ss << x;
741 return ss.str();
742 }
743 // WAR for non-printable characters
744 template <>
value_string(const char & x)745 inline std::string value_string<char>(const char& x) {
746 std::stringstream ss;
747 ss << (int)x;
748 return ss.str();
749 }
750 template <>
value_string(const signed char & x)751 inline std::string value_string<signed char>(const signed char& x) {
752 std::stringstream ss;
753 ss << (int)x;
754 return ss.str();
755 }
756 template <>
value_string(const unsigned char & x)757 inline std::string value_string<unsigned char>(const unsigned char& x) {
758 std::stringstream ss;
759 ss << (int)x;
760 return ss.str();
761 }
762 template <>
value_string(const wchar_t & x)763 inline std::string value_string<wchar_t>(const wchar_t& x) {
764 std::stringstream ss;
765 ss << (long)x;
766 return ss.str();
767 }
768 // Specialisation for bool true/false literals
769 template <>
value_string(const bool & x)770 inline std::string value_string<bool>(const bool& x) {
771 return x ? "true" : "false";
772 }
773
774 // Removes all tokens that start with double underscores.
strip_double_underscore_tokens(char * s)775 inline void strip_double_underscore_tokens(char* s) {
776 using jitify::detail::is_tokenchar;
777 char* w = s;
778 do {
779 if (*s == '_' && *(s + 1) == '_') {
780 while (is_tokenchar(*++s))
781 ;
782 }
783 } while ((*w++ = *s++));
784 }
785
786 //#if CUDA_VERSION < 8000
787 #ifdef _MSC_VER // MSVC compiler
demangle_cuda_symbol(const char * mangled_name)788 inline std::string demangle_cuda_symbol(const char* mangled_name) {
789 // We don't have a way to demangle CUDA symbol names under MSVC.
790 return mangled_name;
791 }
demangle_native_type(const std::type_info & typeinfo)792 inline std::string demangle_native_type(const std::type_info& typeinfo) {
793 // Get the decorated name and skip over the leading '.'.
794 const char* decorated_name = typeinfo.raw_name() + 1;
795 char undecorated_name[4096];
796 if (UnDecorateSymbolName(
797 decorated_name, undecorated_name,
798 sizeof(undecorated_name) / sizeof(*undecorated_name),
799 UNDNAME_NO_ARGUMENTS | // Treat input as a type name
800 UNDNAME_NAME_ONLY // No "class" and "struct" prefixes
801 /*UNDNAME_NO_MS_KEYWORDS*/)) { // No "__cdecl", "__ptr64" etc.
802 // WAR for UNDNAME_NO_MS_KEYWORDS messing up function types.
803 strip_double_underscore_tokens(undecorated_name);
804 return undecorated_name;
805 }
806 throw std::runtime_error("UnDecorateSymbolName failed");
807 }
808 #else // not MSVC
demangle_cuda_symbol(const char * mangled_name)809 inline std::string demangle_cuda_symbol(const char* mangled_name) {
810 size_t bufsize = 0;
811 char* buf = nullptr;
812 std::string demangled_name;
813 int status;
814 auto demangled_ptr = std::unique_ptr<char, decltype(free)*>(
815 abi::__cxa_demangle(mangled_name, buf, &bufsize, &status), free);
816 if (status == 0) {
817 demangled_name = demangled_ptr.get(); // all worked as expected
818 } else if (status == -2) {
819 demangled_name = mangled_name; // we interpret this as plain C name
820 } else if (status == -1) {
821 throw std::runtime_error(
822 std::string("memory allocation failure in __cxa_demangle"));
823 } else if (status == -3) {
824 throw std::runtime_error(std::string("invalid argument to __cxa_demangle"));
825 }
826 return demangled_name;
827 }
demangle_native_type(const std::type_info & typeinfo)828 inline std::string demangle_native_type(const std::type_info& typeinfo) {
829 return demangle_cuda_symbol(typeinfo.name());
830 }
831 #endif // not MSVC
832 //#endif // CUDA_VERSION < 8000
833
834 template <typename>
835 class JitifyTypeNameWrapper_ {};
836
837 template <typename T>
838 struct type_reflection {
namejitify::reflection::detail::type_reflection839 inline static std::string name() {
840 //#if CUDA_VERSION < 8000
841 // TODO: Use nvrtcGetTypeName once it has the same behavior as this.
842 // WAR for typeid discarding cv qualifiers on value-types
843 // Wrap type in dummy template class to preserve cv-qualifiers, then strip
844 // off the wrapper from the resulting string.
845 std::string wrapped_name =
846 demangle_native_type(typeid(JitifyTypeNameWrapper_<T>));
847 // Note: The reflected name of this class also has namespace prefixes.
848 const std::string wrapper_class_name = "JitifyTypeNameWrapper_<";
849 size_t start = wrapped_name.find(wrapper_class_name);
850 if (start == std::string::npos) {
851 throw std::runtime_error("Type reflection failed: " + wrapped_name);
852 }
853 start += wrapper_class_name.size();
854 std::string name =
855 wrapped_name.substr(start, wrapped_name.size() - (start + 1));
856 return name;
857 //#else
858 // std::string ret;
859 // nvrtcResult status = nvrtcGetTypeName<T>(&ret);
860 // if( status != NVRTC_SUCCESS ) {
861 // throw std::runtime_error(std::string("nvrtcGetTypeName
862 // failed:
863 //")+ nvrtcGetErrorString(status));
864 // }
865 // return ret;
866 //#endif
867 }
868 }; // namespace detail
869 template <typename T, T VALUE>
870 struct type_reflection<NonType<T, VALUE> > {
namejitify::reflection::detail::type_reflection871 inline static std::string name() {
872 return jitify::reflection::reflect(VALUE);
873 }
874 };
875
876 } // namespace detail
877
878 //! \endcond
879
880 /*! Create an Instance object that contains a const reference to the
881 * value. We use this to wrap abstract objects from which we want to extract
882 * their type at runtime (e.g., derived type). This is used to facilitate
883 * templating on derived type when all we know at compile time is abstract
884 * type.
885 */
886 template <typename T>
887 struct Instance {
888 const T& value;
Instancejitify::reflection::Instance889 Instance(const T& value) : value(value) {}
890 };
891
892 /*! Create an Instance object from which we can extract the value's run-time
893 * type.
894 * \param value The const value to be captured.
895 */
896 template <typename T>
instance_of(T const & value)897 inline Instance<T const> instance_of(T const& value) {
898 return Instance<T const>(value);
899 }
900
901 /*! A wrapper used for representing types as values.
902 */
903 template <typename T>
904 struct Type {};
905
906 // Type reflection
907 // E.g., reflect<float>() -> "float"
908 // Note: This strips trailing const and volatile qualifiers
909 /*! Generate a code-string for a type.
910 * \code{.cpp}reflect<float>() --> "float"\endcode
911 */
912 template <typename T>
reflect()913 inline std::string reflect() {
914 return detail::type_reflection<T>::name();
915 }
916 // Value reflection
917 // E.g., reflect(3.14f) -> "(float)3.14"
918 /*! Generate a code-string for a value.
919 * \code{.cpp}reflect(3.14f) --> "(float)3.14"\endcode
920 */
921 template <typename T>
reflect(T const & value)922 inline std::string reflect(T const& value) {
923 return "(" + reflect<T>() + ")" + detail::value_string(value);
924 }
925 // Non-type template arg reflection (implicit conversion to int64_t)
926 // E.g., reflect<7>() -> "(int64_t)7"
927 /*! Generate a code-string for an integer non-type template argument.
928 * \code{.cpp}reflect<7>() --> "(int64_t)7"\endcode
929 */
930 template <int64_t N>
reflect()931 inline std::string reflect() {
932 return reflect<NonType<int64_t, N> >();
933 }
934 // Non-type template arg reflection (explicit type)
935 // E.g., reflect<int,7>() -> "(int)7"
936 /*! Generate a code-string for a generic non-type template argument.
937 * \code{.cpp} reflect<int,7>() --> "(int)7" \endcode
938 */
939 template <typename T, T N>
reflect()940 inline std::string reflect() {
941 return reflect<NonType<T, N> >();
942 }
943 // Type reflection via value
944 // E.g., reflect(Type<float>()) -> "float"
945 /*! Generate a code-string for a type wrapped as a Type instance.
946 * \code{.cpp}reflect(Type<float>()) --> "float"\endcode
947 */
948 template <typename T>
reflect(jitify::reflection::Type<T>)949 inline std::string reflect(jitify::reflection::Type<T>) {
950 return reflect<T>();
951 }
952
953 /*! Generate a code-string for a type wrapped as an Instance instance.
954 * \code{.cpp}reflect(Instance<float>(3.1f)) --> "float"\endcode
955 * or more simply when passed to a instance_of helper
956 * \code{.cpp}reflect(instance_of(3.1f)) --> "float"\endcodei
957 * This is specifically for the case where we want to extract the run-time
958 * type, e.g., derived type, of an object pointer.
959 */
960 template <typename T>
reflect(jitify::reflection::Instance<T> & value)961 inline std::string reflect(jitify::reflection::Instance<T>& value) {
962 return detail::demangle_native_type(typeid(value.value));
963 }
964
965 // Type from value
966 // E.g., type_of(3.14f) -> Type<float>()
967 /*! Create a Type object representing a value's type.
968 * \param value The value whose type is to be captured.
969 */
970 template <typename T>
type_of(T & value)971 inline Type<T> type_of(T& value) {
972 return Type<T>();
973 }
974 /*! Create a Type object representing a value's type.
975 * \param value The const value whose type is to be captured.
976 */
977 template <typename T>
type_of(T const & value)978 inline Type<T const> type_of(T const& value) {
979 return Type<T const>();
980 }
981
982 // Multiple value reflections one call, returning list of strings
983 template <typename... Args>
reflect_all(Args...args)984 inline std::vector<std::string> reflect_all(Args... args) {
985 return {reflect(args)...};
986 }
987
reflect_list(jitify::detail::vector<std::string> const & args,std::string opener="",std::string closer="")988 inline std::string reflect_list(jitify::detail::vector<std::string> const& args,
989 std::string opener = "",
990 std::string closer = "") {
991 std::stringstream ss;
992 ss << opener;
993 for (int i = 0; i < (int)args.size(); ++i) {
994 if (i > 0) ss << ",";
995 ss << args[i];
996 }
997 ss << closer;
998 return ss.str();
999 }
1000
1001 // Template instantiation reflection
1002 // inline std::string reflect_template(std::vector<std::string> const& args) {
reflect_template(jitify::detail::vector<std::string> const & args)1003 inline std::string reflect_template(
1004 jitify::detail::vector<std::string> const& args) {
1005 // Note: The space in " >" is a WAR to avoid '>>' appearing
1006 return reflect_list(args, "<", " >");
1007 }
1008 // TODO: See if can make this evaluate completely at compile-time
1009 template <typename... Ts>
reflect_template()1010 inline std::string reflect_template() {
1011 return reflect_template({reflect<Ts>()...});
1012 // return reflect_template<sizeof...(Ts)>({reflect<Ts>()...});
1013 }
1014
1015 } // namespace reflection
1016
1017 //! \cond
1018
1019 namespace detail {
1020
1021 // Demangles nested variable names using the PTX name mangling scheme
1022 // (which follows the Itanium64 ABI). E.g., _ZN1a3Foo2bcE -> a::Foo::bc.
demangle_ptx_variable_name(const char * name)1023 inline std::string demangle_ptx_variable_name(const char* name) {
1024 std::stringstream ss;
1025 const char* c = name;
1026 if (*c++ != '_' || *c++ != 'Z') return name; // Non-mangled name
1027 if (*c++ != 'N') return ""; // Not a nested name, unsupported
1028 while (true) {
1029 // Parse identifier length.
1030 int n = 0;
1031 while (std::isdigit(*c)) {
1032 n = n * 10 + (*c - '0');
1033 c++;
1034 }
1035 if (!n) return ""; // Invalid or unsupported mangled name
1036 // Parse identifier.
1037 const char* c0 = c;
1038 while (n-- && *c) c++;
1039 if (!*c) return ""; // Mangled name is truncated
1040 std::string id(c0, c);
1041 // Identifiers starting with "_GLOBAL" are anonymous namespaces.
1042 ss << (id.substr(0, 7) == "_GLOBAL" ? "(anonymous namespace)" : id);
1043 // Nested name specifiers end with 'E'.
1044 if (*c == 'E') break;
1045 // There are more identifiers to come, add join token.
1046 ss << "::";
1047 }
1048 return ss.str();
1049 }
1050
get_current_executable_path()1051 static const char* get_current_executable_path() {
1052 static const char* path = []() -> const char* {
1053 static char buffer[JITIFY_PATH_MAX] = {};
1054 #ifdef __linux__
1055 if (!::realpath("/proc/self/exe", buffer)) return nullptr;
1056 #elif defined(_WIN32) || defined(_WIN64)
1057 if (!GetModuleFileNameA(nullptr, buffer, JITIFY_PATH_MAX)) return nullptr;
1058 #endif
1059 return buffer;
1060 }();
1061 return path;
1062 }
1063
endswith(const std::string & str,const std::string & suffix)1064 inline bool endswith(const std::string& str, const std::string& suffix) {
1065 return str.size() >= suffix.size() &&
1066 str.substr(str.size() - suffix.size()) == suffix;
1067 }
1068
1069 // Infers the JIT input type from the filename suffix. If no known suffix is
1070 // present, the filename is assumed to refer to a library, and the associated
1071 // suffix (and possibly prefix) is automatically added to the filename.
get_cuda_jit_input_type(std::string * filename)1072 inline CUjitInputType get_cuda_jit_input_type(std::string* filename) {
1073 if (endswith(*filename, ".ptx")) {
1074 return CU_JIT_INPUT_PTX;
1075 } else if (endswith(*filename, ".cubin")) {
1076 return CU_JIT_INPUT_CUBIN;
1077 } else if (endswith(*filename, ".fatbin")) {
1078 return CU_JIT_INPUT_FATBINARY;
1079 } else if (endswith(*filename,
1080 #if defined _WIN32 || defined _WIN64
1081 ".obj"
1082 #else // Linux
1083 ".o"
1084 #endif
1085 )) {
1086 return CU_JIT_INPUT_OBJECT;
1087 } else { // Assume library
1088 #if defined _WIN32 || defined _WIN64
1089 if (!endswith(*filename, ".lib")) {
1090 *filename += ".lib";
1091 }
1092 #else // Linux
1093 if (!endswith(*filename, ".a")) {
1094 *filename = "lib" + *filename + ".a";
1095 }
1096 #endif
1097 return CU_JIT_INPUT_LIBRARY;
1098 }
1099 }
1100
1101 class CUDAKernel {
1102 std::vector<std::string> _link_files;
1103 std::vector<std::string> _link_paths;
1104 CUlinkState _link_state;
1105 CUmodule _module;
1106 CUfunction _kernel;
1107 std::string _func_name;
1108 std::string _ptx;
1109 std::map<std::string, std::string> _global_map;
1110 std::vector<CUjit_option> _opts;
1111 std::vector<void*> _optvals;
1112 #ifdef JITIFY_PRINT_LINKER_LOG
1113 static const unsigned int _log_size = 8192;
1114 char _error_log[_log_size];
1115 char _info_log[_log_size];
1116 #endif
1117
cuda_safe_call(CUresult res) const1118 inline void cuda_safe_call(CUresult res) const {
1119 if (res != CUDA_SUCCESS) {
1120 const char* msg;
1121 cuGetErrorName(res, &msg);
1122 throw std::runtime_error(msg);
1123 }
1124 }
create_module(std::vector<std::string> link_files,std::vector<std::string> link_paths)1125 inline void create_module(std::vector<std::string> link_files,
1126 std::vector<std::string> link_paths) {
1127 CUresult result;
1128 #ifndef JITIFY_PRINT_LINKER_LOG
1129 // WAR since linker log does not seem to be constructed using a single call
1130 // to cuModuleLoadDataEx.
1131 if (link_files.empty()) {
1132 result =
1133 cuModuleLoadDataEx(&_module, _ptx.c_str(), (unsigned)_opts.size(),
1134 _opts.data(), _optvals.data());
1135 } else
1136 #endif
1137 {
1138 cuda_safe_call(cuLinkCreate((unsigned)_opts.size(), _opts.data(),
1139 _optvals.data(), &_link_state));
1140 cuda_safe_call(cuLinkAddData(_link_state, CU_JIT_INPUT_PTX,
1141 (void*)_ptx.c_str(), _ptx.size(),
1142 "jitified_source.ptx", 0, 0, 0));
1143 for (int i = 0; i < (int)link_files.size(); ++i) {
1144 std::string link_file = link_files[i];
1145 CUjitInputType jit_input_type;
1146 if (link_file == ".") {
1147 // Special case for linking to current executable.
1148 link_file = get_current_executable_path();
1149 jit_input_type = CU_JIT_INPUT_OBJECT;
1150 } else {
1151 // Infer based on filename.
1152 jit_input_type = get_cuda_jit_input_type(&link_file);
1153 }
1154 CUresult result = cuLinkAddFile(_link_state, jit_input_type,
1155 link_file.c_str(), 0, 0, 0);
1156 int path_num = 0;
1157 while (result == CUDA_ERROR_FILE_NOT_FOUND &&
1158 path_num < (int)link_paths.size()) {
1159 std::string filename = path_join(link_paths[path_num++], link_file);
1160 result = cuLinkAddFile(_link_state, jit_input_type, filename.c_str(),
1161 0, 0, 0);
1162 }
1163 #if JITIFY_PRINT_LINKER_LOG
1164 if (result == CUDA_ERROR_FILE_NOT_FOUND) {
1165 std::cerr << "Linker error: Device library not found: " << link_file
1166 << std::endl;
1167 } else if (result != CUDA_SUCCESS) {
1168 std::cerr << "Linker error: Failed to add file: " << link_file
1169 << std::endl;
1170 std::cerr << _error_log << std::endl;
1171 }
1172 #endif
1173 cuda_safe_call(result);
1174 }
1175 size_t cubin_size;
1176 void* cubin;
1177 result = cuLinkComplete(_link_state, &cubin, &cubin_size);
1178 if (result == CUDA_SUCCESS) {
1179 result = cuModuleLoadData(&_module, cubin);
1180 }
1181 }
1182 #ifdef JITIFY_PRINT_LINKER_LOG
1183 std::cout << "---------------------------------------" << std::endl;
1184 std::cout << "--- Linker for "
1185 << reflection::detail::demangle_cuda_symbol(_func_name.c_str())
1186 << " ---" << std::endl;
1187 std::cout << "---------------------------------------" << std::endl;
1188 std::cout << _info_log << std::endl;
1189 std::cout << std::endl;
1190 std::cout << _error_log << std::endl;
1191 std::cout << "---------------------------------------" << std::endl;
1192 #endif
1193 cuda_safe_call(result);
1194 // Allow _func_name to be empty to support cases where we want to generate
1195 // PTX containing extern symbol definitions but no kernels.
1196 if (!_func_name.empty()) {
1197 cuda_safe_call(
1198 cuModuleGetFunction(&_kernel, _module, _func_name.c_str()));
1199 }
1200 }
destroy_module()1201 inline void destroy_module() {
1202 if (_link_state) {
1203 cuda_safe_call(cuLinkDestroy(_link_state));
1204 }
1205 _link_state = 0;
1206 if (_module) {
1207 cuModuleUnload(_module);
1208 }
1209 _module = 0;
1210 }
1211
1212 // create a map of __constant__ and __device__ variables in the ptx file
1213 // mapping demangled to mangled name
create_global_variable_map()1214 inline void create_global_variable_map() {
1215 size_t pos = 0;
1216 while (pos < _ptx.size()) {
1217 pos = std::min(_ptx.find(".const .align", pos),
1218 _ptx.find(".global .align", pos));
1219 if (pos == std::string::npos) break;
1220 size_t end = _ptx.find_first_of(";=", pos);
1221 if (_ptx[end] == '=') --end;
1222 std::string line = _ptx.substr(pos, end - pos);
1223 pos = end;
1224 size_t symbol_start = line.find_last_of(" ") + 1;
1225 size_t symbol_end = line.find_last_of("[");
1226 std::string entry = line.substr(symbol_start, symbol_end - symbol_start);
1227 std::string key = detail::demangle_ptx_variable_name(entry.c_str());
1228 // Skip unsupported mangled names. E.g., a static variable defined inside
1229 // a function (such variables are not directly addressable from outside
1230 // the function, so skipping them is the correct behavior).
1231 if (key == "") continue;
1232 _global_map[key] = entry;
1233 }
1234 }
1235
set_linker_log()1236 inline void set_linker_log() {
1237 #ifdef JITIFY_PRINT_LINKER_LOG
1238 _opts.push_back(CU_JIT_INFO_LOG_BUFFER);
1239 _optvals.push_back((void*)_info_log);
1240 _opts.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES);
1241 _optvals.push_back((void*)(long)_log_size);
1242 _opts.push_back(CU_JIT_ERROR_LOG_BUFFER);
1243 _optvals.push_back((void*)_error_log);
1244 _opts.push_back(CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES);
1245 _optvals.push_back((void*)(long)_log_size);
1246 _opts.push_back(CU_JIT_LOG_VERBOSE);
1247 _optvals.push_back((void*)1);
1248 #endif
1249 }
1250
1251 public:
CUDAKernel()1252 inline CUDAKernel() : _link_state(0), _module(0), _kernel(0) {}
1253 inline CUDAKernel(const CUDAKernel& other) = delete;
1254 inline CUDAKernel& operator=(const CUDAKernel& other) = delete;
1255 inline CUDAKernel(CUDAKernel&& other) = delete;
1256 inline CUDAKernel& operator=(CUDAKernel&& other) = delete;
CUDAKernel(const char * func_name,const char * ptx,std::vector<std::string> link_files,std::vector<std::string> link_paths,unsigned int nopts=0,CUjit_option * opts=0,void ** optvals=0)1257 inline CUDAKernel(const char* func_name, const char* ptx,
1258 std::vector<std::string> link_files,
1259 std::vector<std::string> link_paths, unsigned int nopts = 0,
1260 CUjit_option* opts = 0, void** optvals = 0)
1261 : _link_files(link_files),
1262 _link_paths(link_paths),
1263 _link_state(0),
1264 _module(0),
1265 _kernel(0),
1266 _func_name(func_name),
1267 _ptx(ptx),
1268 _opts(opts, opts + nopts),
1269 _optvals(optvals, optvals + nopts) {
1270 this->set_linker_log();
1271 this->create_module(link_files, link_paths);
1272 this->create_global_variable_map();
1273 }
1274
set(const char * func_name,const char * ptx,std::vector<std::string> link_files,std::vector<std::string> link_paths,unsigned int nopts=0,CUjit_option * opts=0,void ** optvals=0)1275 inline CUDAKernel& set(const char* func_name, const char* ptx,
1276 std::vector<std::string> link_files,
1277 std::vector<std::string> link_paths,
1278 unsigned int nopts = 0, CUjit_option* opts = 0,
1279 void** optvals = 0) {
1280 this->destroy_module();
1281 _func_name = func_name;
1282 _ptx = ptx;
1283 _link_files = link_files;
1284 _link_paths = link_paths;
1285 _opts.assign(opts, opts + nopts);
1286 _optvals.assign(optvals, optvals + nopts);
1287 this->set_linker_log();
1288 this->create_module(link_files, link_paths);
1289 this->create_global_variable_map();
1290 return *this;
1291 }
~CUDAKernel()1292 inline ~CUDAKernel() { this->destroy_module(); }
operator CUfunction() const1293 inline operator CUfunction() const { return _kernel; }
1294
launch(dim3 grid,dim3 block,unsigned int smem,CUstream stream,std::vector<void * > arg_ptrs) const1295 inline CUresult launch(dim3 grid, dim3 block, unsigned int smem,
1296 CUstream stream, std::vector<void*> arg_ptrs) const {
1297 return cuLaunchKernel(_kernel, grid.x, grid.y, grid.z, block.x, block.y,
1298 block.z, smem, stream, arg_ptrs.data(), NULL);
1299 }
1300
get_global_ptr(const char * name,size_t * size=nullptr) const1301 inline CUdeviceptr get_global_ptr(const char* name,
1302 size_t* size = nullptr) const {
1303 CUdeviceptr global_ptr = 0;
1304 auto global = _global_map.find(name);
1305 if (global != _global_map.end()) {
1306 cuda_safe_call(cuModuleGetGlobal(&global_ptr, size, _module,
1307 global->second.c_str()));
1308 } else {
1309 throw std::runtime_error(std::string("failed to look up global ") + name);
1310 }
1311 return global_ptr;
1312 }
1313
1314 template <typename T>
get_global_data(const char * name,T * data,size_t count,CUstream stream=0) const1315 inline CUresult get_global_data(const char* name, T* data, size_t count,
1316 CUstream stream = 0) const {
1317 size_t size_bytes;
1318 CUdeviceptr ptr = get_global_ptr(name, &size_bytes);
1319 size_t given_size_bytes = count * sizeof(T);
1320 if (given_size_bytes != size_bytes) {
1321 throw std::runtime_error(
1322 std::string("Value for global variable ") + name +
1323 " has wrong size: got " + std::to_string(given_size_bytes) +
1324 " bytes, expected " + std::to_string(size_bytes));
1325 }
1326 return cuMemcpyDtoH(data, ptr, size_bytes);
1327 }
1328
1329 template <typename T>
set_global_data(const char * name,const T * data,size_t count,CUstream stream=0) const1330 inline CUresult set_global_data(const char* name, const T* data, size_t count,
1331 CUstream stream = 0) const {
1332 size_t size_bytes;
1333 CUdeviceptr ptr = get_global_ptr(name, &size_bytes);
1334 size_t given_size_bytes = count * sizeof(T);
1335 if (given_size_bytes != size_bytes) {
1336 throw std::runtime_error(
1337 std::string("Value for global variable ") + name +
1338 " has wrong size: got " + std::to_string(given_size_bytes) +
1339 " bytes, expected " + std::to_string(size_bytes));
1340 }
1341 return cuMemcpyHtoD(ptr, data, size_bytes);
1342 }
1343
function_name() const1344 const std::string& function_name() const { return _func_name; }
ptx() const1345 const std::string& ptx() const { return _ptx; }
link_files() const1346 const std::vector<std::string>& link_files() const { return _link_files; }
link_paths() const1347 const std::vector<std::string>& link_paths() const { return _link_paths; }
1348 };
1349
1350 static const char* jitsafe_header_preinclude_h = R"(
1351 //// WAR for Thrust (which appears to have forgotten to include this in result_of_adaptable_function.h
1352 //#include <type_traits>
1353
1354 //// WAR for Thrust (which appear to have forgotten to include this in error_code.h)
1355 //#include <string>
1356
1357 // WAR for Thrust (which only supports gnuc, clang or msvc)
1358 #define __GNUC__ 4
1359
1360 // WAR for generics/shfl.h
1361 #define THRUST_STATIC_ASSERT(x)
1362
1363 // WAR for CUB
1364 #ifdef __host__
1365 #undef __host__
1366 #endif
1367 #define __host__
1368
1369 // WAR to allow exceptions to be parsed
1370 #define try
1371 #define catch(...)
1372 )";
1373
1374
1375 static const char* jitsafe_header_float_h = R"(
1376 #pragma once
1377
1378 #define FLT_RADIX 2
1379 #define FLT_MANT_DIG 24
1380 #define DBL_MANT_DIG 53
1381 #define FLT_DIG 6
1382 #define DBL_DIG 15
1383 #define FLT_MIN_EXP -125
1384 #define DBL_MIN_EXP -1021
1385 #define FLT_MIN_10_EXP -37
1386 #define DBL_MIN_10_EXP -307
1387 #define FLT_MAX_EXP 128
1388 #define DBL_MAX_EXP 1024
1389 #define FLT_MAX_10_EXP 38
1390 #define DBL_MAX_10_EXP 308
1391 #define FLT_MAX 3.4028234e38f
1392 #define DBL_MAX 1.7976931348623157e308
1393 #define FLT_EPSILON 1.19209289e-7f
1394 #define DBL_EPSILON 2.220440492503130e-16
1395 #define FLT_MIN 1.1754943e-38f;
1396 #define DBL_MIN 2.2250738585072013e-308
1397 #define FLT_ROUNDS 1
1398 #if defined __cplusplus && __cplusplus >= 201103L
1399 #define FLT_EVAL_METHOD 0
1400 #define DECIMAL_DIG 21
1401 #endif
1402 )";
1403
1404 static const char* jitsafe_header_limits_h = R"(
1405 #pragma once
1406
1407 #if defined _WIN32 || defined _WIN64
1408 #define __WORDSIZE 32
1409 #else
1410 #if defined __x86_64__ && !defined __ILP32__
1411 #define __WORDSIZE 64
1412 #else
1413 #define __WORDSIZE 32
1414 #endif
1415 #endif
1416 #define MB_LEN_MAX 16
1417 #define CHAR_BIT 8
1418 #define SCHAR_MIN (-128)
1419 #define SCHAR_MAX 127
1420 #define UCHAR_MAX 255
1421 enum {
1422 _JITIFY_CHAR_IS_UNSIGNED = (char)-1 >= 0,
1423 CHAR_MIN = _JITIFY_CHAR_IS_UNSIGNED ? 0 : SCHAR_MIN,
1424 CHAR_MAX = _JITIFY_CHAR_IS_UNSIGNED ? UCHAR_MAX : SCHAR_MAX,
1425 };
1426 #define SHRT_MIN (-32768)
1427 #define SHRT_MAX 32767
1428 #define USHRT_MAX 65535
1429 #define INT_MIN (-INT_MAX - 1)
1430 #define INT_MAX 2147483647
1431 #define UINT_MAX 4294967295U
1432 #if __WORDSIZE == 64
1433 # define LONG_MAX 9223372036854775807L
1434 #else
1435 # define LONG_MAX 2147483647L
1436 #endif
1437 #define LONG_MIN (-LONG_MAX - 1L)
1438 #if __WORDSIZE == 64
1439 #define ULONG_MAX 18446744073709551615UL
1440 #else
1441 #define ULONG_MAX 4294967295UL
1442 #endif
1443 #define LLONG_MAX 9223372036854775807LL
1444 #define LLONG_MIN (-LLONG_MAX - 1LL)
1445 #define ULLONG_MAX 18446744073709551615ULL
1446 )";
1447
1448 static const char* jitsafe_header_iterator = R"(
1449 #pragma once
1450
1451 namespace __jitify_iterator_ns {
1452 struct output_iterator_tag {};
1453 struct input_iterator_tag {};
1454 struct forward_iterator_tag {};
1455 struct bidirectional_iterator_tag {};
1456 struct random_access_iterator_tag {};
1457 template<class Iterator>
1458 struct iterator_traits {
1459 typedef typename Iterator::iterator_category iterator_category;
1460 typedef typename Iterator::value_type value_type;
1461 typedef typename Iterator::difference_type difference_type;
1462 typedef typename Iterator::pointer pointer;
1463 typedef typename Iterator::reference reference;
1464 };
1465 template<class T>
1466 struct iterator_traits<T*> {
1467 typedef random_access_iterator_tag iterator_category;
1468 typedef T value_type;
1469 typedef ptrdiff_t difference_type;
1470 typedef T* pointer;
1471 typedef T& reference;
1472 };
1473 template<class T>
1474 struct iterator_traits<T const*> {
1475 typedef random_access_iterator_tag iterator_category;
1476 typedef T value_type;
1477 typedef ptrdiff_t difference_type;
1478 typedef T const* pointer;
1479 typedef T const& reference;
1480 };
1481 } // namespace __jitify_iterator_ns
1482 namespace std { using namespace __jitify_iterator_ns; }
1483 using namespace __jitify_iterator_ns;
1484 )";
1485
1486 // TODO: This is incomplete; need floating point limits
1487 // Joe Eaton: added IEEE float and double types, none of the smaller types
1488 // using type specific structs since we can't template on floats.
1489 static const char* jitsafe_header_limits = R"(
1490 #pragma once
1491 #include <climits>
1492 #include <cfloat>
1493 // TODO: epsilon(), infinity(), etc
1494 namespace __jitify_detail {
1495 #if __cplusplus >= 201103L
1496 #define JITIFY_CXX11_CONSTEXPR constexpr
1497 #define JITIFY_CXX11_NOEXCEPT noexcept
1498 #else
1499 #define JITIFY_CXX11_CONSTEXPR
1500 #define JITIFY_CXX11_NOEXCEPT
1501 #endif
1502
1503 struct FloatLimits {
1504 #if __cplusplus >= 201103L
1505 static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1506 float lowest() JITIFY_CXX11_NOEXCEPT { return -FLT_MAX;}
1507 static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1508 float min() JITIFY_CXX11_NOEXCEPT { return FLT_MIN; }
1509 static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1510 float max() JITIFY_CXX11_NOEXCEPT { return FLT_MAX; }
1511 #endif // __cplusplus >= 201103L
1512 enum {
1513 is_specialized = true,
1514 is_signed = true,
1515 is_integer = false,
1516 is_exact = false,
1517 has_infinity = true,
1518 has_quiet_NaN = true,
1519 has_signaling_NaN = true,
1520 has_denorm = 1,
1521 has_denorm_loss = true,
1522 round_style = 1,
1523 is_iec559 = true,
1524 is_bounded = true,
1525 is_modulo = false,
1526 digits = 24,
1527 digits10 = 6,
1528 max_digits10 = 9,
1529 radix = 2,
1530 min_exponent = -125,
1531 min_exponent10 = -37,
1532 max_exponent = 128,
1533 max_exponent10 = 38,
1534 tinyness_before = false,
1535 traps = false
1536 };
1537 };
1538 struct DoubleLimits {
1539 #if __cplusplus >= 201103L
1540 static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1541 double lowest() noexcept { return -DBL_MAX; }
1542 static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1543 double min() noexcept { return DBL_MIN; }
1544 static JITIFY_CXX11_CONSTEXPR inline __host__ __device__
1545 double max() noexcept { return DBL_MAX; }
1546 #endif // __cplusplus >= 201103L
1547 enum {
1548 is_specialized = true,
1549 is_signed = true,
1550 is_integer = false,
1551 is_exact = false,
1552 has_infinity = true,
1553 has_quiet_NaN = true,
1554 has_signaling_NaN = true,
1555 has_denorm = 1,
1556 has_denorm_loss = true,
1557 round_style = 1,
1558 is_iec559 = true,
1559 is_bounded = true,
1560 is_modulo = false,
1561 digits = 53,
1562 digits10 = 15,
1563 max_digits10 = 17,
1564 radix = 2,
1565 min_exponent = -1021,
1566 min_exponent10 = -307,
1567 max_exponent = 1024,
1568 max_exponent10 = 308,
1569 tinyness_before = false,
1570 traps = false
1571 };
1572 };
1573 template<class T, T Min, T Max, int Digits=-1>
1574 struct IntegerLimits {
1575 static inline __host__ __device__ T min() { return Min; }
1576 static inline __host__ __device__ T max() { return Max; }
1577 #if __cplusplus >= 201103L
1578 static constexpr inline __host__ __device__ T lowest() noexcept {
1579 return Min;
1580 }
1581 #endif // __cplusplus >= 201103L
1582 enum {
1583 is_specialized = true,
1584 digits = (Digits == -1) ? (int)(sizeof(T)*8 - (Min != 0)) : Digits,
1585 digits10 = (digits * 30103) / 100000,
1586 is_signed = ((T)(-1)<0),
1587 is_integer = true,
1588 is_exact = true,
1589 radix = 2,
1590 is_bounded = true,
1591 is_modulo = false
1592 };
1593 };
1594 } // namespace __jitify_detail
1595 namespace std { using namespace __jitify_detail; }
1596 namespace __jitify_limits_ns {
1597 template<typename T> struct numeric_limits {
1598 enum { is_specialized = false };
1599 };
1600 template<> struct numeric_limits<bool> : public
1601 __jitify_detail::IntegerLimits<bool, false, true,1> {};
1602 template<> struct numeric_limits<char> : public
1603 __jitify_detail::IntegerLimits<char, CHAR_MIN, CHAR_MAX>
1604 {};
1605 template<> struct numeric_limits<signed char> : public
1606 __jitify_detail::IntegerLimits<signed char, SCHAR_MIN,SCHAR_MAX>
1607 {};
1608 template<> struct numeric_limits<unsigned char> : public
1609 __jitify_detail::IntegerLimits<unsigned char, 0, UCHAR_MAX>
1610 {};
1611 template<> struct numeric_limits<wchar_t> : public
1612 __jitify_detail::IntegerLimits<wchar_t, INT_MIN, INT_MAX> {};
1613 template<> struct numeric_limits<short> : public
1614 __jitify_detail::IntegerLimits<short, SHRT_MIN, SHRT_MAX>
1615 {};
1616 template<> struct numeric_limits<unsigned short> : public
1617 __jitify_detail::IntegerLimits<unsigned short, 0, USHRT_MAX>
1618 {};
1619 template<> struct numeric_limits<int> : public
1620 __jitify_detail::IntegerLimits<int, INT_MIN, INT_MAX> {};
1621 template<> struct numeric_limits<unsigned int> : public
1622 __jitify_detail::IntegerLimits<unsigned int, 0, UINT_MAX>
1623 {};
1624 template<> struct numeric_limits<long> : public
1625 __jitify_detail::IntegerLimits<long, LONG_MIN, LONG_MAX>
1626 {};
1627 template<> struct numeric_limits<unsigned long> : public
1628 __jitify_detail::IntegerLimits<unsigned long, 0, ULONG_MAX>
1629 {};
1630 template<> struct numeric_limits<long long> : public
1631 __jitify_detail::IntegerLimits<long long, LLONG_MIN,LLONG_MAX>
1632 {};
1633 template<> struct numeric_limits<unsigned long long> : public
1634 __jitify_detail::IntegerLimits<unsigned long long,0, ULLONG_MAX>
1635 {};
1636 //template<typename T> struct numeric_limits { static const bool
1637 //is_signed = ((T)(-1)<0); };
1638 template<> struct numeric_limits<float> : public
1639 __jitify_detail::FloatLimits
1640 {};
1641 template<> struct numeric_limits<double> : public
1642 __jitify_detail::DoubleLimits
1643 {};
1644 } // namespace __jitify_limits_ns
1645 namespace std { using namespace __jitify_limits_ns; }
1646 using namespace __jitify_limits_ns;
1647 )";
1648
1649 // TODO: This is highly incomplete
1650 static const char* jitsafe_header_type_traits = R"(
1651 #pragma once
1652 #if __cplusplus >= 201103L
1653 namespace __jitify_type_traits_ns {
1654
1655 template<bool B, class T = void> struct enable_if {};
1656 template<class T> struct enable_if<true, T> { typedef T type; };
1657 #if __cplusplus >= 201402L
1658 template< bool B, class T = void > using enable_if_t = typename enable_if<B,T>::type;
1659 #endif
1660
1661 struct true_type {
1662 enum { value = true };
1663 operator bool() const { return true; }
1664 };
1665 struct false_type {
1666 enum { value = false };
1667 operator bool() const { return false; }
1668 };
1669
1670 template<typename T> struct is_floating_point : false_type {};
1671 template<> struct is_floating_point<float> : true_type {};
1672 template<> struct is_floating_point<double> : true_type {};
1673 template<> struct is_floating_point<long double> : true_type {};
1674
1675 template<class T> struct is_integral : false_type {};
1676 template<> struct is_integral<bool> : true_type {};
1677 template<> struct is_integral<char> : true_type {};
1678 template<> struct is_integral<signed char> : true_type {};
1679 template<> struct is_integral<unsigned char> : true_type {};
1680 template<> struct is_integral<short> : true_type {};
1681 template<> struct is_integral<unsigned short> : true_type {};
1682 template<> struct is_integral<int> : true_type {};
1683 template<> struct is_integral<unsigned int> : true_type {};
1684 template<> struct is_integral<long> : true_type {};
1685 template<> struct is_integral<unsigned long> : true_type {};
1686 template<> struct is_integral<long long> : true_type {};
1687 template<> struct is_integral<unsigned long long> : true_type {};
1688
1689 template<typename T> struct is_signed : false_type {};
1690 template<> struct is_signed<float> : true_type {};
1691 template<> struct is_signed<double> : true_type {};
1692 template<> struct is_signed<long double> : true_type {};
1693 template<> struct is_signed<signed char> : true_type {};
1694 template<> struct is_signed<short> : true_type {};
1695 template<> struct is_signed<int> : true_type {};
1696 template<> struct is_signed<long> : true_type {};
1697 template<> struct is_signed<long long> : true_type {};
1698
1699 template<typename T> struct is_unsigned : false_type {};
1700 template<> struct is_unsigned<unsigned char> : true_type {};
1701 template<> struct is_unsigned<unsigned short> : true_type {};
1702 template<> struct is_unsigned<unsigned int> : true_type {};
1703 template<> struct is_unsigned<unsigned long> : true_type {};
1704 template<> struct is_unsigned<unsigned long long> : true_type {};
1705
1706 template<typename T, typename U> struct is_same : false_type {};
1707 template<typename T> struct is_same<T,T> : true_type {};
1708
1709 template<class T> struct is_array : false_type {};
1710 template<class T> struct is_array<T[]> : true_type {};
1711 template<class T, size_t N> struct is_array<T[N]> : true_type {};
1712
1713 //partial implementation only of is_function
1714 template<class> struct is_function : false_type { };
1715 template<class Ret, class... Args> struct is_function<Ret(Args...)> : true_type {}; //regular
1716 template<class Ret, class... Args> struct is_function<Ret(Args......)> : true_type {}; // variadic
1717
1718 template<class> struct result_of;
1719 template<class F, typename... Args>
1720 struct result_of<F(Args...)> {
1721 // TODO: This is a hack; a proper implem is quite complicated.
1722 typedef typename F::result_type type;
1723 };
1724
1725 template <class T> struct remove_reference { typedef T type; };
1726 template <class T> struct remove_reference<T&> { typedef T type; };
1727 template <class T> struct remove_reference<T&&> { typedef T type; };
1728 #if __cplusplus >= 201402L
1729 template< class T > using remove_reference_t = typename remove_reference<T>::type;
1730 #endif
1731
1732 template<class T> struct remove_extent { typedef T type; };
1733 template<class T> struct remove_extent<T[]> { typedef T type; };
1734 template<class T, size_t N> struct remove_extent<T[N]> { typedef T type; };
1735 #if __cplusplus >= 201402L
1736 template< class T > using remove_extent_t = typename remove_extent<T>::type;
1737 #endif
1738
1739 template< class T > struct remove_const { typedef T type; };
1740 template< class T > struct remove_const<const T> { typedef T type; };
1741 template< class T > struct remove_volatile { typedef T type; };
1742 template< class T > struct remove_volatile<volatile T> { typedef T type; };
1743 template< class T > struct remove_cv { typedef typename remove_volatile<typename remove_const<T>::type>::type type; };
1744 #if __cplusplus >= 201402L
1745 template< class T > using remove_cv_t = typename remove_cv<T>::type;
1746 template< class T > using remove_const_t = typename remove_const<T>::type;
1747 template< class T > using remove_volatile_t = typename remove_volatile<T>::type;
1748 #endif
1749
1750 template<bool B, class T, class F> struct conditional { typedef T type; };
1751 template<class T, class F> struct conditional<false, T, F> { typedef F type; };
1752 #if __cplusplus >= 201402L
1753 template< bool B, class T, class F > using conditional_t = typename conditional<B,T,F>::type;
1754 #endif
1755
1756 namespace __jitify_detail {
1757 template< class T, bool is_function_type = false > struct add_pointer { using type = typename remove_reference<T>::type*; };
1758 template< class T > struct add_pointer<T, true> { using type = T; };
1759 template< class T, class... Args > struct add_pointer<T(Args...), true> { using type = T(*)(Args...); };
1760 template< class T, class... Args > struct add_pointer<T(Args..., ...), true> { using type = T(*)(Args..., ...); };
1761 }
1762 template< class T > struct add_pointer : __jitify_detail::add_pointer<T, is_function<T>::value> {};
1763 #if __cplusplus >= 201402L
1764 template< class T > using add_pointer_t = typename add_pointer<T>::type;
1765 #endif
1766
1767 template< class T > struct decay {
1768 private:
1769 typedef typename remove_reference<T>::type U;
1770 public:
1771 typedef typename conditional<is_array<U>::value, typename remove_extent<U>::type*,
1772 typename conditional<is_function<U>::value,typename add_pointer<U>::type,typename remove_cv<U>::type
1773 >::type>::type type;
1774 };
1775 #if __cplusplus >= 201402L
1776 template< class T > using decay_t = typename decay<T>::type;
1777 #endif
1778
1779 } // namespace __jtiify_type_traits_ns
1780 namespace std { using namespace __jitify_type_traits_ns; }
1781 using namespace __jitify_type_traits_ns;
1782 #endif // c++11
1783 )";
1784
1785 // TODO: INT_FAST8_MAX et al. and a few other misc constants
1786 static const char* jitsafe_header_stdint_h =
1787 "#pragma once\n"
1788 "#include <climits>\n"
1789 "namespace __jitify_stdint_ns {\n"
1790 "typedef signed char int8_t;\n"
1791 "typedef signed short int16_t;\n"
1792 "typedef signed int int32_t;\n"
1793 "typedef signed long long int64_t;\n"
1794 "typedef signed char int_fast8_t;\n"
1795 "typedef signed short int_fast16_t;\n"
1796 "typedef signed int int_fast32_t;\n"
1797 "typedef signed long long int_fast64_t;\n"
1798 "typedef signed char int_least8_t;\n"
1799 "typedef signed short int_least16_t;\n"
1800 "typedef signed int int_least32_t;\n"
1801 "typedef signed long long int_least64_t;\n"
1802 "typedef signed long long intmax_t;\n"
1803 "typedef signed long intptr_t; //optional\n"
1804 "typedef unsigned char uint8_t;\n"
1805 "typedef unsigned short uint16_t;\n"
1806 "typedef unsigned int uint32_t;\n"
1807 "typedef unsigned long long uint64_t;\n"
1808 "typedef unsigned char uint_fast8_t;\n"
1809 "typedef unsigned short uint_fast16_t;\n"
1810 "typedef unsigned int uint_fast32_t;\n"
1811 "typedef unsigned long long uint_fast64_t;\n"
1812 "typedef unsigned char uint_least8_t;\n"
1813 "typedef unsigned short uint_least16_t;\n"
1814 "typedef unsigned int uint_least32_t;\n"
1815 "typedef unsigned long long uint_least64_t;\n"
1816 "typedef unsigned long long uintmax_t;\n"
1817 "typedef unsigned long uintptr_t; //optional\n"
1818 "#define INT8_MIN SCHAR_MIN\n"
1819 "#define INT16_MIN SHRT_MIN\n"
1820 "#define INT32_MIN INT_MIN\n"
1821 "#define INT64_MIN LLONG_MIN\n"
1822 "#define INT8_MAX SCHAR_MAX\n"
1823 "#define INT16_MAX SHRT_MAX\n"
1824 "#define INT32_MAX INT_MAX\n"
1825 "#define INT64_MAX LLONG_MAX\n"
1826 "#define UINT8_MAX UCHAR_MAX\n"
1827 "#define UINT16_MAX USHRT_MAX\n"
1828 "#define UINT32_MAX UINT_MAX\n"
1829 "#define UINT64_MAX ULLONG_MAX\n"
1830 "#define INTPTR_MIN LONG_MIN\n"
1831 "#define INTMAX_MIN LLONG_MIN\n"
1832 "#define INTPTR_MAX LONG_MAX\n"
1833 "#define INTMAX_MAX LLONG_MAX\n"
1834 "#define UINTPTR_MAX ULONG_MAX\n"
1835 "#define UINTMAX_MAX ULLONG_MAX\n"
1836 "#define PTRDIFF_MIN INTPTR_MIN\n"
1837 "#define PTRDIFF_MAX INTPTR_MAX\n"
1838 "#define SIZE_MAX UINT64_MAX\n"
1839 "} // namespace __jitify_stdint_ns\n"
1840 "namespace std { using namespace __jitify_stdint_ns; }\n"
1841 "using namespace __jitify_stdint_ns;\n";
1842
1843 // TODO: offsetof
1844 static const char* jitsafe_header_stddef_h =
1845 "#pragma once\n"
1846 "#include <climits>\n"
1847 "namespace __jitify_stddef_ns {\n"
1848 "#if __cplusplus >= 201103L\n"
1849 "typedef decltype(nullptr) nullptr_t;\n"
1850 "#if defined(_MSC_VER)\n"
1851 " typedef double max_align_t;\n"
1852 "#elif defined(__APPLE__)\n"
1853 " typedef long double max_align_t;\n"
1854 "#else\n"
1855 " // Define max_align_t to match the GCC definition.\n"
1856 " typedef struct {\n"
1857 " long long __jitify_max_align_nonce1\n"
1858 " __attribute__((__aligned__(__alignof__(long long))));\n"
1859 " long double __jitify_max_align_nonce2\n"
1860 " __attribute__((__aligned__(__alignof__(long double))));\n"
1861 " } max_align_t;\n"
1862 "#endif\n"
1863 "#endif // __cplusplus >= 201103L\n"
1864 "#if __cplusplus >= 201703L\n"
1865 "enum class byte : unsigned char {};\n"
1866 "#endif // __cplusplus >= 201703L\n"
1867 "} // namespace __jitify_stddef_ns\n"
1868 "namespace std {\n"
1869 " // NVRTC provides built-in definitions of ::size_t and ::ptrdiff_t.\n"
1870 " using ::size_t;\n"
1871 " using ::ptrdiff_t;\n"
1872 " using namespace __jitify_stddef_ns;\n"
1873 "} // namespace std\n"
1874 "using namespace __jitify_stddef_ns;\n";
1875
1876 static const char* jitsafe_header_stdlib_h =
1877 "#pragma once\n"
1878 "#include <stddef.h>\n";
1879 static const char* jitsafe_header_stdio_h =
1880 "#pragma once\n"
1881 "#include <stddef.h>\n"
1882 "#define FILE int\n"
1883 "int fflush ( FILE * stream );\n"
1884 "int fprintf ( FILE * stream, const char * format, ... );\n";
1885
1886 static const char* jitsafe_header_string_h =
1887 "#pragma once\n"
1888 "char* strcpy ( char * destination, const char * source );\n"
1889 "int strcmp ( const char * str1, const char * str2 );\n"
1890 "char* strerror( int errnum );\n";
1891
1892 static const char* jitsafe_header_cstring =
1893 "#pragma once\n"
1894 "\n"
1895 "namespace __jitify_cstring_ns {\n"
1896 "char* strcpy ( char * destination, const char * source );\n"
1897 "int strcmp ( const char * str1, const char * str2 );\n"
1898 "char* strerror( int errnum );\n"
1899 "} // namespace __jitify_cstring_ns\n"
1900 "namespace std { using namespace __jitify_cstring_ns; }\n"
1901 "using namespace __jitify_cstring_ns;\n";
1902
1903 // HACK TESTING (WAR for cub)
1904 static const char* jitsafe_header_iostream =
1905 "#pragma once\n"
1906 "#include <ostream>\n"
1907 "#include <istream>\n";
1908 // HACK TESTING (WAR for Thrust)
1909 static const char* jitsafe_header_ostream =
1910 "#pragma once\n"
1911 "\n"
1912 "namespace __jitify_ostream_ns {\n"
1913 "template<class CharT,class Traits=void>\n" // = std::char_traits<CharT>
1914 // >\n"
1915 "struct basic_ostream {\n"
1916 "};\n"
1917 "typedef basic_ostream<char> ostream;\n"
1918 "ostream& endl(ostream& os);\n"
1919 "ostream& operator<<( ostream&, ostream& (*f)( ostream& ) );\n"
1920 "template< class CharT, class Traits > basic_ostream<CharT, Traits>& endl( "
1921 "basic_ostream<CharT, Traits>& os );\n"
1922 "template< class CharT, class Traits > basic_ostream<CharT, Traits>& "
1923 "operator<<( basic_ostream<CharT,Traits>& os, const char* c );\n"
1924 "#if __cplusplus >= 201103L\n"
1925 "template< class CharT, class Traits, class T > basic_ostream<CharT, "
1926 "Traits>& operator<<( basic_ostream<CharT,Traits>&& os, const T& value );\n"
1927 "#endif // __cplusplus >= 201103L\n"
1928 "} // namespace __jitify_ostream_ns\n"
1929 "namespace std { using namespace __jitify_ostream_ns; }\n"
1930 "using namespace __jitify_ostream_ns;\n";
1931
1932 static const char* jitsafe_header_istream =
1933 "#pragma once\n"
1934 "\n"
1935 "namespace __jitify_istream_ns {\n"
1936 "template<class CharT,class Traits=void>\n" // = std::char_traits<CharT>
1937 // >\n"
1938 "struct basic_istream {\n"
1939 "};\n"
1940 "typedef basic_istream<char> istream;\n"
1941 "} // namespace __jitify_istream_ns\n"
1942 "namespace std { using namespace __jitify_istream_ns; }\n"
1943 "using namespace __jitify_istream_ns;\n";
1944
1945 static const char* jitsafe_header_sstream =
1946 "#pragma once\n"
1947 "#include <ostream>\n"
1948 "#include <istream>\n";
1949
1950 static const char* jitsafe_header_utility =
1951 "#pragma once\n"
1952 "namespace __jitify_utility_ns {\n"
1953 "template<class T1, class T2>\n"
1954 "struct pair {\n"
1955 " T1 first;\n"
1956 " T2 second;\n"
1957 " inline pair() {}\n"
1958 " inline pair(T1 const& first_, T2 const& second_)\n"
1959 " : first(first_), second(second_) {}\n"
1960 " // TODO: Standard includes many more constructors...\n"
1961 " // TODO: Comparison operators\n"
1962 "};\n"
1963 "template<class T1, class T2>\n"
1964 "pair<T1,T2> make_pair(T1 const& first, T2 const& second) {\n"
1965 " return pair<T1,T2>(first, second);\n"
1966 "}\n"
1967 "} // namespace __jitify_utility_ns\n"
1968 "namespace std { using namespace __jitify_utility_ns; }\n"
1969 "using namespace __jitify_utility_ns;\n";
1970
1971 // TODO: incomplete
1972 static const char* jitsafe_header_vector =
1973 "#pragma once\n"
1974 "namespace __jitify_vector_ns {\n"
1975 "template<class T, class Allocator=void>\n" // = std::allocator> \n"
1976 "struct vector {\n"
1977 "};\n"
1978 "} // namespace __jitify_vector_ns\n"
1979 "namespace std { using namespace __jitify_vector_ns; }\n"
1980 "using namespace __jitify_vector_ns;\n";
1981
1982 // TODO: incomplete
1983 static const char* jitsafe_header_string =
1984 "#pragma once\n"
1985 "namespace __jitify_string_ns {\n"
1986 "template<class CharT,class Traits=void,class Allocator=void>\n"
1987 "struct basic_string {\n"
1988 "basic_string();\n"
1989 "basic_string( const CharT* s );\n" //, const Allocator& alloc =
1990 // Allocator() );\n"
1991 "const CharT* c_str() const;\n"
1992 "bool empty() const;\n"
1993 "void operator+=(const char *);\n"
1994 "void operator+=(const basic_string &);\n"
1995 "};\n"
1996 "typedef basic_string<char> string;\n"
1997 "} // namespace __jitify_string_ns\n"
1998 "namespace std { using namespace __jitify_string_ns; }\n"
1999 "using namespace __jitify_string_ns;\n";
2000
2001 // TODO: incomplete
2002 static const char* jitsafe_header_stdexcept =
2003 "#pragma once\n"
2004 "namespace __jitify_stdexcept_ns {\n"
2005 "struct runtime_error {\n"
2006 "explicit runtime_error( const std::string& what_arg );"
2007 "explicit runtime_error( const char* what_arg );"
2008 "virtual const char* what() const;\n"
2009 "};\n"
2010 "} // namespace __jitify_stdexcept_ns\n"
2011 "namespace std { using namespace __jitify_stdexcept_ns; }\n"
2012 "using namespace __jitify_stdexcept_ns;\n";
2013
2014 // TODO: incomplete
2015 static const char* jitsafe_header_complex =
2016 "#pragma once\n"
2017 "namespace __jitify_complex_ns {\n"
2018 "template<typename T>\n"
2019 "class complex {\n"
2020 " T _real;\n"
2021 " T _imag;\n"
2022 "public:\n"
2023 " complex() : _real(0), _imag(0) {}\n"
2024 " complex(T const& real, T const& imag)\n"
2025 " : _real(real), _imag(imag) {}\n"
2026 " complex(T const& real)\n"
2027 " : _real(real), _imag(static_cast<T>(0)) {}\n"
2028 " T const& real() const { return _real; }\n"
2029 " T& real() { return _real; }\n"
2030 " void real(const T &r) { _real = r; }\n"
2031 " T const& imag() const { return _imag; }\n"
2032 " T& imag() { return _imag; }\n"
2033 " void imag(const T &i) { _imag = i; }\n"
2034 " complex<T>& operator+=(const complex<T> z)\n"
2035 " { _real += z.real(); _imag += z.imag(); return *this; }\n"
2036 "};\n"
2037 "template<typename T>\n"
2038 "complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs)\n"
2039 " { return complex<T>(lhs.real()*rhs.real()-lhs.imag()*rhs.imag(),\n"
2040 " lhs.real()*rhs.imag()+lhs.imag()*rhs.real()); }\n"
2041 "template<typename T>\n"
2042 "complex<T> operator*(const complex<T>& lhs, const T & rhs)\n"
2043 " { return complexs<T>(lhs.real()*rhs,lhs.imag()*rhs); }\n"
2044 "template<typename T>\n"
2045 "complex<T> operator*(const T& lhs, const complex<T>& rhs)\n"
2046 " { return complexs<T>(rhs.real()*lhs,rhs.imag()*lhs); }\n"
2047 "} // namespace __jitify_complex_ns\n"
2048 "namespace std { using namespace __jitify_complex_ns; }\n"
2049 "using namespace __jitify_complex_ns;\n";
2050
2051 // TODO: This is incomplete (missing binary and integer funcs, macros,
2052 // constants, types)
2053 static const char* jitsafe_header_math =
2054 "#pragma once\n"
2055 "namespace __jitify_math_ns {\n"
2056 "#if __cplusplus >= 201103L\n"
2057 "#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\\n"
2058 " inline double f(double x) { return ::f(x); } \\\n"
2059 " inline float f##f(float x) { return ::f(x); } \\\n"
2060 " /*inline long double f##l(long double x) { return ::f(x); }*/ \\\n"
2061 " inline float f(float x) { return ::f(x); } \\\n"
2062 " /*inline long double f(long double x) { return ::f(x); }*/\n"
2063 "#else\n"
2064 "#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\\n"
2065 " inline double f(double x) { return ::f(x); } \\\n"
2066 " inline float f##f(float x) { return ::f(x); } \\\n"
2067 " /*inline long double f##l(long double x) { return ::f(x); }*/\n"
2068 "#endif\n"
2069 "DEFINE_MATH_UNARY_FUNC_WRAPPER(cos)\n"
2070 "DEFINE_MATH_UNARY_FUNC_WRAPPER(sin)\n"
2071 "DEFINE_MATH_UNARY_FUNC_WRAPPER(tan)\n"
2072 "DEFINE_MATH_UNARY_FUNC_WRAPPER(acos)\n"
2073 "DEFINE_MATH_UNARY_FUNC_WRAPPER(asin)\n"
2074 "DEFINE_MATH_UNARY_FUNC_WRAPPER(atan)\n"
2075 "template<typename T> inline T atan2(T y, T x) { return ::atan2(y, x); }\n"
2076 "DEFINE_MATH_UNARY_FUNC_WRAPPER(cosh)\n"
2077 "DEFINE_MATH_UNARY_FUNC_WRAPPER(sinh)\n"
2078 "DEFINE_MATH_UNARY_FUNC_WRAPPER(tanh)\n"
2079 "DEFINE_MATH_UNARY_FUNC_WRAPPER(exp)\n"
2080 "template<typename T> inline T frexp(T x, int* exp) { return ::frexp(x, "
2081 "exp); }\n"
2082 "template<typename T> inline T ldexp(T x, int exp) { return ::ldexp(x, "
2083 "exp); }\n"
2084 "DEFINE_MATH_UNARY_FUNC_WRAPPER(log)\n"
2085 "DEFINE_MATH_UNARY_FUNC_WRAPPER(log10)\n"
2086 "template<typename T> inline T modf(T x, T* intpart) { return ::modf(x, "
2087 "intpart); }\n"
2088 "template<typename T> inline T pow(T x, T y) { return ::pow(x, y); }\n"
2089 "DEFINE_MATH_UNARY_FUNC_WRAPPER(sqrt)\n"
2090 "DEFINE_MATH_UNARY_FUNC_WRAPPER(ceil)\n"
2091 "DEFINE_MATH_UNARY_FUNC_WRAPPER(floor)\n"
2092 "template<typename T> inline T fmod(T n, T d) { return ::fmod(n, d); }\n"
2093 "DEFINE_MATH_UNARY_FUNC_WRAPPER(fabs)\n"
2094 "template<typename T> inline T abs(T x) { return ::abs(x); }\n"
2095 "#if __cplusplus >= 201103L\n"
2096 "DEFINE_MATH_UNARY_FUNC_WRAPPER(acosh)\n"
2097 "DEFINE_MATH_UNARY_FUNC_WRAPPER(asinh)\n"
2098 "DEFINE_MATH_UNARY_FUNC_WRAPPER(atanh)\n"
2099 "DEFINE_MATH_UNARY_FUNC_WRAPPER(exp2)\n"
2100 "DEFINE_MATH_UNARY_FUNC_WRAPPER(expm1)\n"
2101 "template<typename T> inline int ilogb(T x) { return ::ilogb(x); }\n"
2102 "DEFINE_MATH_UNARY_FUNC_WRAPPER(log1p)\n"
2103 "DEFINE_MATH_UNARY_FUNC_WRAPPER(log2)\n"
2104 "DEFINE_MATH_UNARY_FUNC_WRAPPER(logb)\n"
2105 "template<typename T> inline T scalbn (T x, int n) { return ::scalbn(x, "
2106 "n); }\n"
2107 "template<typename T> inline T scalbln(T x, long n) { return ::scalbn(x, "
2108 "n); }\n"
2109 "DEFINE_MATH_UNARY_FUNC_WRAPPER(cbrt)\n"
2110 "template<typename T> inline T hypot(T x, T y) { return ::hypot(x, y); }\n"
2111 "DEFINE_MATH_UNARY_FUNC_WRAPPER(erf)\n"
2112 "DEFINE_MATH_UNARY_FUNC_WRAPPER(erfc)\n"
2113 "DEFINE_MATH_UNARY_FUNC_WRAPPER(tgamma)\n"
2114 "DEFINE_MATH_UNARY_FUNC_WRAPPER(lgamma)\n"
2115 "DEFINE_MATH_UNARY_FUNC_WRAPPER(trunc)\n"
2116 "DEFINE_MATH_UNARY_FUNC_WRAPPER(round)\n"
2117 "template<typename T> inline long lround(T x) { return ::lround(x); }\n"
2118 "template<typename T> inline long long llround(T x) { return ::llround(x); "
2119 "}\n"
2120 "DEFINE_MATH_UNARY_FUNC_WRAPPER(rint)\n"
2121 "template<typename T> inline long lrint(T x) { return ::lrint(x); }\n"
2122 "template<typename T> inline long long llrint(T x) { return ::llrint(x); "
2123 "}\n"
2124 "DEFINE_MATH_UNARY_FUNC_WRAPPER(nearbyint)\n"
2125 // TODO: remainder, remquo, copysign, nan, nextafter, nexttoward, fdim,
2126 // fmax, fmin, fma
2127 "#endif\n"
2128 "#undef DEFINE_MATH_UNARY_FUNC_WRAPPER\n"
2129 "} // namespace __jitify_math_ns\n"
2130 "namespace std { using namespace __jitify_math_ns; }\n"
2131 "#define M_PI 3.14159265358979323846\n"
2132 // Note: Global namespace already includes CUDA math funcs
2133 "//using namespace __jitify_math_ns;\n";
2134
2135 static const char* jitsafe_header_memory_h = R"(
2136 #pragma once
2137 #include <string.h>
2138 )";
2139
2140 // TODO: incomplete
2141 static const char* jitsafe_header_mutex = R"(
2142 #pragma once
2143 #if __cplusplus >= 201103L
2144 namespace __jitify_mutex_ns {
2145 class mutex {
2146 public:
2147 void lock();
2148 bool try_lock();
2149 void unlock();
2150 };
2151 } // namespace __jitify_mutex_ns
2152 namespace std { using namespace __jitify_mutex_ns; }
2153 using namespace __jitify_mutex_ns;
2154 #endif
2155 )";
2156
2157 static const char* jitsafe_header_algorithm = R"(
2158 #pragma once
2159 #if __cplusplus >= 201103L
2160 namespace __jitify_algorithm_ns {
2161
2162 #if __cplusplus == 201103L
2163 #define JITIFY_CXX14_CONSTEXPR
2164 #else
2165 #define JITIFY_CXX14_CONSTEXPR constexpr
2166 #endif
2167
2168 template<class T> JITIFY_CXX14_CONSTEXPR const T& max(const T& a, const T& b)
2169 {
2170 return (b > a) ? b : a;
2171 }
2172 template<class T> JITIFY_CXX14_CONSTEXPR const T& min(const T& a, const T& b)
2173 {
2174 return (b < a) ? b : a;
2175 }
2176
2177 } // namespace __jitify_algorithm_ns
2178 namespace std { using namespace __jitify_algorithm_ns; }
2179 using namespace __jitify_algorithm_ns;
2180 #endif
2181 )";
2182
2183 static const char* jitsafe_header_time_h = R"(
2184 #pragma once
2185 #define NULL 0
2186 #define CLOCKS_PER_SEC 1000000
2187 namespace __jitify_time_ns {
2188 typedef long time_t;
2189 struct tm {
2190 int tm_sec;
2191 int tm_min;
2192 int tm_hour;
2193 int tm_mday;
2194 int tm_mon;
2195 int tm_year;
2196 int tm_wday;
2197 int tm_yday;
2198 int tm_isdst;
2199 };
2200 #if __cplusplus >= 201703L
2201 struct timespec {
2202 time_t tv_sec;
2203 long tv_nsec;
2204 };
2205 #endif
2206 } // namespace __jitify_time_ns
2207 namespace std {
2208 // NVRTC provides built-in definitions of ::size_t and ::clock_t.
2209 using ::size_t;
2210 using ::clock_t;
2211 using namespace __jitify_time_ns;
2212 }
2213 using namespace __jitify_time_ns;
2214 )";
2215
2216 // WAR: These need to be pre-included as a workaround for NVRTC implicitly using
2217 // /usr/include as an include path. The other built-in headers will be included
2218 // lazily as needed.
2219 static const char* preinclude_jitsafe_header_names[] = {
2220 "jitify_preinclude.h",
2221 "limits.h",
2222 "math.h",
2223 "memory.h",
2224 "stdint.h",
2225 "stdlib.h",
2226 "stdio.h",
2227 "string.h",
2228 "time.h",
2229 };
2230
2231 template <class T, int N>
array_size(T (&)[N])2232 int array_size(T (&)[N]) {
2233 return N;
2234 }
2235 const int preinclude_jitsafe_headers_count =
2236 array_size(preinclude_jitsafe_header_names);
2237
get_jitsafe_headers_map()2238 static const std::map<std::string, std::string>& get_jitsafe_headers_map() {
2239 static const std::map<std::string, std::string> jitsafe_headers_map = {
2240 {"jitify_preinclude.h", jitsafe_header_preinclude_h},
2241 {"float.h", jitsafe_header_float_h},
2242 {"cfloat", jitsafe_header_float_h},
2243 {"limits.h", jitsafe_header_limits_h},
2244 {"climits", jitsafe_header_limits_h},
2245 {"stdint.h", jitsafe_header_stdint_h},
2246 {"cstdint", jitsafe_header_stdint_h},
2247 {"stddef.h", jitsafe_header_stddef_h},
2248 {"cstddef", jitsafe_header_stddef_h},
2249 {"stdlib.h", jitsafe_header_stdlib_h},
2250 {"cstdlib", jitsafe_header_stdlib_h},
2251 {"stdio.h", jitsafe_header_stdio_h},
2252 {"cstdio", jitsafe_header_stdio_h},
2253 {"string.h", jitsafe_header_string_h},
2254 {"cstring", jitsafe_header_cstring},
2255 {"iterator", jitsafe_header_iterator},
2256 {"limits", jitsafe_header_limits},
2257 {"type_traits", jitsafe_header_type_traits},
2258 {"utility", jitsafe_header_utility},
2259 {"math.h", jitsafe_header_math},
2260 {"cmath", jitsafe_header_math},
2261 {"memory.h", jitsafe_header_memory_h},
2262 {"complex", jitsafe_header_complex},
2263 {"iostream", jitsafe_header_iostream},
2264 {"ostream", jitsafe_header_ostream},
2265 {"istream", jitsafe_header_istream},
2266 {"sstream", jitsafe_header_sstream},
2267 {"vector", jitsafe_header_vector},
2268 {"string", jitsafe_header_string},
2269 {"stdexcept", jitsafe_header_stdexcept},
2270 {"mutex", jitsafe_header_mutex},
2271 {"algorithm", jitsafe_header_algorithm},
2272 {"time.h", jitsafe_header_time_h},
2273 {"ctime", jitsafe_header_time_h},
2274 };
2275 return jitsafe_headers_map;
2276 }
2277
add_options_from_env(std::vector<std::string> & options)2278 inline void add_options_from_env(std::vector<std::string>& options) {
2279 // Add options from environment variable
2280 const char* env_options = std::getenv("JITIFY_OPTIONS");
2281 if (env_options) {
2282 std::stringstream ss;
2283 ss << env_options;
2284 std::string opt;
2285 while (!(ss >> opt).fail()) {
2286 options.push_back(opt);
2287 }
2288 }
2289 // Add options from JITIFY_OPTIONS macro
2290 #ifdef JITIFY_OPTIONS
2291 #define JITIFY_TOSTRING_IMPL(x) #x
2292 #define JITIFY_TOSTRING(x) JITIFY_TOSTRING_IMPL(x)
2293 std::stringstream ss;
2294 ss << JITIFY_TOSTRING(JITIFY_OPTIONS);
2295 std::string opt;
2296 while (!(ss >> opt).fail()) {
2297 options.push_back(opt);
2298 }
2299 #undef JITIFY_TOSTRING
2300 #undef JITIFY_TOSTRING_IMPL
2301 #endif // JITIFY_OPTIONS
2302 }
2303
detect_and_add_cuda_arch(std::vector<std::string> & options)2304 inline void detect_and_add_cuda_arch(std::vector<std::string>& options) {
2305 for (int i = 0; i < (int)options.size(); ++i) {
2306 // Note that this will also match the middle of "--gpu-architecture".
2307 if (options[i].find("-arch") != std::string::npos) {
2308 // Arch already specified in options
2309 return;
2310 }
2311 }
2312 // Use the compute capability of the current device
2313 // TODO: Check these API calls for errors
2314 cudaError_t status;
2315 int device;
2316 status = cudaGetDevice(&device);
2317 if (status != cudaSuccess) {
2318 throw std::runtime_error(
2319 std::string(
2320 "Failed to detect GPU architecture: cudaGetDevice failed: ") +
2321 cudaGetErrorString(status));
2322 }
2323 int cc_major;
2324 cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device);
2325 int cc_minor;
2326 cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, device);
2327 int cc = cc_major * 10 + cc_minor;
2328 // Note: We must limit the architecture to the max supported by the current
2329 // version of NVRTC, otherwise newer hardware will cause errors
2330 // on older versions of CUDA.
2331 // TODO: It would be better to detect this somehow, rather than hard-coding it
2332
2333 // Tegra chips do not have forwards compatibility so we need to special case
2334 // them.
2335 bool is_tegra = ((cc_major == 3 && cc_minor == 2) || // Logan
2336 (cc_major == 5 && cc_minor == 3) || // Erista
2337 (cc_major == 6 && cc_minor == 2) || // Parker
2338 (cc_major == 7 && cc_minor == 2)); // Xavier
2339 if (!is_tegra) {
2340 // ensure that future CUDA versions just work (even if suboptimal)
2341 const int cuda_major = std::min(10, CUDA_VERSION / 1000);
2342 // clang-format off
2343 switch (cuda_major) {
2344 case 10: cc = std::min(cc, 75); break; // Turing
2345 case 9: cc = std::min(cc, 70); break; // Volta
2346 case 8: cc = std::min(cc, 61); break; // Pascal
2347 case 7: cc = std::min(cc, 52); break; // Maxwell
2348 default:
2349 throw std::runtime_error("Unexpected CUDA major version " +
2350 std::to_string(cuda_major));
2351 }
2352 // clang-format on
2353 }
2354
2355 std::stringstream ss;
2356 ss << cc;
2357 options.push_back("-arch=compute_" + ss.str());
2358 }
2359
detect_and_add_cxx11_flag(std::vector<std::string> & options)2360 inline void detect_and_add_cxx11_flag(std::vector<std::string>& options) {
2361 // Reverse loop so we can erase on the fly.
2362 for (int i = (int)options.size() - 1; i >= 0; --i) {
2363 if (options[i].find("-std=c++98") != std::string::npos) {
2364 // NVRTC doesn't support specifying c++98 explicitly, so we remove it.
2365 options.erase(options.begin() + i);
2366 return;
2367 } else if (options[i].find("-std") != std::string::npos) {
2368 // Some other standard was explicitly specified, don't change anything.
2369 return;
2370 }
2371 }
2372 // Jitify must be compiled with C++11 support, so we default to enabling it
2373 // for the JIT-compiled code too.
2374 options.push_back("-std=c++11");
2375 }
2376
split_compiler_and_linker_options(std::vector<std::string> options,std::vector<std::string> * compiler_options,std::vector<std::string> * linker_files,std::vector<std::string> * linker_paths)2377 inline void split_compiler_and_linker_options(
2378 std::vector<std::string> options,
2379 std::vector<std::string>* compiler_options,
2380 std::vector<std::string>* linker_files,
2381 std::vector<std::string>* linker_paths) {
2382 for (int i = 0; i < (int)options.size(); ++i) {
2383 std::string opt = options[i];
2384 std::string flag = opt.substr(0, 2);
2385 std::string value = opt.substr(2);
2386 if (flag == "-l") {
2387 linker_files->push_back(value);
2388 } else if (flag == "-L") {
2389 linker_paths->push_back(value);
2390 } else {
2391 compiler_options->push_back(opt);
2392 }
2393 }
2394 }
2395
pop_remove_unused_globals_flag(std::vector<std::string> * options)2396 inline bool pop_remove_unused_globals_flag(std::vector<std::string>* options) {
2397 auto it = std::remove_if(
2398 options->begin(), options->end(), [](const std::string& opt) {
2399 return opt.find("-remove-unused-globals") != std::string::npos;
2400 });
2401 if (it != options->end()) {
2402 options->resize(it - options->begin());
2403 return true;
2404 }
2405 return false;
2406 }
2407
ptx_parse_decl_name(const std::string & line)2408 inline std::string ptx_parse_decl_name(const std::string& line) {
2409 size_t name_end = line.find_first_of("[;");
2410 if (name_end == std::string::npos) {
2411 throw std::runtime_error(
2412 "Failed to parse .global/.const declaration in PTX: expected a "
2413 "semicolon");
2414 }
2415 size_t name_start_minus1 = line.find_last_of(" \t", name_end);
2416 if (name_start_minus1 == std::string::npos) {
2417 throw std::runtime_error(
2418 "Failed to parse .global/.const declaration in PTX: expected "
2419 "whitespace");
2420 }
2421 size_t name_start = name_start_minus1 + 1;
2422 std::string name = line.substr(name_start, name_end - name_start);
2423 return name;
2424 }
2425
ptx_remove_unused_globals(std::string * ptx)2426 inline void ptx_remove_unused_globals(std::string* ptx) {
2427 std::istringstream iss(*ptx);
2428 std::vector<std::string> lines;
2429 std::unordered_map<size_t, std::string> line_num_to_global_name;
2430 std::unordered_set<std::string> name_set;
2431 for (std::string line; std::getline(iss, line);) {
2432 size_t line_num = lines.size();
2433 lines.push_back(line);
2434 auto terms = split_string(line);
2435 if (terms.size() <= 1) continue; // Ignore lines with no arguments
2436 if (terms[0].substr(0, 2) == "//") continue; // Ignore comment lines
2437 if (terms[0].substr(0, 7) == ".global" ||
2438 terms[0].substr(0, 6) == ".const") {
2439 line_num_to_global_name.emplace(line_num, ptx_parse_decl_name(line));
2440 continue;
2441 }
2442 if (terms[0][0] == '.') continue; // Ignore .version, .reg, .param etc.
2443 // Note: The first term will always be an instruction name; starting at 1
2444 // also allows unchecked inspection of the previous term.
2445 for (int i = 1; i < (int)terms.size(); ++i) {
2446 if (terms[i].substr(0, 2) == "//") break; // Ignore comments
2447 // Note: The characters '.' and '%' are not treated as delimiters.
2448 const char* token_delims = " \t()[]{},;+-*/~&|^?:=!<>\"'\\";
2449 for (auto token : split_string(terms[i], -1, token_delims)) {
2450 if ( // Ignore non-names
2451 !(std::isalpha(token[0]) || token[0] == '_' || token[0] == '$') ||
2452 token.find('.') != std::string::npos ||
2453 // Ignore variable/parameter declarations
2454 terms[i - 1][0] == '.' ||
2455 // Ignore branch instructions
2456 (token == "bra" && terms[i - 1][0] == '@') ||
2457 // Ignore branch labels
2458 (token.substr(0, 2) == "BB" &&
2459 terms[i - 1].substr(0, 3) == "bra")) {
2460 continue;
2461 }
2462 name_set.insert(token);
2463 }
2464 }
2465 }
2466 std::ostringstream oss;
2467 for (size_t line_num = 0; line_num < lines.size(); ++line_num) {
2468 auto it = line_num_to_global_name.find(line_num);
2469 if (it != line_num_to_global_name.end()) {
2470 const std::string& name = it->second;
2471 if (!name_set.count(name)) {
2472 continue; // Remove unused .global declaration.
2473 }
2474 }
2475 oss << lines[line_num] << '\n';
2476 }
2477 *ptx = oss.str();
2478 }
2479
compile_kernel(std::string program_name,std::map<std::string,std::string> sources,std::vector<std::string> options,std::string instantiation="",std::string * log=0,std::string * ptx=0,std::string * mangled_instantiation=0)2480 inline nvrtcResult compile_kernel(std::string program_name,
2481 std::map<std::string, std::string> sources,
2482 std::vector<std::string> options,
2483 std::string instantiation = "",
2484 std::string* log = 0, std::string* ptx = 0,
2485 std::string* mangled_instantiation = 0) {
2486 std::string program_source = sources[program_name];
2487 // Build arrays of header names and sources
2488 std::vector<const char*> header_names_c;
2489 std::vector<const char*> header_sources_c;
2490 int num_headers = (int)(sources.size() - 1);
2491 header_names_c.reserve(num_headers);
2492 header_sources_c.reserve(num_headers);
2493 typedef std::map<std::string, std::string> source_map;
2494 for (source_map::const_iterator iter = sources.begin(); iter != sources.end();
2495 ++iter) {
2496 std::string const& name = iter->first;
2497 std::string const& code = iter->second;
2498 if (name == program_name) {
2499 continue;
2500 }
2501 header_names_c.push_back(name.c_str());
2502 header_sources_c.push_back(code.c_str());
2503 }
2504
2505 // TODO: This WAR is expected to be unnecessary as of CUDA > 10.2.
2506 bool should_remove_unused_globals =
2507 detail::pop_remove_unused_globals_flag(&options);
2508
2509 std::vector<const char*> options_c(options.size() + 2);
2510 options_c[0] = "--device-as-default-execution-space";
2511 options_c[1] = "--pre-include=jitify_preinclude.h";
2512 for (int i = 0; i < (int)options.size(); ++i) {
2513 options_c[i + 2] = options[i].c_str();
2514 }
2515
2516 #if CUDA_VERSION < 8000
2517 std::string inst_dummy;
2518 if (!instantiation.empty()) {
2519 // WAR for no nvrtcAddNameExpression before CUDA 8.0
2520 // Force template instantiation by adding dummy reference to kernel
2521 inst_dummy = "__jitify_instantiation";
2522 program_source +=
2523 "\nvoid* " + inst_dummy + " = (void*)" + instantiation + ";\n";
2524 }
2525 #endif
2526
2527 #define CHECK_NVRTC(call) \
2528 do { \
2529 nvrtcResult ret = call; \
2530 if (ret != NVRTC_SUCCESS) { \
2531 return ret; \
2532 } \
2533 } while (0)
2534
2535 nvrtcProgram nvrtc_program;
2536 CHECK_NVRTC(nvrtcCreateProgram(
2537 &nvrtc_program, program_source.c_str(), program_name.c_str(), num_headers,
2538 header_sources_c.data(), header_names_c.data()));
2539
2540 #if CUDA_VERSION >= 8000
2541 if (!instantiation.empty()) {
2542 CHECK_NVRTC(nvrtcAddNameExpression(nvrtc_program, instantiation.c_str()));
2543 }
2544 #endif
2545
2546 nvrtcResult ret = nvrtcCompileProgram(nvrtc_program, (int)options_c.size(),
2547 options_c.data());
2548 if (log) {
2549 size_t logsize;
2550 CHECK_NVRTC(nvrtcGetProgramLogSize(nvrtc_program, &logsize));
2551 std::vector<char> vlog(logsize, 0);
2552 CHECK_NVRTC(nvrtcGetProgramLog(nvrtc_program, vlog.data()));
2553 log->assign(vlog.data(), logsize);
2554 }
2555 if (ret != NVRTC_SUCCESS) {
2556 return ret;
2557 }
2558
2559 if (ptx) {
2560 size_t ptxsize;
2561 CHECK_NVRTC(nvrtcGetPTXSize(nvrtc_program, &ptxsize));
2562 std::vector<char> vptx(ptxsize);
2563 CHECK_NVRTC(nvrtcGetPTX(nvrtc_program, vptx.data()));
2564 ptx->assign(vptx.data(), ptxsize);
2565 if (should_remove_unused_globals) {
2566 detail::ptx_remove_unused_globals(ptx);
2567 }
2568 }
2569
2570 if (!instantiation.empty() && mangled_instantiation) {
2571 #if CUDA_VERSION >= 8000
2572 const char* mangled_instantiation_cstr;
2573 // Note: The returned string pointer becomes invalid after
2574 // nvrtcDestroyProgram has been called, so we save it.
2575 CHECK_NVRTC(nvrtcGetLoweredName(nvrtc_program, instantiation.c_str(),
2576 &mangled_instantiation_cstr));
2577 *mangled_instantiation = mangled_instantiation_cstr;
2578 #else
2579 // Extract mangled kernel template instantiation from PTX
2580 inst_dummy += " = "; // Note: This must match how the PTX is generated
2581 int mi_beg = ptx->find(inst_dummy) + inst_dummy.size();
2582 int mi_end = ptx->find(";", mi_beg);
2583 *mangled_instantiation = ptx->substr(mi_beg, mi_end - mi_beg);
2584 #endif
2585 }
2586
2587 CHECK_NVRTC(nvrtcDestroyProgram(&nvrtc_program));
2588 #undef CHECK_NVRTC
2589 return NVRTC_SUCCESS;
2590 }
2591
load_program(std::string const & cuda_source,std::vector<std::string> const & headers,file_callback_type file_callback,std::vector<std::string> * include_paths,std::map<std::string,std::string> * program_sources,std::vector<std::string> * program_options,std::string * program_name)2592 inline void load_program(std::string const& cuda_source,
2593 std::vector<std::string> const& headers,
2594 file_callback_type file_callback,
2595 std::vector<std::string>* include_paths,
2596 std::map<std::string, std::string>* program_sources,
2597 std::vector<std::string>* program_options,
2598 std::string* program_name) {
2599 // Extract include paths from compile options
2600 std::vector<std::string>::iterator iter = program_options->begin();
2601 while (iter != program_options->end()) {
2602 std::string const& opt = *iter;
2603 if (opt.substr(0, 2) == "-I") {
2604 include_paths->push_back(opt.substr(2));
2605 iter = program_options->erase(iter);
2606 } else {
2607 ++iter;
2608 }
2609 }
2610
2611 // Load program source
2612 if (!detail::load_source(cuda_source, *program_sources, "", *include_paths,
2613 file_callback)) {
2614 throw std::runtime_error("Source not found: " + cuda_source);
2615 }
2616 *program_name = program_sources->begin()->first;
2617
2618 // Maps header include names to their full file paths.
2619 std::map<std::string, std::string> header_fullpaths;
2620
2621 // Load header sources
2622 for (std::string const& header : headers) {
2623 if (!detail::load_source(header, *program_sources, "", *include_paths,
2624 file_callback, &header_fullpaths)) {
2625 // **TODO: Deal with source not found
2626 throw std::runtime_error("Source not found: " + header);
2627 }
2628 }
2629
2630 #if JITIFY_PRINT_SOURCE
2631 std::string& program_source = (*program_sources)[*program_name];
2632 std::cout << "---------------------------------------" << std::endl;
2633 std::cout << "--- Source of " << *program_name << " ---" << std::endl;
2634 std::cout << "---------------------------------------" << std::endl;
2635 detail::print_with_line_numbers(program_source);
2636 std::cout << "---------------------------------------" << std::endl;
2637 #endif
2638
2639 std::vector<std::string> compiler_options, linker_files, linker_paths;
2640 detail::split_compiler_and_linker_options(*program_options, &compiler_options,
2641 &linker_files, &linker_paths);
2642
2643 // If no arch is specified at this point we use whatever the current
2644 // context is. This ensures we pick up the correct internal headers
2645 // for arch-dependent compilation, e.g., some intrinsics are only
2646 // present for specific architectures.
2647 detail::detect_and_add_cuda_arch(compiler_options);
2648 detail::detect_and_add_cxx11_flag(compiler_options);
2649
2650 // Iteratively try to compile the sources, and use the resulting errors to
2651 // identify missing headers.
2652 std::string log;
2653 nvrtcResult ret;
2654 while ((ret = detail::compile_kernel(*program_name, *program_sources,
2655 compiler_options, "", &log)) ==
2656 NVRTC_ERROR_COMPILATION) {
2657 std::string include_name;
2658 std::string include_parent;
2659 int line_num = 0;
2660 if (!detail::extract_include_info_from_compile_error(
2661 log, include_name, include_parent, line_num)) {
2662 #if JITIFY_PRINT_LOG
2663 detail::print_compile_log(*program_name, log);
2664 #endif
2665 // There was a non include-related compilation error
2666 // TODO: How to handle error?
2667 throw std::runtime_error("Runtime compilation failed");
2668 }
2669
2670 bool is_included_with_quotes = false;
2671 if (program_sources->count(include_parent)) {
2672 const std::string& parent_source = (*program_sources)[include_parent];
2673 is_included_with_quotes =
2674 is_include_directive_with_quotes(parent_source, line_num);
2675 }
2676
2677 // Try to load the new header
2678 // Note: This fullpath lookup is needed because the compiler error
2679 // messages have the include name of the header instead of its full path.
2680 std::string include_parent_fullpath = header_fullpaths[include_parent];
2681 std::string include_path = detail::path_base(include_parent_fullpath);
2682 if (detail::load_source(include_name, *program_sources, include_path,
2683 *include_paths, file_callback, &header_fullpaths,
2684 is_included_with_quotes)) {
2685 #if JITIFY_PRINT_HEADER_PATHS
2686 std::cout << "Found #include " << include_name << " from "
2687 << include_parent << ":" << line_num << " ["
2688 << include_parent_fullpath << "]"
2689 << " at:\n " << header_fullpaths[include_name] << std::endl;
2690 #endif
2691 } else { // Failed to find header file.
2692 // Comment-out the include line and print a warning
2693 if (!program_sources->count(include_parent)) {
2694 // ***TODO: Unless there's another mechanism (e.g., potentially
2695 // the parent path vs. filename problem), getting
2696 // here means include_parent was found automatically
2697 // in a system include path.
2698 // We need a WAR to zap it from *its parent*.
2699
2700 typedef std::map<std::string, std::string> source_map;
2701 for (source_map::const_iterator it = program_sources->begin();
2702 it != program_sources->end(); ++it) {
2703 std::cout << " " << it->first << std::endl;
2704 }
2705 throw std::out_of_range(include_parent +
2706 " not in loaded sources!"
2707 " This may be due to a header being loaded by"
2708 " NVRTC without Jitify's knowledge.");
2709 }
2710 std::string& parent_source = (*program_sources)[include_parent];
2711 parent_source = detail::comment_out_code_line(line_num, parent_source);
2712 #if JITIFY_PRINT_LOG
2713 std::cout << include_parent << "(" << line_num
2714 << "): warning: " << include_name << ": [jitify] File not found"
2715 << std::endl;
2716 #endif
2717 }
2718 }
2719 if (ret != NVRTC_SUCCESS) {
2720 #if JITIFY_PRINT_LOG
2721 if (ret == NVRTC_ERROR_INVALID_OPTION) {
2722 std::cout << "Compiler options: ";
2723 for (int i = 0; i < (int)compiler_options.size(); ++i) {
2724 std::cout << compiler_options[i] << " ";
2725 }
2726 std::cout << std::endl;
2727 }
2728 #endif
2729 throw std::runtime_error(std::string("NVRTC error: ") +
2730 nvrtcGetErrorString(ret));
2731 }
2732 }
2733
instantiate_kernel(std::string const & program_name,std::map<std::string,std::string> const & program_sources,std::string const & instantiation,std::vector<std::string> const & options,std::string * log,std::string * ptx,std::string * mangled_instantiation,std::vector<std::string> * linker_files,std::vector<std::string> * linker_paths)2734 inline void instantiate_kernel(
2735 std::string const& program_name,
2736 std::map<std::string, std::string> const& program_sources,
2737 std::string const& instantiation, std::vector<std::string> const& options,
2738 std::string* log, std::string* ptx, std::string* mangled_instantiation,
2739 std::vector<std::string>* linker_files,
2740 std::vector<std::string>* linker_paths) {
2741 std::vector<std::string> compiler_options;
2742 detail::split_compiler_and_linker_options(options, &compiler_options,
2743 linker_files, linker_paths);
2744
2745 nvrtcResult ret =
2746 detail::compile_kernel(program_name, program_sources, compiler_options,
2747 instantiation, log, ptx, mangled_instantiation);
2748 #if JITIFY_PRINT_LOG
2749 if (log->size() > 1) {
2750 detail::print_compile_log(program_name, *log);
2751 }
2752 #endif
2753 if (ret != NVRTC_SUCCESS) {
2754 throw std::runtime_error(std::string("NVRTC error: ") +
2755 nvrtcGetErrorString(ret));
2756 }
2757
2758 #if JITIFY_PRINT_PTX
2759 std::cout << "---------------------------------------" << std::endl;
2760 std::cout << *mangled_instantiation << std::endl;
2761 std::cout << "---------------------------------------" << std::endl;
2762 std::cout << "--- PTX for " << mangled_instantiation << " in " << program_name
2763 << " ---" << std::endl;
2764 std::cout << "---------------------------------------" << std::endl;
2765 std::cout << *ptx << std::endl;
2766 std::cout << "---------------------------------------" << std::endl;
2767 #endif
2768 }
2769
get_1d_max_occupancy(CUfunction func,CUoccupancyB2DSize smem_callback,unsigned int * smem,int max_block_size,unsigned int flags,int * grid,int * block)2770 inline void get_1d_max_occupancy(CUfunction func,
2771 CUoccupancyB2DSize smem_callback,
2772 unsigned int* smem, int max_block_size,
2773 unsigned int flags, int* grid, int* block) {
2774 if (!func) {
2775 throw std::runtime_error(
2776 "Kernel pointer is NULL; you may need to define JITIFY_THREAD_SAFE "
2777 "1");
2778 }
2779 CUresult res = cuOccupancyMaxPotentialBlockSizeWithFlags(
2780 grid, block, func, smem_callback, *smem, max_block_size, flags);
2781 if (res != CUDA_SUCCESS) {
2782 const char* msg;
2783 cuGetErrorName(res, &msg);
2784 throw std::runtime_error(msg);
2785 }
2786 if (smem_callback) {
2787 *smem = (unsigned int)smem_callback(*block);
2788 }
2789 }
2790
2791 } // namespace detail
2792
2793 //! \endcond
2794
2795 class KernelInstantiation;
2796 class Kernel;
2797 class Program;
2798 class JitCache;
2799
2800 struct ProgramConfig {
2801 std::vector<std::string> options;
2802 std::vector<std::string> include_paths;
2803 std::string name;
2804 typedef std::map<std::string, std::string> source_map;
2805 source_map sources;
2806 };
2807
2808 class JitCache_impl {
2809 friend class Program_impl;
2810 friend class KernelInstantiation_impl;
2811 friend class KernelLauncher_impl;
2812 typedef uint64_t key_type;
2813 jitify::ObjectCache<key_type, detail::CUDAKernel> _kernel_cache;
2814 jitify::ObjectCache<key_type, ProgramConfig> _program_config_cache;
2815 std::vector<std::string> _options;
2816 #if JITIFY_THREAD_SAFE
2817 std::mutex _kernel_cache_mutex;
2818 std::mutex _program_cache_mutex;
2819 #endif
2820 public:
JitCache_impl(size_t cache_size)2821 inline JitCache_impl(size_t cache_size)
2822 : _kernel_cache(cache_size), _program_config_cache(cache_size) {
2823 detail::add_options_from_env(_options);
2824
2825 // Bootstrap the cuda context to avoid errors
2826 cudaFree(0);
2827 }
2828 };
2829
2830 class Program_impl {
2831 // A friendly class
2832 friend class Kernel_impl;
2833 friend class KernelLauncher_impl;
2834 friend class KernelInstantiation_impl;
2835 // TODO: This can become invalid if JitCache is destroyed before the
2836 // Program object is. However, this can't happen if JitCache
2837 // instances are static.
2838 JitCache_impl& _cache;
2839 uint64_t _hash;
2840 ProgramConfig* _config;
2841 void load_sources(std::string source, std::vector<std::string> headers,
2842 std::vector<std::string> options,
2843 file_callback_type file_callback);
2844
2845 public:
2846 inline Program_impl(JitCache_impl& cache, std::string source,
2847 jitify::detail::vector<std::string> headers = 0,
2848 jitify::detail::vector<std::string> options = 0,
2849 file_callback_type file_callback = 0);
2850 inline Program_impl(Program_impl const&) = default;
2851 inline Program_impl(Program_impl&&) = default;
options() const2852 inline std::vector<std::string> const& options() const {
2853 return _config->options;
2854 }
name() const2855 inline std::string const& name() const { return _config->name; }
sources() const2856 inline ProgramConfig::source_map const& sources() const {
2857 return _config->sources;
2858 }
include_paths() const2859 inline std::vector<std::string> const& include_paths() const {
2860 return _config->include_paths;
2861 }
2862 };
2863
2864 class Kernel_impl {
2865 friend class KernelLauncher_impl;
2866 friend class KernelInstantiation_impl;
2867 Program_impl _program;
2868 std::string _name;
2869 std::vector<std::string> _options;
2870 uint64_t _hash;
2871
2872 public:
2873 inline Kernel_impl(Program_impl const& program, std::string name,
2874 jitify::detail::vector<std::string> options = 0);
2875 inline Kernel_impl(Kernel_impl const&) = default;
2876 inline Kernel_impl(Kernel_impl&&) = default;
2877 };
2878
2879 class KernelInstantiation_impl {
2880 friend class KernelLauncher_impl;
2881 Kernel_impl _kernel;
2882 uint64_t _hash;
2883 std::string _template_inst;
2884 std::vector<std::string> _options;
2885 detail::CUDAKernel* _cuda_kernel;
2886 inline void print() const;
2887 void build_kernel();
2888
2889 public:
2890 inline KernelInstantiation_impl(
2891 Kernel_impl const& kernel, std::vector<std::string> const& template_args);
2892 inline KernelInstantiation_impl(KernelInstantiation_impl const&) = default;
2893 inline KernelInstantiation_impl(KernelInstantiation_impl&&) = default;
cuda_kernel() const2894 detail::CUDAKernel const& cuda_kernel() const { return *_cuda_kernel; }
2895 };
2896
2897 class KernelLauncher_impl {
2898 KernelInstantiation_impl _kernel_inst;
2899 dim3 _grid;
2900 dim3 _block;
2901 unsigned int _smem;
2902 cudaStream_t _stream;
2903
2904 public:
KernelLauncher_impl(KernelInstantiation_impl const & kernel_inst,dim3 grid,dim3 block,unsigned int smem=0,cudaStream_t stream=0)2905 inline KernelLauncher_impl(KernelInstantiation_impl const& kernel_inst,
2906 dim3 grid, dim3 block, unsigned int smem = 0,
2907 cudaStream_t stream = 0)
2908 : _kernel_inst(kernel_inst),
2909 _grid(grid),
2910 _block(block),
2911 _smem(smem),
2912 _stream(stream) {}
2913 inline KernelLauncher_impl(KernelLauncher_impl const&) = default;
2914 inline KernelLauncher_impl(KernelLauncher_impl&&) = default;
2915 inline CUresult launch(
2916 jitify::detail::vector<void*> arg_ptrs,
2917 jitify::detail::vector<std::string> arg_types = 0) const;
2918 };
2919
2920 /*! An object representing a configured and instantiated kernel ready
2921 * for launching.
2922 */
2923 class KernelLauncher {
2924 std::unique_ptr<KernelLauncher_impl const> _impl;
2925
2926 public:
2927 inline KernelLauncher(KernelInstantiation const& kernel_inst, dim3 grid,
2928 dim3 block, unsigned int smem = 0,
2929 cudaStream_t stream = 0);
2930
2931 // Note: It's important that there is no implicit conversion required
2932 // for arg_ptrs, because otherwise the parameter pack version
2933 // below gets called instead (probably resulting in a segfault).
2934 /*! Launch the kernel.
2935 *
2936 * \param arg_ptrs A vector of pointers to each function argument for the
2937 * kernel.
2938 * \param arg_types A vector of function argument types represented
2939 * as code-strings. This parameter is optional and is only used to print
2940 * out the function signature.
2941 */
launch(std::vector<void * > arg_ptrs=std::vector<void * > (),jitify::detail::vector<std::string> arg_types=0) const2942 inline CUresult launch(
2943 std::vector<void*> arg_ptrs = std::vector<void*>(),
2944 jitify::detail::vector<std::string> arg_types = 0) const {
2945 return _impl->launch(arg_ptrs, arg_types);
2946 }
2947 // Regular function call syntax
2948 /*! Launch the kernel.
2949 *
2950 * \see launch
2951 */
2952 template <typename... ArgTypes>
operator ()(ArgTypes...args) const2953 inline CUresult operator()(ArgTypes... args) const {
2954 return this->launch(args...);
2955 }
2956 /*! Launch the kernel.
2957 *
2958 * \param args Function arguments for the kernel.
2959 */
2960 template <typename... ArgTypes>
launch(ArgTypes...args) const2961 inline CUresult launch(ArgTypes... args) const {
2962 return this->launch(std::vector<void*>({(void*)&args...}),
2963 {reflection::reflect<ArgTypes>()...});
2964 }
2965 };
2966
2967 /*! An object representing a kernel instantiation made up of a Kernel and
2968 * template arguments.
2969 */
2970 class KernelInstantiation {
2971 friend class KernelLauncher;
2972 std::unique_ptr<KernelInstantiation_impl const> _impl;
2973
2974 public:
2975 inline KernelInstantiation(Kernel const& kernel,
2976 std::vector<std::string> const& template_args);
2977
2978 /*! Implicit conversion to the underlying CUfunction object.
2979 *
2980 * \note This allows use of CUDA APIs like
2981 * cuOccupancyMaxActiveBlocksPerMultiprocessor.
2982 */
operator CUfunction() const2983 inline operator CUfunction() const { return _impl->cuda_kernel(); }
2984
2985 /*! Configure the kernel launch.
2986 *
2987 * \see configure
2988 */
operator ()(dim3 grid,dim3 block,unsigned int smem=0,cudaStream_t stream=0) const2989 inline KernelLauncher operator()(dim3 grid, dim3 block, unsigned int smem = 0,
2990 cudaStream_t stream = 0) const {
2991 return this->configure(grid, block, smem, stream);
2992 }
2993 /*! Configure the kernel launch.
2994 *
2995 * \param grid The thread grid dimensions for the launch.
2996 * \param block The thread block dimensions for the launch.
2997 * \param smem The amount of shared memory to dynamically allocate, in
2998 * bytes.
2999 * \param stream The CUDA stream to launch the kernel in.
3000 */
configure(dim3 grid,dim3 block,unsigned int smem=0,cudaStream_t stream=0) const3001 inline KernelLauncher configure(dim3 grid, dim3 block, unsigned int smem = 0,
3002 cudaStream_t stream = 0) const {
3003 return KernelLauncher(*this, grid, block, smem, stream);
3004 }
3005 /*! Configure the kernel launch with a 1-dimensional block and grid chosen
3006 * automatically to maximise occupancy.
3007 *
3008 * \param max_block_size The upper limit on the block size, or 0 for no
3009 * limit.
3010 * \param smem The amount of shared memory to dynamically allocate, in bytes.
3011 * \param smem_callback A function returning smem for a given block size (overrides \p smem).
3012 * \param stream The CUDA stream to launch the kernel in.
3013 * \param flags The flags to pass to cuOccupancyMaxPotentialBlockSizeWithFlags.
3014 */
configure_1d_max_occupancy(int max_block_size=0,unsigned int smem=0,CUoccupancyB2DSize smem_callback=0,cudaStream_t stream=0,unsigned int flags=0) const3015 inline KernelLauncher configure_1d_max_occupancy(
3016 int max_block_size = 0, unsigned int smem = 0,
3017 CUoccupancyB2DSize smem_callback = 0, cudaStream_t stream = 0,
3018 unsigned int flags = 0) const {
3019 int grid;
3020 int block;
3021 CUfunction func = _impl->cuda_kernel();
3022 detail::get_1d_max_occupancy(func, smem_callback, &smem, max_block_size,
3023 flags, &grid, &block);
3024 return this->configure(grid, block, smem, stream);
3025 }
3026
3027 /*
3028 * \deprecated Use \p get_global_ptr instead.
3029 */
get_constant_ptr(const char * name,size_t * size=nullptr) const3030 inline CUdeviceptr get_constant_ptr(const char* name,
3031 size_t* size = nullptr) const {
3032 return get_global_ptr(name, size);
3033 }
3034
3035 /*
3036 * Get a device pointer to a global __constant__ or __device__ variable using
3037 * its un-mangled name. If provided, *size is set to the size of the variable
3038 * in bytes.
3039 */
get_global_ptr(const char * name,size_t * size=nullptr) const3040 inline CUdeviceptr get_global_ptr(const char* name,
3041 size_t* size = nullptr) const {
3042 return _impl->cuda_kernel().get_global_ptr(name, size);
3043 }
3044
3045 /*
3046 * Copy data from a global __constant__ or __device__ array to the host using
3047 * its un-mangled name.
3048 */
3049 template <typename T>
get_global_array(const char * name,T * data,size_t count,CUstream stream=0) const3050 inline CUresult get_global_array(const char* name, T* data, size_t count,
3051 CUstream stream = 0) const {
3052 return _impl->cuda_kernel().get_global_data(name, data, count, stream);
3053 }
3054
3055 /*
3056 * Copy a value from a global __constant__ or __device__ variable to the host
3057 * using its un-mangled name.
3058 */
3059 template <typename T>
get_global_value(const char * name,T * value,CUstream stream=0) const3060 inline CUresult get_global_value(const char* name, T* value,
3061 CUstream stream = 0) const {
3062 return get_global_array(name, value, 1, stream);
3063 }
3064
3065 /*
3066 * Copy data from the host to a global __constant__ or __device__ array using
3067 * its un-mangled name.
3068 */
3069 template <typename T>
set_global_array(const char * name,const T * data,size_t count,CUstream stream=0) const3070 inline CUresult set_global_array(const char* name, const T* data,
3071 size_t count, CUstream stream = 0) const {
3072 return _impl->cuda_kernel().set_global_data(name, data, count, stream);
3073 }
3074
3075 /*
3076 * Copy a value from the host to a global __constant__ or __device__ variable
3077 * using its un-mangled name.
3078 */
3079 template <typename T>
set_global_value(const char * name,const T & value,CUstream stream=0) const3080 inline CUresult set_global_value(const char* name, const T& value,
3081 CUstream stream = 0) const {
3082 return set_global_array(name, &value, 1, stream);
3083 }
3084
mangled_name() const3085 const std::string& mangled_name() const {
3086 return _impl->cuda_kernel().function_name();
3087 }
3088
ptx() const3089 const std::string& ptx() const { return _impl->cuda_kernel().ptx(); }
3090
link_files() const3091 const std::vector<std::string>& link_files() const {
3092 return _impl->cuda_kernel().link_files();
3093 }
3094
link_paths() const3095 const std::vector<std::string>& link_paths() const {
3096 return _impl->cuda_kernel().link_paths();
3097 }
3098 };
3099
3100 /*! An object representing a kernel made up of a Program, a name and options.
3101 */
3102 class Kernel {
3103 friend class KernelInstantiation;
3104 std::unique_ptr<Kernel_impl const> _impl;
3105
3106 public:
3107 Kernel(Program const& program, std::string name,
3108 jitify::detail::vector<std::string> options = 0);
3109
3110 /*! Instantiate the kernel.
3111 *
3112 * \param template_args A vector of template arguments represented as
3113 * code-strings. These can be generated using
3114 * \code{.cpp}jitify::reflection::reflect<type>()\endcode or
3115 * \code{.cpp}jitify::reflection::reflect(value)\endcode
3116 *
3117 * \note Template type deduction is not possible, so all types must be
3118 * explicitly specified.
3119 */
3120 // inline KernelInstantiation instantiate(std::vector<std::string> const&
3121 // template_args) const {
instantiate(std::vector<std::string> const & template_args=std::vector<std::string> ()) const3122 inline KernelInstantiation instantiate(
3123 std::vector<std::string> const& template_args =
3124 std::vector<std::string>()) const {
3125 return KernelInstantiation(*this, template_args);
3126 }
3127
3128 // Regular template instantiation syntax (note limited flexibility)
3129 /*! Instantiate the kernel.
3130 *
3131 * \note The template arguments specified on this function are
3132 * used to instantiate the kernel. Non-type template arguments must
3133 * be wrapped with
3134 * \code{.cpp}jitify::reflection::NonType<type,value>\endcode
3135 *
3136 * \note Template type deduction is not possible, so all types must be
3137 * explicitly specified.
3138 */
3139 template <typename... TemplateArgs>
instantiate() const3140 inline KernelInstantiation instantiate() const {
3141 return this->instantiate(
3142 std::vector<std::string>({reflection::reflect<TemplateArgs>()...}));
3143 }
3144 // Template-like instantiation syntax
3145 // E.g., instantiate(myvar,Type<MyType>())(grid,block)
3146 /*! Instantiate the kernel.
3147 *
3148 * \param targs The template arguments for the kernel, represented as
3149 * values. Types must be wrapped with
3150 * \code{.cpp}jitify::reflection::Type<type>()\endcode or
3151 * \code{.cpp}jitify::reflection::type_of(value)\endcode
3152 *
3153 * \note Template type deduction is not possible, so all types must be
3154 * explicitly specified.
3155 */
3156 template <typename... TemplateArgs>
instantiate(TemplateArgs...targs) const3157 inline KernelInstantiation instantiate(TemplateArgs... targs) const {
3158 return this->instantiate(
3159 std::vector<std::string>({reflection::reflect(targs)...}));
3160 }
3161 };
3162
3163 /*! An object representing a program made up of source code, headers
3164 * and options.
3165 */
3166 class Program {
3167 friend class Kernel;
3168 std::unique_ptr<Program_impl const> _impl;
3169
3170 public:
3171 Program(JitCache& cache, std::string source,
3172 jitify::detail::vector<std::string> headers = 0,
3173 jitify::detail::vector<std::string> options = 0,
3174 file_callback_type file_callback = 0);
3175
3176 /*! Select a kernel.
3177 *
3178 * \param name The name of the kernel (unmangled and without
3179 * template arguments).
3180 * \param options A vector of options to be passed to the NVRTC
3181 * compiler when compiling this kernel.
3182 */
kernel(std::string name,jitify::detail::vector<std::string> options=0) const3183 inline Kernel kernel(std::string name,
3184 jitify::detail::vector<std::string> options = 0) const {
3185 return Kernel(*this, name, options);
3186 }
3187 /*! Select a kernel.
3188 *
3189 * \see kernel
3190 */
operator ()(std::string name,jitify::detail::vector<std::string> options=0) const3191 inline Kernel operator()(
3192 std::string name, jitify::detail::vector<std::string> options = 0) const {
3193 return this->kernel(name, options);
3194 }
3195 };
3196
3197 /*! An object that manages a cache of JIT-compiled CUDA kernels.
3198 *
3199 */
3200 class JitCache {
3201 friend class Program;
3202 std::unique_ptr<JitCache_impl> _impl;
3203
3204 public:
3205 /*! JitCache constructor.
3206 * \param cache_size The number of kernels to hold in the cache
3207 * before overwriting the least-recently-used ones.
3208 */
3209 enum { DEFAULT_CACHE_SIZE = 128 };
JitCache(size_t cache_size=DEFAULT_CACHE_SIZE)3210 JitCache(size_t cache_size = DEFAULT_CACHE_SIZE)
3211 : _impl(new JitCache_impl(cache_size)) {}
3212
3213 /*! Create a program.
3214 *
3215 * \param source A string containing either the source filename or
3216 * the source itself; in the latter case, the first line must be
3217 * the name of the program.
3218 * \param headers A vector of strings representing the source of
3219 * each header file required by the program. Each entry can be
3220 * either the header filename or the header source itself; in
3221 * the latter case, the first line must be the name of the header
3222 * (i.e., the name by which the header is #included).
3223 * \param options A vector of options to be passed to the
3224 * NVRTC compiler. Include paths specified with \p -I
3225 * are added to the search paths used by Jitify. The environment
3226 * variable JITIFY_OPTIONS can also be used to define additional
3227 * options.
3228 * \param file_callback A pointer to a callback function that is
3229 * invoked whenever a source file needs to be loaded. Inside this
3230 * function, the user can either load/specify the source themselves
3231 * or defer to Jitify's file-loading mechanisms.
3232 * \note Program or header source files referenced by filename are
3233 * looked-up using the following mechanisms (in this order):
3234 * \note 1) By calling file_callback.
3235 * \note 2) By looking for the file embedded in the executable via the GCC
3236 * linker.
3237 * \note 3) By looking for the file in the filesystem.
3238 *
3239 * \note Jitify recursively scans all source files for \p #include
3240 * directives and automatically adds them to the set of headers needed
3241 * by the program.
3242 * If a \p #include directive references a header that cannot be found,
3243 * the directive is automatically removed from the source code to prevent
3244 * immediate compilation failure. This may result in compilation errors
3245 * if the header was required by the program.
3246 *
3247 * \note Jitify automatically includes NVRTC-safe versions of some
3248 * standard library headers.
3249 */
program(std::string source,jitify::detail::vector<std::string> headers=0,jitify::detail::vector<std::string> options=0,file_callback_type file_callback=0)3250 inline Program program(std::string source,
3251 jitify::detail::vector<std::string> headers = 0,
3252 jitify::detail::vector<std::string> options = 0,
3253 file_callback_type file_callback = 0) {
3254 return Program(*this, source, headers, options, file_callback);
3255 }
3256 };
3257
Program(JitCache & cache,std::string source,jitify::detail::vector<std::string> headers,jitify::detail::vector<std::string> options,file_callback_type file_callback)3258 inline Program::Program(JitCache& cache, std::string source,
3259 jitify::detail::vector<std::string> headers,
3260 jitify::detail::vector<std::string> options,
3261 file_callback_type file_callback)
3262 : _impl(new Program_impl(*cache._impl, source, headers, options,
3263 file_callback)) {}
3264
Kernel(Program const & program,std::string name,jitify::detail::vector<std::string> options)3265 inline Kernel::Kernel(Program const& program, std::string name,
3266 jitify::detail::vector<std::string> options)
3267 : _impl(new Kernel_impl(*program._impl, name, options)) {}
3268
KernelInstantiation(Kernel const & kernel,std::vector<std::string> const & template_args)3269 inline KernelInstantiation::KernelInstantiation(
3270 Kernel const& kernel, std::vector<std::string> const& template_args)
3271 : _impl(new KernelInstantiation_impl(*kernel._impl, template_args)) {}
3272
KernelLauncher(KernelInstantiation const & kernel_inst,dim3 grid,dim3 block,unsigned int smem,cudaStream_t stream)3273 inline KernelLauncher::KernelLauncher(KernelInstantiation const& kernel_inst,
3274 dim3 grid, dim3 block, unsigned int smem,
3275 cudaStream_t stream)
3276 : _impl(new KernelLauncher_impl(*kernel_inst._impl, grid, block, smem,
3277 stream)) {}
3278
operator <<(std::ostream & stream,dim3 d)3279 inline std::ostream& operator<<(std::ostream& stream, dim3 d) {
3280 if (d.y == 1 && d.z == 1) {
3281 stream << d.x;
3282 } else {
3283 stream << "(" << d.x << "," << d.y << "," << d.z << ")";
3284 }
3285 return stream;
3286 }
3287
launch(jitify::detail::vector<void * > arg_ptrs,jitify::detail::vector<std::string> arg_types) const3288 inline CUresult KernelLauncher_impl::launch(
3289 jitify::detail::vector<void*> arg_ptrs,
3290 jitify::detail::vector<std::string> arg_types) const {
3291 #if JITIFY_PRINT_LAUNCH
3292 Kernel_impl const& kernel = _kernel_inst._kernel;
3293 std::string arg_types_string =
3294 (arg_types.empty() ? "..." : reflection::reflect_list(arg_types));
3295 std::cout << "Launching " << kernel._name << _kernel_inst._template_inst
3296 << "<<<" << _grid << "," << _block << "," << _smem << "," << _stream
3297 << ">>>"
3298 << "(" << arg_types_string << ")" << std::endl;
3299 #endif
3300 if (!_kernel_inst._cuda_kernel) {
3301 throw std::runtime_error(
3302 "Kernel pointer is NULL; you may need to define JITIFY_THREAD_SAFE 1");
3303 }
3304 return _kernel_inst._cuda_kernel->launch(_grid, _block, _smem, _stream,
3305 arg_ptrs);
3306 }
3307
KernelInstantiation_impl(Kernel_impl const & kernel,std::vector<std::string> const & template_args)3308 inline KernelInstantiation_impl::KernelInstantiation_impl(
3309 Kernel_impl const& kernel, std::vector<std::string> const& template_args)
3310 : _kernel(kernel), _options(kernel._options) {
3311 _template_inst =
3312 (template_args.empty() ? ""
3313 : reflection::reflect_template(template_args));
3314 using detail::hash_combine;
3315 using detail::hash_larson64;
3316 _hash = _kernel._hash;
3317 _hash = hash_combine(_hash, hash_larson64(_template_inst.c_str()));
3318 JitCache_impl& cache = _kernel._program._cache;
3319 uint64_t cache_key = _hash;
3320 #if JITIFY_THREAD_SAFE
3321 std::lock_guard<std::mutex> lock(cache._kernel_cache_mutex);
3322 #endif
3323 if (cache._kernel_cache.contains(cache_key)) {
3324 #if JITIFY_PRINT_INSTANTIATION
3325 std::cout << "Found ";
3326 this->print();
3327 #endif
3328 _cuda_kernel = &cache._kernel_cache.get(cache_key);
3329 } else {
3330 #if JITIFY_PRINT_INSTANTIATION
3331 std::cout << "Building ";
3332 this->print();
3333 #endif
3334 _cuda_kernel = &cache._kernel_cache.emplace(cache_key);
3335 this->build_kernel();
3336 }
3337 }
3338
print() const3339 inline void KernelInstantiation_impl::print() const {
3340 std::string options_string = reflection::reflect_list(_options);
3341 std::cout << _kernel._name << _template_inst << " [" << options_string << "]"
3342 << std::endl;
3343 }
3344
build_kernel()3345 inline void KernelInstantiation_impl::build_kernel() {
3346 Program_impl const& program = _kernel._program;
3347
3348 std::string instantiation = _kernel._name + _template_inst;
3349
3350 std::string log, ptx, mangled_instantiation;
3351 std::vector<std::string> linker_files, linker_paths;
3352 detail::instantiate_kernel(program.name(), program.sources(), instantiation,
3353 _options, &log, &ptx, &mangled_instantiation,
3354 &linker_files, &linker_paths);
3355
3356 _cuda_kernel->set(mangled_instantiation.c_str(), ptx.c_str(), linker_files,
3357 linker_paths);
3358 }
3359
Kernel_impl(Program_impl const & program,std::string name,jitify::detail::vector<std::string> options)3360 Kernel_impl::Kernel_impl(Program_impl const& program, std::string name,
3361 jitify::detail::vector<std::string> options)
3362 : _program(program), _name(name), _options(options) {
3363 // Merge options from parent
3364 _options.insert(_options.end(), _program.options().begin(),
3365 _program.options().end());
3366 detail::detect_and_add_cuda_arch(_options);
3367 detail::detect_and_add_cxx11_flag(_options);
3368 std::string options_string = reflection::reflect_list(_options);
3369 using detail::hash_combine;
3370 using detail::hash_larson64;
3371 _hash = _program._hash;
3372 _hash = hash_combine(_hash, hash_larson64(_name.c_str()));
3373 _hash = hash_combine(_hash, hash_larson64(options_string.c_str()));
3374 }
3375
Program_impl(JitCache_impl & cache,std::string source,jitify::detail::vector<std::string> headers,jitify::detail::vector<std::string> options,file_callback_type file_callback)3376 Program_impl::Program_impl(JitCache_impl& cache, std::string source,
3377 jitify::detail::vector<std::string> headers,
3378 jitify::detail::vector<std::string> options,
3379 file_callback_type file_callback)
3380 : _cache(cache) {
3381 // Compute hash of source, headers and options
3382 std::string options_string = reflection::reflect_list(options);
3383 using detail::hash_combine;
3384 using detail::hash_larson64;
3385 _hash = hash_combine(hash_larson64(source.c_str()),
3386 hash_larson64(options_string.c_str()));
3387 for (size_t i = 0; i < headers.size(); ++i) {
3388 _hash = hash_combine(_hash, hash_larson64(headers[i].c_str()));
3389 }
3390 _hash = hash_combine(_hash, (uint64_t)file_callback);
3391 // Add pre-include built-in JIT-safe headers
3392 for (int i = 0; i < detail::preinclude_jitsafe_headers_count; ++i) {
3393 const char* hdr_name = detail::preinclude_jitsafe_header_names[i];
3394 const std::string& hdr_source =
3395 detail::get_jitsafe_headers_map().at(hdr_name);
3396 headers.push_back(std::string(hdr_name) + "\n" + hdr_source);
3397 }
3398 // Merge options from parent
3399 options.insert(options.end(), _cache._options.begin(), _cache._options.end());
3400 // Load sources
3401 #if JITIFY_THREAD_SAFE
3402 std::lock_guard<std::mutex> lock(cache._program_cache_mutex);
3403 #endif
3404 if (!cache._program_config_cache.contains(_hash)) {
3405 _config = &cache._program_config_cache.insert(_hash);
3406 this->load_sources(source, headers, options, file_callback);
3407 } else {
3408 _config = &cache._program_config_cache.get(_hash);
3409 }
3410 }
3411
load_sources(std::string source,std::vector<std::string> headers,std::vector<std::string> options,file_callback_type file_callback)3412 inline void Program_impl::load_sources(std::string source,
3413 std::vector<std::string> headers,
3414 std::vector<std::string> options,
3415 file_callback_type file_callback) {
3416 _config->options = options;
3417 detail::load_program(source, headers, file_callback, &_config->include_paths,
3418 &_config->sources, &_config->options, &_config->name);
3419 }
3420
3421 enum Location { HOST, DEVICE };
3422
3423 /*! Specifies location and parameters for execution of an algorithm.
3424 * \param stream The CUDA stream on which to execute.
3425 * \param headers A vector of headers to include in the code.
3426 * \param options Options to pass to the NVRTC compiler.
3427 * \param file_callback See jitify::Program.
3428 * \param block_size The size of the CUDA thread block with which to
3429 * execute.
3430 * \param cache_size The number of kernels to store in the cache
3431 * before overwriting the least-recently-used ones.
3432 */
3433 struct ExecutionPolicy {
3434 /*! Location (HOST or DEVICE) on which to execute.*/
3435 Location location;
3436 /*! List of headers to include when compiling the algorithm.*/
3437 std::vector<std::string> headers;
3438 /*! List of compiler options.*/
3439 std::vector<std::string> options;
3440 /*! Optional callback for loading source files.*/
3441 file_callback_type file_callback;
3442 /*! CUDA stream on which to execute.*/
3443 cudaStream_t stream;
3444 /*! CUDA device on which to execute.*/
3445 int device;
3446 /*! CUDA block size with which to execute.*/
3447 int block_size;
3448 /*! The number of instantiations to store in the cache before overwriting
3449 * the least-recently-used ones.*/
3450 size_t cache_size;
ExecutionPolicyjitify::ExecutionPolicy3451 ExecutionPolicy(Location location_ = DEVICE,
3452 jitify::detail::vector<std::string> headers_ = 0,
3453 jitify::detail::vector<std::string> options_ = 0,
3454 file_callback_type file_callback_ = 0,
3455 cudaStream_t stream_ = 0, int device_ = 0,
3456 int block_size_ = 256,
3457 size_t cache_size_ = JitCache::DEFAULT_CACHE_SIZE)
3458 : location(location_),
3459 headers(headers_),
3460 options(options_),
3461 file_callback(file_callback_),
3462 stream(stream_),
3463 device(device_),
3464 block_size(block_size_),
3465 cache_size(cache_size_) {}
3466 };
3467
3468 template <class Func>
3469 class Lambda;
3470
3471 /*! An object that captures a set of variables for use in a parallel_for
3472 * expression. See JITIFY_CAPTURE().
3473 */
3474 class Capture {
3475 public:
3476 std::vector<std::string> _arg_decls;
3477 std::vector<void*> _arg_ptrs;
3478
3479 public:
3480 template <typename... Args>
Capture(std::vector<std::string> arg_names,Args const &...args)3481 inline Capture(std::vector<std::string> arg_names, Args const&... args)
3482 : _arg_ptrs{(void*)&args...} {
3483 std::vector<std::string> arg_types = {reflection::reflect<Args>()...};
3484 _arg_decls.resize(arg_names.size());
3485 for (int i = 0; i < (int)arg_names.size(); ++i) {
3486 _arg_decls[i] = arg_types[i] + " " + arg_names[i];
3487 }
3488 }
3489 };
3490
3491 /*! An object that captures the instantiated Lambda function for use
3492 in a parallel_for expression and the function string for NVRTC
3493 compilation
3494 */
3495 template <class Func>
3496 class Lambda {
3497 public:
3498 Capture _capture;
3499 std::string _func_string;
3500 Func _func;
3501
3502 public:
Lambda(Capture const & capture,std::string func_string,Func func)3503 inline Lambda(Capture const& capture, std::string func_string, Func func)
3504 : _capture(capture), _func_string(func_string), _func(func) {}
3505 };
3506
3507 template <typename T>
make_Lambda(Capture const & capture,std::string func,T lambda)3508 inline Lambda<T> make_Lambda(Capture const& capture, std::string func,
3509 T lambda) {
3510 return Lambda<T>(capture, func, lambda);
3511 }
3512
3513 #define JITIFY_CAPTURE(...) \
3514 jitify::Capture(jitify::detail::split_string(#__VA_ARGS__, -1, ","), \
3515 __VA_ARGS__)
3516
3517 #define JITIFY_MAKE_LAMBDA(capture, x, ...) \
3518 jitify::make_Lambda(capture, std::string(#__VA_ARGS__), \
3519 [x](int i) { __VA_ARGS__; })
3520
3521 #define JITIFY_ARGS(...) __VA_ARGS__
3522
3523 #define JITIFY_LAMBDA_(x, ...) \
3524 JITIFY_MAKE_LAMBDA(JITIFY_CAPTURE(x), JITIFY_ARGS(x), __VA_ARGS__)
3525
3526 // macro sequence to strip surrounding brackets
3527 #define JITIFY_STRIP_PARENS(X) X
3528 #define JITIFY_PASS_PARAMETERS(X) JITIFY_STRIP_PARENS(JITIFY_ARGS X)
3529
3530 /*! Creates a Lambda object with captured variables and a function
3531 * definition.
3532 * \param capture A bracket-enclosed list of variables to capture.
3533 * \param ... The function definition.
3534 *
3535 * \code{.cpp}
3536 * float* capture_me;
3537 * int capture_me_too;
3538 * auto my_lambda = JITIFY_LAMBDA( (capture_me, capture_me_too),
3539 * capture_me[i] = i*capture_me_too );
3540 * \endcode
3541 */
3542 #define JITIFY_LAMBDA(capture, ...) \
3543 JITIFY_LAMBDA_(JITIFY_ARGS(JITIFY_PASS_PARAMETERS(capture)), \
3544 JITIFY_ARGS(__VA_ARGS__))
3545
3546 // TODO: Try to implement for_each that accepts iterators instead of indices
3547 // Add compile guard for NOCUDA compilation
3548 /*! Call a function for a range of indices
3549 *
3550 * \param policy Determines the location and device parameters for
3551 * execution of the parallel_for.
3552 * \param begin The starting index.
3553 * \param end The ending index.
3554 * \param lambda A Lambda object created using the JITIFY_LAMBDA() macro.
3555 *
3556 * \code{.cpp}
3557 * char const* in;
3558 * float* out;
3559 * parallel_for(0, 100, JITIFY_LAMBDA( (in, out), {char x = in[i]; out[i] =
3560 * x*x; } ); \endcode
3561 */
3562 template <typename IndexType, class Func>
parallel_for(ExecutionPolicy policy,IndexType begin,IndexType end,Lambda<Func> const & lambda)3563 CUresult parallel_for(ExecutionPolicy policy, IndexType begin, IndexType end,
3564 Lambda<Func> const& lambda) {
3565 using namespace jitify;
3566
3567 if (policy.location == HOST) {
3568 #ifdef _OPENMP
3569 #pragma omp parallel for
3570 #endif
3571 for (IndexType i = begin; i < end; i++) {
3572 lambda._func(i);
3573 }
3574 return CUDA_SUCCESS; // FIXME - replace with non-CUDA enum type?
3575 }
3576
3577 thread_local static JitCache kernel_cache(policy.cache_size);
3578
3579 std::vector<std::string> arg_decls;
3580 arg_decls.push_back("I begin, I end");
3581 arg_decls.insert(arg_decls.end(), lambda._capture._arg_decls.begin(),
3582 lambda._capture._arg_decls.end());
3583
3584 std::stringstream source_ss;
3585 source_ss << "parallel_for_program\n";
3586 for (auto const& header : policy.headers) {
3587 std::string header_name = header.substr(0, header.find("\n"));
3588 source_ss << "#include <" << header_name << ">\n";
3589 }
3590 source_ss << "template<typename I>\n"
3591 "__global__\n"
3592 "void parallel_for_kernel("
3593 << reflection::reflect_list(arg_decls)
3594 << ") {\n"
3595 " I i0 = threadIdx.x + blockDim.x*blockIdx.x;\n"
3596 " for( I i=i0+begin; i<end; i+=blockDim.x*gridDim.x ) {\n"
3597 " "
3598 << "\t" << lambda._func_string << ";\n"
3599 << " }\n"
3600 "}\n";
3601
3602 Program program = kernel_cache.program(source_ss.str(), policy.headers,
3603 policy.options, policy.file_callback);
3604
3605 std::vector<void*> arg_ptrs;
3606 arg_ptrs.push_back(&begin);
3607 arg_ptrs.push_back(&end);
3608 arg_ptrs.insert(arg_ptrs.end(), lambda._capture._arg_ptrs.begin(),
3609 lambda._capture._arg_ptrs.end());
3610
3611 size_t n = end - begin;
3612 dim3 block(policy.block_size);
3613 dim3 grid((unsigned int)std::min((n - 1) / block.x + 1, size_t(65535)));
3614 cudaSetDevice(policy.device);
3615 return program.kernel("parallel_for_kernel")
3616 .instantiate<IndexType>()
3617 .configure(grid, block, 0, policy.stream)
3618 .launch(arg_ptrs);
3619 }
3620
3621 namespace experimental {
3622
3623 using jitify::file_callback_type;
3624
3625 namespace serialization {
3626
3627 namespace detail {
3628
3629 // This should be incremented whenever the serialization format changes in any
3630 // incompatible way.
3631 static constexpr const size_t kSerializationVersion = 1;
3632
serialize(std::ostream & stream,size_t u)3633 inline void serialize(std::ostream& stream, size_t u) {
3634 uint64_t u64 = u;
3635 stream.write(reinterpret_cast<char*>(&u64), sizeof(u64));
3636 }
3637
deserialize(std::istream & stream,size_t * size)3638 inline bool deserialize(std::istream& stream, size_t* size) {
3639 uint64_t u64;
3640 stream.read(reinterpret_cast<char*>(&u64), sizeof(u64));
3641 *size = u64;
3642 return stream.good();
3643 }
3644
serialize(std::ostream & stream,std::string const & s)3645 inline void serialize(std::ostream& stream, std::string const& s) {
3646 serialize(stream, s.size());
3647 stream.write(s.data(), s.size());
3648 }
3649
deserialize(std::istream & stream,std::string * s)3650 inline bool deserialize(std::istream& stream, std::string* s) {
3651 size_t size;
3652 if (!deserialize(stream, &size)) return false;
3653 s->resize(size);
3654 if (s->size()) {
3655 stream.read(&(*s)[0], s->size());
3656 }
3657 return stream.good();
3658 }
3659
serialize(std::ostream & stream,std::vector<std::string> const & v)3660 inline void serialize(std::ostream& stream, std::vector<std::string> const& v) {
3661 serialize(stream, v.size());
3662 for (auto const& s : v) {
3663 serialize(stream, s);
3664 }
3665 }
3666
deserialize(std::istream & stream,std::vector<std::string> * v)3667 inline bool deserialize(std::istream& stream, std::vector<std::string>* v) {
3668 size_t size;
3669 if (!deserialize(stream, &size)) return false;
3670 v->resize(size);
3671 for (auto& s : *v) {
3672 if (!deserialize(stream, &s)) return false;
3673 }
3674 return true;
3675 }
3676
serialize(std::ostream & stream,std::map<std::string,std::string> const & m)3677 inline void serialize(std::ostream& stream,
3678 std::map<std::string, std::string> const& m) {
3679 serialize(stream, m.size());
3680 for (auto const& kv : m) {
3681 serialize(stream, kv.first);
3682 serialize(stream, kv.second);
3683 }
3684 }
3685
deserialize(std::istream & stream,std::map<std::string,std::string> * m)3686 inline bool deserialize(std::istream& stream,
3687 std::map<std::string, std::string>* m) {
3688 size_t size;
3689 if (!deserialize(stream, &size)) return false;
3690 for (size_t i = 0; i < size; ++i) {
3691 std::string key;
3692 if (!deserialize(stream, &key)) return false;
3693 if (!deserialize(stream, &(*m)[key])) return false;
3694 }
3695 return true;
3696 }
3697
3698 template <typename T, typename... Rest>
serialize(std::ostream & stream,T const & value,Rest...rest)3699 inline void serialize(std::ostream& stream, T const& value, Rest... rest) {
3700 serialize(stream, value);
3701 serialize(stream, rest...);
3702 }
3703
3704 template <typename T, typename... Rest>
deserialize(std::istream & stream,T * value,Rest...rest)3705 inline bool deserialize(std::istream& stream, T* value, Rest... rest) {
3706 if (!deserialize(stream, value)) return false;
3707 return deserialize(stream, rest...);
3708 }
3709
serialize_magic_number(std::ostream & stream)3710 inline void serialize_magic_number(std::ostream& stream) {
3711 stream.write("JTFY", 4);
3712 serialize(stream, kSerializationVersion);
3713 }
3714
deserialize_magic_number(std::istream & stream)3715 inline bool deserialize_magic_number(std::istream& stream) {
3716 char magic_number[4] = {0, 0, 0, 0};
3717 stream.read(&magic_number[0], 4);
3718 if (!(magic_number[0] == 'J' && magic_number[1] == 'T' &&
3719 magic_number[2] == 'F' && magic_number[3] == 'Y')) {
3720 return false;
3721 }
3722 size_t serialization_version;
3723 if (!deserialize(stream, &serialization_version)) return false;
3724 return serialization_version == kSerializationVersion;
3725 }
3726
3727 } // namespace detail
3728
3729 template <typename... Values>
serialize(Values const &...values)3730 inline std::string serialize(Values const&... values) {
3731 std::ostringstream ss(std::stringstream::out | std::stringstream::binary);
3732 detail::serialize_magic_number(ss);
3733 detail::serialize(ss, values...);
3734 return ss.str();
3735 }
3736
3737 template <typename... Values>
deserialize(std::string const & serialized,Values * ...values)3738 inline bool deserialize(std::string const& serialized, Values*... values) {
3739 std::istringstream ss(serialized,
3740 std::stringstream::in | std::stringstream::binary);
3741 if (!detail::deserialize_magic_number(ss)) return false;
3742 return detail::deserialize(ss, values...);
3743 }
3744
3745 } // namespace serialization
3746
3747 class Program;
3748 class Kernel;
3749 class KernelInstantiation;
3750 class KernelLauncher;
3751
3752 /*! An object representing a program made up of source code, headers
3753 * and options.
3754 */
3755 class Program {
3756 private:
3757 friend class KernelInstantiation;
3758 std::string _name;
3759 std::vector<std::string> _options;
3760 std::map<std::string, std::string> _sources;
3761
3762 // Private constructor used by deserialize()
Program()3763 Program() {}
3764
3765 public:
3766 /*! Create a program.
3767 *
3768 * \param source A string containing either the source filename or
3769 * the source itself; in the latter case, the first line must be
3770 * the name of the program.
3771 * \param headers A vector of strings representing the source of
3772 * each header file required by the program. Each entry can be
3773 * either the header filename or the header source itself; in
3774 * the latter case, the first line must be the name of the header
3775 * (i.e., the name by which the header is #included).
3776 * \param options A vector of options to be passed to the
3777 * NVRTC compiler. Include paths specified with \p -I
3778 * are added to the search paths used by Jitify. The environment
3779 * variable JITIFY_OPTIONS can also be used to define additional
3780 * options.
3781 * \param file_callback A pointer to a callback function that is
3782 * invoked whenever a source file needs to be loaded. Inside this
3783 * function, the user can either load/specify the source themselves
3784 * or defer to Jitify's file-loading mechanisms.
3785 * \note Program or header source files referenced by filename are
3786 * looked-up using the following mechanisms (in this order):
3787 * \note 1) By calling file_callback.
3788 * \note 2) By looking for the file embedded in the executable via the GCC
3789 * linker.
3790 * \note 3) By looking for the file in the filesystem.
3791 *
3792 * \note Jitify recursively scans all source files for \p #include
3793 * directives and automatically adds them to the set of headers needed
3794 * by the program.
3795 * If a \p #include directive references a header that cannot be found,
3796 * the directive is automatically removed from the source code to prevent
3797 * immediate compilation failure. This may result in compilation errors
3798 * if the header was required by the program.
3799 *
3800 * \note Jitify automatically includes NVRTC-safe versions of some
3801 * standard library headers.
3802 */
Program(std::string const & cuda_source,std::vector<std::string> const & given_headers={},std::vector<std::string> const & given_options={},file_callback_type file_callback=nullptr)3803 Program(std::string const& cuda_source,
3804 std::vector<std::string> const& given_headers = {},
3805 std::vector<std::string> const& given_options = {},
3806 file_callback_type file_callback = nullptr) {
3807 // Add pre-include built-in JIT-safe headers
3808 std::vector<std::string> headers = given_headers;
3809 for (int i = 0; i < detail::preinclude_jitsafe_headers_count; ++i) {
3810 const char* hdr_name = detail::preinclude_jitsafe_header_names[i];
3811 const std::string& hdr_source =
3812 detail::get_jitsafe_headers_map().at(hdr_name);
3813 headers.push_back(std::string(hdr_name) + "\n" + hdr_source);
3814 }
3815
3816 _options = given_options;
3817 detail::add_options_from_env(_options);
3818 std::vector<std::string> include_paths;
3819 detail::load_program(cuda_source, headers, file_callback, &include_paths,
3820 &_sources, &_options, &_name);
3821 }
3822
3823 /*! Restore a serialized program.
3824 *
3825 * \param serialized_program The serialized program to restore.
3826 *
3827 * \see serialize
3828 */
deserialize(std::string const & serialized_program)3829 static Program deserialize(std::string const& serialized_program) {
3830 Program program;
3831 if (!serialization::deserialize(serialized_program, &program._name,
3832 &program._options, &program._sources)) {
3833 throw std::runtime_error("Failed to deserialize program");
3834 }
3835 return program;
3836 }
3837
3838 /*! Save the program.
3839 *
3840 * \see deserialize
3841 */
serialize() const3842 std::string serialize() const {
3843 // Note: Must update kSerializationVersion if this is changed.
3844 return serialization::serialize(_name, _options, _sources);
3845 };
3846
3847 /*! Select a kernel.
3848 *
3849 * \param name The name of the kernel (unmangled and without
3850 * template arguments).
3851 * \param options A vector of options to be passed to the NVRTC
3852 * compiler when compiling this kernel.
3853 */
3854 Kernel kernel(std::string const& name,
3855 std::vector<std::string> const& options = {}) const;
3856 };
3857
3858 class Kernel {
3859 friend class KernelInstantiation;
3860 Program const* _program;
3861 std::string _name;
3862 std::vector<std::string> _options;
3863
3864 public:
Kernel(Program const * program,std::string const & name,std::vector<std::string> const & options={})3865 Kernel(Program const* program, std::string const& name,
3866 std::vector<std::string> const& options = {})
3867 : _program(program), _name(name), _options(options) {}
3868
3869 /*! Instantiate the kernel.
3870 *
3871 * \param template_args A vector of template arguments represented as
3872 * code-strings. These can be generated using
3873 * \code{.cpp}jitify::reflection::reflect<type>()\endcode or
3874 * \code{.cpp}jitify::reflection::reflect(value)\endcode
3875 *
3876 * \note Template type deduction is not possible, so all types must be
3877 * explicitly specified.
3878 */
3879 KernelInstantiation instantiate(
3880 std::vector<std::string> const& template_args =
3881 std::vector<std::string>()) const;
3882
3883 // Regular template instantiation syntax (note limited flexibility)
3884 /*! Instantiate the kernel.
3885 *
3886 * \note The template arguments specified on this function are
3887 * used to instantiate the kernel. Non-type template arguments must
3888 * be wrapped with
3889 * \code{.cpp}jitify::reflection::NonType<type,value>\endcode
3890 *
3891 * \note Template type deduction is not possible, so all types must be
3892 * explicitly specified.
3893 */
3894 template <typename... TemplateArgs>
3895 KernelInstantiation instantiate() const;
3896
3897 // Template-like instantiation syntax
3898 // E.g., instantiate(myvar,Type<MyType>())(grid,block)
3899 /*! Instantiate the kernel.
3900 *
3901 * \param targs The template arguments for the kernel, represented as
3902 * values. Types must be wrapped with
3903 * \code{.cpp}jitify::reflection::Type<type>()\endcode or
3904 * \code{.cpp}jitify::reflection::type_of(value)\endcode
3905 *
3906 * \note Template type deduction is not possible, so all types must be
3907 * explicitly specified.
3908 */
3909 template <typename... TemplateArgs>
3910 KernelInstantiation instantiate(TemplateArgs... targs) const;
3911 };
3912
3913 class KernelInstantiation {
3914 friend class KernelLauncher;
3915 std::unique_ptr<detail::CUDAKernel> _cuda_kernel;
3916
3917 // Private constructor used by deserialize()
KernelInstantiation(std::string const & func_name,std::string const & ptx,std::vector<std::string> const & link_files,std::vector<std::string> const & link_paths)3918 KernelInstantiation(std::string const& func_name, std::string const& ptx,
3919 std::vector<std::string> const& link_files,
3920 std::vector<std::string> const& link_paths)
3921 : _cuda_kernel(new detail::CUDAKernel(func_name.c_str(), ptx.c_str(),
3922 link_files, link_paths)) {}
3923
3924 public:
KernelInstantiation(Kernel const & kernel,std::vector<std::string> const & template_args)3925 KernelInstantiation(Kernel const& kernel,
3926 std::vector<std::string> const& template_args) {
3927 Program const* program = kernel._program;
3928
3929 std::string template_inst =
3930 (template_args.empty() ? ""
3931 : reflection::reflect_template(template_args));
3932 std::string instantiation = kernel._name + template_inst;
3933
3934 std::vector<std::string> options;
3935 options.insert(options.begin(), program->_options.begin(),
3936 program->_options.end());
3937 options.insert(options.begin(), kernel._options.begin(),
3938 kernel._options.end());
3939 detail::detect_and_add_cuda_arch(options);
3940 detail::detect_and_add_cxx11_flag(options);
3941
3942 std::string log, ptx, mangled_instantiation;
3943 std::vector<std::string> linker_files, linker_paths;
3944 detail::instantiate_kernel(program->_name, program->_sources, instantiation,
3945 options, &log, &ptx, &mangled_instantiation,
3946 &linker_files, &linker_paths);
3947
3948 _cuda_kernel.reset(new detail::CUDAKernel(mangled_instantiation.c_str(),
3949 ptx.c_str(), linker_files,
3950 linker_paths));
3951 }
3952
3953 /*! Implicit conversion to the underlying CUfunction object.
3954 *
3955 * \note This allows use of CUDA APIs like
3956 * cuOccupancyMaxActiveBlocksPerMultiprocessor.
3957 */
operator CUfunction() const3958 operator CUfunction() const { return *_cuda_kernel; }
3959
3960 /*! Restore a serialized kernel instantiation.
3961 *
3962 * \param serialized_kernel_inst The serialized kernel instantiation to
3963 * restore.
3964 *
3965 * \see serialize
3966 */
deserialize(std::string const & serialized_kernel_inst)3967 static KernelInstantiation deserialize(
3968 std::string const& serialized_kernel_inst) {
3969 std::string func_name, ptx;
3970 std::vector<std::string> link_files, link_paths;
3971 if (!serialization::deserialize(serialized_kernel_inst, &func_name, &ptx,
3972 &link_files, &link_paths)) {
3973 throw std::runtime_error("Failed to deserialize kernel instantiation");
3974 }
3975 return KernelInstantiation(func_name, ptx, link_files, link_paths);
3976 }
3977
3978 /*! Save the program.
3979 *
3980 * \see deserialize
3981 */
serialize() const3982 std::string serialize() const {
3983 // Note: Must update kSerializationVersion if this is changed.
3984 return serialization::serialize(
3985 _cuda_kernel->function_name(), _cuda_kernel->ptx(),
3986 _cuda_kernel->link_files(), _cuda_kernel->link_paths());
3987 }
3988
3989 /*! Configure the kernel launch.
3990 *
3991 * \param grid The thread grid dimensions for the launch.
3992 * \param block The thread block dimensions for the launch.
3993 * \param smem The amount of shared memory to dynamically allocate, in
3994 * bytes.
3995 * \param stream The CUDA stream to launch the kernel in.
3996 */
3997 KernelLauncher configure(dim3 grid, dim3 block, unsigned int smem = 0,
3998 cudaStream_t stream = 0) const;
3999
4000 /*! Configure the kernel launch with a 1-dimensional block and grid chosen
4001 * automatically to maximise occupancy.
4002 *
4003 * \param max_block_size The upper limit on the block size, or 0 for no
4004 * limit.
4005 * \param smem The amount of shared memory to dynamically allocate, in bytes.
4006 * \param smem_callback A function returning smem for a given block size
4007 * (overrides \p smem).
4008 * \param stream The CUDA stream to launch the kernel in.
4009 * \param flags The flags to pass to
4010 * cuOccupancyMaxPotentialBlockSizeWithFlags.
4011 */
4012 KernelLauncher configure_1d_max_occupancy(
4013 int max_block_size = 0, unsigned int smem = 0,
4014 CUoccupancyB2DSize smem_callback = 0, cudaStream_t stream = 0,
4015 unsigned int flags = 0) const;
4016
4017 /*
4018 * \deprecated Use \p get_global_ptr instead.
4019 */
get_constant_ptr(const char * name,size_t * size=nullptr) const4020 CUdeviceptr get_constant_ptr(const char* name, size_t* size = nullptr) const {
4021 return get_global_ptr(name, size);
4022 }
4023
4024 /*
4025 * Get a device pointer to a global __constant__ or __device__ variable using
4026 * its un-mangled name. If provided, *size is set to the size of the variable
4027 * in bytes.
4028 */
get_global_ptr(const char * name,size_t * size=nullptr) const4029 CUdeviceptr get_global_ptr(const char* name, size_t* size = nullptr) const {
4030 return _cuda_kernel->get_global_ptr(name, size);
4031 }
4032
4033 /*
4034 * Copy data from a global __constant__ or __device__ array to the host using
4035 * its un-mangled name.
4036 */
4037 template <typename T>
get_global_array(const char * name,T * data,size_t count,CUstream stream=0) const4038 CUresult get_global_array(const char* name, T* data, size_t count,
4039 CUstream stream = 0) const {
4040 return _cuda_kernel->get_global_data(name, data, count, stream);
4041 }
4042
4043 /*
4044 * Copy a value from a global __constant__ or __device__ variable to the host
4045 * using its un-mangled name.
4046 */
4047 template <typename T>
get_global_value(const char * name,T * value,CUstream stream=0) const4048 CUresult get_global_value(const char* name, T* value,
4049 CUstream stream = 0) const {
4050 return get_global_array(name, value, 1, stream);
4051 }
4052
4053 /*
4054 * Copy data from the host to a global __constant__ or __device__ array using
4055 * its un-mangled name.
4056 */
4057 template <typename T>
set_global_array(const char * name,const T * data,size_t count,CUstream stream=0) const4058 CUresult set_global_array(const char* name, const T* data, size_t count,
4059 CUstream stream = 0) const {
4060 return _cuda_kernel->set_global_data(name, data, count, stream);
4061 }
4062
4063 /*
4064 * Copy a value from the host to a global __constant__ or __device__ variable
4065 * using its un-mangled name.
4066 */
4067 template <typename T>
set_global_value(const char * name,const T & value,CUstream stream=0) const4068 CUresult set_global_value(const char* name, const T& value,
4069 CUstream stream = 0) const {
4070 return set_global_array(name, &value, 1, stream);
4071 }
4072
mangled_name() const4073 const std::string& mangled_name() const {
4074 return _cuda_kernel->function_name();
4075 }
4076
ptx() const4077 const std::string& ptx() const { return _cuda_kernel->ptx(); }
4078
link_files() const4079 const std::vector<std::string>& link_files() const {
4080 return _cuda_kernel->link_files();
4081 }
4082
link_paths() const4083 const std::vector<std::string>& link_paths() const {
4084 return _cuda_kernel->link_paths();
4085 }
4086 };
4087
4088 class KernelLauncher {
4089 KernelInstantiation const* _kernel_inst;
4090 dim3 _grid;
4091 dim3 _block;
4092 unsigned int _smem;
4093 cudaStream_t _stream;
4094
4095 public:
KernelLauncher(KernelInstantiation const * kernel_inst,dim3 grid,dim3 block,unsigned int smem=0,cudaStream_t stream=0)4096 KernelLauncher(KernelInstantiation const* kernel_inst, dim3 grid, dim3 block,
4097 unsigned int smem = 0, cudaStream_t stream = 0)
4098 : _kernel_inst(kernel_inst),
4099 _grid(grid),
4100 _block(block),
4101 _smem(smem),
4102 _stream(stream) {}
4103
4104 // Note: It's important that there is no implicit conversion required
4105 // for arg_ptrs, because otherwise the parameter pack version
4106 // below gets called instead (probably resulting in a segfault).
4107 /*! Launch the kernel.
4108 *
4109 * \param arg_ptrs A vector of pointers to each function argument for the
4110 * kernel.
4111 * \param arg_types A vector of function argument types represented
4112 * as code-strings. This parameter is optional and is only used to print
4113 * out the function signature.
4114 */
launch(std::vector<void * > arg_ptrs={},std::vector<std::string> arg_types={}) const4115 CUresult launch(std::vector<void*> arg_ptrs = {},
4116 std::vector<std::string> arg_types = {}) const {
4117 #if JITIFY_PRINT_LAUNCH
4118 std::string arg_types_string =
4119 (arg_types.empty() ? "..." : reflection::reflect_list(arg_types));
4120 std::cout << "Launching " << _kernel_inst->_cuda_kernel->function_name()
4121 << "<<<" << _grid << "," << _block << "," << _smem << ","
4122 << _stream << ">>>"
4123 << "(" << arg_types_string << ")" << std::endl;
4124 #endif
4125 return _kernel_inst->_cuda_kernel->launch(_grid, _block, _smem, _stream,
4126 arg_ptrs);
4127 }
4128
4129 /*! Launch the kernel.
4130 *
4131 * \param args Function arguments for the kernel.
4132 */
4133 template <typename... ArgTypes>
launch(ArgTypes...args) const4134 CUresult launch(ArgTypes... args) const {
4135 return this->launch(std::vector<void*>({(void*)&args...}),
4136 {reflection::reflect<ArgTypes>()...});
4137 }
4138 };
4139
kernel(std::string const & name,std::vector<std::string> const & options) const4140 inline Kernel Program::kernel(std::string const& name,
4141 std::vector<std::string> const& options) const {
4142 return Kernel(this, name, options);
4143 }
4144
instantiate(std::vector<std::string> const & template_args) const4145 inline KernelInstantiation Kernel::instantiate(
4146 std::vector<std::string> const& template_args) const {
4147 return KernelInstantiation(*this, template_args);
4148 }
4149
4150 template <typename... TemplateArgs>
instantiate() const4151 inline KernelInstantiation Kernel::instantiate() const {
4152 return this->instantiate(
4153 std::vector<std::string>({reflection::reflect<TemplateArgs>()...}));
4154 }
4155
4156 template <typename... TemplateArgs>
instantiate(TemplateArgs...targs) const4157 inline KernelInstantiation Kernel::instantiate(TemplateArgs... targs) const {
4158 return this->instantiate(
4159 std::vector<std::string>({reflection::reflect(targs)...}));
4160 }
4161
configure(dim3 grid,dim3 block,unsigned int smem,cudaStream_t stream) const4162 inline KernelLauncher KernelInstantiation::configure(
4163 dim3 grid, dim3 block, unsigned int smem, cudaStream_t stream) const {
4164 return KernelLauncher(this, grid, block, smem, stream);
4165 }
4166
configure_1d_max_occupancy(int max_block_size,unsigned int smem,CUoccupancyB2DSize smem_callback,cudaStream_t stream,unsigned int flags) const4167 inline KernelLauncher KernelInstantiation::configure_1d_max_occupancy(
4168 int max_block_size, unsigned int smem, CUoccupancyB2DSize smem_callback,
4169 cudaStream_t stream, unsigned int flags) const {
4170 int grid;
4171 int block;
4172 CUfunction func = *_cuda_kernel;
4173 detail::get_1d_max_occupancy(func, smem_callback, &smem, max_block_size,
4174 flags, &grid, &block);
4175 return this->configure(grid, block, smem, stream);
4176 }
4177
4178 } // namespace experimental
4179
4180 } // namespace jitify
4181
4182 #if defined(_WIN32) || defined(_WIN64)
4183 #pragma pop_macro("max")
4184 #pragma pop_macro("min")
4185 #pragma pop_macro("strtok_r")
4186 #endif
4187