1 //===-- clang-linker-wrapper/ClangLinkerWrapper.cpp - wrapper over linker-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This tool works as a wrapper over a linking job. This tool is used to create
10 // linked device images for offloading. It scans the linker's input for embedded
11 // device offloading data stored in sections `.llvm.offloading` and extracts it
12 // as a temporary file. The extracted device files will then be passed to a
13 // device linking job to create a final device image.
14 //
15 //===---------------------------------------------------------------------===//
16 
17 #include "OffloadWrapper.h"
18 #include "clang/Basic/Version.h"
19 #include "llvm/BinaryFormat/Magic.h"
20 #include "llvm/Bitcode/BitcodeWriter.h"
21 #include "llvm/CodeGen/CommandFlags.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DiagnosticPrinter.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IRReader/IRReader.h"
26 #include "llvm/LTO/LTO.h"
27 #include "llvm/MC/TargetRegistry.h"
28 #include "llvm/Object/Archive.h"
29 #include "llvm/Object/ArchiveWriter.h"
30 #include "llvm/Object/Binary.h"
31 #include "llvm/Object/ELFObjectFile.h"
32 #include "llvm/Object/IRObjectFile.h"
33 #include "llvm/Object/ObjectFile.h"
34 #include "llvm/Object/OffloadBinary.h"
35 #include "llvm/Option/ArgList.h"
36 #include "llvm/Option/OptTable.h"
37 #include "llvm/Option/Option.h"
38 #include "llvm/Support/CommandLine.h"
39 #include "llvm/Support/Errc.h"
40 #include "llvm/Support/FileOutputBuffer.h"
41 #include "llvm/Support/FileSystem.h"
42 #include "llvm/Support/Host.h"
43 #include "llvm/Support/InitLLVM.h"
44 #include "llvm/Support/MemoryBuffer.h"
45 #include "llvm/Support/Parallel.h"
46 #include "llvm/Support/Path.h"
47 #include "llvm/Support/Program.h"
48 #include "llvm/Support/Signals.h"
49 #include "llvm/Support/SourceMgr.h"
50 #include "llvm/Support/StringSaver.h"
51 #include "llvm/Support/TargetSelect.h"
52 #include "llvm/Support/WithColor.h"
53 #include "llvm/Support/raw_ostream.h"
54 #include "llvm/Target/TargetMachine.h"
55 #include <atomic>
56 #include <optional>
57 
58 using namespace llvm;
59 using namespace llvm::opt;
60 using namespace llvm::object;
61 
62 /// Path of the current binary.
63 static const char *LinkerExecutable;
64 
65 /// Ssave intermediary results.
66 static bool SaveTemps = false;
67 
68 /// Print arguments without executing.
69 static bool DryRun = false;
70 
71 /// Print verbose output.
72 static bool Verbose = false;
73 
74 /// Filename of the executable being created.
75 static StringRef ExecutableName;
76 
77 /// Binary path for the CUDA installation.
78 static std::string CudaBinaryPath;
79 
80 /// Mutex lock to protect writes to shared TempFiles in parallel.
81 static std::mutex TempFilesMutex;
82 
83 /// Temporary files created by the linker wrapper.
84 static std::list<SmallString<128>> TempFiles;
85 
86 /// Codegen flags for LTO backend.
87 static codegen::RegisterCodeGenFlags CodeGenFlags;
88 
89 /// Global flag to indicate that the LTO pipeline threw an error.
90 static std::atomic<bool> LTOError;
91 
92 using OffloadingImage = OffloadBinary::OffloadingImage;
93 
94 namespace llvm {
95 // Provide DenseMapInfo so that OffloadKind can be used in a DenseMap.
96 template <> struct DenseMapInfo<OffloadKind> {
97   static inline OffloadKind getEmptyKey() { return OFK_LAST; }
98   static inline OffloadKind getTombstoneKey() {
99     return static_cast<OffloadKind>(OFK_LAST + 1);
100   }
101   static unsigned getHashValue(const OffloadKind &Val) { return Val; }
102 
103   static bool isEqual(const OffloadKind &LHS, const OffloadKind &RHS) {
104     return LHS == RHS;
105   }
106 };
107 } // namespace llvm
108 
109 namespace {
110 using std::error_code;
111 
112 /// Must not overlap with llvm::opt::DriverFlag.
113 enum WrapperFlags {
114   WrapperOnlyOption = (1 << 4), // Options only used by the linker wrapper.
115   DeviceOnlyOption = (1 << 5),  // Options only used for device linking.
116 };
117 
118 enum ID {
119   OPT_INVALID = 0, // This is not an option ID.
120 #define OPTION(PREFIX, NAME, ID, KIND, GROUP, ALIAS, ALIASARGS, FLAGS, PARAM,  \
121                HELPTEXT, METAVAR, VALUES)                                      \
122   OPT_##ID,
123 #include "LinkerWrapperOpts.inc"
124   LastOption
125 #undef OPTION
126 };
127 
128 #define PREFIX(NAME, VALUE)                                                    \
129   static constexpr StringLiteral NAME##_init[] = VALUE;                        \
130   static constexpr ArrayRef<StringLiteral> NAME(NAME##_init,                   \
131                                                 std::size(NAME##_init) - 1);
132 #include "LinkerWrapperOpts.inc"
133 #undef PREFIX
134 
135 static constexpr OptTable::Info InfoTable[] = {
136 #define OPTION(PREFIX, NAME, ID, KIND, GROUP, ALIAS, ALIASARGS, FLAGS, PARAM,  \
137                HELPTEXT, METAVAR, VALUES)                                      \
138   {PREFIX, NAME,  HELPTEXT,    METAVAR,     OPT_##ID,  Option::KIND##Class,    \
139    PARAM,  FLAGS, OPT_##GROUP, OPT_##ALIAS, ALIASARGS, VALUES},
140 #include "LinkerWrapperOpts.inc"
141 #undef OPTION
142 };
143 
144 class WrapperOptTable : public opt::GenericOptTable {
145 public:
146   WrapperOptTable() : opt::GenericOptTable(InfoTable) {}
147 };
148 
149 const OptTable &getOptTable() {
150   static const WrapperOptTable *Table = []() {
151     auto Result = std::make_unique<WrapperOptTable>();
152     return Result.release();
153   }();
154   return *Table;
155 }
156 
157 void printCommands(ArrayRef<StringRef> CmdArgs) {
158   if (CmdArgs.empty())
159     return;
160 
161   llvm::errs() << " \"" << CmdArgs.front() << "\" ";
162   for (auto IC = std::next(CmdArgs.begin()), IE = CmdArgs.end(); IC != IE; ++IC)
163     llvm::errs() << *IC << (std::next(IC) != IE ? " " : "\n");
164 }
165 
166 [[noreturn]] void reportError(Error E) {
167   outs().flush();
168   logAllUnhandledErrors(std::move(E),
169                         WithColor::error(errs(), LinkerExecutable));
170   exit(EXIT_FAILURE);
171 }
172 
173 /// Create an extra user-specified \p OffloadFile.
174 /// TODO: We should find a way to wrap these as libraries instead.
175 Expected<OffloadFile> getInputBitcodeLibrary(StringRef Input) {
176   auto [Device, Path] = StringRef(Input).split('=');
177   auto [String, Arch] = Device.rsplit('-');
178   auto [Kind, Triple] = String.split('-');
179 
180   llvm::ErrorOr<std::unique_ptr<MemoryBuffer>> ImageOrError =
181       llvm::MemoryBuffer::getFileOrSTDIN(Path);
182   if (std::error_code EC = ImageOrError.getError())
183     return createFileError(Path, EC);
184 
185   OffloadingImage Image{};
186   Image.TheImageKind = IMG_Bitcode;
187   Image.TheOffloadKind = getOffloadKind(Kind);
188   Image.StringData = {{"triple", Triple}, {"arch", Arch}};
189   Image.Image = std::move(*ImageOrError);
190 
191   std::unique_ptr<MemoryBuffer> Binary = OffloadBinary::write(Image);
192   auto NewBinaryOrErr = OffloadBinary::create(*Binary);
193   if (!NewBinaryOrErr)
194     return NewBinaryOrErr.takeError();
195   return OffloadFile(std::move(*NewBinaryOrErr), std::move(Binary));
196 }
197 
198 std::string getMainExecutable(const char *Name) {
199   void *Ptr = (void *)(intptr_t)&getMainExecutable;
200   auto COWPath = sys::fs::getMainExecutable(Name, Ptr);
201   return sys::path::parent_path(COWPath).str();
202 }
203 
204 /// Get a temporary filename suitable for output.
205 Expected<StringRef> createOutputFile(const Twine &Prefix, StringRef Extension) {
206   std::scoped_lock<decltype(TempFilesMutex)> Lock(TempFilesMutex);
207   SmallString<128> OutputFile;
208   if (SaveTemps) {
209     (Prefix + "." + Extension).toNullTerminatedStringRef(OutputFile);
210   } else {
211     if (std::error_code EC =
212             sys::fs::createTemporaryFile(Prefix, Extension, OutputFile))
213       return createFileError(OutputFile, EC);
214   }
215 
216   TempFiles.emplace_back(std::move(OutputFile));
217   return TempFiles.back();
218 }
219 
220 /// Execute the command \p ExecutablePath with the arguments \p Args.
221 Error executeCommands(StringRef ExecutablePath, ArrayRef<StringRef> Args) {
222   if (Verbose || DryRun)
223     printCommands(Args);
224 
225   if (!DryRun)
226     if (sys::ExecuteAndWait(ExecutablePath, Args))
227       return createStringError(inconvertibleErrorCode(),
228                                "'" + sys::path::filename(ExecutablePath) + "'" +
229                                    " failed");
230   return Error::success();
231 }
232 
233 Expected<std::string> findProgram(StringRef Name, ArrayRef<StringRef> Paths) {
234 
235   ErrorOr<std::string> Path = sys::findProgramByName(Name, Paths);
236   if (!Path)
237     Path = sys::findProgramByName(Name);
238   if (!Path && DryRun)
239     return Name.str();
240   if (!Path)
241     return createStringError(Path.getError(),
242                              "Unable to find '" + Name + "' in path");
243   return *Path;
244 }
245 
246 /// Runs the wrapped linker job with the newly created input.
247 Error runLinker(ArrayRef<StringRef> Files, const ArgList &Args) {
248   llvm::TimeTraceScope TimeScope("Execute host linker");
249 
250   // Render the linker arguments and add the newly created image. We add it
251   // after the output file to ensure it is linked with the correct libraries.
252   StringRef LinkerPath = Args.getLastArgValue(OPT_linker_path_EQ);
253   ArgStringList NewLinkerArgs;
254   for (const opt::Arg *Arg : Args) {
255     // Do not forward arguments only intended for the linker wrapper.
256     if (Arg->getOption().hasFlag(WrapperOnlyOption))
257       continue;
258 
259     Arg->render(Args, NewLinkerArgs);
260     if (Arg->getOption().matches(OPT_o))
261       llvm::transform(Files, std::back_inserter(NewLinkerArgs),
262                       [&](StringRef Arg) { return Args.MakeArgString(Arg); });
263   }
264 
265   SmallVector<StringRef> LinkerArgs({LinkerPath});
266   for (StringRef Arg : NewLinkerArgs)
267     LinkerArgs.push_back(Arg);
268   if (Error Err = executeCommands(LinkerPath, LinkerArgs))
269     return Err;
270   return Error::success();
271 }
272 
273 void printVersion(raw_ostream &OS) {
274   OS << clang::getClangToolFullVersion("clang-linker-wrapper") << '\n';
275 }
276 
277 namespace nvptx {
278 Expected<StringRef>
279 fatbinary(ArrayRef<std::pair<StringRef, StringRef>> InputFiles,
280           const ArgList &Args) {
281   llvm::TimeTraceScope TimeScope("NVPTX fatbinary");
282   // NVPTX uses the fatbinary program to bundle the linked images.
283   Expected<std::string> FatBinaryPath =
284       findProgram("fatbinary", {CudaBinaryPath + "/bin"});
285   if (!FatBinaryPath)
286     return FatBinaryPath.takeError();
287 
288   llvm::Triple Triple(
289       Args.getLastArgValue(OPT_host_triple_EQ, sys::getDefaultTargetTriple()));
290 
291   // Create a new file to write the linked device image to.
292   auto TempFileOrErr =
293       createOutputFile(sys::path::filename(ExecutableName), "fatbin");
294   if (!TempFileOrErr)
295     return TempFileOrErr.takeError();
296 
297   SmallVector<StringRef, 16> CmdArgs;
298   CmdArgs.push_back(*FatBinaryPath);
299   CmdArgs.push_back(Triple.isArch64Bit() ? "-64" : "-32");
300   CmdArgs.push_back("--create");
301   CmdArgs.push_back(*TempFileOrErr);
302   for (const auto &[File, Arch] : InputFiles)
303     CmdArgs.push_back(
304         Args.MakeArgString("--image=profile=" + Arch + ",file=" + File));
305 
306   if (Error Err = executeCommands(*FatBinaryPath, CmdArgs))
307     return std::move(Err);
308 
309   return *TempFileOrErr;
310 }
311 } // namespace nvptx
312 
313 namespace amdgcn {
314 Expected<StringRef>
315 fatbinary(ArrayRef<std::pair<StringRef, StringRef>> InputFiles,
316           const ArgList &Args) {
317   llvm::TimeTraceScope TimeScope("AMDGPU Fatbinary");
318 
319   // AMDGPU uses the clang-offload-bundler to bundle the linked images.
320   Expected<std::string> OffloadBundlerPath = findProgram(
321       "clang-offload-bundler", {getMainExecutable("clang-offload-bundler")});
322   if (!OffloadBundlerPath)
323     return OffloadBundlerPath.takeError();
324 
325   llvm::Triple Triple(
326       Args.getLastArgValue(OPT_host_triple_EQ, sys::getDefaultTargetTriple()));
327 
328   // Create a new file to write the linked device image to.
329   auto TempFileOrErr =
330       createOutputFile(sys::path::filename(ExecutableName), "hipfb");
331   if (!TempFileOrErr)
332     return TempFileOrErr.takeError();
333 
334   BumpPtrAllocator Alloc;
335   StringSaver Saver(Alloc);
336 
337   SmallVector<StringRef, 16> CmdArgs;
338   CmdArgs.push_back(*OffloadBundlerPath);
339   CmdArgs.push_back("-type=o");
340   CmdArgs.push_back("-bundle-align=4096");
341 
342   SmallVector<StringRef> Targets = {"-targets=host-x86_64-unknown-linux"};
343   for (const auto &[File, Arch] : InputFiles)
344     Targets.push_back(Saver.save("hipv4-amdgcn-amd-amdhsa--" + Arch));
345   CmdArgs.push_back(Saver.save(llvm::join(Targets, ",")));
346 
347   CmdArgs.push_back("-input=/dev/null");
348   for (const auto &[File, Arch] : InputFiles)
349     CmdArgs.push_back(Saver.save("-input=" + File));
350 
351   CmdArgs.push_back(Saver.save("-output=" + *TempFileOrErr));
352 
353   if (Error Err = executeCommands(*OffloadBundlerPath, CmdArgs))
354     return std::move(Err);
355 
356   return *TempFileOrErr;
357 }
358 } // namespace amdgcn
359 
360 namespace generic {
361 Expected<StringRef> clang(ArrayRef<StringRef> InputFiles, const ArgList &Args) {
362   llvm::TimeTraceScope TimeScope("Clang");
363   // Use `clang` to invoke the appropriate device tools.
364   Expected<std::string> ClangPath =
365       findProgram("clang", {getMainExecutable("clang")});
366   if (!ClangPath)
367     return ClangPath.takeError();
368 
369   const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ));
370   StringRef Arch = Args.getLastArgValue(OPT_arch_EQ);
371   if (Arch.empty())
372     Arch = "native";
373   // Create a new file to write the linked device image to. Assume that the
374   // input filename already has the device and architecture.
375   auto TempFileOrErr =
376       createOutputFile(sys::path::filename(ExecutableName) + "." +
377                            Triple.getArchName() + "." + Arch,
378                        "img");
379   if (!TempFileOrErr)
380     return TempFileOrErr.takeError();
381 
382   StringRef OptLevel = Args.getLastArgValue(OPT_opt_level, "O2");
383   SmallVector<StringRef, 16> CmdArgs{
384       *ClangPath,
385       "-o",
386       *TempFileOrErr,
387       Args.MakeArgString("--target=" + Triple.getTriple()),
388       Triple.isAMDGPU() ? Args.MakeArgString("-mcpu=" + Arch)
389                         : Args.MakeArgString("-march=" + Arch),
390       Args.MakeArgString("-" + OptLevel),
391       "-Wl,--no-undefined",
392   };
393 
394   // If this is CPU offloading we copy the input libraries.
395   if (!Triple.isAMDGPU() && !Triple.isNVPTX()) {
396     CmdArgs.push_back("-Bsymbolic");
397     CmdArgs.push_back("-shared");
398     ArgStringList LinkerArgs;
399     for (const opt::Arg *Arg :
400          Args.filtered(OPT_library, OPT_rpath, OPT_library_path))
401       Arg->render(Args, LinkerArgs);
402     llvm::copy(LinkerArgs, std::back_inserter(CmdArgs));
403   }
404 
405   if (Args.hasArg(OPT_debug))
406     CmdArgs.push_back("-g");
407 
408   if (SaveTemps)
409     CmdArgs.push_back("-save-temps");
410 
411   if (Verbose)
412     CmdArgs.push_back("-v");
413 
414   if (!CudaBinaryPath.empty())
415     CmdArgs.push_back(Args.MakeArgString("--cuda-path=" + CudaBinaryPath));
416 
417   for (StringRef Arg : Args.getAllArgValues(OPT_ptxas_arg))
418     llvm::copy(SmallVector<StringRef>({"-Xcuda-ptxas", Arg}),
419                std::back_inserter(CmdArgs));
420 
421   for (StringRef Arg : Args.getAllArgValues(OPT_linker_arg_EQ))
422     CmdArgs.push_back(Args.MakeArgString("-Wl," + Arg));
423 
424   for (StringRef InputFile : InputFiles)
425     CmdArgs.push_back(InputFile);
426 
427   if (Error Err = executeCommands(*ClangPath, CmdArgs))
428     return std::move(Err);
429 
430   return *TempFileOrErr;
431 }
432 } // namespace generic
433 
434 Expected<StringRef> linkDevice(ArrayRef<StringRef> InputFiles,
435                                const ArgList &Args) {
436   const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ));
437   switch (Triple.getArch()) {
438   case Triple::nvptx:
439   case Triple::nvptx64:
440   case Triple::amdgcn:
441   case Triple::x86:
442   case Triple::x86_64:
443   case Triple::aarch64:
444   case Triple::aarch64_be:
445   case Triple::ppc64:
446   case Triple::ppc64le:
447     return generic::clang(InputFiles, Args);
448   default:
449     return createStringError(inconvertibleErrorCode(),
450                              Triple.getArchName() +
451                                  " linking is not supported");
452   }
453 }
454 
455 void diagnosticHandler(const DiagnosticInfo &DI) {
456   std::string ErrStorage;
457   raw_string_ostream OS(ErrStorage);
458   DiagnosticPrinterRawOStream DP(OS);
459   DI.print(DP);
460 
461   switch (DI.getSeverity()) {
462   case DS_Error:
463     WithColor::error(errs(), LinkerExecutable) << ErrStorage << "\n";
464     LTOError = true;
465     break;
466   case DS_Warning:
467     WithColor::warning(errs(), LinkerExecutable) << ErrStorage << "\n";
468     break;
469   case DS_Note:
470     WithColor::note(errs(), LinkerExecutable) << ErrStorage << "\n";
471     break;
472   case DS_Remark:
473     WithColor::remark(errs()) << ErrStorage << "\n";
474     break;
475   }
476 }
477 
478 // Get the list of target features from the input file and unify them such that
479 // if there are multiple +xxx or -xxx features we only keep the last one.
480 std::vector<std::string> getTargetFeatures(ArrayRef<OffloadFile> InputFiles) {
481   SmallVector<StringRef> Features;
482   for (const OffloadFile &File : InputFiles) {
483     for (auto Arg : llvm::split(File.getBinary()->getString("feature"), ","))
484       Features.emplace_back(Arg);
485   }
486 
487   // Only add a feature if it hasn't been seen before starting from the end.
488   std::vector<std::string> UnifiedFeatures;
489   DenseSet<StringRef> UsedFeatures;
490   for (StringRef Feature : llvm::reverse(Features)) {
491     if (UsedFeatures.insert(Feature.drop_front()).second)
492       UnifiedFeatures.push_back(Feature.str());
493   }
494 
495   return UnifiedFeatures;
496 }
497 
498 template <typename ModuleHook = function_ref<bool(size_t, const Module &)>>
499 std::unique_ptr<lto::LTO> createLTO(
500     const ArgList &Args, const std::vector<std::string> &Features,
501     ModuleHook Hook = [](size_t, const Module &) { return true; }) {
502   const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ));
503   StringRef Arch = Args.getLastArgValue(OPT_arch_EQ);
504   lto::Config Conf;
505   lto::ThinBackend Backend;
506   // TODO: Handle index-only thin-LTO
507   Backend =
508       lto::createInProcessThinBackend(llvm::heavyweight_hardware_concurrency());
509 
510   Conf.CPU = Arch.str();
511   Conf.Options = codegen::InitTargetOptionsFromCodeGenFlags(Triple);
512 
513   StringRef OptLevel = Args.getLastArgValue(OPT_opt_level, "O2");
514   Conf.MAttrs = Features;
515   std::optional<CodeGenOpt::Level> CGOptLevelOrNone =
516       CodeGenOpt::parseLevel(OptLevel[1]);
517   assert(CGOptLevelOrNone && "Invalid optimization level");
518   Conf.CGOptLevel = *CGOptLevelOrNone;
519   Conf.OptLevel = OptLevel[1] - '0';
520   Conf.DefaultTriple = Triple.getTriple();
521 
522   LTOError = false;
523   Conf.DiagHandler = diagnosticHandler;
524 
525   Conf.PTO.LoopVectorization = Conf.OptLevel > 1;
526   Conf.PTO.SLPVectorization = Conf.OptLevel > 1;
527 
528   if (SaveTemps) {
529     std::string TempName = (sys::path::filename(ExecutableName) + "." +
530                             Triple.getTriple() + "." + Arch)
531                                .str();
532     Conf.PostInternalizeModuleHook = [=](size_t Task, const Module &M) {
533       std::string File =
534           !Task ? TempName + ".postlink.bc"
535                 : TempName + "." + std::to_string(Task) + ".postlink.bc";
536       error_code EC;
537       raw_fd_ostream LinkedBitcode(File, EC, sys::fs::OF_None);
538       if (EC)
539         reportError(errorCodeToError(EC));
540       WriteBitcodeToFile(M, LinkedBitcode);
541       return true;
542     };
543     Conf.PreCodeGenModuleHook = [=](size_t Task, const Module &M) {
544       std::string File =
545           !Task ? TempName + ".postopt.bc"
546                 : TempName + "." + std::to_string(Task) + ".postopt.bc";
547       error_code EC;
548       raw_fd_ostream LinkedBitcode(File, EC, sys::fs::OF_None);
549       if (EC)
550         reportError(errorCodeToError(EC));
551       WriteBitcodeToFile(M, LinkedBitcode);
552       return true;
553     };
554   }
555   Conf.PostOptModuleHook = Hook;
556   Conf.CGFileType =
557       (Triple.isNVPTX() || SaveTemps) ? CGFT_AssemblyFile : CGFT_ObjectFile;
558 
559   // TODO: Handle remark files
560   Conf.HasWholeProgramVisibility = Args.hasArg(OPT_whole_program);
561 
562   return std::make_unique<lto::LTO>(std::move(Conf), Backend);
563 }
564 
565 // Returns true if \p S is valid as a C language identifier and will be given
566 // `__start_` and `__stop_` symbols.
567 bool isValidCIdentifier(StringRef S) {
568   return !S.empty() && (isAlpha(S[0]) || S[0] == '_') &&
569          llvm::all_of(llvm::drop_begin(S),
570                       [](char C) { return C == '_' || isAlnum(C); });
571 }
572 
573 Error linkBitcodeFiles(SmallVectorImpl<OffloadFile> &InputFiles,
574                        SmallVectorImpl<StringRef> &OutputFiles,
575                        const ArgList &Args) {
576   llvm::TimeTraceScope TimeScope("Link bitcode files");
577   const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ));
578   StringRef Arch = Args.getLastArgValue(OPT_arch_EQ);
579 
580   SmallVector<OffloadFile, 4> BitcodeInputFiles;
581   DenseSet<StringRef> UsedInRegularObj;
582   DenseSet<StringRef> UsedInSharedLib;
583   BumpPtrAllocator Alloc;
584   StringSaver Saver(Alloc);
585 
586   // Search for bitcode files in the input and create an LTO input file. If it
587   // is not a bitcode file, scan its symbol table for symbols we need to save.
588   for (OffloadFile &File : InputFiles) {
589     MemoryBufferRef Buffer = MemoryBufferRef(File.getBinary()->getImage(), "");
590 
591     file_magic Type = identify_magic(Buffer.getBuffer());
592     switch (Type) {
593     case file_magic::bitcode: {
594       BitcodeInputFiles.emplace_back(std::move(File));
595       continue;
596     }
597     case file_magic::elf_relocatable:
598     case file_magic::elf_shared_object: {
599       Expected<std::unique_ptr<ObjectFile>> ObjFile =
600           ObjectFile::createObjectFile(Buffer);
601       if (!ObjFile)
602         continue;
603 
604       for (SymbolRef Sym : (*ObjFile)->symbols()) {
605         Expected<StringRef> Name = Sym.getName();
606         if (!Name)
607           return Name.takeError();
608 
609         // Record if we've seen these symbols in any object or shared libraries.
610         if ((*ObjFile)->isRelocatableObject())
611           UsedInRegularObj.insert(Saver.save(*Name));
612         else
613           UsedInSharedLib.insert(Saver.save(*Name));
614       }
615       continue;
616     }
617     default:
618       continue;
619     }
620   }
621 
622   if (BitcodeInputFiles.empty())
623     return Error::success();
624 
625   // Remove all the bitcode files that we moved from the original input.
626   llvm::erase_if(InputFiles, [](OffloadFile &F) { return !F.getBinary(); });
627 
628   // LTO Module hook to output bitcode without running the backend.
629   SmallVector<StringRef, 4> BitcodeOutput;
630   auto OutputBitcode = [&](size_t, const Module &M) {
631     auto TempFileOrErr = createOutputFile(sys::path::filename(ExecutableName) +
632                                               "-jit-" + Triple.getTriple(),
633                                           "bc");
634     if (!TempFileOrErr)
635       reportError(TempFileOrErr.takeError());
636 
637     std::error_code EC;
638     raw_fd_ostream LinkedBitcode(*TempFileOrErr, EC, sys::fs::OF_None);
639     if (EC)
640       reportError(errorCodeToError(EC));
641     WriteBitcodeToFile(M, LinkedBitcode);
642     BitcodeOutput.push_back(*TempFileOrErr);
643     return false;
644   };
645 
646   // We assume visibility of the whole program if every input file was bitcode.
647   auto Features = getTargetFeatures(BitcodeInputFiles);
648   auto LTOBackend = Args.hasArg(OPT_embed_bitcode)
649                         ? createLTO(Args, Features, OutputBitcode)
650                         : createLTO(Args, Features);
651 
652   // We need to resolve the symbols so the LTO backend knows which symbols need
653   // to be kept or can be internalized. This is a simplified symbol resolution
654   // scheme to approximate the full resolution a linker would do.
655   uint64_t Idx = 0;
656   DenseSet<StringRef> PrevailingSymbols;
657   for (auto &BitcodeInput : BitcodeInputFiles) {
658     // Get a semi-unique buffer identifier for Thin-LTO.
659     StringRef Identifier = Saver.save(
660         std::to_string(Idx++) + "." +
661         BitcodeInput.getBinary()->getMemoryBufferRef().getBufferIdentifier());
662     MemoryBufferRef Buffer =
663         MemoryBufferRef(BitcodeInput.getBinary()->getImage(), Identifier);
664     Expected<std::unique_ptr<lto::InputFile>> BitcodeFileOrErr =
665         llvm::lto::InputFile::create(Buffer);
666     if (!BitcodeFileOrErr)
667       return BitcodeFileOrErr.takeError();
668 
669     // Save the input file and the buffer associated with its memory.
670     const auto Symbols = (*BitcodeFileOrErr)->symbols();
671     SmallVector<lto::SymbolResolution, 16> Resolutions(Symbols.size());
672     size_t Idx = 0;
673     for (auto &Sym : Symbols) {
674       lto::SymbolResolution &Res = Resolutions[Idx++];
675 
676       // We will use this as the prevailing symbol definition in LTO unless
677       // it is undefined or another definition has already been used.
678       Res.Prevailing =
679           !Sym.isUndefined() &&
680           PrevailingSymbols.insert(Saver.save(Sym.getName())).second;
681 
682       // We need LTO to preseve the following global symbols:
683       // 1) Symbols used in regular objects.
684       // 2) Sections that will be given a __start/__stop symbol.
685       // 3) Prevailing symbols that are needed visible to external libraries.
686       Res.VisibleToRegularObj =
687           UsedInRegularObj.contains(Sym.getName()) ||
688           isValidCIdentifier(Sym.getSectionName()) ||
689           (Res.Prevailing &&
690            (Sym.getVisibility() != GlobalValue::HiddenVisibility &&
691             !Sym.canBeOmittedFromSymbolTable()));
692 
693       // Identify symbols that must be exported dynamically and can be
694       // referenced by other files.
695       Res.ExportDynamic =
696           Sym.getVisibility() != GlobalValue::HiddenVisibility &&
697           (UsedInSharedLib.contains(Sym.getName()) ||
698            !Sym.canBeOmittedFromSymbolTable());
699 
700       // The final definition will reside in this linkage unit if the symbol is
701       // defined and local to the module. This only checks for bitcode files,
702       // full assertion will require complete symbol resolution.
703       Res.FinalDefinitionInLinkageUnit =
704           Sym.getVisibility() != GlobalValue::DefaultVisibility &&
705           (!Sym.isUndefined() && !Sym.isCommon());
706 
707       // We do not support linker redefined symbols (e.g. --wrap) for device
708       // image linking, so the symbols will not be changed after LTO.
709       Res.LinkerRedefined = false;
710     }
711 
712     // Add the bitcode file with its resolved symbols to the LTO job.
713     if (Error Err = LTOBackend->add(std::move(*BitcodeFileOrErr), Resolutions))
714       return Err;
715   }
716 
717   // Run the LTO job to compile the bitcode.
718   size_t MaxTasks = LTOBackend->getMaxTasks();
719   SmallVector<StringRef> Files(MaxTasks);
720   auto AddStream =
721       [&](size_t Task,
722           const Twine &ModuleName) -> std::unique_ptr<CachedFileStream> {
723     int FD = -1;
724     auto &TempFile = Files[Task];
725     StringRef Extension = (Triple.isNVPTX() || SaveTemps) ? "s" : "o";
726     std::string TaskStr = Task ? "." + std::to_string(Task) : "";
727     auto TempFileOrErr =
728         createOutputFile(sys::path::filename(ExecutableName) + "." +
729                              Triple.getTriple() + "." + Arch + TaskStr,
730                          Extension);
731     if (!TempFileOrErr)
732       reportError(TempFileOrErr.takeError());
733     TempFile = *TempFileOrErr;
734     if (std::error_code EC = sys::fs::openFileForWrite(TempFile, FD))
735       reportError(errorCodeToError(EC));
736     return std::make_unique<CachedFileStream>(
737         std::make_unique<llvm::raw_fd_ostream>(FD, true));
738   };
739 
740   if (Error Err = LTOBackend->run(AddStream))
741     return Err;
742 
743   if (LTOError)
744     return createStringError(inconvertibleErrorCode(),
745                              "Errors encountered inside the LTO pipeline.");
746 
747   // If we are embedding bitcode we only need the intermediate output.
748   bool SingleOutput = Files.size() == 1;
749   if (Args.hasArg(OPT_embed_bitcode)) {
750     if (BitcodeOutput.size() != 1 || !SingleOutput)
751       return createStringError(inconvertibleErrorCode(),
752                                "Cannot embed bitcode with multiple files.");
753     OutputFiles.push_back(Args.MakeArgString(BitcodeOutput.front()));
754     return Error::success();
755   }
756 
757   // Append the new inputs to the device linker input.
758   for (StringRef File : Files)
759     OutputFiles.push_back(File);
760 
761   return Error::success();
762 }
763 
764 Expected<StringRef> writeOffloadFile(const OffloadFile &File) {
765   const OffloadBinary &Binary = *File.getBinary();
766 
767   StringRef Prefix =
768       sys::path::stem(Binary.getMemoryBufferRef().getBufferIdentifier());
769   StringRef Suffix = getImageKindName(Binary.getImageKind());
770 
771   auto TempFileOrErr = createOutputFile(
772       Prefix + "-" + Binary.getTriple() + "-" + Binary.getArch(), Suffix);
773   if (!TempFileOrErr)
774     return TempFileOrErr.takeError();
775 
776   Expected<std::unique_ptr<FileOutputBuffer>> OutputOrErr =
777       FileOutputBuffer::create(*TempFileOrErr, Binary.getImage().size());
778   if (!OutputOrErr)
779     return OutputOrErr.takeError();
780   std::unique_ptr<FileOutputBuffer> Output = std::move(*OutputOrErr);
781   llvm::copy(Binary.getImage(), Output->getBufferStart());
782   if (Error E = Output->commit())
783     return std::move(E);
784 
785   return *TempFileOrErr;
786 }
787 
788 // Compile the module to an object file using the appropriate target machine for
789 // the host triple.
790 Expected<StringRef> compileModule(Module &M) {
791   llvm::TimeTraceScope TimeScope("Compile module");
792   std::string Msg;
793   const Target *T = TargetRegistry::lookupTarget(M.getTargetTriple(), Msg);
794   if (!T)
795     return createStringError(inconvertibleErrorCode(), Msg);
796 
797   auto Options =
798       codegen::InitTargetOptionsFromCodeGenFlags(Triple(M.getTargetTriple()));
799   StringRef CPU = "";
800   StringRef Features = "";
801   std::unique_ptr<TargetMachine> TM(
802       T->createTargetMachine(M.getTargetTriple(), CPU, Features, Options,
803                              Reloc::PIC_, M.getCodeModel()));
804 
805   if (M.getDataLayout().isDefault())
806     M.setDataLayout(TM->createDataLayout());
807 
808   int FD = -1;
809   auto TempFileOrErr = createOutputFile(
810       sys::path::filename(ExecutableName) + ".image.wrapper", "o");
811   if (!TempFileOrErr)
812     return TempFileOrErr.takeError();
813   if (std::error_code EC = sys::fs::openFileForWrite(*TempFileOrErr, FD))
814     return errorCodeToError(EC);
815 
816   auto OS = std::make_unique<llvm::raw_fd_ostream>(FD, true);
817 
818   legacy::PassManager CodeGenPasses;
819   TargetLibraryInfoImpl TLII(Triple(M.getTargetTriple()));
820   CodeGenPasses.add(new TargetLibraryInfoWrapperPass(TLII));
821   if (TM->addPassesToEmitFile(CodeGenPasses, *OS, nullptr, CGFT_ObjectFile))
822     return createStringError(inconvertibleErrorCode(),
823                              "Failed to execute host backend");
824   CodeGenPasses.run(M);
825 
826   return *TempFileOrErr;
827 }
828 
829 /// Creates the object file containing the device image and runtime
830 /// registration code from the device images stored in \p Images.
831 Expected<StringRef>
832 wrapDeviceImages(ArrayRef<std::unique_ptr<MemoryBuffer>> Buffers,
833                  const ArgList &Args, OffloadKind Kind) {
834   llvm::TimeTraceScope TimeScope("Wrap bundled images");
835 
836   SmallVector<ArrayRef<char>, 4> BuffersToWrap;
837   for (const auto &Buffer : Buffers)
838     BuffersToWrap.emplace_back(
839         ArrayRef<char>(Buffer->getBufferStart(), Buffer->getBufferSize()));
840 
841   LLVMContext Context;
842   Module M("offload.wrapper.module", Context);
843   M.setTargetTriple(
844       Args.getLastArgValue(OPT_host_triple_EQ, sys::getDefaultTargetTriple()));
845 
846   switch (Kind) {
847   case OFK_OpenMP:
848     if (Error Err = wrapOpenMPBinaries(M, BuffersToWrap))
849       return std::move(Err);
850     break;
851   case OFK_Cuda:
852     if (Error Err = wrapCudaBinary(M, BuffersToWrap.front()))
853       return std::move(Err);
854     break;
855   case OFK_HIP:
856     if (Error Err = wrapHIPBinary(M, BuffersToWrap.front()))
857       return std::move(Err);
858     break;
859   default:
860     return createStringError(inconvertibleErrorCode(),
861                              getOffloadKindName(Kind) +
862                                  " wrapping is not supported");
863   }
864 
865   if (Args.hasArg(OPT_print_wrapped_module))
866     errs() << M;
867 
868   auto FileOrErr = compileModule(M);
869   if (!FileOrErr)
870     return FileOrErr.takeError();
871   return *FileOrErr;
872 }
873 
874 Expected<SmallVector<std::unique_ptr<MemoryBuffer>>>
875 bundleOpenMP(ArrayRef<OffloadingImage> Images) {
876   SmallVector<std::unique_ptr<MemoryBuffer>> Buffers;
877   for (const OffloadingImage &Image : Images)
878     Buffers.emplace_back(OffloadBinary::write(Image));
879 
880   return std::move(Buffers);
881 }
882 
883 Expected<SmallVector<std::unique_ptr<MemoryBuffer>>>
884 bundleCuda(ArrayRef<OffloadingImage> Images, const ArgList &Args) {
885   SmallVector<std::pair<StringRef, StringRef>, 4> InputFiles;
886   for (const OffloadingImage &Image : Images)
887     InputFiles.emplace_back(std::make_pair(Image.Image->getBufferIdentifier(),
888                                            Image.StringData.lookup("arch")));
889 
890   Triple TheTriple = Triple(Images.front().StringData.lookup("triple"));
891   auto FileOrErr = nvptx::fatbinary(InputFiles, Args);
892   if (!FileOrErr)
893     return FileOrErr.takeError();
894 
895   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ImageOrError =
896       llvm::MemoryBuffer::getFileOrSTDIN(*FileOrErr);
897 
898   SmallVector<std::unique_ptr<MemoryBuffer>> Buffers;
899   if (std::error_code EC = ImageOrError.getError())
900     return createFileError(*FileOrErr, EC);
901   Buffers.emplace_back(std::move(*ImageOrError));
902 
903   return std::move(Buffers);
904 }
905 
906 Expected<SmallVector<std::unique_ptr<MemoryBuffer>>>
907 bundleHIP(ArrayRef<OffloadingImage> Images, const ArgList &Args) {
908   SmallVector<std::pair<StringRef, StringRef>, 4> InputFiles;
909   for (const OffloadingImage &Image : Images)
910     InputFiles.emplace_back(std::make_pair(Image.Image->getBufferIdentifier(),
911                                            Image.StringData.lookup("arch")));
912 
913   Triple TheTriple = Triple(Images.front().StringData.lookup("triple"));
914   auto FileOrErr = amdgcn::fatbinary(InputFiles, Args);
915   if (!FileOrErr)
916     return FileOrErr.takeError();
917 
918   llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ImageOrError =
919       llvm::MemoryBuffer::getFileOrSTDIN(*FileOrErr);
920 
921   SmallVector<std::unique_ptr<MemoryBuffer>> Buffers;
922   if (std::error_code EC = ImageOrError.getError())
923     return createFileError(*FileOrErr, EC);
924   Buffers.emplace_back(std::move(*ImageOrError));
925 
926   return std::move(Buffers);
927 }
928 
929 /// Transforms the input \p Images into the binary format the runtime expects
930 /// for the given \p Kind.
931 Expected<SmallVector<std::unique_ptr<MemoryBuffer>>>
932 bundleLinkedOutput(ArrayRef<OffloadingImage> Images, const ArgList &Args,
933                    OffloadKind Kind) {
934   llvm::TimeTraceScope TimeScope("Bundle linked output");
935   switch (Kind) {
936   case OFK_OpenMP:
937     return bundleOpenMP(Images);
938   case OFK_Cuda:
939     return bundleCuda(Images, Args);
940   case OFK_HIP:
941     return bundleHIP(Images, Args);
942   default:
943     return createStringError(inconvertibleErrorCode(),
944                              getOffloadKindName(Kind) +
945                                  " bundling is not supported");
946   }
947 }
948 
949 /// Returns a new ArgList containg arguments used for the device linking phase.
950 DerivedArgList getLinkerArgs(ArrayRef<OffloadFile> Input,
951                              const InputArgList &Args) {
952   DerivedArgList DAL = DerivedArgList(DerivedArgList(Args));
953   for (Arg *A : Args)
954     DAL.append(A);
955 
956   // Set the subarchitecture and target triple for this compilation.
957   const OptTable &Tbl = getOptTable();
958   DAL.AddJoinedArg(nullptr, Tbl.getOption(OPT_arch_EQ),
959                    Args.MakeArgString(Input.front().getBinary()->getArch()));
960   DAL.AddJoinedArg(nullptr, Tbl.getOption(OPT_triple_EQ),
961                    Args.MakeArgString(Input.front().getBinary()->getTriple()));
962 
963   // If every input file is bitcode we have whole program visibility as we do
964   // only support static linking with bitcode.
965   auto ContainsBitcode = [](const OffloadFile &F) {
966     return identify_magic(F.getBinary()->getImage()) == file_magic::bitcode;
967   };
968   if (llvm::all_of(Input, ContainsBitcode))
969     DAL.AddFlagArg(nullptr, Tbl.getOption(OPT_whole_program));
970 
971   // Forward '-Xoffload-linker' options to the appropriate backend.
972   for (StringRef Arg : Args.getAllArgValues(OPT_device_linker_args_EQ)) {
973     auto [Triple, Value] = Arg.split('=');
974     if (Value.empty())
975       DAL.AddJoinedArg(nullptr, Tbl.getOption(OPT_linker_arg_EQ),
976                        Args.MakeArgString(Triple));
977     else if (Triple == DAL.getLastArgValue(OPT_triple_EQ))
978       DAL.AddJoinedArg(nullptr, Tbl.getOption(OPT_linker_arg_EQ),
979                        Args.MakeArgString(Value));
980   }
981 
982   return DAL;
983 }
984 
985 /// Transforms all the extracted offloading input files into an image that can
986 /// be registered by the runtime.
987 Expected<SmallVector<StringRef>>
988 linkAndWrapDeviceFiles(SmallVectorImpl<OffloadFile> &LinkerInputFiles,
989                        const InputArgList &Args, char **Argv, int Argc) {
990   llvm::TimeTraceScope TimeScope("Handle all device input");
991 
992   DenseMap<OffloadFile::TargetID, SmallVector<OffloadFile>> InputMap;
993   for (auto &File : LinkerInputFiles)
994     InputMap[File].emplace_back(std::move(File));
995   LinkerInputFiles.clear();
996 
997   SmallVector<SmallVector<OffloadFile>> InputsForTarget;
998   for (auto &[ID, Input] : InputMap)
999     InputsForTarget.emplace_back(std::move(Input));
1000   InputMap.clear();
1001 
1002   std::mutex ImageMtx;
1003   DenseMap<OffloadKind, SmallVector<OffloadingImage>> Images;
1004   auto Err = parallelForEachError(InputsForTarget, [&](auto &Input) -> Error {
1005     llvm::TimeTraceScope TimeScope("Link device input");
1006 
1007     // Each thread needs its own copy of the base arguments to maintain
1008     // per-device argument storage of synthetic strings.
1009     const OptTable &Tbl = getOptTable();
1010     BumpPtrAllocator Alloc;
1011     StringSaver Saver(Alloc);
1012     auto BaseArgs =
1013         Tbl.parseArgs(Argc, Argv, OPT_INVALID, Saver, [](StringRef Err) {
1014           reportError(createStringError(inconvertibleErrorCode(), Err));
1015         });
1016     auto LinkerArgs = getLinkerArgs(Input, BaseArgs);
1017 
1018     DenseSet<OffloadKind> ActiveOffloadKinds;
1019     for (const auto &File : Input)
1020       if (File.getBinary()->getOffloadKind() != OFK_None)
1021         ActiveOffloadKinds.insert(File.getBinary()->getOffloadKind());
1022 
1023     // First link and remove all the input files containing bitcode.
1024     SmallVector<StringRef> InputFiles;
1025     if (Error Err = linkBitcodeFiles(Input, InputFiles, LinkerArgs))
1026       return Err;
1027 
1028     // Write any remaining device inputs to an output file for the linker.
1029     for (const OffloadFile &File : Input) {
1030       auto FileNameOrErr = writeOffloadFile(File);
1031       if (!FileNameOrErr)
1032         return FileNameOrErr.takeError();
1033       InputFiles.emplace_back(*FileNameOrErr);
1034     }
1035 
1036     // Link the remaining device files using the device linker.
1037     auto OutputOrErr = !Args.hasArg(OPT_embed_bitcode)
1038                            ? linkDevice(InputFiles, LinkerArgs)
1039                            : InputFiles.front();
1040     if (!OutputOrErr)
1041       return OutputOrErr.takeError();
1042 
1043     // Store the offloading image for each linked output file.
1044     for (OffloadKind Kind : ActiveOffloadKinds) {
1045       llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
1046           llvm::MemoryBuffer::getFileOrSTDIN(*OutputOrErr);
1047       if (std::error_code EC = FileOrErr.getError()) {
1048         if (DryRun)
1049           FileOrErr = MemoryBuffer::getMemBuffer("");
1050         else
1051           return createFileError(*OutputOrErr, EC);
1052       }
1053 
1054       std::scoped_lock<decltype(ImageMtx)> Guard(ImageMtx);
1055       OffloadingImage TheImage{};
1056       TheImage.TheImageKind =
1057           Args.hasArg(OPT_embed_bitcode) ? IMG_Bitcode : IMG_Object;
1058       TheImage.TheOffloadKind = Kind;
1059       TheImage.StringData = {
1060           {"triple",
1061            Args.MakeArgString(LinkerArgs.getLastArgValue(OPT_triple_EQ))},
1062           {"arch",
1063            Args.MakeArgString(LinkerArgs.getLastArgValue(OPT_arch_EQ))}};
1064       TheImage.Image = std::move(*FileOrErr);
1065 
1066       Images[Kind].emplace_back(std::move(TheImage));
1067     }
1068     return Error::success();
1069   });
1070   if (Err)
1071     return std::move(Err);
1072 
1073   // Create a binary image of each offloading image and embed it into a new
1074   // object file.
1075   SmallVector<StringRef> WrappedOutput;
1076   for (auto &[Kind, Input] : Images) {
1077     // We sort the entries before bundling so they appear in a deterministic
1078     // order in the final binary.
1079     llvm::sort(Input, [](OffloadingImage &A, OffloadingImage &B) {
1080       return A.StringData["triple"] > B.StringData["triple"] ||
1081              A.StringData["arch"] > B.StringData["arch"] ||
1082              A.TheOffloadKind < B.TheOffloadKind;
1083     });
1084     auto BundledImagesOrErr = bundleLinkedOutput(Input, Args, Kind);
1085     if (!BundledImagesOrErr)
1086       return BundledImagesOrErr.takeError();
1087     auto OutputOrErr = wrapDeviceImages(*BundledImagesOrErr, Args, Kind);
1088     if (!OutputOrErr)
1089       return OutputOrErr.takeError();
1090     WrappedOutput.push_back(*OutputOrErr);
1091   }
1092 
1093   return WrappedOutput;
1094 }
1095 
1096 std::optional<std::string> findFile(StringRef Dir, StringRef Root,
1097                                     const Twine &Name) {
1098   SmallString<128> Path;
1099   if (Dir.startswith("="))
1100     sys::path::append(Path, Root, Dir.substr(1), Name);
1101   else
1102     sys::path::append(Path, Dir, Name);
1103 
1104   if (sys::fs::exists(Path))
1105     return static_cast<std::string>(Path);
1106   return std::nullopt;
1107 }
1108 
1109 std::optional<std::string>
1110 findFromSearchPaths(StringRef Name, StringRef Root,
1111                     ArrayRef<StringRef> SearchPaths) {
1112   for (StringRef Dir : SearchPaths)
1113     if (std::optional<std::string> File = findFile(Dir, Root, Name))
1114       return File;
1115   return std::nullopt;
1116 }
1117 
1118 std::optional<std::string>
1119 searchLibraryBaseName(StringRef Name, StringRef Root,
1120                       ArrayRef<StringRef> SearchPaths) {
1121   for (StringRef Dir : SearchPaths) {
1122     if (std::optional<std::string> File =
1123             findFile(Dir, Root, "lib" + Name + ".so"))
1124       return File;
1125     if (std::optional<std::string> File =
1126             findFile(Dir, Root, "lib" + Name + ".a"))
1127       return File;
1128   }
1129   return std::nullopt;
1130 }
1131 
1132 /// Search for static libraries in the linker's library path given input like
1133 /// `-lfoo` or `-l:libfoo.a`.
1134 std::optional<std::string> searchLibrary(StringRef Input, StringRef Root,
1135                                          ArrayRef<StringRef> SearchPaths) {
1136   if (Input.startswith(":"))
1137     return findFromSearchPaths(Input.drop_front(), Root, SearchPaths);
1138   return searchLibraryBaseName(Input, Root, SearchPaths);
1139 }
1140 
1141 /// Common redeclaration of needed symbol flags.
1142 enum Symbol : uint32_t {
1143   Sym_None = 0,
1144   Sym_Undefined = 1U << 1,
1145   Sym_Weak = 1U << 2,
1146 };
1147 
1148 /// Scan the symbols from a BitcodeFile \p Buffer and record if we need to
1149 /// extract any symbols from it.
1150 Expected<bool> getSymbolsFromBitcode(MemoryBufferRef Buffer, StringSaver &Saver,
1151                                      DenseMap<StringRef, Symbol> &Syms) {
1152   Expected<IRSymtabFile> IRSymtabOrErr = readIRSymtab(Buffer);
1153   if (!IRSymtabOrErr)
1154     return IRSymtabOrErr.takeError();
1155 
1156   bool ShouldExtract = false;
1157   for (unsigned I = 0; I != IRSymtabOrErr->Mods.size(); ++I) {
1158     for (const auto &Sym : IRSymtabOrErr->TheReader.module_symbols(I)) {
1159       if (Sym.isFormatSpecific() || !Sym.isGlobal())
1160         continue;
1161 
1162       bool NewSymbol = Syms.count(Sym.getName()) == 0;
1163       auto &OldSym = Syms[Saver.save(Sym.getName())];
1164 
1165       // We will extract if it defines a currenlty undefined non-weak symbol.
1166       bool ResolvesStrongReference =
1167           ((OldSym & Sym_Undefined && !(OldSym & Sym_Weak)) &&
1168            !Sym.isUndefined());
1169       // We will extract if it defines a new global symbol visible to the host.
1170       bool NewGlobalSymbol =
1171           ((NewSymbol || (OldSym & Sym_Undefined)) && !Sym.isUndefined() &&
1172            !Sym.canBeOmittedFromSymbolTable() &&
1173            (Sym.getVisibility() != GlobalValue::HiddenVisibility));
1174       ShouldExtract |= ResolvesStrongReference | NewGlobalSymbol;
1175 
1176       // Update this symbol in the "table" with the new information.
1177       if (OldSym & Sym_Undefined && !Sym.isUndefined())
1178         OldSym = static_cast<Symbol>(OldSym & ~Sym_Undefined);
1179       if (Sym.isUndefined() && NewSymbol)
1180         OldSym = static_cast<Symbol>(OldSym | Sym_Undefined);
1181       if (Sym.isWeak())
1182         OldSym = static_cast<Symbol>(OldSym | Sym_Weak);
1183     }
1184   }
1185 
1186   return ShouldExtract;
1187 }
1188 
1189 /// Scan the symbols from an ObjectFile \p Obj and record if we need to extract
1190 /// any symbols from it.
1191 Expected<bool> getSymbolsFromObject(const ObjectFile &Obj, StringSaver &Saver,
1192                                     DenseMap<StringRef, Symbol> &Syms) {
1193   bool ShouldExtract = false;
1194   for (SymbolRef Sym : Obj.symbols()) {
1195     auto FlagsOrErr = Sym.getFlags();
1196     if (!FlagsOrErr)
1197       return FlagsOrErr.takeError();
1198 
1199     if (!(*FlagsOrErr & SymbolRef::SF_Global) ||
1200         (*FlagsOrErr & SymbolRef::SF_FormatSpecific))
1201       continue;
1202 
1203     auto NameOrErr = Sym.getName();
1204     if (!NameOrErr)
1205       return NameOrErr.takeError();
1206 
1207     bool NewSymbol = Syms.count(*NameOrErr) == 0;
1208     auto &OldSym = Syms[Saver.save(*NameOrErr)];
1209 
1210     // We will extract if it defines a currenlty undefined non-weak symbol.
1211     bool ResolvesStrongReference = (OldSym & Sym_Undefined) &&
1212                                    !(OldSym & Sym_Weak) &&
1213                                    !(*FlagsOrErr & SymbolRef::SF_Undefined);
1214 
1215     // We will extract if it defines a new global symbol visible to the host.
1216     bool NewGlobalSymbol = ((NewSymbol || (OldSym & Sym_Undefined)) &&
1217                             !(*FlagsOrErr & SymbolRef::SF_Undefined) &&
1218                             !(*FlagsOrErr & SymbolRef::SF_Hidden));
1219     ShouldExtract |= ResolvesStrongReference | NewGlobalSymbol;
1220 
1221     // Update this symbol in the "table" with the new information.
1222     if (OldSym & Sym_Undefined && !(*FlagsOrErr & SymbolRef::SF_Undefined))
1223       OldSym = static_cast<Symbol>(OldSym & ~Sym_Undefined);
1224     if (*FlagsOrErr & SymbolRef::SF_Undefined && NewSymbol)
1225       OldSym = static_cast<Symbol>(OldSym | Sym_Undefined);
1226     if (*FlagsOrErr & SymbolRef::SF_Weak)
1227       OldSym = static_cast<Symbol>(OldSym | Sym_Weak);
1228   }
1229   return ShouldExtract;
1230 }
1231 
1232 /// Attempt to 'resolve' symbols found in input files. We use this to
1233 /// determine if an archive member needs to be extracted. An archive member
1234 /// will be extracted if any of the following is true.
1235 ///   1) It defines an undefined symbol in a regular object filie.
1236 ///   2) It defines a global symbol without hidden visibility that has not
1237 ///      yet been defined.
1238 Expected<bool> getSymbols(StringRef Image, StringSaver &Saver,
1239                           DenseMap<StringRef, Symbol> &Syms) {
1240   MemoryBufferRef Buffer = MemoryBufferRef(Image, "");
1241   switch (identify_magic(Image)) {
1242   case file_magic::bitcode:
1243     return getSymbolsFromBitcode(Buffer, Saver, Syms);
1244   case file_magic::elf_relocatable: {
1245     Expected<std::unique_ptr<ObjectFile>> ObjFile =
1246         ObjectFile::createObjectFile(Buffer);
1247     if (!ObjFile)
1248       return ObjFile.takeError();
1249     return getSymbolsFromObject(**ObjFile, Saver, Syms);
1250   }
1251   default:
1252     return false;
1253   }
1254 }
1255 
1256 /// Search the input files and libraries for embedded device offloading code
1257 /// and add it to the list of files to be linked. Files coming from static
1258 /// libraries are only added to the input if they are used by an existing
1259 /// input file.
1260 Expected<SmallVector<OffloadFile>> getDeviceInput(const ArgList &Args) {
1261   llvm::TimeTraceScope TimeScope("ExtractDeviceCode");
1262 
1263   StringRef Root = Args.getLastArgValue(OPT_sysroot_EQ);
1264   SmallVector<StringRef> LibraryPaths;
1265   for (const opt::Arg *Arg : Args.filtered(OPT_library_path))
1266     LibraryPaths.push_back(Arg->getValue());
1267 
1268   BumpPtrAllocator Alloc;
1269   StringSaver Saver(Alloc);
1270 
1271   // Try to extract device code from the linker input files.
1272   SmallVector<OffloadFile> InputFiles;
1273   DenseMap<OffloadFile::TargetID, DenseMap<StringRef, Symbol>> Syms;
1274   for (const opt::Arg *Arg : Args.filtered(OPT_INPUT, OPT_library)) {
1275     std::optional<std::string> Filename =
1276         Arg->getOption().matches(OPT_library)
1277             ? searchLibrary(Arg->getValue(), Root, LibraryPaths)
1278             : std::string(Arg->getValue());
1279 
1280     if (!Filename && Arg->getOption().matches(OPT_library))
1281       reportError(createStringError(inconvertibleErrorCode(),
1282                                     "unable to find library -l%s",
1283                                     Arg->getValue()));
1284 
1285     if (!Filename || !sys::fs::exists(*Filename) ||
1286         sys::fs::is_directory(*Filename))
1287       continue;
1288 
1289     ErrorOr<std::unique_ptr<MemoryBuffer>> BufferOrErr =
1290         MemoryBuffer::getFileOrSTDIN(*Filename);
1291     if (std::error_code EC = BufferOrErr.getError())
1292       return createFileError(*Filename, EC);
1293 
1294     MemoryBufferRef Buffer = **BufferOrErr;
1295     if (identify_magic(Buffer.getBuffer()) == file_magic::elf_shared_object)
1296       continue;
1297 
1298     SmallVector<OffloadFile> Binaries;
1299     if (Error Err = extractOffloadBinaries(Buffer, Binaries))
1300       return std::move(Err);
1301 
1302     // We only extract archive members that are needed.
1303     bool IsArchive = identify_magic(Buffer.getBuffer()) == file_magic::archive;
1304     bool Extracted = true;
1305     while (Extracted) {
1306       Extracted = false;
1307       for (OffloadFile &Binary : Binaries) {
1308         if (!Binary.getBinary())
1309           continue;
1310 
1311         // If we don't have an object file for this architecture do not
1312         // extract.
1313         if (IsArchive && !Syms.count(Binary))
1314           continue;
1315 
1316         Expected<bool> ExtractOrErr =
1317             getSymbols(Binary.getBinary()->getImage(), Saver, Syms[Binary]);
1318         if (!ExtractOrErr)
1319           return ExtractOrErr.takeError();
1320 
1321         Extracted = IsArchive && *ExtractOrErr;
1322 
1323         if (!IsArchive || Extracted)
1324           InputFiles.emplace_back(std::move(Binary));
1325 
1326         // If we extracted any files we need to check all the symbols again.
1327         if (Extracted)
1328           break;
1329       }
1330     }
1331   }
1332 
1333   for (StringRef Library : Args.getAllArgValues(OPT_bitcode_library_EQ)) {
1334     auto FileOrErr = getInputBitcodeLibrary(Library);
1335     if (!FileOrErr)
1336       return FileOrErr.takeError();
1337     InputFiles.push_back(std::move(*FileOrErr));
1338   }
1339 
1340   return std::move(InputFiles);
1341 }
1342 
1343 } // namespace
1344 
1345 int main(int Argc, char **Argv) {
1346   InitLLVM X(Argc, Argv);
1347   InitializeAllTargetInfos();
1348   InitializeAllTargets();
1349   InitializeAllTargetMCs();
1350   InitializeAllAsmParsers();
1351   InitializeAllAsmPrinters();
1352 
1353   LinkerExecutable = Argv[0];
1354   sys::PrintStackTraceOnErrorSignal(Argv[0]);
1355 
1356   const OptTable &Tbl = getOptTable();
1357   BumpPtrAllocator Alloc;
1358   StringSaver Saver(Alloc);
1359   auto Args = Tbl.parseArgs(Argc, Argv, OPT_INVALID, Saver, [&](StringRef Err) {
1360     reportError(createStringError(inconvertibleErrorCode(), Err));
1361   });
1362 
1363   if (Args.hasArg(OPT_help) || Args.hasArg(OPT_help_hidden)) {
1364     Tbl.printHelp(
1365         outs(),
1366         "clang-linker-wrapper [options] -- <options to passed to the linker>",
1367         "\nA wrapper utility over the host linker. It scans the input files\n"
1368         "for sections that require additional processing prior to linking.\n"
1369         "The will then transparently pass all arguments and input to the\n"
1370         "specified host linker to create the final binary.\n",
1371         Args.hasArg(OPT_help_hidden), Args.hasArg(OPT_help_hidden));
1372     return EXIT_SUCCESS;
1373   }
1374   if (Args.hasArg(OPT_v)) {
1375     printVersion(outs());
1376     return EXIT_SUCCESS;
1377   }
1378 
1379   // This forwards '-mllvm' arguments to LLVM if present.
1380   SmallVector<const char *> NewArgv = {Argv[0]};
1381   for (const opt::Arg *Arg : Args.filtered(OPT_mllvm))
1382     NewArgv.push_back(Arg->getValue());
1383   for (const opt::Arg *Arg : Args.filtered(OPT_offload_opt_eq_minus))
1384     NewArgv.push_back(Args.MakeArgString(StringRef("-") + Arg->getValue()));
1385   cl::ParseCommandLineOptions(NewArgv.size(), &NewArgv[0]);
1386 
1387   Verbose = Args.hasArg(OPT_verbose);
1388   DryRun = Args.hasArg(OPT_dry_run);
1389   SaveTemps = Args.hasArg(OPT_save_temps);
1390   ExecutableName = Args.getLastArgValue(OPT_o, "a.out");
1391   CudaBinaryPath = Args.getLastArgValue(OPT_cuda_path_EQ).str();
1392 
1393   parallel::strategy = hardware_concurrency(1);
1394   if (auto *Arg = Args.getLastArg(OPT_wrapper_jobs)) {
1395     unsigned Threads = 0;
1396     if (!llvm::to_integer(Arg->getValue(), Threads) || Threads == 0)
1397       reportError(createStringError(
1398           inconvertibleErrorCode(), "%s: expected a positive integer, got '%s'",
1399           Arg->getSpelling().data(), Arg->getValue()));
1400     parallel::strategy = hardware_concurrency(Threads);
1401   }
1402 
1403   if (Args.hasArg(OPT_wrapper_time_trace_eq)) {
1404     unsigned Granularity;
1405     Args.getLastArgValue(OPT_wrapper_time_trace_granularity, "500")
1406         .getAsInteger(10, Granularity);
1407     timeTraceProfilerInitialize(Granularity, Argv[0]);
1408   }
1409 
1410   {
1411     llvm::TimeTraceScope TimeScope("Execute linker wrapper");
1412 
1413     // Extract the device input files stored in the host fat binary.
1414     auto DeviceInputFiles = getDeviceInput(Args);
1415     if (!DeviceInputFiles)
1416       reportError(DeviceInputFiles.takeError());
1417 
1418     // Link and wrap the device images extracted from the linker input.
1419     auto FilesOrErr =
1420         linkAndWrapDeviceFiles(*DeviceInputFiles, Args, Argv, Argc);
1421     if (!FilesOrErr)
1422       reportError(FilesOrErr.takeError());
1423 
1424     // Run the host linking job with the rendered arguments.
1425     if (Error Err = runLinker(*FilesOrErr, Args))
1426       reportError(std::move(Err));
1427   }
1428 
1429   if (const opt::Arg *Arg = Args.getLastArg(OPT_wrapper_time_trace_eq)) {
1430     if (Error Err = timeTraceProfilerWrite(Arg->getValue(), ExecutableName))
1431       reportError(std::move(Err));
1432     timeTraceProfilerCleanup();
1433   }
1434 
1435   // Remove the temporary files created.
1436   if (!SaveTemps)
1437     for (const auto &TempFile : TempFiles)
1438       if (std::error_code EC = sys::fs::remove(TempFile))
1439         reportError(createFileError(TempFile, EC));
1440 
1441   return EXIT_SUCCESS;
1442 }
1443