1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <cassert>
10 #include <charconv>
11 #include <cstdint>
12 #include <fstream>
13 #include <initializer_list>
14 #include <ios>
15 #include <memory>
16 #include <mutex>
17 #include <optional>
18 #include <ostream>
19 #include <regex>
20 #include <sstream>
21 #include <stdexcept>
22 #include <string>
23 #include <system_error>
24 #include <tuple>
25 #include <type_traits>
26 #include <utility>
27 #include <variant>
28 #include <vector>
29 
30 #include <boost/algorithm/string/predicate.hpp>
31 #include <boost/numeric/conversion/cast.hpp>
32 
33 #include <fmt/format.h>
34 
35 #include <llvm/ADT/SmallString.h>
36 #include <llvm/ADT/SmallVector.h>
37 #include <llvm/ADT/Triple.h>
38 #include <llvm/Analysis/TargetLibraryInfo.h>
39 #include <llvm/Analysis/TargetTransformInfo.h>
40 #include <llvm/CodeGen/TargetPassConfig.h>
41 #include <llvm/Config/llvm-config.h>
42 #include <llvm/ExecutionEngine/JITSymbol.h>
43 #include <llvm/ExecutionEngine/Orc/CompileUtils.h>
44 #include <llvm/ExecutionEngine/Orc/Core.h>
45 #include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
46 #include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
47 #include <llvm/ExecutionEngine/Orc/LLJIT.h>
48 #include <llvm/ExecutionEngine/Orc/ObjectTransformLayer.h>
49 #include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
50 #include <llvm/IR/Attributes.h>
51 #include <llvm/IR/Function.h>
52 #include <llvm/IR/IRBuilder.h>
53 #include <llvm/IR/InstrTypes.h>
54 #include <llvm/IR/LLVMContext.h>
55 #include <llvm/IR/LegacyPassManager.h>
56 #include <llvm/IR/Module.h>
57 #include <llvm/IR/Operator.h>
58 #include <llvm/IR/Value.h>
59 #include <llvm/IR/Verifier.h>
60 #include <llvm/IRReader/IRReader.h>
61 #include <llvm/Pass.h>
62 #include <llvm/Support/MemoryBuffer.h>
63 #include <llvm/Support/SmallVectorMemoryBuffer.h>
64 #include <llvm/Support/SourceMgr.h>
65 #include <llvm/Support/TargetRegistry.h>
66 #include <llvm/Support/TargetSelect.h>
67 #include <llvm/Support/raw_ostream.h>
68 #include <llvm/Target/TargetMachine.h>
69 #include <llvm/Transforms/IPO.h>
70 #include <llvm/Transforms/IPO/PassManagerBuilder.h>
71 #include <llvm/Transforms/Vectorize.h>
72 
73 #if LLVM_VERSION_MAJOR == 10
74 
75 #include <llvm/CodeGen/CommandFlags.inc>
76 
77 #endif
78 
79 #include <heyoka/detail/llvm_fwd.hpp>
80 #include <heyoka/llvm_state.hpp>
81 #include <heyoka/number.hpp>
82 #include <heyoka/s11n.hpp>
83 #include <heyoka/variable.hpp>
84 
85 #if defined(_MSC_VER) && !defined(__clang__)
86 
87 // NOTE: MSVC has issues with the other "using"
88 // statement form.
89 using namespace fmt::literals;
90 
91 #else
92 
93 using fmt::literals::operator""_format;
94 
95 #endif
96 
97 namespace heyoka
98 {
99 
100 namespace detail
101 {
102 
103 namespace
104 {
105 
106 // Make sure our definition of ir_builder matches llvm::IRBuilder<>.
107 static_assert(std::is_same_v<ir_builder, llvm::IRBuilder<>>, "Inconsistent definition of the ir_builder type.");
108 
109 // LCOV_EXCL_START
110 
111 // Helper function to detect specific features
112 // on the host machine via LLVM's machinery.
get_target_features_impl()113 target_features get_target_features_impl()
114 {
115     auto jtmb = llvm::orc::JITTargetMachineBuilder::detectHost();
116     if (!jtmb) {
117         throw std::invalid_argument("Error creating a JITTargetMachineBuilder for the host system");
118     }
119 
120     auto tm = jtmb->createTargetMachine();
121     if (!tm) {
122         throw std::invalid_argument("Error creating the target machine");
123     }
124 
125     target_features retval;
126 
127     const auto target_name = std::string{(*tm)->getTarget().getName()};
128 
129     if (boost::starts_with(target_name, "x86")) {
130         const auto t_features = (*tm)->getTargetFeatureString();
131 
132         if (boost::algorithm::contains(t_features, "+avx512f")) {
133             retval.avx512f = true;
134         }
135 
136         if (boost::algorithm::contains(t_features, "+avx2")) {
137             retval.avx2 = true;
138         }
139 
140         if (boost::algorithm::contains(t_features, "+avx")) {
141             retval.avx = true;
142         }
143 
144         if (boost::algorithm::contains(t_features, "+sse2")) {
145             retval.sse2 = true;
146         }
147     }
148 
149     if (boost::starts_with(target_name, "aarch64")) {
150         retval.aarch64 = true;
151     }
152 
153     if (boost::starts_with(target_name, "ppc")) {
154         // On powerpc, detect the presence of the VSX
155         // instruction set from the CPU string.
156         const auto target_cpu = std::string{(*tm)->getTargetCPU()};
157 
158         // NOTE: the pattern reported by LLVM here seems to be pwrN
159         // (sample size of 1, on travis...).
160         std::regex pattern("pwr([1-9]*)");
161         std::cmatch m;
162 
163         if (std::regex_match(target_cpu.c_str(), m, pattern)) {
164             if (m.size() == 2u) {
165                 // The CPU name matches and contains a subgroup.
166                 // Extract the N from "pwrN".
167                 std::uint32_t pwr_idx{};
168                 auto ret = std::from_chars(m[1].first, m[1].second, pwr_idx);
169 
170                 // NOTE: it looks like VSX3 is supported from Power9,
171                 // VSX from Power7.
172                 // https://packages.gentoo.org/useflags/cpu_flags_ppc_vsx3
173                 if (ret.ec == std::errc{}) {
174                     if (pwr_idx >= 9) {
175                         retval.vsx3 = true;
176                     }
177 
178                     if (pwr_idx >= 7) {
179                         retval.vsx = true;
180                     }
181                 }
182             }
183         }
184     }
185 
186     return retval;
187 }
188 
189 // LCOV_EXCL_STOP
190 
191 } // namespace
192 
193 // Helper function to fetch a const ref to a global object
194 // containing info about the host machine.
get_target_features()195 const target_features &get_target_features()
196 {
197     static const target_features retval{get_target_features_impl()};
198 
199     return retval;
200 }
201 
202 namespace
203 {
204 
205 std::once_flag nt_inited;
206 
207 } // namespace
208 
209 } // namespace detail
210 
211 // Implementation of the jit class.
212 struct llvm_state::jit {
213     std::unique_ptr<llvm::orc::LLJIT> m_lljit;
214     std::unique_ptr<llvm::TargetMachine> m_tm;
215     std::unique_ptr<llvm::orc::ThreadSafeContext> m_ctx;
216 #if LLVM_VERSION_MAJOR == 10
217     std::unique_ptr<llvm::Triple> m_triple;
218 #endif
219     std::optional<std::string> m_object_file;
220 
jitheyoka::llvm_state::jit221     jit()
222     {
223         // NOTE: the native target initialization needs to be done only once
224         std::call_once(detail::nt_inited, []() {
225             llvm::InitializeNativeTarget();
226             llvm::InitializeNativeTargetAsmPrinter();
227             llvm::InitializeNativeTargetAsmParser();
228         });
229 
230         // Create the target machine builder.
231         auto jtmb = llvm::orc::JITTargetMachineBuilder::detectHost();
232         // LCOV_EXCL_START
233         if (!jtmb) {
234             throw std::invalid_argument("Error creating a JITTargetMachineBuilder for the host system");
235         }
236         // LCOV_EXCL_STOP
237         // Set the codegen optimisation level to aggressive.
238         jtmb->setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive);
239 
240         // Create the jit builder.
241         llvm::orc::LLJITBuilder lljit_builder;
242         // NOTE: other settable properties may
243         // be of interest:
244         // https://www.llvm.org/doxygen/classllvm_1_1orc_1_1LLJITBuilder.html
245         lljit_builder.setJITTargetMachineBuilder(*jtmb);
246 
247         // Create the jit.
248         auto lljit = lljit_builder.create();
249         // LCOV_EXCL_START
250         if (!lljit) {
251             throw std::invalid_argument("Error creating an LLJIT object");
252         }
253         // LCOV_EXCL_STOP
254         m_lljit = std::move(*lljit);
255 
256         // Setup the machinery to cache the module's binary code
257         // when it is lazily generated.
258         m_lljit->getObjTransformLayer().setTransform([this](std::unique_ptr<llvm::MemoryBuffer> obj_buffer) {
259             assert(obj_buffer);
260             assert(!m_object_file);
261 
262             // Copy obj_buffer to the local m_object_file member.
263             m_object_file.emplace(obj_buffer->getBufferStart(), obj_buffer->getBufferEnd());
264 
265             return llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>(std::move(obj_buffer));
266         });
267 
268         // Setup the jit so that it can look up symbols from the current process.
269         auto dlsg = llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
270             m_lljit->getDataLayout().getGlobalPrefix());
271         // LCOV_EXCL_START
272         if (!dlsg) {
273             throw std::invalid_argument("Could not create the dynamic library search generator");
274         }
275         // LCOV_EXCL_STOP
276         m_lljit->getMainJITDylib().addGenerator(std::move(*dlsg));
277 
278         // Keep a target machine around to fetch various
279         // properties of the host CPU.
280         auto tm = jtmb->createTargetMachine();
281         // LCOV_EXCL_START
282         if (!tm) {
283             throw std::invalid_argument("Error creating the target machine");
284         }
285         // LCOV_EXCL_STOP
286         m_tm = std::move(*tm);
287 
288         // Create the context.
289         m_ctx = std::make_unique<llvm::orc::ThreadSafeContext>(std::make_unique<llvm::LLVMContext>());
290 
291 #if LLVM_VERSION_MAJOR == 10
292         // NOTE: on LLVM 10, we cannot fetch the target triple
293         // from the lljit class. Thus, we get it from the jtmb instead.
294         m_triple = std::make_unique<llvm::Triple>(jtmb->getTargetTriple());
295 #endif
296 
297         // NOTE: by default, errors in the execution session are printed
298         // to screen. A custom error reported can be specified, ideally
299         // we would like th throw here but I am not sure whether throwing
300         // here would disrupt LLVM's cleanup actions?
301         // https://llvm.org/doxygen/classllvm_1_1orc_1_1ExecutionSession.html
302     }
303 
304     jit(const jit &) = delete;
305     jit(jit &&) = delete;
306     jit &operator=(const jit &) = delete;
307     jit &operator=(jit &&) = delete;
308 
309     ~jit() = default;
310 
311     // Accessors.
get_contextheyoka::llvm_state::jit312     llvm::LLVMContext &get_context()
313     {
314         return *m_ctx->getContext();
315     }
get_contextheyoka::llvm_state::jit316     const llvm::LLVMContext &get_context() const
317     {
318         return *m_ctx->getContext();
319     }
get_target_cpuheyoka::llvm_state::jit320     std::string get_target_cpu() const
321     {
322         return m_tm->getTargetCPU().str();
323     }
get_target_featuresheyoka::llvm_state::jit324     std::string get_target_features() const
325     {
326         return m_tm->getTargetFeatureString().str();
327     }
get_target_ir_analysisheyoka::llvm_state::jit328     llvm::TargetIRAnalysis get_target_ir_analysis() const
329     {
330         return m_tm->getTargetIRAnalysis();
331     }
get_target_tripleheyoka::llvm_state::jit332     const llvm::Triple &get_target_triple() const
333     {
334 #if LLVM_VERSION_MAJOR == 10
335         return *m_triple;
336 #else
337         return m_lljit->getTargetTriple();
338 #endif
339     }
340 
add_moduleheyoka::llvm_state::jit341     void add_module(std::unique_ptr<llvm::Module> m)
342     {
343         auto err = m_lljit->addIRModule(llvm::orc::ThreadSafeModule(std::move(m), *m_ctx));
344 
345         // LCOV_EXCL_START
346         if (err) {
347             std::string err_report;
348             llvm::raw_string_ostream ostr(err_report);
349 
350             ostr << err;
351 
352             throw std::invalid_argument(
353                 "The function for adding a module to the jit failed. The full error message:\n{}"_format(ostr.str()));
354         }
355         // LCOV_EXCL_STOP
356     }
357 
358     // Symbol lookup.
lookupheyoka::llvm_state::jit359     llvm::Expected<llvm::JITEvaluatedSymbol> lookup(const std::string &name)
360     {
361         return m_lljit->lookup(name);
362     }
363 };
364 
365 // Small shared helper to setup the math flags in the builder at the
366 // end of a constructor or a deserialization.
ctor_setup_math_flags()367 void llvm_state::ctor_setup_math_flags()
368 {
369     assert(m_builder);
370 
371     llvm::FastMathFlags fmf;
372 
373     if (m_fast_math) {
374         // Set flags for faster math at the
375         // price of potential change of semantics.
376         fmf.setFast();
377     } else {
378         // By default, allow only fp contraction.
379         // NOTE: if we ever implement double-double
380         // arithmetic, we must either revisit this
381         // or make sure that fp contraction is off
382         // for the double-double primitives.
383         fmf.setAllowContract();
384     }
385 
386     m_builder->setFastMathFlags(fmf);
387 }
388 
389 namespace detail
390 {
391 
392 namespace
393 {
394 
395 // Helper to load object code into a jit.
396 template <typename Jit>
llvm_state_add_obj_to_jit(Jit & j,const std::string & obj)397 void llvm_state_add_obj_to_jit(Jit &j, const std::string &obj)
398 {
399     llvm::SmallVector<char, 0> buffer(obj.begin(), obj.end());
400     auto err = j.m_lljit->addObjectFile(std::make_unique<llvm::SmallVectorMemoryBuffer>(std::move(buffer)));
401 
402     // LCOV_EXCL_START
403     if (err) {
404         std::string err_report;
405         llvm::raw_string_ostream ostr(err_report);
406 
407         ostr << err;
408 
409         throw std::invalid_argument(
410             "The function for adding a compiled module to the jit failed. The full error message:\n{}"_format(
411                 ostr.str()));
412     }
413     // LCOV_EXCL_STOP
414 }
415 
416 // Helper to create an LLVM module from a IR in string representation.
llvm_state_ir_to_module(std::string && ir,llvm::LLVMContext & ctx)417 auto llvm_state_ir_to_module(std::string &&ir, llvm::LLVMContext &ctx)
418 {
419     // Create the corresponding memory buffer.
420     auto mb = llvm::MemoryBuffer::getMemBuffer(std::move(ir));
421 
422     // Construct a new module from the parsed IR.
423     llvm::SMDiagnostic err;
424     auto ret = llvm::parseIR(*mb, err, ctx);
425 
426     // LCOV_EXCL_START
427     if (!ret) {
428         std::string err_report;
429         llvm::raw_string_ostream ostr(err_report);
430 
431         err.print("", ostr);
432 
433         throw std::invalid_argument("IR parsing failed. The full error message:\n{}"_format(ostr.str()));
434     }
435     // LCOV_EXCL_STOP
436 
437     return ret;
438 }
439 
440 } // namespace
441 
442 } // namespace detail
443 
llvm_state(std::tuple<std::string,unsigned,bool,bool> && tup)444 llvm_state::llvm_state(std::tuple<std::string, unsigned, bool, bool> &&tup)
445     : m_jitter(std::make_unique<jit>()), m_opt_level(std::get<1>(tup)), m_fast_math(std::get<2>(tup)),
446       m_module_name(std::move(std::get<0>(tup))), m_inline_functions(std::get<3>(tup))
447 {
448     // Create the module.
449     m_module = std::make_unique<llvm::Module>(m_module_name, context());
450     // Setup the data layout and the target triple.
451     m_module->setDataLayout(m_jitter->m_lljit->getDataLayout());
452     m_module->setTargetTriple(m_jitter->get_target_triple().str());
453 
454     // Create a new builder for the module.
455     m_builder = std::make_unique<ir_builder>(context());
456 
457     // Setup the math flags in the builder.
458     ctor_setup_math_flags();
459 }
460 
461 // NOTE: this will ensure that all kwargs
462 // are set to their default values.
llvm_state()463 llvm_state::llvm_state() : llvm_state(kw_args_ctor_impl()) {}
464 
llvm_state(const llvm_state & other)465 llvm_state::llvm_state(const llvm_state &other)
466     // NOTE: start off by:
467     // - creating a new jit,
468     // - copying over the options from other.
469     : m_jitter(std::make_unique<jit>()), m_opt_level(other.m_opt_level), m_fast_math(other.m_fast_math),
470       m_module_name(other.m_module_name), m_inline_functions(other.m_inline_functions)
471 {
472     if (other.is_compiled() && other.m_jitter->m_object_file) {
473         // 'other' was compiled and code was generated.
474         // We leave module and builder empty, copy over the
475         // IR snapshot and add the cached compiled module
476         // to the jit.
477         m_ir_snapshot = other.m_ir_snapshot;
478         detail::llvm_state_add_obj_to_jit(*m_jitter, *other.m_jitter->m_object_file);
479     } else {
480         // 'other' has not been compiled yet, or
481         // it has been compiled but no code has been
482         // lazily generated yet.
483         // We will fetch its IR and reconstruct
484         // module and builder.
485 
486         // Get the IR of other.
487         // NOTE: this works regardless of the compiled
488         // status of other.
489         auto other_ir = other.get_ir();
490 
491         // Create the module from the IR.
492         m_module = detail::llvm_state_ir_to_module(std::move(other_ir), context());
493 
494         // Create a new builder for the module.
495         m_builder = std::make_unique<ir_builder>(context());
496 
497         // Setup the math flags in the builder.
498         ctor_setup_math_flags();
499 
500         // Compile if needed.
501         // NOTE: compilation will take care of setting up m_ir_snapshot.
502         // If no compilation happens, m_ir_snapshot is left empty after init.
503         if (other.is_compiled()) {
504             compile();
505         }
506     }
507 }
508 
509 llvm_state::llvm_state(llvm_state &&) noexcept = default;
510 
operator =(const llvm_state & other)511 llvm_state &llvm_state::operator=(const llvm_state &other)
512 {
513     if (this != &other) {
514         *this = llvm_state(other);
515     }
516 
517     return *this;
518 }
519 
520 // NOTE: this cannot be defaulted because the moving of the LLVM objects
521 // needs to be done in a different order.
operator =(llvm_state && other)522 llvm_state &llvm_state::operator=(llvm_state &&other) noexcept
523 {
524     if (this != &other) {
525         // The LLVM bits.
526         m_builder = std::move(other.m_builder);
527         m_module = std::move(other.m_module);
528         m_jitter = std::move(other.m_jitter);
529 
530         // The remaining bits.
531         m_opt_level = other.m_opt_level;
532         m_ir_snapshot = std::move(other.m_ir_snapshot);
533         m_fast_math = other.m_fast_math;
534         m_module_name = std::move(other.m_module_name);
535         m_inline_functions = other.m_inline_functions;
536     }
537 
538     return *this;
539 }
540 
541 llvm_state::~llvm_state() = default;
542 
543 // NOTE: the save/load logic is essentially the same as in the
544 // copy constructor. Specifically, we have 2 different paths
545 // depending on whether the state is compiled AND object
546 // code was generated.
547 template <typename Archive>
save_impl(Archive & ar,unsigned) const548 void llvm_state::save_impl(Archive &ar, unsigned) const
549 {
550     // Start by establishing if the state is compiled and binary
551     // code has been emitted.
552     // NOTE: we need both flags when deserializing.
553     const auto cmp = is_compiled();
554     ar << cmp;
555 
556     const auto with_obj = static_cast<bool>(m_jitter->m_object_file);
557     ar << with_obj;
558 
559     assert(!with_obj || cmp);
560 
561     // Store the config options.
562     ar << m_opt_level;
563     ar << m_fast_math;
564     ar << m_module_name;
565     ar << m_inline_functions;
566 
567     // Store the IR.
568     // NOTE: avoid get_ir() if the module has been compiled,
569     // and use the snapshot directly, so that we don't make
570     // a useless copy.
571     if (cmp) {
572         ar << m_ir_snapshot;
573     } else {
574         ar << get_ir();
575     }
576 
577     if (with_obj) {
578         // Save the object file if available.
579         ar << *m_jitter->m_object_file;
580     }
581 }
582 
583 template <typename Archive>
load_impl(Archive & ar,unsigned)584 void llvm_state::load_impl(Archive &ar, unsigned)
585 {
586     // NOTE: all serialised objects in the archive
587     // are primitive types, no need to reset the
588     // addresses.
589 
590     // Load the status flags from the archive.
591     bool cmp{};
592     ar >> cmp;
593 
594     bool with_obj{};
595     ar >> with_obj;
596 
597     assert(!with_obj || cmp);
598 
599     // Load the config options.
600     unsigned opt_level{};
601     ar >> opt_level;
602 
603     bool fast_math{};
604     ar >> fast_math;
605 
606     std::string module_name;
607     ar >> module_name;
608 
609     bool inline_functions{};
610     ar >> inline_functions;
611 
612     // Load the ir
613     std::string ir;
614     ar >> ir;
615 
616     // Recover the object file, if available.
617     std::optional<std::string> obj_file;
618     if (with_obj) {
619         obj_file.emplace();
620         ar >> *obj_file;
621     }
622 
623     try {
624         // Set the config options.
625         m_opt_level = opt_level;
626         m_fast_math = fast_math;
627         m_module_name = module_name;
628         m_inline_functions = inline_functions;
629 
630         // Reset module and builder to the def-cted state.
631         m_module.reset();
632         m_builder.reset();
633 
634         // Reset the jit with a new one.
635         m_jitter = std::make_unique<jit>();
636 
637         if (cmp && with_obj) {
638             // Assign the ir snapshot.
639             m_ir_snapshot = std::move(ir);
640 
641             // Add the object code to the jit.
642             detail::llvm_state_add_obj_to_jit(*m_jitter, *obj_file);
643         } else {
644             // Clear the existing ir snapshot
645             // (it will be replaced with the
646             // actual ir if compilation is needed).
647             m_ir_snapshot.clear();
648 
649             // Create the module from the IR.
650             m_module = detail::llvm_state_ir_to_module(std::move(ir), context());
651 
652             // Create a new builder for the module.
653             m_builder = std::make_unique<ir_builder>(context());
654 
655             // Setup the math flags in the builder.
656             ctor_setup_math_flags();
657 
658             // Compile if needed.
659             // NOTE: compilation will take care of setting up m_ir_snapshot.
660             // If no compilation happens, m_ir_snapshot is left empty after
661             // clearing earlier.
662             if (cmp) {
663                 compile();
664             }
665         }
666         // LCOV_EXCL_START
667     } catch (...) {
668         // Reset to a def-cted state in case of error,
669         // as it looks like there's no way of recovering.
670         *this = []() noexcept { return llvm_state{}; }();
671 
672         throw;
673         // LCOV_EXCL_STOP
674     }
675 }
676 
save(boost::archive::binary_oarchive & ar,unsigned v) const677 void llvm_state::save(boost::archive::binary_oarchive &ar, unsigned v) const
678 {
679     save_impl(ar, v);
680 }
681 
load(boost::archive::binary_iarchive & ar,unsigned v)682 void llvm_state::load(boost::archive::binary_iarchive &ar, unsigned v)
683 {
684     load_impl(ar, v);
685 }
686 
module()687 llvm::Module &llvm_state::module()
688 {
689     check_uncompiled(__func__);
690     return *m_module;
691 }
692 
builder()693 ir_builder &llvm_state::builder()
694 {
695     check_uncompiled(__func__);
696     return *m_builder;
697 }
698 
context()699 llvm::LLVMContext &llvm_state::context()
700 {
701     return m_jitter->get_context();
702 }
703 
opt_level()704 unsigned &llvm_state::opt_level()
705 {
706     return m_opt_level;
707 }
708 
fast_math()709 bool &llvm_state::fast_math()
710 {
711     return m_fast_math;
712 }
713 
inline_functions()714 bool &llvm_state::inline_functions()
715 {
716     return m_inline_functions;
717 }
718 
module() const719 const llvm::Module &llvm_state::module() const
720 {
721     check_uncompiled(__func__);
722     return *m_module;
723 }
724 
builder() const725 const ir_builder &llvm_state::builder() const
726 {
727     check_uncompiled(__func__);
728     return *m_builder;
729 }
730 
context() const731 const llvm::LLVMContext &llvm_state::context() const
732 {
733     return m_jitter->get_context();
734 }
735 
opt_level() const736 const unsigned &llvm_state::opt_level() const
737 {
738     return m_opt_level;
739 }
740 
fast_math() const741 const bool &llvm_state::fast_math() const
742 {
743     return m_fast_math;
744 }
745 
inline_functions() const746 const bool &llvm_state::inline_functions() const
747 {
748     return m_inline_functions;
749 }
750 
check_uncompiled(const char * f) const751 void llvm_state::check_uncompiled(const char *f) const
752 {
753     if (!m_module) {
754         throw std::invalid_argument(
755             "The function '{}' can be invoked only if the module has not been compiled yet"_format(f));
756     }
757 }
758 
check_compiled(const char * f) const759 void llvm_state::check_compiled(const char *f) const
760 {
761     if (m_module) {
762         throw std::invalid_argument(
763             "The function '{}' can be invoked only after the module has been compiled"_format(f));
764     }
765 }
766 
verify_function(llvm::Function * f)767 void llvm_state::verify_function(llvm::Function *f)
768 {
769     check_uncompiled(__func__);
770 
771     if (f == nullptr) {
772         throw std::invalid_argument("Cannot verify a null function pointer");
773     }
774 
775     std::string err_report;
776     llvm::raw_string_ostream ostr(err_report);
777     if (llvm::verifyFunction(*f, &ostr)) {
778         // Remove function before throwing.
779         const auto fname = std::string(f->getName());
780         f->eraseFromParent();
781 
782         throw std::invalid_argument(
783             "The verification of the function '{}' failed. The full error message:\n{}"_format(fname, ostr.str()));
784     }
785 }
786 
verify_function(const std::string & name)787 void llvm_state::verify_function(const std::string &name)
788 {
789     check_uncompiled(__func__);
790 
791     // Lookup the function in the module.
792     auto f = m_module->getFunction(name);
793     if (f == nullptr) {
794         throw std::invalid_argument("The function '{}' does not exist in the module"_format(name));
795     }
796 
797     // Run the actual check.
798     verify_function(f);
799 }
800 
optimise()801 void llvm_state::optimise()
802 {
803     check_uncompiled(__func__);
804 
805     if (m_opt_level > 0u) {
806         // NOTE: the logic here largely mimics (with a lot of simplifications)
807         // the implementation of the 'opt' tool. See:
808         // https://github.com/llvm/llvm-project/blob/release/10.x/llvm/tools/opt/opt.cpp
809 
810         // For every function in the module, setup its attributes
811         // so that the codegen uses all the features available on
812         // the host CPU.
813 #if LLVM_VERSION_MAJOR == 10
814         ::setFunctionAttributes(m_jitter->get_target_cpu(), m_jitter->get_target_features(), *m_module);
815 #else
816         // NOTE: in LLVM > 10, the setFunctionAttributes() function is gone in favour of another
817         // function in another namespace, which however does not seem to work out of the box
818         // because (I think) it might be reading some non-existent command-line options. See:
819         // https://llvm.org/doxygen/CommandFlags_8cpp_source.html#l00552
820         // Here we are reproducing a trimmed-down version of the same function.
821         const auto cpu = m_jitter->get_target_cpu();
822         const auto features = m_jitter->get_target_features();
823 
824         for (auto &f : module()) {
825             auto attrs = f.getAttributes();
826             llvm::AttrBuilder new_attrs;
827 
828             if (!cpu.empty() && !f.hasFnAttribute("target-cpu")) {
829                 new_attrs.addAttribute("target-cpu", cpu);
830             }
831 
832             if (!features.empty()) {
833                 auto old_features = f.getFnAttribute("target-features").getValueAsString();
834 
835                 if (old_features.empty()) {
836                     new_attrs.addAttribute("target-features", features);
837                 } else {
838                     llvm::SmallString<256> appended(old_features);
839                     appended.push_back(',');
840                     appended.append(features);
841                     new_attrs.addAttribute("target-features", appended);
842                 }
843             }
844 
845             f.setAttributes(attrs.addAttributes(context(), llvm::AttributeList::FunctionIndex, new_attrs));
846         }
847 #endif
848 
849         // NOTE: currently LLVM forces 256-bit vector
850         // width when AVX-512 is available, due to clock
851         // frequency scaling concerns. We used to have the following
852         // code here:
853         // for (auto &f : *m_module) {
854         //     f.addFnAttr("prefer-vector-width", "512");
855         // }
856         // in order to force 512-bit vector width, but it looks
857         // like this can hurt performance in scalar mode.
858         // Let's keep this in mind for the future, perhaps
859         // we could consider enabling 512-bit vector width
860         // only in batch mode?
861 
862         // Init the module pass manager.
863         auto module_pm = std::make_unique<llvm::legacy::PassManager>();
864         // These are passes which set up target-specific info
865         // that are used by successive optimisation passes.
866         auto tliwp = std::make_unique<llvm::TargetLibraryInfoWrapperPass>(
867             llvm::TargetLibraryInfoImpl(m_jitter->get_target_triple()));
868         module_pm->add(tliwp.release());
869         module_pm->add(llvm::createTargetTransformInfoWrapperPass(m_jitter->get_target_ir_analysis()));
870 
871         // NOTE: not sure what this does, presumably some target-specifc
872         // configuration.
873         module_pm->add(static_cast<llvm::LLVMTargetMachine &>(*m_jitter->m_tm).createPassConfig(*module_pm));
874 
875         // Init the function pass manager.
876         auto f_pm = std::make_unique<llvm::legacy::FunctionPassManager>(m_module.get());
877         f_pm->add(llvm::createTargetTransformInfoWrapperPass(m_jitter->get_target_ir_analysis()));
878 
879         // Add an initial pass to vectorize load/stores.
880         // This is useful to ensure that the
881         // pattern adopted in load_vector_from_memory() and
882         // store_vector_to_memory() is translated to
883         // vectorized store/load instructions.
884         auto lsv_pass = std::unique_ptr<llvm::Pass>(llvm::createLoadStoreVectorizerPass());
885         f_pm->add(lsv_pass.release());
886 
887         // We use the helper class PassManagerBuilder to populate the module
888         // pass manager with standard options.
889         llvm::PassManagerBuilder pm_builder;
890         // See here for the defaults:
891         // https://llvm.org/doxygen/PassManagerBuilder_8cpp_source.html
892         // NOTE: we used to have the SLP vectorizer on here, but
893         // we don't activate it any more in favour of explicit vectorization.
894         // NOTE: perhaps in the future we can make the autovectorizer an
895         // option like the fast math flag.
896         pm_builder.OptLevel = m_opt_level;
897         if (m_inline_functions) {
898             // Enable function inlining if the inlining flag is enabled.
899             pm_builder.Inliner = llvm::createFunctionInliningPass(m_opt_level, 0, false);
900         }
901 
902         m_jitter->m_tm->adjustPassManager(pm_builder);
903 
904         // Populate both the function pass manager and the module pass manager.
905         pm_builder.populateFunctionPassManager(*f_pm);
906         pm_builder.populateModulePassManager(*module_pm);
907 
908         // Run the function pass manager on all functions in the module.
909         f_pm->doInitialization();
910         for (auto &f : *m_module) {
911             f_pm->run(f);
912         }
913         f_pm->doFinalization();
914 
915         // Run the module passes.
916         module_pm->run(*m_module);
917     }
918 }
919 
compile()920 void llvm_state::compile()
921 {
922     check_uncompiled(__func__);
923 
924     // Run a verification on the module before compiling.
925     {
926         std::string out;
927         llvm::raw_string_ostream ostr(out);
928 
929         if (llvm::verifyModule(*m_module, &ostr)) {
930             throw std::runtime_error(
931                 "The verification of the module '{}' produced an error:\n{}"_format(m_module_name, ostr.str()));
932         }
933     }
934 
935     try {
936         // Store a snapshot of the IR before compiling.
937         m_ir_snapshot = get_ir();
938 
939         // Add the module (this will clear out m_module).
940         m_jitter->add_module(std::move(m_module));
941 
942         // Clear out the builder, which won't be usable any more.
943         m_builder.reset();
944         // LCOV_EXCL_START
945     } catch (...) {
946         // Reset to a def-cted state in case of error,
947         // as it looks like there's no way of recovering.
948         *this = []() noexcept { return llvm_state{}; }();
949 
950         throw;
951         // LCOV_EXCL_STOP
952     }
953 }
954 
is_compiled() const955 bool llvm_state::is_compiled() const
956 {
957     return !m_module;
958 }
959 
960 // NOTE: this function will lookup symbol names,
961 // so it does not necessarily return a function
962 // pointer (could be, e.g., a global variable).
jit_lookup(const std::string & name)963 std::uintptr_t llvm_state::jit_lookup(const std::string &name)
964 {
965     check_compiled(__func__);
966 
967     auto sym = m_jitter->lookup(name);
968     if (!sym) {
969         throw std::invalid_argument("Could not find the symbol '{}' in the compiled module"_format(name));
970     }
971 
972     return static_cast<std::uintptr_t>((*sym).getAddress());
973 }
974 
get_ir() const975 std::string llvm_state::get_ir() const
976 {
977     if (m_module) {
978         // The module has not been compiled yet,
979         // get the IR from it.
980         std::string out;
981         llvm::raw_string_ostream ostr(out);
982         m_module->print(ostr, nullptr);
983         return ostr.str();
984     } else {
985         // The module has been compiled.
986         // Return the IR snapshot that
987         // was created before the compilation.
988         return m_ir_snapshot;
989     }
990 }
991 
992 // LCOV_EXCL_START
993 
dump_object_code(const std::string & filename) const994 void llvm_state::dump_object_code(const std::string &filename) const
995 {
996     const auto &oc = get_object_code();
997 
998     std::ofstream ofs;
999     // NOTE: turn on exceptions, and overwrite any existing content.
1000     ofs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
1001     ofs.open(filename, std::ios_base::out | std::ios::trunc);
1002 
1003     // Write out the binary data to ofs.
1004     ofs.write(oc.data(), boost::numeric_cast<std::streamsize>(oc.size()));
1005 }
1006 
1007 // LCOV_EXCL_STOP
1008 
get_object_code() const1009 const std::string &llvm_state::get_object_code() const
1010 {
1011     if (!is_compiled()) {
1012         throw std::invalid_argument(
1013             "Cannot extract the object code from an llvm_state which has not been compiled yet");
1014     }
1015 
1016     if (!m_jitter->m_object_file) {
1017         throw std::invalid_argument(
1018             "Cannot extract the object code from an llvm_state if the binary code has not been generated yet");
1019     }
1020 
1021     return *m_jitter->m_object_file;
1022 }
1023 
module_name() const1024 const std::string &llvm_state::module_name() const
1025 {
1026     return m_module_name;
1027 }
1028 
1029 // A helper that returns a new llvm_state configured in the same
1030 // way as this (i.e., same module name, opt level, fast math flags, etc.),
1031 // but with no code defined in it.
make_similar() const1032 llvm_state llvm_state::make_similar() const
1033 {
1034     return llvm_state(kw::mname = m_module_name, kw::opt_level = m_opt_level, kw::fast_math = m_fast_math,
1035                       kw::inline_functions = m_inline_functions);
1036 }
1037 
operator <<(std::ostream & os,const llvm_state & s)1038 std::ostream &operator<<(std::ostream &os, const llvm_state &s)
1039 {
1040     std::ostringstream oss;
1041     oss << std::boolalpha;
1042 
1043     oss << "Module name        : " << s.m_module_name << '\n';
1044     oss << "Compiled           : " << s.is_compiled() << '\n';
1045     oss << "Fast math          : " << s.m_fast_math << '\n';
1046     oss << "Optimisation level : " << s.m_opt_level << '\n';
1047     oss << "Inline functions   : " << s.m_inline_functions << '\n';
1048     oss << "Target triple      : " << s.m_jitter->get_target_triple().str() << '\n';
1049     oss << "Target CPU         : " << s.m_jitter->get_target_cpu() << '\n';
1050     oss << "Target features    : " << s.m_jitter->get_target_features() << '\n';
1051     oss << "IR size            : " << s.get_ir().size() << '\n';
1052 
1053     return os << oss.str();
1054 }
1055 
1056 } // namespace heyoka
1057