1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/TargetLibraryInfo.h"
24 #include "llvm/IR/CFG.h"
25 #include "llvm/IR/Constants.h"
26 #include "llvm/IR/DebugInfoMetadata.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/GlobalVariable.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/MDBuilder.h"
31 #include "llvm/IR/PassManager.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/MC/TargetRegistry.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Target/TargetMachine.h"
36 #include "llvm/Target/TargetOptions.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include "llvm/Transforms/Utils/Cloning.h"
39 #include "llvm/Transforms/Utils/CodeExtractor.h"
40 #include "llvm/Transforms/Utils/LoopPeel.h"
41 #include "llvm/Transforms/Utils/UnrollLoop.h"
42 
43 #include <cstdint>
44 #include <optional>
45 
46 #define DEBUG_TYPE "openmp-ir-builder"
47 
48 using namespace llvm;
49 using namespace omp;
50 
51 static cl::opt<bool>
52     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
53                          cl::desc("Use optimistic attributes describing "
54                                   "'as-if' properties of runtime calls."),
55                          cl::init(false));
56 
57 static cl::opt<double> UnrollThresholdFactor(
58     "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
59     cl::desc("Factor for the unroll threshold to account for code "
60              "simplifications still taking place"),
61     cl::init(1.5));
62 
63 #ifndef NDEBUG
64 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
65 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
66 /// an InsertPoint stores the instruction before something is inserted. For
67 /// instance, if both point to the same instruction, two IRBuilders alternating
68 /// creating instruction will cause the instructions to be interleaved.
isConflictIP(IRBuilder<>::InsertPoint IP1,IRBuilder<>::InsertPoint IP2)69 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
70                          IRBuilder<>::InsertPoint IP2) {
71   if (!IP1.isSet() || !IP2.isSet())
72     return false;
73   return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
74 }
75 
isValidWorkshareLoopScheduleType(OMPScheduleType SchedType)76 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
77   // Valid ordered/unordered and base algorithm combinations.
78   switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
79   case OMPScheduleType::UnorderedStaticChunked:
80   case OMPScheduleType::UnorderedStatic:
81   case OMPScheduleType::UnorderedDynamicChunked:
82   case OMPScheduleType::UnorderedGuidedChunked:
83   case OMPScheduleType::UnorderedRuntime:
84   case OMPScheduleType::UnorderedAuto:
85   case OMPScheduleType::UnorderedTrapezoidal:
86   case OMPScheduleType::UnorderedGreedy:
87   case OMPScheduleType::UnorderedBalanced:
88   case OMPScheduleType::UnorderedGuidedIterativeChunked:
89   case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
90   case OMPScheduleType::UnorderedSteal:
91   case OMPScheduleType::UnorderedStaticBalancedChunked:
92   case OMPScheduleType::UnorderedGuidedSimd:
93   case OMPScheduleType::UnorderedRuntimeSimd:
94   case OMPScheduleType::OrderedStaticChunked:
95   case OMPScheduleType::OrderedStatic:
96   case OMPScheduleType::OrderedDynamicChunked:
97   case OMPScheduleType::OrderedGuidedChunked:
98   case OMPScheduleType::OrderedRuntime:
99   case OMPScheduleType::OrderedAuto:
100   case OMPScheduleType::OrderdTrapezoidal:
101   case OMPScheduleType::NomergeUnorderedStaticChunked:
102   case OMPScheduleType::NomergeUnorderedStatic:
103   case OMPScheduleType::NomergeUnorderedDynamicChunked:
104   case OMPScheduleType::NomergeUnorderedGuidedChunked:
105   case OMPScheduleType::NomergeUnorderedRuntime:
106   case OMPScheduleType::NomergeUnorderedAuto:
107   case OMPScheduleType::NomergeUnorderedTrapezoidal:
108   case OMPScheduleType::NomergeUnorderedGreedy:
109   case OMPScheduleType::NomergeUnorderedBalanced:
110   case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
111   case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
112   case OMPScheduleType::NomergeUnorderedSteal:
113   case OMPScheduleType::NomergeOrderedStaticChunked:
114   case OMPScheduleType::NomergeOrderedStatic:
115   case OMPScheduleType::NomergeOrderedDynamicChunked:
116   case OMPScheduleType::NomergeOrderedGuidedChunked:
117   case OMPScheduleType::NomergeOrderedRuntime:
118   case OMPScheduleType::NomergeOrderedAuto:
119   case OMPScheduleType::NomergeOrderedTrapezoidal:
120     break;
121   default:
122     return false;
123   }
124 
125   // Must not set both monotonicity modifiers at the same time.
126   OMPScheduleType MonotonicityFlags =
127       SchedType & OMPScheduleType::MonotonicityMask;
128   if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
129     return false;
130 
131   return true;
132 }
133 #endif
134 
135 /// Determine which scheduling algorithm to use, determined from schedule clause
136 /// arguments.
137 static OMPScheduleType
getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier)138 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
139                           bool HasSimdModifier) {
140   // Currently, the default schedule it static.
141   switch (ClauseKind) {
142   case OMP_SCHEDULE_Default:
143   case OMP_SCHEDULE_Static:
144     return HasChunks ? OMPScheduleType::BaseStaticChunked
145                      : OMPScheduleType::BaseStatic;
146   case OMP_SCHEDULE_Dynamic:
147     return OMPScheduleType::BaseDynamicChunked;
148   case OMP_SCHEDULE_Guided:
149     return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
150                            : OMPScheduleType::BaseGuidedChunked;
151   case OMP_SCHEDULE_Auto:
152     return llvm::omp::OMPScheduleType::BaseAuto;
153   case OMP_SCHEDULE_Runtime:
154     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
155                            : OMPScheduleType::BaseRuntime;
156   }
157   llvm_unreachable("unhandled schedule clause argument");
158 }
159 
160 /// Adds ordering modifier flags to schedule type.
161 static OMPScheduleType
getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,bool HasOrderedClause)162 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
163                               bool HasOrderedClause) {
164   assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
165              OMPScheduleType::None &&
166          "Must not have ordering nor monotonicity flags already set");
167 
168   OMPScheduleType OrderingModifier = HasOrderedClause
169                                          ? OMPScheduleType::ModifierOrdered
170                                          : OMPScheduleType::ModifierUnordered;
171   OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
172 
173   // Unsupported combinations
174   if (OrderingScheduleType ==
175       (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
176     return OMPScheduleType::OrderedGuidedChunked;
177   else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
178                                     OMPScheduleType::ModifierOrdered))
179     return OMPScheduleType::OrderedRuntime;
180 
181   return OrderingScheduleType;
182 }
183 
184 /// Adds monotonicity modifier flags to schedule type.
185 static OMPScheduleType
getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,bool HasSimdModifier,bool HasMonotonic,bool HasNonmonotonic,bool HasOrderedClause)186 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
187                                   bool HasSimdModifier, bool HasMonotonic,
188                                   bool HasNonmonotonic, bool HasOrderedClause) {
189   assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
190              OMPScheduleType::None &&
191          "Must not have monotonicity flags already set");
192   assert((!HasMonotonic || !HasNonmonotonic) &&
193          "Monotonic and Nonmonotonic are contradicting each other");
194 
195   if (HasMonotonic) {
196     return ScheduleType | OMPScheduleType::ModifierMonotonic;
197   } else if (HasNonmonotonic) {
198     return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
199   } else {
200     // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
201     // If the static schedule kind is specified or if the ordered clause is
202     // specified, and if the nonmonotonic modifier is not specified, the
203     // effect is as if the monotonic modifier is specified. Otherwise, unless
204     // the monotonic modifier is specified, the effect is as if the
205     // nonmonotonic modifier is specified.
206     OMPScheduleType BaseScheduleType =
207         ScheduleType & ~OMPScheduleType::ModifierMask;
208     if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
209         (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
210         HasOrderedClause) {
211       // The monotonic is used by default in openmp runtime library, so no need
212       // to set it.
213       return ScheduleType;
214     } else {
215       return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
216     }
217   }
218 }
219 
220 /// Determine the schedule type using schedule and ordering clause arguments.
221 static OMPScheduleType
computeOpenMPScheduleType(ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)222 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
223                           bool HasSimdModifier, bool HasMonotonicModifier,
224                           bool HasNonmonotonicModifier, bool HasOrderedClause) {
225   OMPScheduleType BaseSchedule =
226       getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
227   OMPScheduleType OrderedSchedule =
228       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
229   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
230       OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
231       HasNonmonotonicModifier, HasOrderedClause);
232 
233   assert(isValidWorkshareLoopScheduleType(Result));
234   return Result;
235 }
236 
237 /// Make \p Source branch to \p Target.
238 ///
239 /// Handles two situations:
240 /// * \p Source already has an unconditional branch.
241 /// * \p Source is a degenerate block (no terminator because the BB is
242 ///             the current head of the IR construction).
redirectTo(BasicBlock * Source,BasicBlock * Target,DebugLoc DL)243 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
244   if (Instruction *Term = Source->getTerminator()) {
245     auto *Br = cast<BranchInst>(Term);
246     assert(!Br->isConditional() &&
247            "BB's terminator must be an unconditional branch (or degenerate)");
248     BasicBlock *Succ = Br->getSuccessor(0);
249     Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
250     Br->setSuccessor(0, Target);
251     return;
252   }
253 
254   auto *NewBr = BranchInst::Create(Target, Source);
255   NewBr->setDebugLoc(DL);
256 }
257 
spliceBB(IRBuilderBase::InsertPoint IP,BasicBlock * New,bool CreateBranch)258 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
259                     bool CreateBranch) {
260   assert(New->getFirstInsertionPt() == New->begin() &&
261          "Target BB must not have PHI nodes");
262 
263   // Move instructions to new block.
264   BasicBlock *Old = IP.getBlock();
265   New->splice(New->begin(), Old, IP.getPoint(), Old->end());
266 
267   if (CreateBranch)
268     BranchInst::Create(New, Old);
269 }
270 
spliceBB(IRBuilder<> & Builder,BasicBlock * New,bool CreateBranch)271 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
272   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
273   BasicBlock *Old = Builder.GetInsertBlock();
274 
275   spliceBB(Builder.saveIP(), New, CreateBranch);
276   if (CreateBranch)
277     Builder.SetInsertPoint(Old->getTerminator());
278   else
279     Builder.SetInsertPoint(Old);
280 
281   // SetInsertPoint also updates the Builder's debug location, but we want to
282   // keep the one the Builder was configured to use.
283   Builder.SetCurrentDebugLocation(DebugLoc);
284 }
285 
splitBB(IRBuilderBase::InsertPoint IP,bool CreateBranch,llvm::Twine Name)286 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
287                           llvm::Twine Name) {
288   BasicBlock *Old = IP.getBlock();
289   BasicBlock *New = BasicBlock::Create(
290       Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
291       Old->getParent(), Old->getNextNode());
292   spliceBB(IP, New, CreateBranch);
293   New->replaceSuccessorsPhiUsesWith(Old, New);
294   return New;
295 }
296 
splitBB(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Name)297 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
298                           llvm::Twine Name) {
299   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
300   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
301   if (CreateBranch)
302     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
303   else
304     Builder.SetInsertPoint(Builder.GetInsertBlock());
305   // SetInsertPoint also updates the Builder's debug location, but we want to
306   // keep the one the Builder was configured to use.
307   Builder.SetCurrentDebugLocation(DebugLoc);
308   return New;
309 }
310 
splitBB(IRBuilder<> & Builder,bool CreateBranch,llvm::Twine Name)311 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
312                           llvm::Twine Name) {
313   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
314   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
315   if (CreateBranch)
316     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
317   else
318     Builder.SetInsertPoint(Builder.GetInsertBlock());
319   // SetInsertPoint also updates the Builder's debug location, but we want to
320   // keep the one the Builder was configured to use.
321   Builder.SetCurrentDebugLocation(DebugLoc);
322   return New;
323 }
324 
splitBBWithSuffix(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Suffix)325 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
326                                     llvm::Twine Suffix) {
327   BasicBlock *Old = Builder.GetInsertBlock();
328   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
329 }
330 
addAttributes(omp::RuntimeFunction FnID,Function & Fn)331 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
332   LLVMContext &Ctx = Fn.getContext();
333   Triple T(M.getTargetTriple());
334 
335   // Get the function's current attributes.
336   auto Attrs = Fn.getAttributes();
337   auto FnAttrs = Attrs.getFnAttrs();
338   auto RetAttrs = Attrs.getRetAttrs();
339   SmallVector<AttributeSet, 4> ArgAttrs;
340   for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
341     ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
342 
343   // Add AS to FnAS while taking special care with integer extensions.
344   auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
345                         bool Param = true) -> void {
346     bool HasSignExt = AS.hasAttribute(Attribute::SExt);
347     bool HasZeroExt = AS.hasAttribute(Attribute::ZExt);
348     if (HasSignExt || HasZeroExt) {
349       assert(AS.getNumAttributes() == 1 &&
350              "Currently not handling extension attr combined with others.");
351       if (Param) {
352         if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
353           FnAS = FnAS.addAttribute(Ctx, AK);
354       } else
355         if (auto AK = TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
356           FnAS = FnAS.addAttribute(Ctx, AK);
357     } else {
358       FnAS = FnAS.addAttributes(Ctx, AS);
359     }
360   };
361 
362 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
363 #include "llvm/Frontend/OpenMP/OMPKinds.def"
364 
365   // Add attributes to the function declaration.
366   switch (FnID) {
367 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
368   case Enum:                                                                   \
369     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
370     addAttrSet(RetAttrs, RetAttrSet, /*Param*/false);                          \
371     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
372       addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]);                         \
373     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
374     break;
375 #include "llvm/Frontend/OpenMP/OMPKinds.def"
376   default:
377     // Attributes are optional.
378     break;
379   }
380 }
381 
382 FunctionCallee
getOrCreateRuntimeFunction(Module & M,RuntimeFunction FnID)383 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
384   FunctionType *FnTy = nullptr;
385   Function *Fn = nullptr;
386 
387   // Try to find the declation in the module first.
388   switch (FnID) {
389 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...)                          \
390   case Enum:                                                                   \
391     FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__},        \
392                              IsVarArg);                                        \
393     Fn = M.getFunction(Str);                                                   \
394     break;
395 #include "llvm/Frontend/OpenMP/OMPKinds.def"
396   }
397 
398   if (!Fn) {
399     // Create a new declaration if we need one.
400     switch (FnID) {
401 #define OMP_RTL(Enum, Str, ...)                                                \
402   case Enum:                                                                   \
403     Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M);         \
404     break;
405 #include "llvm/Frontend/OpenMP/OMPKinds.def"
406     }
407 
408     // Add information if the runtime function takes a callback function
409     if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
410       if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
411         LLVMContext &Ctx = Fn->getContext();
412         MDBuilder MDB(Ctx);
413         // Annotate the callback behavior of the runtime function:
414         //  - The callback callee is argument number 2 (microtask).
415         //  - The first two arguments of the callback callee are unknown (-1).
416         //  - All variadic arguments to the runtime function are passed to the
417         //    callback callee.
418         Fn->addMetadata(
419             LLVMContext::MD_callback,
420             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
421                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
422       }
423     }
424 
425     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
426                       << " with type " << *Fn->getFunctionType() << "\n");
427     addAttributes(FnID, *Fn);
428 
429   } else {
430     LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
431                       << " with type " << *Fn->getFunctionType() << "\n");
432   }
433 
434   assert(Fn && "Failed to create OpenMP runtime function");
435 
436   // Cast the function to the expected type if necessary
437   Constant *C = ConstantExpr::getBitCast(Fn, FnTy->getPointerTo());
438   return {FnTy, C};
439 }
440 
getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID)441 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
442   FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
443   auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
444   assert(Fn && "Failed to create OpenMP runtime function pointer");
445   return Fn;
446 }
447 
initialize()448 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
449 
finalize(Function * Fn)450 void OpenMPIRBuilder::finalize(Function *Fn) {
451   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
452   SmallVector<BasicBlock *, 32> Blocks;
453   SmallVector<OutlineInfo, 16> DeferredOutlines;
454   for (OutlineInfo &OI : OutlineInfos) {
455     // Skip functions that have not finalized yet; may happen with nested
456     // function generation.
457     if (Fn && OI.getFunction() != Fn) {
458       DeferredOutlines.push_back(OI);
459       continue;
460     }
461 
462     ParallelRegionBlockSet.clear();
463     Blocks.clear();
464     OI.collectBlocks(ParallelRegionBlockSet, Blocks);
465 
466     Function *OuterFn = OI.getFunction();
467     CodeExtractorAnalysisCache CEAC(*OuterFn);
468     CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
469                             /* AggregateArgs */ true,
470                             /* BlockFrequencyInfo */ nullptr,
471                             /* BranchProbabilityInfo */ nullptr,
472                             /* AssumptionCache */ nullptr,
473                             /* AllowVarArgs */ true,
474                             /* AllowAlloca */ true,
475                             /* AllocaBlock*/ OI.OuterAllocaBB,
476                             /* Suffix */ ".omp_par");
477 
478     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
479     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
480                       << " Exit: " << OI.ExitBB->getName() << "\n");
481     assert(Extractor.isEligible() &&
482            "Expected OpenMP outlining to be possible!");
483 
484     for (auto *V : OI.ExcludeArgsFromAggregate)
485       Extractor.excludeArgFromAggregate(V);
486 
487     Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
488 
489     LLVM_DEBUG(dbgs() << "After      outlining: " << *OuterFn << "\n");
490     LLVM_DEBUG(dbgs() << "   Outlined function: " << *OutlinedFn << "\n");
491     assert(OutlinedFn->getReturnType()->isVoidTy() &&
492            "OpenMP outlined functions should not return a value!");
493 
494     // For compability with the clang CG we move the outlined function after the
495     // one with the parallel region.
496     OutlinedFn->removeFromParent();
497     M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
498 
499     // Remove the artificial entry introduced by the extractor right away, we
500     // made our own entry block after all.
501     {
502       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
503       assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
504       assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
505       // Move instructions from the to-be-deleted ArtificialEntry to the entry
506       // basic block of the parallel region. CodeExtractor generates
507       // instructions to unwrap the aggregate argument and may sink
508       // allocas/bitcasts for values that are solely used in the outlined region
509       // and do not escape.
510       assert(!ArtificialEntry.empty() &&
511              "Expected instructions to add in the outlined region entry");
512       for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
513                                         End = ArtificialEntry.rend();
514            It != End;) {
515         Instruction &I = *It;
516         It++;
517 
518         if (I.isTerminator())
519           continue;
520 
521         I.moveBefore(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
522       }
523 
524       OI.EntryBB->moveBefore(&ArtificialEntry);
525       ArtificialEntry.eraseFromParent();
526     }
527     assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
528     assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
529 
530     // Run a user callback, e.g. to add attributes.
531     if (OI.PostOutlineCB)
532       OI.PostOutlineCB(*OutlinedFn);
533   }
534 
535   // Remove work items that have been completed.
536   OutlineInfos = std::move(DeferredOutlines);
537 }
538 
~OpenMPIRBuilder()539 OpenMPIRBuilder::~OpenMPIRBuilder() {
540   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
541 }
542 
createGlobalFlag(unsigned Value,StringRef Name)543 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
544   IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
545   auto *GV =
546       new GlobalVariable(M, I32Ty,
547                          /* isConstant = */ true, GlobalValue::WeakODRLinkage,
548                          ConstantInt::get(I32Ty, Value), Name);
549   GV->setVisibility(GlobalValue::HiddenVisibility);
550 
551   return GV;
552 }
553 
getOrCreateIdent(Constant * SrcLocStr,uint32_t SrcLocStrSize,IdentFlag LocFlags,unsigned Reserve2Flags)554 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
555                                             uint32_t SrcLocStrSize,
556                                             IdentFlag LocFlags,
557                                             unsigned Reserve2Flags) {
558   // Enable "C-mode".
559   LocFlags |= OMP_IDENT_FLAG_KMPC;
560 
561   Constant *&Ident =
562       IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
563   if (!Ident) {
564     Constant *I32Null = ConstantInt::getNullValue(Int32);
565     Constant *IdentData[] = {I32Null,
566                              ConstantInt::get(Int32, uint32_t(LocFlags)),
567                              ConstantInt::get(Int32, Reserve2Flags),
568                              ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
569     Constant *Initializer =
570         ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
571 
572     // Look for existing encoding of the location + flags, not needed but
573     // minimizes the difference to the existing solution while we transition.
574     for (GlobalVariable &GV : M.getGlobalList())
575       if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
576         if (GV.getInitializer() == Initializer)
577           Ident = &GV;
578 
579     if (!Ident) {
580       auto *GV = new GlobalVariable(
581           M, OpenMPIRBuilder::Ident,
582           /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
583           nullptr, GlobalValue::NotThreadLocal,
584           M.getDataLayout().getDefaultGlobalsAddressSpace());
585       GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
586       GV->setAlignment(Align(8));
587       Ident = GV;
588     }
589   }
590 
591   return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
592 }
593 
getOrCreateSrcLocStr(StringRef LocStr,uint32_t & SrcLocStrSize)594 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
595                                                 uint32_t &SrcLocStrSize) {
596   SrcLocStrSize = LocStr.size();
597   Constant *&SrcLocStr = SrcLocStrMap[LocStr];
598   if (!SrcLocStr) {
599     Constant *Initializer =
600         ConstantDataArray::getString(M.getContext(), LocStr);
601 
602     // Look for existing encoding of the location, not needed but minimizes the
603     // difference to the existing solution while we transition.
604     for (GlobalVariable &GV : M.getGlobalList())
605       if (GV.isConstant() && GV.hasInitializer() &&
606           GV.getInitializer() == Initializer)
607         return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
608 
609     SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
610                                               /* AddressSpace */ 0, &M);
611   }
612   return SrcLocStr;
613 }
614 
getOrCreateSrcLocStr(StringRef FunctionName,StringRef FileName,unsigned Line,unsigned Column,uint32_t & SrcLocStrSize)615 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
616                                                 StringRef FileName,
617                                                 unsigned Line, unsigned Column,
618                                                 uint32_t &SrcLocStrSize) {
619   SmallString<128> Buffer;
620   Buffer.push_back(';');
621   Buffer.append(FileName);
622   Buffer.push_back(';');
623   Buffer.append(FunctionName);
624   Buffer.push_back(';');
625   Buffer.append(std::to_string(Line));
626   Buffer.push_back(';');
627   Buffer.append(std::to_string(Column));
628   Buffer.push_back(';');
629   Buffer.push_back(';');
630   return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
631 }
632 
633 Constant *
getOrCreateDefaultSrcLocStr(uint32_t & SrcLocStrSize)634 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
635   StringRef UnknownLoc = ";unknown;unknown;0;0;;";
636   return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
637 }
638 
getOrCreateSrcLocStr(DebugLoc DL,uint32_t & SrcLocStrSize,Function * F)639 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
640                                                 uint32_t &SrcLocStrSize,
641                                                 Function *F) {
642   DILocation *DIL = DL.get();
643   if (!DIL)
644     return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
645   StringRef FileName = M.getName();
646   if (DIFile *DIF = DIL->getFile())
647     if (std::optional<StringRef> Source = DIF->getSource())
648       FileName = *Source;
649   StringRef Function = DIL->getScope()->getSubprogram()->getName();
650   if (Function.empty() && F)
651     Function = F->getName();
652   return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
653                               DIL->getColumn(), SrcLocStrSize);
654 }
655 
getOrCreateSrcLocStr(const LocationDescription & Loc,uint32_t & SrcLocStrSize)656 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
657                                                 uint32_t &SrcLocStrSize) {
658   return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
659                               Loc.IP.getBlock()->getParent());
660 }
661 
getOrCreateThreadID(Value * Ident)662 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
663   return Builder.CreateCall(
664       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
665       "omp_global_thread_num");
666 }
667 
668 OpenMPIRBuilder::InsertPointTy
createBarrier(const LocationDescription & Loc,Directive DK,bool ForceSimpleCall,bool CheckCancelFlag)669 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive DK,
670                                bool ForceSimpleCall, bool CheckCancelFlag) {
671   if (!updateToLocation(Loc))
672     return Loc.IP;
673   return emitBarrierImpl(Loc, DK, ForceSimpleCall, CheckCancelFlag);
674 }
675 
676 OpenMPIRBuilder::InsertPointTy
emitBarrierImpl(const LocationDescription & Loc,Directive Kind,bool ForceSimpleCall,bool CheckCancelFlag)677 OpenMPIRBuilder::emitBarrierImpl(const LocationDescription &Loc, Directive Kind,
678                                  bool ForceSimpleCall, bool CheckCancelFlag) {
679   // Build call __kmpc_cancel_barrier(loc, thread_id) or
680   //            __kmpc_barrier(loc, thread_id);
681 
682   IdentFlag BarrierLocFlags;
683   switch (Kind) {
684   case OMPD_for:
685     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
686     break;
687   case OMPD_sections:
688     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
689     break;
690   case OMPD_single:
691     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
692     break;
693   case OMPD_barrier:
694     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
695     break;
696   default:
697     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
698     break;
699   }
700 
701   uint32_t SrcLocStrSize;
702   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
703   Value *Args[] = {
704       getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
705       getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
706 
707   // If we are in a cancellable parallel region, barriers are cancellation
708   // points.
709   // TODO: Check why we would force simple calls or to ignore the cancel flag.
710   bool UseCancelBarrier =
711       !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
712 
713   Value *Result =
714       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
715                              UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
716                                               : OMPRTL___kmpc_barrier),
717                          Args);
718 
719   if (UseCancelBarrier && CheckCancelFlag)
720     emitCancelationCheckImpl(Result, OMPD_parallel);
721 
722   return Builder.saveIP();
723 }
724 
725 OpenMPIRBuilder::InsertPointTy
createCancel(const LocationDescription & Loc,Value * IfCondition,omp::Directive CanceledDirective)726 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
727                               Value *IfCondition,
728                               omp::Directive CanceledDirective) {
729   if (!updateToLocation(Loc))
730     return Loc.IP;
731 
732   // LLVM utilities like blocks with terminators.
733   auto *UI = Builder.CreateUnreachable();
734 
735   Instruction *ThenTI = UI, *ElseTI = nullptr;
736   if (IfCondition)
737     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
738   Builder.SetInsertPoint(ThenTI);
739 
740   Value *CancelKind = nullptr;
741   switch (CanceledDirective) {
742 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
743   case DirectiveEnum:                                                          \
744     CancelKind = Builder.getInt32(Value);                                      \
745     break;
746 #include "llvm/Frontend/OpenMP/OMPKinds.def"
747   default:
748     llvm_unreachable("Unknown cancel kind!");
749   }
750 
751   uint32_t SrcLocStrSize;
752   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
753   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
754   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
755   Value *Result = Builder.CreateCall(
756       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
757   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
758     if (CanceledDirective == OMPD_parallel) {
759       IRBuilder<>::InsertPointGuard IPG(Builder);
760       Builder.restoreIP(IP);
761       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
762                     omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
763                     /* CheckCancelFlag */ false);
764     }
765   };
766 
767   // The actual cancel logic is shared with others, e.g., cancel_barriers.
768   emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
769 
770   // Update the insertion point and remove the terminator we introduced.
771   Builder.SetInsertPoint(UI->getParent());
772   UI->eraseFromParent();
773 
774   return Builder.saveIP();
775 }
776 
emitOffloadingEntry(Constant * Addr,StringRef Name,uint64_t Size,int32_t Flags,StringRef SectionName)777 void OpenMPIRBuilder::emitOffloadingEntry(Constant *Addr, StringRef Name,
778                                           uint64_t Size, int32_t Flags,
779                                           StringRef SectionName) {
780   Type *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
781   Type *Int32Ty = Type::getInt32Ty(M.getContext());
782   Type *SizeTy = M.getDataLayout().getIntPtrType(M.getContext());
783 
784   Constant *AddrName = ConstantDataArray::getString(M.getContext(), Name);
785 
786   // Create the constant string used to look up the symbol in the device.
787   auto *Str =
788       new llvm::GlobalVariable(M, AddrName->getType(), /*isConstant=*/true,
789                                llvm::GlobalValue::InternalLinkage, AddrName,
790                                ".omp_offloading.entry_name");
791   Str->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
792 
793   // Construct the offloading entry.
794   Constant *EntryData[] = {
795       ConstantExpr::getPointerBitCastOrAddrSpaceCast(Addr, Int8PtrTy),
796       ConstantExpr::getPointerBitCastOrAddrSpaceCast(Str, Int8PtrTy),
797       ConstantInt::get(SizeTy, Size),
798       ConstantInt::get(Int32Ty, Flags),
799       ConstantInt::get(Int32Ty, 0),
800   };
801   Constant *EntryInitializer =
802       ConstantStruct::get(OpenMPIRBuilder::OffloadEntry, EntryData);
803 
804   auto *Entry = new GlobalVariable(
805       M, OpenMPIRBuilder::OffloadEntry,
806       /* isConstant = */ true, GlobalValue::WeakAnyLinkage, EntryInitializer,
807       ".omp_offloading.entry." + Name, nullptr, GlobalValue::NotThreadLocal,
808       M.getDataLayout().getDefaultGlobalsAddressSpace());
809 
810   // The entry has to be created in the section the linker expects it to be.
811   Entry->setSection(SectionName);
812   Entry->setAlignment(Align(1));
813 }
814 
emitTargetKernel(const LocationDescription & Loc,Value * & Return,Value * Ident,Value * DeviceID,Value * NumTeams,Value * NumThreads,Value * HostPtr,ArrayRef<Value * > KernelArgs)815 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
816     const LocationDescription &Loc, Value *&Return, Value *Ident,
817     Value *DeviceID, Value *NumTeams, Value *NumThreads, Value *HostPtr,
818     ArrayRef<Value *> KernelArgs) {
819   if (!updateToLocation(Loc))
820     return Loc.IP;
821 
822   auto *KernelArgsPtr =
823       Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
824   for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
825     llvm::Value *Arg =
826         Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
827     Builder.CreateAlignedStore(
828         KernelArgs[I], Arg,
829         M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
830   }
831 
832   SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
833                                       NumThreads, HostPtr,  KernelArgsPtr};
834 
835   Return = Builder.CreateCall(
836       getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
837       OffloadingArgs);
838 
839   return Builder.saveIP();
840 }
841 
emitCancelationCheckImpl(Value * CancelFlag,omp::Directive CanceledDirective,FinalizeCallbackTy ExitCB)842 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
843                                                omp::Directive CanceledDirective,
844                                                FinalizeCallbackTy ExitCB) {
845   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
846          "Unexpected cancellation!");
847 
848   // For a cancel barrier we create two new blocks.
849   BasicBlock *BB = Builder.GetInsertBlock();
850   BasicBlock *NonCancellationBlock;
851   if (Builder.GetInsertPoint() == BB->end()) {
852     // TODO: This branch will not be needed once we moved to the
853     // OpenMPIRBuilder codegen completely.
854     NonCancellationBlock = BasicBlock::Create(
855         BB->getContext(), BB->getName() + ".cont", BB->getParent());
856   } else {
857     NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
858     BB->getTerminator()->eraseFromParent();
859     Builder.SetInsertPoint(BB);
860   }
861   BasicBlock *CancellationBlock = BasicBlock::Create(
862       BB->getContext(), BB->getName() + ".cncl", BB->getParent());
863 
864   // Jump to them based on the return value.
865   Value *Cmp = Builder.CreateIsNull(CancelFlag);
866   Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
867                        /* TODO weight */ nullptr, nullptr);
868 
869   // From the cancellation block we finalize all variables and go to the
870   // post finalization block that is known to the FiniCB callback.
871   Builder.SetInsertPoint(CancellationBlock);
872   if (ExitCB)
873     ExitCB(Builder.saveIP());
874   auto &FI = FinalizationStack.back();
875   FI.FiniCB(Builder.saveIP());
876 
877   // The continuation block is where code generation continues.
878   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
879 }
880 
createParallel(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,Value * IfCondition,Value * NumThreads,omp::ProcBindKind ProcBind,bool IsCancellable)881 IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
882     const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
883     BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
884     FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
885     omp::ProcBindKind ProcBind, bool IsCancellable) {
886   assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
887 
888   if (!updateToLocation(Loc))
889     return Loc.IP;
890 
891   uint32_t SrcLocStrSize;
892   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
893   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
894   Value *ThreadID = getOrCreateThreadID(Ident);
895 
896   if (NumThreads) {
897     // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
898     Value *Args[] = {
899         Ident, ThreadID,
900         Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
901     Builder.CreateCall(
902         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
903   }
904 
905   if (ProcBind != OMP_PROC_BIND_default) {
906     // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
907     Value *Args[] = {
908         Ident, ThreadID,
909         ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
910     Builder.CreateCall(
911         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
912   }
913 
914   BasicBlock *InsertBB = Builder.GetInsertBlock();
915   Function *OuterFn = InsertBB->getParent();
916 
917   // Save the outer alloca block because the insertion iterator may get
918   // invalidated and we still need this later.
919   BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
920 
921   // Vector to remember instructions we used only during the modeling but which
922   // we want to delete at the end.
923   SmallVector<Instruction *, 4> ToBeDeleted;
924 
925   // Change the location to the outer alloca insertion point to create and
926   // initialize the allocas we pass into the parallel region.
927   Builder.restoreIP(OuterAllocaIP);
928   AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
929   AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr");
930 
931   // We only need TIDAddr and ZeroAddr for modeling purposes to get the
932   // associated arguments in the outlined function, so we delete them later.
933   ToBeDeleted.push_back(TIDAddr);
934   ToBeDeleted.push_back(ZeroAddr);
935 
936   // Create an artificial insertion point that will also ensure the blocks we
937   // are about to split are not degenerated.
938   auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
939 
940   BasicBlock *EntryBB = UI->getParent();
941   BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
942   BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
943   BasicBlock *PRegPreFiniBB =
944       PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
945   BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
946 
947   auto FiniCBWrapper = [&](InsertPointTy IP) {
948     // Hide "open-ended" blocks from the given FiniCB by setting the right jump
949     // target to the region exit block.
950     if (IP.getBlock()->end() == IP.getPoint()) {
951       IRBuilder<>::InsertPointGuard IPG(Builder);
952       Builder.restoreIP(IP);
953       Instruction *I = Builder.CreateBr(PRegExitBB);
954       IP = InsertPointTy(I->getParent(), I->getIterator());
955     }
956     assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
957            IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
958            "Unexpected insertion point for finalization call!");
959     return FiniCB(IP);
960   };
961 
962   FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
963 
964   // Generate the privatization allocas in the block that will become the entry
965   // of the outlined function.
966   Builder.SetInsertPoint(PRegEntryBB->getTerminator());
967   InsertPointTy InnerAllocaIP = Builder.saveIP();
968 
969   AllocaInst *PrivTIDAddr =
970       Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
971   Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
972 
973   // Add some fake uses for OpenMP provided arguments.
974   ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
975   Instruction *ZeroAddrUse =
976       Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
977   ToBeDeleted.push_back(ZeroAddrUse);
978 
979   // EntryBB
980   //   |
981   //   V
982   // PRegionEntryBB         <- Privatization allocas are placed here.
983   //   |
984   //   V
985   // PRegionBodyBB          <- BodeGen is invoked here.
986   //   |
987   //   V
988   // PRegPreFiniBB          <- The block we will start finalization from.
989   //   |
990   //   V
991   // PRegionExitBB          <- A common exit to simplify block collection.
992   //
993 
994   LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
995 
996   // Let the caller create the body.
997   assert(BodyGenCB && "Expected body generation callback!");
998   InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
999   BodyGenCB(InnerAllocaIP, CodeGenIP);
1000 
1001   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
1002   FunctionCallee RTLFn;
1003   if (IfCondition)
1004     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1005   else
1006     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1007 
1008   if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
1009     if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
1010       llvm::LLVMContext &Ctx = F->getContext();
1011       MDBuilder MDB(Ctx);
1012       // Annotate the callback behavior of the __kmpc_fork_call:
1013       //  - The callback callee is argument number 2 (microtask).
1014       //  - The first two arguments of the callback callee are unknown (-1).
1015       //  - All variadic arguments to the __kmpc_fork_call are passed to the
1016       //    callback callee.
1017       F->addMetadata(
1018           llvm::LLVMContext::MD_callback,
1019           *llvm::MDNode::get(
1020               Ctx, {MDB.createCallbackEncoding(2, {-1, -1},
1021                                                /* VarArgsArePassed */ true)}));
1022     }
1023   }
1024 
1025   OutlineInfo OI;
1026   OI.PostOutlineCB = [=](Function &OutlinedFn) {
1027     // Add some known attributes.
1028     OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1029     OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1030     OutlinedFn.addFnAttr(Attribute::NoUnwind);
1031     OutlinedFn.addFnAttr(Attribute::NoRecurse);
1032 
1033     assert(OutlinedFn.arg_size() >= 2 &&
1034            "Expected at least tid and bounded tid as arguments");
1035     unsigned NumCapturedVars =
1036         OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1037 
1038     CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1039     CI->getParent()->setName("omp_parallel");
1040     Builder.SetInsertPoint(CI);
1041 
1042     // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1043     Value *ForkCallArgs[] = {
1044         Ident, Builder.getInt32(NumCapturedVars),
1045         Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};
1046 
1047     SmallVector<Value *, 16> RealArgs;
1048     RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1049     if (IfCondition) {
1050       Value *Cond = Builder.CreateSExtOrTrunc(IfCondition,
1051                                               Type::getInt32Ty(M.getContext()));
1052       RealArgs.push_back(Cond);
1053     }
1054     RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1055 
1056     // __kmpc_fork_call_if always expects a void ptr as the last argument
1057     // If there are no arguments, pass a null pointer.
1058     auto PtrTy = Type::getInt8PtrTy(M.getContext());
1059     if (IfCondition && NumCapturedVars == 0) {
1060       llvm::Value *Void = ConstantPointerNull::get(PtrTy);
1061       RealArgs.push_back(Void);
1062     }
1063     if (IfCondition && RealArgs.back()->getType() != PtrTy)
1064       RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1065 
1066     Builder.CreateCall(RTLFn, RealArgs);
1067 
1068     LLVM_DEBUG(dbgs() << "With fork_call placed: "
1069                       << *Builder.GetInsertBlock()->getParent() << "\n");
1070 
1071     InsertPointTy ExitIP(PRegExitBB, PRegExitBB->end());
1072 
1073     // Initialize the local TID stack location with the argument value.
1074     Builder.SetInsertPoint(PrivTID);
1075     Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1076     Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);
1077 
1078     CI->eraseFromParent();
1079 
1080     for (Instruction *I : ToBeDeleted)
1081       I->eraseFromParent();
1082   };
1083 
1084   // Adjust the finalization stack, verify the adjustment, and call the
1085   // finalize function a last time to finalize values between the pre-fini
1086   // block and the exit block if we left the parallel "the normal way".
1087   auto FiniInfo = FinalizationStack.pop_back_val();
1088   (void)FiniInfo;
1089   assert(FiniInfo.DK == OMPD_parallel &&
1090          "Unexpected finalization stack state!");
1091 
1092   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1093 
1094   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1095   FiniCB(PreFiniIP);
1096 
1097   OI.OuterAllocaBB = OuterAllocaBlock;
1098   OI.EntryBB = PRegEntryBB;
1099   OI.ExitBB = PRegExitBB;
1100 
1101   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1102   SmallVector<BasicBlock *, 32> Blocks;
1103   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1104 
1105   // Ensure a single exit node for the outlined region by creating one.
1106   // We might have multiple incoming edges to the exit now due to finalizations,
1107   // e.g., cancel calls that cause the control flow to leave the region.
1108   BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1109   PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
1110   PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1111   Blocks.push_back(PRegOutlinedExitBB);
1112 
1113   CodeExtractorAnalysisCache CEAC(*OuterFn);
1114   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1115                           /* AggregateArgs */ false,
1116                           /* BlockFrequencyInfo */ nullptr,
1117                           /* BranchProbabilityInfo */ nullptr,
1118                           /* AssumptionCache */ nullptr,
1119                           /* AllowVarArgs */ true,
1120                           /* AllowAlloca */ true,
1121                           /* AllocationBlock */ OuterAllocaBlock,
1122                           /* Suffix */ ".omp_par");
1123 
1124   // Find inputs to, outputs from the code region.
1125   BasicBlock *CommonExit = nullptr;
1126   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1127   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1128   Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
1129 
1130   LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1131 
1132   FunctionCallee TIDRTLFn =
1133       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1134 
1135   auto PrivHelper = [&](Value &V) {
1136     if (&V == TIDAddr || &V == ZeroAddr) {
1137       OI.ExcludeArgsFromAggregate.push_back(&V);
1138       return;
1139     }
1140 
1141     SetVector<Use *> Uses;
1142     for (Use &U : V.uses())
1143       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1144         if (ParallelRegionBlockSet.count(UserI->getParent()))
1145           Uses.insert(&U);
1146 
1147     // __kmpc_fork_call expects extra arguments as pointers. If the input
1148     // already has a pointer type, everything is fine. Otherwise, store the
1149     // value onto stack and load it back inside the to-be-outlined region. This
1150     // will ensure only the pointer will be passed to the function.
1151     // FIXME: if there are more than 15 trailing arguments, they must be
1152     // additionally packed in a struct.
1153     Value *Inner = &V;
1154     if (!V.getType()->isPointerTy()) {
1155       IRBuilder<>::InsertPointGuard Guard(Builder);
1156       LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1157 
1158       Builder.restoreIP(OuterAllocaIP);
1159       Value *Ptr =
1160           Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1161 
1162       // Store to stack at end of the block that currently branches to the entry
1163       // block of the to-be-outlined region.
1164       Builder.SetInsertPoint(InsertBB,
1165                              InsertBB->getTerminator()->getIterator());
1166       Builder.CreateStore(&V, Ptr);
1167 
1168       // Load back next to allocations in the to-be-outlined region.
1169       Builder.restoreIP(InnerAllocaIP);
1170       Inner = Builder.CreateLoad(V.getType(), Ptr);
1171     }
1172 
1173     Value *ReplacementValue = nullptr;
1174     CallInst *CI = dyn_cast<CallInst>(&V);
1175     if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1176       ReplacementValue = PrivTID;
1177     } else {
1178       Builder.restoreIP(
1179           PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1180       assert(ReplacementValue &&
1181              "Expected copy/create callback to set replacement value!");
1182       if (ReplacementValue == &V)
1183         return;
1184     }
1185 
1186     for (Use *UPtr : Uses)
1187       UPtr->set(ReplacementValue);
1188   };
1189 
1190   // Reset the inner alloca insertion as it will be used for loading the values
1191   // wrapped into pointers before passing them into the to-be-outlined region.
1192   // Configure it to insert immediately after the fake use of zero address so
1193   // that they are available in the generated body and so that the
1194   // OpenMP-related values (thread ID and zero address pointers) remain leading
1195   // in the argument list.
1196   InnerAllocaIP = IRBuilder<>::InsertPoint(
1197       ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1198 
1199   // Reset the outer alloca insertion point to the entry of the relevant block
1200   // in case it was invalidated.
1201   OuterAllocaIP = IRBuilder<>::InsertPoint(
1202       OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1203 
1204   for (Value *Input : Inputs) {
1205     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1206     PrivHelper(*Input);
1207   }
1208   LLVM_DEBUG({
1209     for (Value *Output : Outputs)
1210       LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1211   });
1212   assert(Outputs.empty() &&
1213          "OpenMP outlining should not produce live-out values!");
1214 
1215   LLVM_DEBUG(dbgs() << "After  privatization: " << *OuterFn << "\n");
1216   LLVM_DEBUG({
1217     for (auto *BB : Blocks)
1218       dbgs() << " PBR: " << BB->getName() << "\n";
1219   });
1220 
1221   // Register the outlined info.
1222   addOutlineInfo(std::move(OI));
1223 
1224   InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1225   UI->eraseFromParent();
1226 
1227   return AfterIP;
1228 }
1229 
emitFlush(const LocationDescription & Loc)1230 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1231   // Build call void __kmpc_flush(ident_t *loc)
1232   uint32_t SrcLocStrSize;
1233   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1234   Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1235 
1236   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1237 }
1238 
createFlush(const LocationDescription & Loc)1239 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1240   if (!updateToLocation(Loc))
1241     return;
1242   emitFlush(Loc);
1243 }
1244 
emitTaskwaitImpl(const LocationDescription & Loc)1245 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1246   // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1247   // global_tid);
1248   uint32_t SrcLocStrSize;
1249   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1250   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1251   Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1252 
1253   // Ignore return result until untied tasks are supported.
1254   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1255                      Args);
1256 }
1257 
createTaskwait(const LocationDescription & Loc)1258 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1259   if (!updateToLocation(Loc))
1260     return;
1261   emitTaskwaitImpl(Loc);
1262 }
1263 
emitTaskyieldImpl(const LocationDescription & Loc)1264 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1265   // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1266   uint32_t SrcLocStrSize;
1267   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1268   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1269   Constant *I32Null = ConstantInt::getNullValue(Int32);
1270   Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1271 
1272   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1273                      Args);
1274 }
1275 
createTaskyield(const LocationDescription & Loc)1276 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1277   if (!updateToLocation(Loc))
1278     return;
1279   emitTaskyieldImpl(Loc);
1280 }
1281 
1282 OpenMPIRBuilder::InsertPointTy
createTask(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB,bool Tied,Value * Final,Value * IfCondition,SmallVector<DependData> Dependencies)1283 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1284                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1285                             bool Tied, Value *Final, Value *IfCondition,
1286                             SmallVector<DependData> Dependencies) {
1287   if (!updateToLocation(Loc))
1288     return InsertPointTy();
1289 
1290   uint32_t SrcLocStrSize;
1291   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1292   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1293   // The current basic block is split into four basic blocks. After outlining,
1294   // they will be mapped as follows:
1295   // ```
1296   // def current_fn() {
1297   //   current_basic_block:
1298   //     br label %task.exit
1299   //   task.exit:
1300   //     ; instructions after task
1301   // }
1302   // def outlined_fn() {
1303   //   task.alloca:
1304   //     br label %task.body
1305   //   task.body:
1306   //     ret void
1307   // }
1308   // ```
1309   BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1310   BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1311   BasicBlock *TaskAllocaBB =
1312       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1313 
1314   OutlineInfo OI;
1315   OI.EntryBB = TaskAllocaBB;
1316   OI.OuterAllocaBB = AllocaIP.getBlock();
1317   OI.ExitBB = TaskExitBB;
1318   OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
1319                       Dependencies](Function &OutlinedFn) {
1320     // The input IR here looks like the following-
1321     // ```
1322     // func @current_fn() {
1323     //   outlined_fn(%args)
1324     // }
1325     // func @outlined_fn(%args) { ... }
1326     // ```
1327     //
1328     // This is changed to the following-
1329     //
1330     // ```
1331     // func @current_fn() {
1332     //   runtime_call(..., wrapper_fn, ...)
1333     // }
1334     // func @wrapper_fn(..., %args) {
1335     //   outlined_fn(%args)
1336     // }
1337     // func @outlined_fn(%args) { ... }
1338     // ```
1339 
1340     // The stale call instruction will be replaced with a new call instruction
1341     // for runtime call with a wrapper function.
1342     assert(OutlinedFn.getNumUses() == 1 &&
1343            "there must be a single user for the outlined function");
1344     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1345 
1346     // HasTaskData is true if any variables are captured in the outlined region,
1347     // false otherwise.
1348     bool HasTaskData = StaleCI->arg_size() > 0;
1349     Builder.SetInsertPoint(StaleCI);
1350 
1351     // Gather the arguments for emitting the runtime call for
1352     // @__kmpc_omp_task_alloc
1353     Function *TaskAllocFn =
1354         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1355 
1356     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1357     // call.
1358     Value *ThreadID = getOrCreateThreadID(Ident);
1359 
1360     // Argument - `flags`
1361     // Task is tied iff (Flags & 1) == 1.
1362     // Task is untied iff (Flags & 1) == 0.
1363     // Task is final iff (Flags & 2) == 2.
1364     // Task is not final iff (Flags & 2) == 0.
1365     // TODO: Handle the other flags.
1366     Value *Flags = Builder.getInt32(Tied);
1367     if (Final) {
1368       Value *FinalFlag =
1369           Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1370       Flags = Builder.CreateOr(FinalFlag, Flags);
1371     }
1372 
1373     // Argument - `sizeof_kmp_task_t` (TaskSize)
1374     // Tasksize refers to the size in bytes of kmp_task_t data structure
1375     // including private vars accessed in task.
1376     Value *TaskSize = Builder.getInt64(0);
1377     if (HasTaskData) {
1378       AllocaInst *ArgStructAlloca =
1379           dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
1380       assert(ArgStructAlloca &&
1381              "Unable to find the alloca instruction corresponding to arguments "
1382              "for extracted function");
1383       StructType *ArgStructType =
1384           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
1385       assert(ArgStructType && "Unable to find struct type corresponding to "
1386                               "arguments for extracted function");
1387       TaskSize =
1388           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
1389     }
1390 
1391     // TODO: Argument - sizeof_shareds
1392 
1393     // Argument - task_entry (the wrapper function)
1394     // If the outlined function has some captured variables (i.e. HasTaskData is
1395     // true), then the wrapper function will have an additional argument (the
1396     // struct containing captured variables). Otherwise, no such argument will
1397     // be present.
1398     SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
1399     if (HasTaskData)
1400       WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
1401     FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
1402         (Twine(OutlinedFn.getName()) + ".wrapper").str(),
1403         FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
1404     Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
1405     PointerType *WrapperFuncBitcastType =
1406         FunctionType::get(Builder.getInt32Ty(),
1407                           {Builder.getInt32Ty(), Builder.getInt8PtrTy()}, false)
1408             ->getPointerTo();
1409     Value *WrapperFuncBitcast =
1410         ConstantExpr::getBitCast(WrapperFunc, WrapperFuncBitcastType);
1411 
1412     // Emit the @__kmpc_omp_task_alloc runtime call
1413     // The runtime call returns a pointer to an area where the task captured
1414     // variables must be copied before the task is run (NewTaskData)
1415     CallInst *NewTaskData = Builder.CreateCall(
1416         TaskAllocFn,
1417         {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1418          /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0),
1419          /*task_func=*/WrapperFuncBitcast});
1420 
1421     // Copy the arguments for outlined function
1422     if (HasTaskData) {
1423       Value *TaskData = StaleCI->getArgOperand(0);
1424       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1425       Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment,
1426                            TaskSize);
1427     }
1428 
1429     Value *DepArrayPtr = nullptr;
1430     if (Dependencies.size()) {
1431       InsertPointTy OldIP = Builder.saveIP();
1432       Builder.SetInsertPoint(
1433           &OldIP.getBlock()->getParent()->getEntryBlock().back());
1434 
1435       Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1436       Value *DepArray =
1437           Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1438 
1439       unsigned P = 0;
1440       for (const DependData &Dep : Dependencies) {
1441         Value *Base =
1442             Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
1443         // Store the pointer to the variable
1444         Value *Addr = Builder.CreateStructGEP(
1445             DependInfo, Base,
1446             static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1447         Value *DepValPtr =
1448             Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1449         Builder.CreateStore(DepValPtr, Addr);
1450         // Store the size of the variable
1451         Value *Size = Builder.CreateStructGEP(
1452             DependInfo, Base,
1453             static_cast<unsigned int>(RTLDependInfoFields::Len));
1454         Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
1455                                 Dep.DepValueType)),
1456                             Size);
1457         // Store the dependency kind
1458         Value *Flags = Builder.CreateStructGEP(
1459             DependInfo, Base,
1460             static_cast<unsigned int>(RTLDependInfoFields::Flags));
1461         Builder.CreateStore(
1462             ConstantInt::get(Builder.getInt8Ty(),
1463                              static_cast<unsigned int>(Dep.DepKind)),
1464             Flags);
1465         ++P;
1466       }
1467 
1468       DepArrayPtr = Builder.CreateBitCast(DepArray, Builder.getInt8PtrTy());
1469       Builder.restoreIP(OldIP);
1470     }
1471 
1472     // In the presence of the `if` clause, the following IR is generated:
1473     //    ...
1474     //    %data = call @__kmpc_omp_task_alloc(...)
1475     //    br i1 %if_condition, label %then, label %else
1476     //  then:
1477     //    call @__kmpc_omp_task(...)
1478     //    br label %exit
1479     //  else:
1480     //    call @__kmpc_omp_task_begin_if0(...)
1481     //    call @wrapper_fn(...)
1482     //    call @__kmpc_omp_task_complete_if0(...)
1483     //    br label %exit
1484     //  exit:
1485     //    ...
1486     if (IfCondition) {
1487       // `SplitBlockAndInsertIfThenElse` requires the block to have a
1488       // terminator.
1489       BasicBlock *NewBasicBlock =
1490           splitBB(Builder, /*CreateBranch=*/true, "if.end");
1491       Instruction *IfTerminator =
1492           NewBasicBlock->getSinglePredecessor()->getTerminator();
1493       Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
1494       Builder.SetInsertPoint(IfTerminator);
1495       SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
1496                                     &ElseTI);
1497       Builder.SetInsertPoint(ElseTI);
1498       Function *TaskBeginFn =
1499           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
1500       Function *TaskCompleteFn =
1501           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
1502       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData});
1503       if (HasTaskData)
1504         Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData});
1505       else
1506         Builder.CreateCall(WrapperFunc, {ThreadID});
1507       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
1508       Builder.SetInsertPoint(ThenTI);
1509     }
1510 
1511     if (Dependencies.size()) {
1512       Function *TaskFn =
1513           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
1514       Builder.CreateCall(
1515           TaskFn,
1516           {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()),
1517            DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0),
1518            ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))});
1519 
1520     } else {
1521       // Emit the @__kmpc_omp_task runtime call to spawn the task
1522       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
1523       Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
1524     }
1525 
1526     StaleCI->eraseFromParent();
1527 
1528     // Emit the body for wrapper function
1529     BasicBlock *WrapperEntryBB =
1530         BasicBlock::Create(M.getContext(), "", WrapperFunc);
1531     Builder.SetInsertPoint(WrapperEntryBB);
1532     if (HasTaskData)
1533       Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)});
1534     else
1535       Builder.CreateCall(&OutlinedFn);
1536     Builder.CreateRet(Builder.getInt32(0));
1537   };
1538 
1539   addOutlineInfo(std::move(OI));
1540 
1541   InsertPointTy TaskAllocaIP =
1542       InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1543   InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1544   BodyGenCB(TaskAllocaIP, TaskBodyIP);
1545   Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
1546 
1547   return Builder.saveIP();
1548 }
1549 
1550 OpenMPIRBuilder::InsertPointTy
createTaskgroup(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB)1551 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
1552                                  InsertPointTy AllocaIP,
1553                                  BodyGenCallbackTy BodyGenCB) {
1554   if (!updateToLocation(Loc))
1555     return InsertPointTy();
1556 
1557   uint32_t SrcLocStrSize;
1558   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1559   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1560   Value *ThreadID = getOrCreateThreadID(Ident);
1561 
1562   // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
1563   Function *TaskgroupFn =
1564       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
1565   Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
1566 
1567   BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
1568   BodyGenCB(AllocaIP, Builder.saveIP());
1569 
1570   Builder.SetInsertPoint(TaskgroupExitBB);
1571   // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
1572   Function *EndTaskgroupFn =
1573       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
1574   Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
1575 
1576   return Builder.saveIP();
1577 }
1578 
createSections(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<StorableBodyGenCallbackTy> SectionCBs,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,bool IsCancellable,bool IsNowait)1579 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
1580     const LocationDescription &Loc, InsertPointTy AllocaIP,
1581     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
1582     FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
1583   assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
1584 
1585   if (!updateToLocation(Loc))
1586     return Loc.IP;
1587 
1588   auto FiniCBWrapper = [&](InsertPointTy IP) {
1589     if (IP.getBlock()->end() != IP.getPoint())
1590       return FiniCB(IP);
1591     // This must be done otherwise any nested constructs using FinalizeOMPRegion
1592     // will fail because that function requires the Finalization Basic Block to
1593     // have a terminator, which is already removed by EmitOMPRegionBody.
1594     // IP is currently at cancelation block.
1595     // We need to backtrack to the condition block to fetch
1596     // the exit block and create a branch from cancelation
1597     // to exit block.
1598     IRBuilder<>::InsertPointGuard IPG(Builder);
1599     Builder.restoreIP(IP);
1600     auto *CaseBB = IP.getBlock()->getSinglePredecessor();
1601     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1602     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1603     Instruction *I = Builder.CreateBr(ExitBB);
1604     IP = InsertPointTy(I->getParent(), I->getIterator());
1605     return FiniCB(IP);
1606   };
1607 
1608   FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
1609 
1610   // Each section is emitted as a switch case
1611   // Each finalization callback is handled from clang.EmitOMPSectionDirective()
1612   // -> OMP.createSection() which generates the IR for each section
1613   // Iterate through all sections and emit a switch construct:
1614   // switch (IV) {
1615   //   case 0:
1616   //     <SectionStmt[0]>;
1617   //     break;
1618   // ...
1619   //   case <NumSection> - 1:
1620   //     <SectionStmt[<NumSection> - 1]>;
1621   //     break;
1622   // }
1623   // ...
1624   // section_loop.after:
1625   // <FiniCB>;
1626   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
1627     Builder.restoreIP(CodeGenIP);
1628     BasicBlock *Continue =
1629         splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
1630     Function *CurFn = Continue->getParent();
1631     SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
1632 
1633     unsigned CaseNumber = 0;
1634     for (auto SectionCB : SectionCBs) {
1635       BasicBlock *CaseBB = BasicBlock::Create(
1636           M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
1637       SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
1638       Builder.SetInsertPoint(CaseBB);
1639       BranchInst *CaseEndBr = Builder.CreateBr(Continue);
1640       SectionCB(InsertPointTy(),
1641                 {CaseEndBr->getParent(), CaseEndBr->getIterator()});
1642       CaseNumber++;
1643     }
1644     // remove the existing terminator from body BB since there can be no
1645     // terminators after switch/case
1646   };
1647   // Loop body ends here
1648   // LowerBound, UpperBound, and STride for createCanonicalLoop
1649   Type *I32Ty = Type::getInt32Ty(M.getContext());
1650   Value *LB = ConstantInt::get(I32Ty, 0);
1651   Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
1652   Value *ST = ConstantInt::get(I32Ty, 1);
1653   llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
1654       Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
1655   InsertPointTy AfterIP =
1656       applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, !IsNowait);
1657 
1658   // Apply the finalization callback in LoopAfterBB
1659   auto FiniInfo = FinalizationStack.pop_back_val();
1660   assert(FiniInfo.DK == OMPD_sections &&
1661          "Unexpected finalization stack state!");
1662   if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
1663     Builder.restoreIP(AfterIP);
1664     BasicBlock *FiniBB =
1665         splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
1666     CB(Builder.saveIP());
1667     AfterIP = {FiniBB, FiniBB->begin()};
1668   }
1669 
1670   return AfterIP;
1671 }
1672 
1673 OpenMPIRBuilder::InsertPointTy
createSection(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)1674 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
1675                                BodyGenCallbackTy BodyGenCB,
1676                                FinalizeCallbackTy FiniCB) {
1677   if (!updateToLocation(Loc))
1678     return Loc.IP;
1679 
1680   auto FiniCBWrapper = [&](InsertPointTy IP) {
1681     if (IP.getBlock()->end() != IP.getPoint())
1682       return FiniCB(IP);
1683     // This must be done otherwise any nested constructs using FinalizeOMPRegion
1684     // will fail because that function requires the Finalization Basic Block to
1685     // have a terminator, which is already removed by EmitOMPRegionBody.
1686     // IP is currently at cancelation block.
1687     // We need to backtrack to the condition block to fetch
1688     // the exit block and create a branch from cancelation
1689     // to exit block.
1690     IRBuilder<>::InsertPointGuard IPG(Builder);
1691     Builder.restoreIP(IP);
1692     auto *CaseBB = Loc.IP.getBlock();
1693     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1694     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1695     Instruction *I = Builder.CreateBr(ExitBB);
1696     IP = InsertPointTy(I->getParent(), I->getIterator());
1697     return FiniCB(IP);
1698   };
1699 
1700   Directive OMPD = Directive::OMPD_sections;
1701   // Since we are using Finalization Callback here, HasFinalize
1702   // and IsCancellable have to be true
1703   return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
1704                               /*Conditional*/ false, /*hasFinalize*/ true,
1705                               /*IsCancellable*/ true);
1706 }
1707 
1708 /// Create a function with a unique name and a "void (i8*, i8*)" signature in
1709 /// the given module and return it.
getFreshReductionFunc(Module & M)1710 Function *getFreshReductionFunc(Module &M) {
1711   Type *VoidTy = Type::getVoidTy(M.getContext());
1712   Type *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
1713   auto *FuncTy =
1714       FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
1715   return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
1716                           M.getDataLayout().getDefaultGlobalsAddressSpace(),
1717                           ".omp.reduction.func", &M);
1718 }
1719 
createReductions(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<ReductionInfo> ReductionInfos,bool IsNoWait)1720 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
1721     const LocationDescription &Loc, InsertPointTy AllocaIP,
1722     ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
1723   for (const ReductionInfo &RI : ReductionInfos) {
1724     (void)RI;
1725     assert(RI.Variable && "expected non-null variable");
1726     assert(RI.PrivateVariable && "expected non-null private variable");
1727     assert(RI.ReductionGen && "expected non-null reduction generator callback");
1728     assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
1729            "expected variables and their private equivalents to have the same "
1730            "type");
1731     assert(RI.Variable->getType()->isPointerTy() &&
1732            "expected variables to be pointers");
1733   }
1734 
1735   if (!updateToLocation(Loc))
1736     return InsertPointTy();
1737 
1738   BasicBlock *InsertBlock = Loc.IP.getBlock();
1739   BasicBlock *ContinuationBlock =
1740       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
1741   InsertBlock->getTerminator()->eraseFromParent();
1742 
1743   // Create and populate array of type-erased pointers to private reduction
1744   // values.
1745   unsigned NumReductions = ReductionInfos.size();
1746   Type *RedArrayTy = ArrayType::get(Builder.getInt8PtrTy(), NumReductions);
1747   Builder.restoreIP(AllocaIP);
1748   Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
1749 
1750   Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
1751 
1752   for (auto En : enumerate(ReductionInfos)) {
1753     unsigned Index = En.index();
1754     const ReductionInfo &RI = En.value();
1755     Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
1756         RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
1757     Value *Casted =
1758         Builder.CreateBitCast(RI.PrivateVariable, Builder.getInt8PtrTy(),
1759                               "private.red.var." + Twine(Index) + ".casted");
1760     Builder.CreateStore(Casted, RedArrayElemPtr);
1761   }
1762 
1763   // Emit a call to the runtime function that orchestrates the reduction.
1764   // Declare the reduction function in the process.
1765   Function *Func = Builder.GetInsertBlock()->getParent();
1766   Module *Module = Func->getParent();
1767   Value *RedArrayPtr =
1768       Builder.CreateBitCast(RedArray, Builder.getInt8PtrTy(), "red.array.ptr");
1769   uint32_t SrcLocStrSize;
1770   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1771   bool CanGenerateAtomic =
1772       llvm::all_of(ReductionInfos, [](const ReductionInfo &RI) {
1773         return RI.AtomicReductionGen;
1774       });
1775   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
1776                                   CanGenerateAtomic
1777                                       ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
1778                                       : IdentFlag(0));
1779   Value *ThreadId = getOrCreateThreadID(Ident);
1780   Constant *NumVariables = Builder.getInt32(NumReductions);
1781   const DataLayout &DL = Module->getDataLayout();
1782   unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
1783   Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
1784   Function *ReductionFunc = getFreshReductionFunc(*Module);
1785   Value *Lock = getOMPCriticalRegionLock(".reduction");
1786   Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
1787       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
1788                : RuntimeFunction::OMPRTL___kmpc_reduce);
1789   CallInst *ReduceCall =
1790       Builder.CreateCall(ReduceFunc,
1791                          {Ident, ThreadId, NumVariables, RedArraySize,
1792                           RedArrayPtr, ReductionFunc, Lock},
1793                          "reduce");
1794 
1795   // Create final reduction entry blocks for the atomic and non-atomic case.
1796   // Emit IR that dispatches control flow to one of the blocks based on the
1797   // reduction supporting the atomic mode.
1798   BasicBlock *NonAtomicRedBlock =
1799       BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
1800   BasicBlock *AtomicRedBlock =
1801       BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
1802   SwitchInst *Switch =
1803       Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
1804   Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
1805   Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
1806 
1807   // Populate the non-atomic reduction using the elementwise reduction function.
1808   // This loads the elements from the global and private variables and reduces
1809   // them before storing back the result to the global variable.
1810   Builder.SetInsertPoint(NonAtomicRedBlock);
1811   for (auto En : enumerate(ReductionInfos)) {
1812     const ReductionInfo &RI = En.value();
1813     Type *ValueType = RI.ElementType;
1814     Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
1815                                          "red.value." + Twine(En.index()));
1816     Value *PrivateRedValue =
1817         Builder.CreateLoad(ValueType, RI.PrivateVariable,
1818                            "red.private.value." + Twine(En.index()));
1819     Value *Reduced;
1820     Builder.restoreIP(
1821         RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced));
1822     if (!Builder.GetInsertBlock())
1823       return InsertPointTy();
1824     Builder.CreateStore(Reduced, RI.Variable);
1825   }
1826   Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
1827       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
1828                : RuntimeFunction::OMPRTL___kmpc_end_reduce);
1829   Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
1830   Builder.CreateBr(ContinuationBlock);
1831 
1832   // Populate the atomic reduction using the atomic elementwise reduction
1833   // function. There are no loads/stores here because they will be happening
1834   // inside the atomic elementwise reduction.
1835   Builder.SetInsertPoint(AtomicRedBlock);
1836   if (CanGenerateAtomic) {
1837     for (const ReductionInfo &RI : ReductionInfos) {
1838       Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
1839                                               RI.Variable, RI.PrivateVariable));
1840       if (!Builder.GetInsertBlock())
1841         return InsertPointTy();
1842     }
1843     Builder.CreateBr(ContinuationBlock);
1844   } else {
1845     Builder.CreateUnreachable();
1846   }
1847 
1848   // Populate the outlined reduction function using the elementwise reduction
1849   // function. Partial values are extracted from the type-erased array of
1850   // pointers to private variables.
1851   BasicBlock *ReductionFuncBlock =
1852       BasicBlock::Create(Module->getContext(), "", ReductionFunc);
1853   Builder.SetInsertPoint(ReductionFuncBlock);
1854   Value *LHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(0),
1855                                              RedArrayTy->getPointerTo());
1856   Value *RHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(1),
1857                                              RedArrayTy->getPointerTo());
1858   for (auto En : enumerate(ReductionInfos)) {
1859     const ReductionInfo &RI = En.value();
1860     Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1861         RedArrayTy, LHSArrayPtr, 0, En.index());
1862     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr);
1863     Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
1864     Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
1865     Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1866         RedArrayTy, RHSArrayPtr, 0, En.index());
1867     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr);
1868     Value *RHSPtr =
1869         Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
1870     Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
1871     Value *Reduced;
1872     Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
1873     if (!Builder.GetInsertBlock())
1874       return InsertPointTy();
1875     Builder.CreateStore(Reduced, LHSPtr);
1876   }
1877   Builder.CreateRetVoid();
1878 
1879   Builder.SetInsertPoint(ContinuationBlock);
1880   return Builder.saveIP();
1881 }
1882 
1883 OpenMPIRBuilder::InsertPointTy
createMaster(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)1884 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
1885                               BodyGenCallbackTy BodyGenCB,
1886                               FinalizeCallbackTy FiniCB) {
1887 
1888   if (!updateToLocation(Loc))
1889     return Loc.IP;
1890 
1891   Directive OMPD = Directive::OMPD_master;
1892   uint32_t SrcLocStrSize;
1893   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1894   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1895   Value *ThreadId = getOrCreateThreadID(Ident);
1896   Value *Args[] = {Ident, ThreadId};
1897 
1898   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
1899   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
1900 
1901   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
1902   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
1903 
1904   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
1905                               /*Conditional*/ true, /*hasFinalize*/ true);
1906 }
1907 
1908 OpenMPIRBuilder::InsertPointTy
createMasked(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,Value * Filter)1909 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
1910                               BodyGenCallbackTy BodyGenCB,
1911                               FinalizeCallbackTy FiniCB, Value *Filter) {
1912   if (!updateToLocation(Loc))
1913     return Loc.IP;
1914 
1915   Directive OMPD = Directive::OMPD_masked;
1916   uint32_t SrcLocStrSize;
1917   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1918   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1919   Value *ThreadId = getOrCreateThreadID(Ident);
1920   Value *Args[] = {Ident, ThreadId, Filter};
1921   Value *ArgsEnd[] = {Ident, ThreadId};
1922 
1923   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
1924   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
1925 
1926   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
1927   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
1928 
1929   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
1930                               /*Conditional*/ true, /*hasFinalize*/ true);
1931 }
1932 
createLoopSkeleton(DebugLoc DL,Value * TripCount,Function * F,BasicBlock * PreInsertBefore,BasicBlock * PostInsertBefore,const Twine & Name)1933 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
1934     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
1935     BasicBlock *PostInsertBefore, const Twine &Name) {
1936   Module *M = F->getParent();
1937   LLVMContext &Ctx = M->getContext();
1938   Type *IndVarTy = TripCount->getType();
1939 
1940   // Create the basic block structure.
1941   BasicBlock *Preheader =
1942       BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
1943   BasicBlock *Header =
1944       BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
1945   BasicBlock *Cond =
1946       BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
1947   BasicBlock *Body =
1948       BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
1949   BasicBlock *Latch =
1950       BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
1951   BasicBlock *Exit =
1952       BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
1953   BasicBlock *After =
1954       BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
1955 
1956   // Use specified DebugLoc for new instructions.
1957   Builder.SetCurrentDebugLocation(DL);
1958 
1959   Builder.SetInsertPoint(Preheader);
1960   Builder.CreateBr(Header);
1961 
1962   Builder.SetInsertPoint(Header);
1963   PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
1964   IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
1965   Builder.CreateBr(Cond);
1966 
1967   Builder.SetInsertPoint(Cond);
1968   Value *Cmp =
1969       Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
1970   Builder.CreateCondBr(Cmp, Body, Exit);
1971 
1972   Builder.SetInsertPoint(Body);
1973   Builder.CreateBr(Latch);
1974 
1975   Builder.SetInsertPoint(Latch);
1976   Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
1977                                   "omp_" + Name + ".next", /*HasNUW=*/true);
1978   Builder.CreateBr(Header);
1979   IndVarPHI->addIncoming(Next, Latch);
1980 
1981   Builder.SetInsertPoint(Exit);
1982   Builder.CreateBr(After);
1983 
1984   // Remember and return the canonical control flow.
1985   LoopInfos.emplace_front();
1986   CanonicalLoopInfo *CL = &LoopInfos.front();
1987 
1988   CL->Header = Header;
1989   CL->Cond = Cond;
1990   CL->Latch = Latch;
1991   CL->Exit = Exit;
1992 
1993 #ifndef NDEBUG
1994   CL->assertOK();
1995 #endif
1996   return CL;
1997 }
1998 
1999 CanonicalLoopInfo *
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * TripCount,const Twine & Name)2000 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
2001                                      LoopBodyGenCallbackTy BodyGenCB,
2002                                      Value *TripCount, const Twine &Name) {
2003   BasicBlock *BB = Loc.IP.getBlock();
2004   BasicBlock *NextBB = BB->getNextNode();
2005 
2006   CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
2007                                              NextBB, NextBB, Name);
2008   BasicBlock *After = CL->getAfter();
2009 
2010   // If location is not set, don't connect the loop.
2011   if (updateToLocation(Loc)) {
2012     // Split the loop at the insertion point: Branch to the preheader and move
2013     // every following instruction to after the loop (the After BB). Also, the
2014     // new successor is the loop's after block.
2015     spliceBB(Builder, After, /*CreateBranch=*/false);
2016     Builder.CreateBr(CL->getPreheader());
2017   }
2018 
2019   // Emit the body content. We do it after connecting the loop to the CFG to
2020   // avoid that the callback encounters degenerate BBs.
2021   BodyGenCB(CL->getBodyIP(), CL->getIndVar());
2022 
2023 #ifndef NDEBUG
2024   CL->assertOK();
2025 #endif
2026   return CL;
2027 }
2028 
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,InsertPointTy ComputeIP,const Twine & Name)2029 CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
2030     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
2031     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
2032     InsertPointTy ComputeIP, const Twine &Name) {
2033 
2034   // Consider the following difficulties (assuming 8-bit signed integers):
2035   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
2036   //      DO I = 1, 100, 50
2037   ///  * A \p Step of INT_MIN cannot not be normalized to a positive direction:
2038   //      DO I = 100, 0, -128
2039 
2040   // Start, Stop and Step must be of the same integer type.
2041   auto *IndVarTy = cast<IntegerType>(Start->getType());
2042   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
2043   assert(IndVarTy == Step->getType() && "Step type mismatch");
2044 
2045   LocationDescription ComputeLoc =
2046       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
2047   updateToLocation(ComputeLoc);
2048 
2049   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
2050   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
2051 
2052   // Like Step, but always positive.
2053   Value *Incr = Step;
2054 
2055   // Distance between Start and Stop; always positive.
2056   Value *Span;
2057 
2058   // Condition whether there are no iterations are executed at all, e.g. because
2059   // UB < LB.
2060   Value *ZeroCmp;
2061 
2062   if (IsSigned) {
2063     // Ensure that increment is positive. If not, negate and invert LB and UB.
2064     Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
2065     Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
2066     Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
2067     Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
2068     Span = Builder.CreateSub(UB, LB, "", false, true);
2069     ZeroCmp = Builder.CreateICmp(
2070         InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
2071   } else {
2072     Span = Builder.CreateSub(Stop, Start, "", true);
2073     ZeroCmp = Builder.CreateICmp(
2074         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
2075   }
2076 
2077   Value *CountIfLooping;
2078   if (InclusiveStop) {
2079     CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
2080   } else {
2081     // Avoid incrementing past stop since it could overflow.
2082     Value *CountIfTwo = Builder.CreateAdd(
2083         Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
2084     Value *OneCmp = Builder.CreateICmp(
2085         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Span, Incr);
2086     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
2087   }
2088   Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
2089                                           "omp_" + Name + ".tripcount");
2090 
2091   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
2092     Builder.restoreIP(CodeGenIP);
2093     Value *Span = Builder.CreateMul(IV, Step);
2094     Value *IndVar = Builder.CreateAdd(Span, Start);
2095     BodyGenCB(Builder.saveIP(), IndVar);
2096   };
2097   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
2098   return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
2099 }
2100 
2101 // Returns an LLVM function to call for initializing loop bounds using OpenMP
2102 // static scheduling depending on `type`. Only i32 and i64 are supported by the
2103 // runtime. Always interpret integers as unsigned similarly to
2104 // CanonicalLoopInfo.
getKmpcForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2105 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
2106                                                   OpenMPIRBuilder &OMPBuilder) {
2107   unsigned Bitwidth = Ty->getIntegerBitWidth();
2108   if (Bitwidth == 32)
2109     return OMPBuilder.getOrCreateRuntimeFunction(
2110         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
2111   if (Bitwidth == 64)
2112     return OMPBuilder.getOrCreateRuntimeFunction(
2113         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
2114   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2115 }
2116 
2117 OpenMPIRBuilder::InsertPointTy
applyStaticWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier)2118 OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
2119                                           InsertPointTy AllocaIP,
2120                                           bool NeedsBarrier) {
2121   assert(CLI->isValid() && "Requires a valid canonical loop");
2122   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
2123          "Require dedicated allocate IP");
2124 
2125   // Set up the source location value for OpenMP runtime.
2126   Builder.restoreIP(CLI->getPreheaderIP());
2127   Builder.SetCurrentDebugLocation(DL);
2128 
2129   uint32_t SrcLocStrSize;
2130   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2131   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2132 
2133   // Declare useful OpenMP runtime functions.
2134   Value *IV = CLI->getIndVar();
2135   Type *IVTy = IV->getType();
2136   FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
2137   FunctionCallee StaticFini =
2138       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2139 
2140   // Allocate space for computed loop bounds as expected by the "init" function.
2141   Builder.restoreIP(AllocaIP);
2142   Type *I32Type = Type::getInt32Ty(M.getContext());
2143   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2144   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
2145   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
2146   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
2147 
2148   // At the end of the preheader, prepare for calling the "init" function by
2149   // storing the current loop bounds into the allocated space. A canonical loop
2150   // always iterates from 0 to trip-count with step 1. Note that "init" expects
2151   // and produces an inclusive upper bound.
2152   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2153   Constant *Zero = ConstantInt::get(IVTy, 0);
2154   Constant *One = ConstantInt::get(IVTy, 1);
2155   Builder.CreateStore(Zero, PLowerBound);
2156   Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
2157   Builder.CreateStore(UpperBound, PUpperBound);
2158   Builder.CreateStore(One, PStride);
2159 
2160   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2161 
2162   Constant *SchedulingType = ConstantInt::get(
2163       I32Type, static_cast<int>(OMPScheduleType::UnorderedStatic));
2164 
2165   // Call the "init" function and update the trip count of the loop with the
2166   // value it produced.
2167   Builder.CreateCall(StaticInit,
2168                      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
2169                       PUpperBound, PStride, One, Zero});
2170   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
2171   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
2172   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
2173   Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
2174   CLI->setTripCount(TripCount);
2175 
2176   // Update all uses of the induction variable except the one in the condition
2177   // block that compares it with the actual upper bound, and the increment in
2178   // the latch block.
2179 
2180   CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
2181     Builder.SetInsertPoint(CLI->getBody(),
2182                            CLI->getBody()->getFirstInsertionPt());
2183     Builder.SetCurrentDebugLocation(DL);
2184     return Builder.CreateAdd(OldIV, LowerBound);
2185   });
2186 
2187   // In the "exit" block, call the "fini" function.
2188   Builder.SetInsertPoint(CLI->getExit(),
2189                          CLI->getExit()->getTerminator()->getIterator());
2190   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2191 
2192   // Add the barrier if requested.
2193   if (NeedsBarrier)
2194     createBarrier(LocationDescription(Builder.saveIP(), DL),
2195                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
2196                   /* CheckCancelFlag */ false);
2197 
2198   InsertPointTy AfterIP = CLI->getAfterIP();
2199   CLI->invalidate();
2200 
2201   return AfterIP;
2202 }
2203 
applyStaticChunkedWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,Value * ChunkSize)2204 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
2205     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2206     bool NeedsBarrier, Value *ChunkSize) {
2207   assert(CLI->isValid() && "Requires a valid canonical loop");
2208   assert(ChunkSize && "Chunk size is required");
2209 
2210   LLVMContext &Ctx = CLI->getFunction()->getContext();
2211   Value *IV = CLI->getIndVar();
2212   Value *OrigTripCount = CLI->getTripCount();
2213   Type *IVTy = IV->getType();
2214   assert(IVTy->getIntegerBitWidth() <= 64 &&
2215          "Max supported tripcount bitwidth is 64 bits");
2216   Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
2217                                                         : Type::getInt64Ty(Ctx);
2218   Type *I32Type = Type::getInt32Ty(M.getContext());
2219   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
2220   Constant *One = ConstantInt::get(InternalIVTy, 1);
2221 
2222   // Declare useful OpenMP runtime functions.
2223   FunctionCallee StaticInit =
2224       getKmpcForStaticInitForType(InternalIVTy, M, *this);
2225   FunctionCallee StaticFini =
2226       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2227 
2228   // Allocate space for computed loop bounds as expected by the "init" function.
2229   Builder.restoreIP(AllocaIP);
2230   Builder.SetCurrentDebugLocation(DL);
2231   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2232   Value *PLowerBound =
2233       Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
2234   Value *PUpperBound =
2235       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
2236   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
2237 
2238   // Set up the source location value for the OpenMP runtime.
2239   Builder.restoreIP(CLI->getPreheaderIP());
2240   Builder.SetCurrentDebugLocation(DL);
2241 
2242   // TODO: Detect overflow in ubsan or max-out with current tripcount.
2243   Value *CastedChunkSize =
2244       Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
2245   Value *CastedTripCount =
2246       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
2247 
2248   Constant *SchedulingType = ConstantInt::get(
2249       I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
2250   Builder.CreateStore(Zero, PLowerBound);
2251   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
2252   Builder.CreateStore(OrigUpperBound, PUpperBound);
2253   Builder.CreateStore(One, PStride);
2254 
2255   // Call the "init" function and update the trip count of the loop with the
2256   // value it produced.
2257   uint32_t SrcLocStrSize;
2258   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2259   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2260   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2261   Builder.CreateCall(StaticInit,
2262                      {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
2263                       /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
2264                       /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
2265                       /*pstride=*/PStride, /*incr=*/One,
2266                       /*chunk=*/CastedChunkSize});
2267 
2268   // Load values written by the "init" function.
2269   Value *FirstChunkStart =
2270       Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
2271   Value *FirstChunkStop =
2272       Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
2273   Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
2274   Value *ChunkRange =
2275       Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
2276   Value *NextChunkStride =
2277       Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
2278 
2279   // Create outer "dispatch" loop for enumerating the chunks.
2280   BasicBlock *DispatchEnter = splitBB(Builder, true);
2281   Value *DispatchCounter;
2282   CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
2283       {Builder.saveIP(), DL},
2284       [&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
2285       FirstChunkStart, CastedTripCount, NextChunkStride,
2286       /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
2287       "dispatch");
2288 
2289   // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
2290   // not have to preserve the canonical invariant.
2291   BasicBlock *DispatchBody = DispatchCLI->getBody();
2292   BasicBlock *DispatchLatch = DispatchCLI->getLatch();
2293   BasicBlock *DispatchExit = DispatchCLI->getExit();
2294   BasicBlock *DispatchAfter = DispatchCLI->getAfter();
2295   DispatchCLI->invalidate();
2296 
2297   // Rewire the original loop to become the chunk loop inside the dispatch loop.
2298   redirectTo(DispatchAfter, CLI->getAfter(), DL);
2299   redirectTo(CLI->getExit(), DispatchLatch, DL);
2300   redirectTo(DispatchBody, DispatchEnter, DL);
2301 
2302   // Prepare the prolog of the chunk loop.
2303   Builder.restoreIP(CLI->getPreheaderIP());
2304   Builder.SetCurrentDebugLocation(DL);
2305 
2306   // Compute the number of iterations of the chunk loop.
2307   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2308   Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
2309   Value *IsLastChunk =
2310       Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
2311   Value *CountUntilOrigTripCount =
2312       Builder.CreateSub(CastedTripCount, DispatchCounter);
2313   Value *ChunkTripCount = Builder.CreateSelect(
2314       IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
2315   Value *BackcastedChunkTC =
2316       Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
2317   CLI->setTripCount(BackcastedChunkTC);
2318 
2319   // Update all uses of the induction variable except the one in the condition
2320   // block that compares it with the actual upper bound, and the increment in
2321   // the latch block.
2322   Value *BackcastedDispatchCounter =
2323       Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
2324   CLI->mapIndVar([&](Instruction *) -> Value * {
2325     Builder.restoreIP(CLI->getBodyIP());
2326     return Builder.CreateAdd(IV, BackcastedDispatchCounter);
2327   });
2328 
2329   // In the "exit" block, call the "fini" function.
2330   Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
2331   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2332 
2333   // Add the barrier if requested.
2334   if (NeedsBarrier)
2335     createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
2336                   /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
2337 
2338 #ifndef NDEBUG
2339   // Even though we currently do not support applying additional methods to it,
2340   // the chunk loop should remain a canonical loop.
2341   CLI->assertOK();
2342 #endif
2343 
2344   return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
2345 }
2346 
applyWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,llvm::omp::ScheduleKind SchedKind,llvm::Value * ChunkSize,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)2347 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
2348     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2349     bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
2350     llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
2351     bool HasNonmonotonicModifier, bool HasOrderedClause) {
2352   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
2353       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
2354       HasNonmonotonicModifier, HasOrderedClause);
2355 
2356   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
2357                    OMPScheduleType::ModifierOrdered;
2358   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
2359   case OMPScheduleType::BaseStatic:
2360     assert(!ChunkSize && "No chunk size with static-chunked schedule");
2361     if (IsOrdered)
2362       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2363                                        NeedsBarrier, ChunkSize);
2364     // FIXME: Monotonicity ignored?
2365     return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
2366 
2367   case OMPScheduleType::BaseStaticChunked:
2368     if (IsOrdered)
2369       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2370                                        NeedsBarrier, ChunkSize);
2371     // FIXME: Monotonicity ignored?
2372     return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
2373                                            ChunkSize);
2374 
2375   case OMPScheduleType::BaseRuntime:
2376   case OMPScheduleType::BaseAuto:
2377   case OMPScheduleType::BaseGreedy:
2378   case OMPScheduleType::BaseBalanced:
2379   case OMPScheduleType::BaseSteal:
2380   case OMPScheduleType::BaseGuidedSimd:
2381   case OMPScheduleType::BaseRuntimeSimd:
2382     assert(!ChunkSize &&
2383            "schedule type does not support user-defined chunk sizes");
2384     LLVM_FALLTHROUGH;
2385   case OMPScheduleType::BaseDynamicChunked:
2386   case OMPScheduleType::BaseGuidedChunked:
2387   case OMPScheduleType::BaseGuidedIterativeChunked:
2388   case OMPScheduleType::BaseGuidedAnalyticalChunked:
2389   case OMPScheduleType::BaseStaticBalancedChunked:
2390     return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2391                                      NeedsBarrier, ChunkSize);
2392 
2393   default:
2394     llvm_unreachable("Unknown/unimplemented schedule kind");
2395   }
2396 }
2397 
2398 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
2399 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2400 /// the runtime. Always interpret integers as unsigned similarly to
2401 /// CanonicalLoopInfo.
2402 static FunctionCallee
getKmpcForDynamicInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2403 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2404   unsigned Bitwidth = Ty->getIntegerBitWidth();
2405   if (Bitwidth == 32)
2406     return OMPBuilder.getOrCreateRuntimeFunction(
2407         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
2408   if (Bitwidth == 64)
2409     return OMPBuilder.getOrCreateRuntimeFunction(
2410         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
2411   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2412 }
2413 
2414 /// Returns an LLVM function to call for updating the next loop using OpenMP
2415 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2416 /// the runtime. Always interpret integers as unsigned similarly to
2417 /// CanonicalLoopInfo.
2418 static FunctionCallee
getKmpcForDynamicNextForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2419 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2420   unsigned Bitwidth = Ty->getIntegerBitWidth();
2421   if (Bitwidth == 32)
2422     return OMPBuilder.getOrCreateRuntimeFunction(
2423         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
2424   if (Bitwidth == 64)
2425     return OMPBuilder.getOrCreateRuntimeFunction(
2426         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
2427   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2428 }
2429 
2430 /// Returns an LLVM function to call for finalizing the dynamic loop using
2431 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
2432 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
2433 static FunctionCallee
getKmpcForDynamicFiniForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2434 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2435   unsigned Bitwidth = Ty->getIntegerBitWidth();
2436   if (Bitwidth == 32)
2437     return OMPBuilder.getOrCreateRuntimeFunction(
2438         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
2439   if (Bitwidth == 64)
2440     return OMPBuilder.getOrCreateRuntimeFunction(
2441         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
2442   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2443 }
2444 
applyDynamicWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,OMPScheduleType SchedType,bool NeedsBarrier,Value * Chunk)2445 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
2446     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2447     OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
2448   assert(CLI->isValid() && "Requires a valid canonical loop");
2449   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
2450          "Require dedicated allocate IP");
2451   assert(isValidWorkshareLoopScheduleType(SchedType) &&
2452          "Require valid schedule type");
2453 
2454   bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
2455                  OMPScheduleType::ModifierOrdered;
2456 
2457   // Set up the source location value for OpenMP runtime.
2458   Builder.SetCurrentDebugLocation(DL);
2459 
2460   uint32_t SrcLocStrSize;
2461   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2462   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2463 
2464   // Declare useful OpenMP runtime functions.
2465   Value *IV = CLI->getIndVar();
2466   Type *IVTy = IV->getType();
2467   FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
2468   FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
2469 
2470   // Allocate space for computed loop bounds as expected by the "init" function.
2471   Builder.restoreIP(AllocaIP);
2472   Type *I32Type = Type::getInt32Ty(M.getContext());
2473   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2474   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
2475   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
2476   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
2477 
2478   // At the end of the preheader, prepare for calling the "init" function by
2479   // storing the current loop bounds into the allocated space. A canonical loop
2480   // always iterates from 0 to trip-count with step 1. Note that "init" expects
2481   // and produces an inclusive upper bound.
2482   BasicBlock *PreHeader = CLI->getPreheader();
2483   Builder.SetInsertPoint(PreHeader->getTerminator());
2484   Constant *One = ConstantInt::get(IVTy, 1);
2485   Builder.CreateStore(One, PLowerBound);
2486   Value *UpperBound = CLI->getTripCount();
2487   Builder.CreateStore(UpperBound, PUpperBound);
2488   Builder.CreateStore(One, PStride);
2489 
2490   BasicBlock *Header = CLI->getHeader();
2491   BasicBlock *Exit = CLI->getExit();
2492   BasicBlock *Cond = CLI->getCond();
2493   BasicBlock *Latch = CLI->getLatch();
2494   InsertPointTy AfterIP = CLI->getAfterIP();
2495 
2496   // The CLI will be "broken" in the code below, as the loop is no longer
2497   // a valid canonical loop.
2498 
2499   if (!Chunk)
2500     Chunk = One;
2501 
2502   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2503 
2504   Constant *SchedulingType =
2505       ConstantInt::get(I32Type, static_cast<int>(SchedType));
2506 
2507   // Call the "init" function.
2508   Builder.CreateCall(DynamicInit,
2509                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
2510                       UpperBound, /* step */ One, Chunk});
2511 
2512   // An outer loop around the existing one.
2513   BasicBlock *OuterCond = BasicBlock::Create(
2514       PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
2515       PreHeader->getParent());
2516   // This needs to be 32-bit always, so can't use the IVTy Zero above.
2517   Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
2518   Value *Res =
2519       Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
2520                                        PLowerBound, PUpperBound, PStride});
2521   Constant *Zero32 = ConstantInt::get(I32Type, 0);
2522   Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
2523   Value *LowerBound =
2524       Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
2525   Builder.CreateCondBr(MoreWork, Header, Exit);
2526 
2527   // Change PHI-node in loop header to use outer cond rather than preheader,
2528   // and set IV to the LowerBound.
2529   Instruction *Phi = &Header->front();
2530   auto *PI = cast<PHINode>(Phi);
2531   PI->setIncomingBlock(0, OuterCond);
2532   PI->setIncomingValue(0, LowerBound);
2533 
2534   // Then set the pre-header to jump to the OuterCond
2535   Instruction *Term = PreHeader->getTerminator();
2536   auto *Br = cast<BranchInst>(Term);
2537   Br->setSuccessor(0, OuterCond);
2538 
2539   // Modify the inner condition:
2540   // * Use the UpperBound returned from the DynamicNext call.
2541   // * jump to the loop outer loop when done with one of the inner loops.
2542   Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
2543   UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
2544   Instruction *Comp = &*Builder.GetInsertPoint();
2545   auto *CI = cast<CmpInst>(Comp);
2546   CI->setOperand(1, UpperBound);
2547   // Redirect the inner exit to branch to outer condition.
2548   Instruction *Branch = &Cond->back();
2549   auto *BI = cast<BranchInst>(Branch);
2550   assert(BI->getSuccessor(1) == Exit);
2551   BI->setSuccessor(1, OuterCond);
2552 
2553   // Call the "fini" function if "ordered" is present in wsloop directive.
2554   if (Ordered) {
2555     Builder.SetInsertPoint(&Latch->back());
2556     FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
2557     Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
2558   }
2559 
2560   // Add the barrier if requested.
2561   if (NeedsBarrier) {
2562     Builder.SetInsertPoint(&Exit->back());
2563     createBarrier(LocationDescription(Builder.saveIP(), DL),
2564                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
2565                   /* CheckCancelFlag */ false);
2566   }
2567 
2568   CLI->invalidate();
2569   return AfterIP;
2570 }
2571 
2572 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
2573 /// after this \p OldTarget will be orphaned.
redirectAllPredecessorsTo(BasicBlock * OldTarget,BasicBlock * NewTarget,DebugLoc DL)2574 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
2575                                       BasicBlock *NewTarget, DebugLoc DL) {
2576   for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
2577     redirectTo(Pred, NewTarget, DL);
2578 }
2579 
2580 /// Determine which blocks in \p BBs are reachable from outside and remove the
2581 /// ones that are not reachable from the function.
removeUnusedBlocksFromParent(ArrayRef<BasicBlock * > BBs)2582 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
2583   SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
2584   auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
2585     for (Use &U : BB->uses()) {
2586       auto *UseInst = dyn_cast<Instruction>(U.getUser());
2587       if (!UseInst)
2588         continue;
2589       if (BBsToErase.count(UseInst->getParent()))
2590         continue;
2591       return true;
2592     }
2593     return false;
2594   };
2595 
2596   while (true) {
2597     bool Changed = false;
2598     for (BasicBlock *BB : make_early_inc_range(BBsToErase)) {
2599       if (HasRemainingUses(BB)) {
2600         BBsToErase.erase(BB);
2601         Changed = true;
2602       }
2603     }
2604     if (!Changed)
2605       break;
2606   }
2607 
2608   SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
2609   DeleteDeadBlocks(BBVec);
2610 }
2611 
2612 CanonicalLoopInfo *
collapseLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,InsertPointTy ComputeIP)2613 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
2614                                InsertPointTy ComputeIP) {
2615   assert(Loops.size() >= 1 && "At least one loop required");
2616   size_t NumLoops = Loops.size();
2617 
2618   // Nothing to do if there is already just one loop.
2619   if (NumLoops == 1)
2620     return Loops.front();
2621 
2622   CanonicalLoopInfo *Outermost = Loops.front();
2623   CanonicalLoopInfo *Innermost = Loops.back();
2624   BasicBlock *OrigPreheader = Outermost->getPreheader();
2625   BasicBlock *OrigAfter = Outermost->getAfter();
2626   Function *F = OrigPreheader->getParent();
2627 
2628   // Loop control blocks that may become orphaned later.
2629   SmallVector<BasicBlock *, 12> OldControlBBs;
2630   OldControlBBs.reserve(6 * Loops.size());
2631   for (CanonicalLoopInfo *Loop : Loops)
2632     Loop->collectControlBlocks(OldControlBBs);
2633 
2634   // Setup the IRBuilder for inserting the trip count computation.
2635   Builder.SetCurrentDebugLocation(DL);
2636   if (ComputeIP.isSet())
2637     Builder.restoreIP(ComputeIP);
2638   else
2639     Builder.restoreIP(Outermost->getPreheaderIP());
2640 
2641   // Derive the collapsed' loop trip count.
2642   // TODO: Find common/largest indvar type.
2643   Value *CollapsedTripCount = nullptr;
2644   for (CanonicalLoopInfo *L : Loops) {
2645     assert(L->isValid() &&
2646            "All loops to collapse must be valid canonical loops");
2647     Value *OrigTripCount = L->getTripCount();
2648     if (!CollapsedTripCount) {
2649       CollapsedTripCount = OrigTripCount;
2650       continue;
2651     }
2652 
2653     // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
2654     CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
2655                                            {}, /*HasNUW=*/true);
2656   }
2657 
2658   // Create the collapsed loop control flow.
2659   CanonicalLoopInfo *Result =
2660       createLoopSkeleton(DL, CollapsedTripCount, F,
2661                          OrigPreheader->getNextNode(), OrigAfter, "collapsed");
2662 
2663   // Build the collapsed loop body code.
2664   // Start with deriving the input loop induction variables from the collapsed
2665   // one, using a divmod scheme. To preserve the original loops' order, the
2666   // innermost loop use the least significant bits.
2667   Builder.restoreIP(Result->getBodyIP());
2668 
2669   Value *Leftover = Result->getIndVar();
2670   SmallVector<Value *> NewIndVars;
2671   NewIndVars.resize(NumLoops);
2672   for (int i = NumLoops - 1; i >= 1; --i) {
2673     Value *OrigTripCount = Loops[i]->getTripCount();
2674 
2675     Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
2676     NewIndVars[i] = NewIndVar;
2677 
2678     Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
2679   }
2680   // Outermost loop gets all the remaining bits.
2681   NewIndVars[0] = Leftover;
2682 
2683   // Construct the loop body control flow.
2684   // We progressively construct the branch structure following in direction of
2685   // the control flow, from the leading in-between code, the loop nest body, the
2686   // trailing in-between code, and rejoining the collapsed loop's latch.
2687   // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
2688   // the ContinueBlock is set, continue with that block. If ContinuePred, use
2689   // its predecessors as sources.
2690   BasicBlock *ContinueBlock = Result->getBody();
2691   BasicBlock *ContinuePred = nullptr;
2692   auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
2693                                                           BasicBlock *NextSrc) {
2694     if (ContinueBlock)
2695       redirectTo(ContinueBlock, Dest, DL);
2696     else
2697       redirectAllPredecessorsTo(ContinuePred, Dest, DL);
2698 
2699     ContinueBlock = nullptr;
2700     ContinuePred = NextSrc;
2701   };
2702 
2703   // The code before the nested loop of each level.
2704   // Because we are sinking it into the nest, it will be executed more often
2705   // that the original loop. More sophisticated schemes could keep track of what
2706   // the in-between code is and instantiate it only once per thread.
2707   for (size_t i = 0; i < NumLoops - 1; ++i)
2708     ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
2709 
2710   // Connect the loop nest body.
2711   ContinueWith(Innermost->getBody(), Innermost->getLatch());
2712 
2713   // The code after the nested loop at each level.
2714   for (size_t i = NumLoops - 1; i > 0; --i)
2715     ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
2716 
2717   // Connect the finished loop to the collapsed loop latch.
2718   ContinueWith(Result->getLatch(), nullptr);
2719 
2720   // Replace the input loops with the new collapsed loop.
2721   redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
2722   redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
2723 
2724   // Replace the input loop indvars with the derived ones.
2725   for (size_t i = 0; i < NumLoops; ++i)
2726     Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
2727 
2728   // Remove unused parts of the input loops.
2729   removeUnusedBlocksFromParent(OldControlBBs);
2730 
2731   for (CanonicalLoopInfo *L : Loops)
2732     L->invalidate();
2733 
2734 #ifndef NDEBUG
2735   Result->assertOK();
2736 #endif
2737   return Result;
2738 }
2739 
2740 std::vector<CanonicalLoopInfo *>
tileLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,ArrayRef<Value * > TileSizes)2741 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
2742                            ArrayRef<Value *> TileSizes) {
2743   assert(TileSizes.size() == Loops.size() &&
2744          "Must pass as many tile sizes as there are loops");
2745   int NumLoops = Loops.size();
2746   assert(NumLoops >= 1 && "At least one loop to tile required");
2747 
2748   CanonicalLoopInfo *OutermostLoop = Loops.front();
2749   CanonicalLoopInfo *InnermostLoop = Loops.back();
2750   Function *F = OutermostLoop->getBody()->getParent();
2751   BasicBlock *InnerEnter = InnermostLoop->getBody();
2752   BasicBlock *InnerLatch = InnermostLoop->getLatch();
2753 
2754   // Loop control blocks that may become orphaned later.
2755   SmallVector<BasicBlock *, 12> OldControlBBs;
2756   OldControlBBs.reserve(6 * Loops.size());
2757   for (CanonicalLoopInfo *Loop : Loops)
2758     Loop->collectControlBlocks(OldControlBBs);
2759 
2760   // Collect original trip counts and induction variable to be accessible by
2761   // index. Also, the structure of the original loops is not preserved during
2762   // the construction of the tiled loops, so do it before we scavenge the BBs of
2763   // any original CanonicalLoopInfo.
2764   SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
2765   for (CanonicalLoopInfo *L : Loops) {
2766     assert(L->isValid() && "All input loops must be valid canonical loops");
2767     OrigTripCounts.push_back(L->getTripCount());
2768     OrigIndVars.push_back(L->getIndVar());
2769   }
2770 
2771   // Collect the code between loop headers. These may contain SSA definitions
2772   // that are used in the loop nest body. To be usable with in the innermost
2773   // body, these BasicBlocks will be sunk into the loop nest body. That is,
2774   // these instructions may be executed more often than before the tiling.
2775   // TODO: It would be sufficient to only sink them into body of the
2776   // corresponding tile loop.
2777   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
2778   for (int i = 0; i < NumLoops - 1; ++i) {
2779     CanonicalLoopInfo *Surrounding = Loops[i];
2780     CanonicalLoopInfo *Nested = Loops[i + 1];
2781 
2782     BasicBlock *EnterBB = Surrounding->getBody();
2783     BasicBlock *ExitBB = Nested->getHeader();
2784     InbetweenCode.emplace_back(EnterBB, ExitBB);
2785   }
2786 
2787   // Compute the trip counts of the floor loops.
2788   Builder.SetCurrentDebugLocation(DL);
2789   Builder.restoreIP(OutermostLoop->getPreheaderIP());
2790   SmallVector<Value *, 4> FloorCount, FloorRems;
2791   for (int i = 0; i < NumLoops; ++i) {
2792     Value *TileSize = TileSizes[i];
2793     Value *OrigTripCount = OrigTripCounts[i];
2794     Type *IVType = OrigTripCount->getType();
2795 
2796     Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
2797     Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
2798 
2799     // 0 if tripcount divides the tilesize, 1 otherwise.
2800     // 1 means we need an additional iteration for a partial tile.
2801     //
2802     // Unfortunately we cannot just use the roundup-formula
2803     //   (tripcount + tilesize - 1)/tilesize
2804     // because the summation might overflow. We do not want introduce undefined
2805     // behavior when the untiled loop nest did not.
2806     Value *FloorTripOverflow =
2807         Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
2808 
2809     FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
2810     FloorTripCount =
2811         Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
2812                           "omp_floor" + Twine(i) + ".tripcount", true);
2813 
2814     // Remember some values for later use.
2815     FloorCount.push_back(FloorTripCount);
2816     FloorRems.push_back(FloorTripRem);
2817   }
2818 
2819   // Generate the new loop nest, from the outermost to the innermost.
2820   std::vector<CanonicalLoopInfo *> Result;
2821   Result.reserve(NumLoops * 2);
2822 
2823   // The basic block of the surrounding loop that enters the nest generated
2824   // loop.
2825   BasicBlock *Enter = OutermostLoop->getPreheader();
2826 
2827   // The basic block of the surrounding loop where the inner code should
2828   // continue.
2829   BasicBlock *Continue = OutermostLoop->getAfter();
2830 
2831   // Where the next loop basic block should be inserted.
2832   BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
2833 
2834   auto EmbeddNewLoop =
2835       [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
2836           Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
2837     CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
2838         DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
2839     redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
2840     redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
2841 
2842     // Setup the position where the next embedded loop connects to this loop.
2843     Enter = EmbeddedLoop->getBody();
2844     Continue = EmbeddedLoop->getLatch();
2845     OutroInsertBefore = EmbeddedLoop->getLatch();
2846     return EmbeddedLoop;
2847   };
2848 
2849   auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
2850                                                   const Twine &NameBase) {
2851     for (auto P : enumerate(TripCounts)) {
2852       CanonicalLoopInfo *EmbeddedLoop =
2853           EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
2854       Result.push_back(EmbeddedLoop);
2855     }
2856   };
2857 
2858   EmbeddNewLoops(FloorCount, "floor");
2859 
2860   // Within the innermost floor loop, emit the code that computes the tile
2861   // sizes.
2862   Builder.SetInsertPoint(Enter->getTerminator());
2863   SmallVector<Value *, 4> TileCounts;
2864   for (int i = 0; i < NumLoops; ++i) {
2865     CanonicalLoopInfo *FloorLoop = Result[i];
2866     Value *TileSize = TileSizes[i];
2867 
2868     Value *FloorIsEpilogue =
2869         Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
2870     Value *TileTripCount =
2871         Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
2872 
2873     TileCounts.push_back(TileTripCount);
2874   }
2875 
2876   // Create the tile loops.
2877   EmbeddNewLoops(TileCounts, "tile");
2878 
2879   // Insert the inbetween code into the body.
2880   BasicBlock *BodyEnter = Enter;
2881   BasicBlock *BodyEntered = nullptr;
2882   for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
2883     BasicBlock *EnterBB = P.first;
2884     BasicBlock *ExitBB = P.second;
2885 
2886     if (BodyEnter)
2887       redirectTo(BodyEnter, EnterBB, DL);
2888     else
2889       redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
2890 
2891     BodyEnter = nullptr;
2892     BodyEntered = ExitBB;
2893   }
2894 
2895   // Append the original loop nest body into the generated loop nest body.
2896   if (BodyEnter)
2897     redirectTo(BodyEnter, InnerEnter, DL);
2898   else
2899     redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
2900   redirectAllPredecessorsTo(InnerLatch, Continue, DL);
2901 
2902   // Replace the original induction variable with an induction variable computed
2903   // from the tile and floor induction variables.
2904   Builder.restoreIP(Result.back()->getBodyIP());
2905   for (int i = 0; i < NumLoops; ++i) {
2906     CanonicalLoopInfo *FloorLoop = Result[i];
2907     CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
2908     Value *OrigIndVar = OrigIndVars[i];
2909     Value *Size = TileSizes[i];
2910 
2911     Value *Scale =
2912         Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
2913     Value *Shift =
2914         Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
2915     OrigIndVar->replaceAllUsesWith(Shift);
2916   }
2917 
2918   // Remove unused parts of the original loops.
2919   removeUnusedBlocksFromParent(OldControlBBs);
2920 
2921   for (CanonicalLoopInfo *L : Loops)
2922     L->invalidate();
2923 
2924 #ifndef NDEBUG
2925   for (CanonicalLoopInfo *GenL : Result)
2926     GenL->assertOK();
2927 #endif
2928   return Result;
2929 }
2930 
2931 /// Attach metadata \p Properties to the basic block described by \p BB. If the
2932 /// basic block already has metadata, the basic block properties are appended.
addBasicBlockMetadata(BasicBlock * BB,ArrayRef<Metadata * > Properties)2933 static void addBasicBlockMetadata(BasicBlock *BB,
2934                                   ArrayRef<Metadata *> Properties) {
2935   // Nothing to do if no property to attach.
2936   if (Properties.empty())
2937     return;
2938 
2939   LLVMContext &Ctx = BB->getContext();
2940   SmallVector<Metadata *> NewProperties;
2941   NewProperties.push_back(nullptr);
2942 
2943   // If the basic block already has metadata, prepend it to the new metadata.
2944   MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
2945   if (Existing)
2946     append_range(NewProperties, drop_begin(Existing->operands(), 1));
2947 
2948   append_range(NewProperties, Properties);
2949   MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
2950   BasicBlockID->replaceOperandWith(0, BasicBlockID);
2951 
2952   BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
2953 }
2954 
2955 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
2956 /// loop already has metadata, the loop properties are appended.
addLoopMetadata(CanonicalLoopInfo * Loop,ArrayRef<Metadata * > Properties)2957 static void addLoopMetadata(CanonicalLoopInfo *Loop,
2958                             ArrayRef<Metadata *> Properties) {
2959   assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
2960 
2961   // Attach metadata to the loop's latch
2962   BasicBlock *Latch = Loop->getLatch();
2963   assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
2964   addBasicBlockMetadata(Latch, Properties);
2965 }
2966 
2967 /// Attach llvm.access.group metadata to the memref instructions of \p Block
addSimdMetadata(BasicBlock * Block,MDNode * AccessGroup,LoopInfo & LI)2968 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
2969                             LoopInfo &LI) {
2970   for (Instruction &I : *Block) {
2971     if (I.mayReadOrWriteMemory()) {
2972       // TODO: This instruction may already have access group from
2973       // other pragmas e.g. #pragma clang loop vectorize.  Append
2974       // so that the existing metadata is not overwritten.
2975       I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
2976     }
2977   }
2978 }
2979 
unrollLoopFull(DebugLoc,CanonicalLoopInfo * Loop)2980 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
2981   LLVMContext &Ctx = Builder.getContext();
2982   addLoopMetadata(
2983       Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2984              MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
2985 }
2986 
unrollLoopHeuristic(DebugLoc,CanonicalLoopInfo * Loop)2987 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
2988   LLVMContext &Ctx = Builder.getContext();
2989   addLoopMetadata(
2990       Loop, {
2991                 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2992             });
2993 }
2994 
createIfVersion(CanonicalLoopInfo * CanonicalLoop,Value * IfCond,ValueToValueMapTy & VMap,const Twine & NamePrefix)2995 void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
2996                                       Value *IfCond, ValueToValueMapTy &VMap,
2997                                       const Twine &NamePrefix) {
2998   Function *F = CanonicalLoop->getFunction();
2999 
3000   // Define where if branch should be inserted
3001   Instruction *SplitBefore;
3002   if (Instruction::classof(IfCond)) {
3003     SplitBefore = dyn_cast<Instruction>(IfCond);
3004   } else {
3005     SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
3006   }
3007 
3008   // TODO: We should not rely on pass manager. Currently we use pass manager
3009   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
3010   // object. We should have a method  which returns all blocks between
3011   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
3012   FunctionAnalysisManager FAM;
3013   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3014   FAM.registerPass([]() { return LoopAnalysis(); });
3015   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3016 
3017   // Get the loop which needs to be cloned
3018   LoopAnalysis LIA;
3019   LoopInfo &&LI = LIA.run(*F, FAM);
3020   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
3021 
3022   // Create additional blocks for the if statement
3023   BasicBlock *Head = SplitBefore->getParent();
3024   Instruction *HeadOldTerm = Head->getTerminator();
3025   llvm::LLVMContext &C = Head->getContext();
3026   llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
3027       C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
3028   llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
3029       C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
3030 
3031   // Create if condition branch.
3032   Builder.SetInsertPoint(HeadOldTerm);
3033   Instruction *BrInstr =
3034       Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
3035   InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
3036   // Then block contains branch to omp loop which needs to be vectorized
3037   spliceBB(IP, ThenBlock, false);
3038   ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
3039 
3040   Builder.SetInsertPoint(ElseBlock);
3041 
3042   // Clone loop for the else branch
3043   SmallVector<BasicBlock *, 8> NewBlocks;
3044 
3045   VMap[CanonicalLoop->getPreheader()] = ElseBlock;
3046   for (BasicBlock *Block : L->getBlocks()) {
3047     BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
3048     NewBB->moveBefore(CanonicalLoop->getExit());
3049     VMap[Block] = NewBB;
3050     NewBlocks.push_back(NewBB);
3051   }
3052   remapInstructionsInBlocks(NewBlocks, VMap);
3053   Builder.CreateBr(NewBlocks.front());
3054 }
3055 
applySimd(CanonicalLoopInfo * CanonicalLoop,MapVector<Value *,Value * > AlignedVars,Value * IfCond,OrderKind Order,ConstantInt * Simdlen,ConstantInt * Safelen)3056 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
3057                                 MapVector<Value *, Value *> AlignedVars,
3058                                 Value *IfCond, OrderKind Order,
3059                                 ConstantInt *Simdlen, ConstantInt *Safelen) {
3060   LLVMContext &Ctx = Builder.getContext();
3061 
3062   Function *F = CanonicalLoop->getFunction();
3063 
3064   // TODO: We should not rely on pass manager. Currently we use pass manager
3065   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
3066   // object. We should have a method  which returns all blocks between
3067   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
3068   FunctionAnalysisManager FAM;
3069   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3070   FAM.registerPass([]() { return LoopAnalysis(); });
3071   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3072 
3073   LoopAnalysis LIA;
3074   LoopInfo &&LI = LIA.run(*F, FAM);
3075 
3076   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
3077   if (AlignedVars.size()) {
3078     InsertPointTy IP = Builder.saveIP();
3079     Builder.SetInsertPoint(CanonicalLoop->getPreheader()->getTerminator());
3080     for (auto &AlignedItem : AlignedVars) {
3081       Value *AlignedPtr = AlignedItem.first;
3082       Value *Alignment = AlignedItem.second;
3083       Builder.CreateAlignmentAssumption(F->getParent()->getDataLayout(),
3084                                         AlignedPtr, Alignment);
3085     }
3086     Builder.restoreIP(IP);
3087   }
3088 
3089   if (IfCond) {
3090     ValueToValueMapTy VMap;
3091     createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
3092     // Add metadata to the cloned loop which disables vectorization
3093     Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
3094     assert(MappedLatch &&
3095            "Cannot find value which corresponds to original loop latch");
3096     assert(isa<BasicBlock>(MappedLatch) &&
3097            "Cannot cast mapped latch block value to BasicBlock");
3098     BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
3099     ConstantAsMetadata *BoolConst =
3100         ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
3101     addBasicBlockMetadata(
3102         NewLatchBlock,
3103         {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
3104                            BoolConst})});
3105   }
3106 
3107   SmallSet<BasicBlock *, 8> Reachable;
3108 
3109   // Get the basic blocks from the loop in which memref instructions
3110   // can be found.
3111   // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
3112   // preferably without running any passes.
3113   for (BasicBlock *Block : L->getBlocks()) {
3114     if (Block == CanonicalLoop->getCond() ||
3115         Block == CanonicalLoop->getHeader())
3116       continue;
3117     Reachable.insert(Block);
3118   }
3119 
3120   SmallVector<Metadata *> LoopMDList;
3121 
3122   // In presence of finite 'safelen', it may be unsafe to mark all
3123   // the memory instructions parallel, because loop-carried
3124   // dependences of 'safelen' iterations are possible.
3125   // If clause order(concurrent) is specified then the memory instructions
3126   // are marked parallel even if 'safelen' is finite.
3127   if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
3128     // Add access group metadata to memory-access instructions.
3129     MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
3130     for (BasicBlock *BB : Reachable)
3131       addSimdMetadata(BB, AccessGroup, LI);
3132     // TODO:  If the loop has existing parallel access metadata, have
3133     // to combine two lists.
3134     LoopMDList.push_back(MDNode::get(
3135         Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
3136   }
3137 
3138   // Use the above access group metadata to create loop level
3139   // metadata, which should be distinct for each loop.
3140   ConstantAsMetadata *BoolConst =
3141       ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
3142   LoopMDList.push_back(MDNode::get(
3143       Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), BoolConst}));
3144 
3145   if (Simdlen || Safelen) {
3146     // If both simdlen and safelen clauses are specified, the value of the
3147     // simdlen parameter must be less than or equal to the value of the safelen
3148     // parameter. Therefore, use safelen only in the absence of simdlen.
3149     ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
3150     LoopMDList.push_back(
3151         MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
3152                           ConstantAsMetadata::get(VectorizeWidth)}));
3153   }
3154 
3155   addLoopMetadata(CanonicalLoop, LoopMDList);
3156 }
3157 
3158 /// Create the TargetMachine object to query the backend for optimization
3159 /// preferences.
3160 ///
3161 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
3162 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
3163 /// needed for the LLVM pass pipline. We use some default options to avoid
3164 /// having to pass too many settings from the frontend that probably do not
3165 /// matter.
3166 ///
3167 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
3168 /// method. If we are going to use TargetMachine for more purposes, especially
3169 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
3170 /// might become be worth requiring front-ends to pass on their TargetMachine,
3171 /// or at least cache it between methods. Note that while fontends such as Clang
3172 /// have just a single main TargetMachine per translation unit, "target-cpu" and
3173 /// "target-features" that determine the TargetMachine are per-function and can
3174 /// be overrided using __attribute__((target("OPTIONS"))).
3175 static std::unique_ptr<TargetMachine>
createTargetMachine(Function * F,CodeGenOpt::Level OptLevel)3176 createTargetMachine(Function *F, CodeGenOpt::Level OptLevel) {
3177   Module *M = F->getParent();
3178 
3179   StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
3180   StringRef Features = F->getFnAttribute("target-features").getValueAsString();
3181   const std::string &Triple = M->getTargetTriple();
3182 
3183   std::string Error;
3184   const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
3185   if (!TheTarget)
3186     return {};
3187 
3188   llvm::TargetOptions Options;
3189   return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
3190       Triple, CPU, Features, Options, /*RelocModel=*/std::nullopt,
3191       /*CodeModel=*/std::nullopt, OptLevel));
3192 }
3193 
3194 /// Heuristically determine the best-performant unroll factor for \p CLI. This
3195 /// depends on the target processor. We are re-using the same heuristics as the
3196 /// LoopUnrollPass.
computeHeuristicUnrollFactor(CanonicalLoopInfo * CLI)3197 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
3198   Function *F = CLI->getFunction();
3199 
3200   // Assume the user requests the most aggressive unrolling, even if the rest of
3201   // the code is optimized using a lower setting.
3202   CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive;
3203   std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
3204 
3205   FunctionAnalysisManager FAM;
3206   FAM.registerPass([]() { return TargetLibraryAnalysis(); });
3207   FAM.registerPass([]() { return AssumptionAnalysis(); });
3208   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3209   FAM.registerPass([]() { return LoopAnalysis(); });
3210   FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
3211   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3212   TargetIRAnalysis TIRA;
3213   if (TM)
3214     TIRA = TargetIRAnalysis(
3215         [&](const Function &F) { return TM->getTargetTransformInfo(F); });
3216   FAM.registerPass([&]() { return TIRA; });
3217 
3218   TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
3219   ScalarEvolutionAnalysis SEA;
3220   ScalarEvolution &&SE = SEA.run(*F, FAM);
3221   DominatorTreeAnalysis DTA;
3222   DominatorTree &&DT = DTA.run(*F, FAM);
3223   LoopAnalysis LIA;
3224   LoopInfo &&LI = LIA.run(*F, FAM);
3225   AssumptionAnalysis ACT;
3226   AssumptionCache &&AC = ACT.run(*F, FAM);
3227   OptimizationRemarkEmitter ORE{F};
3228 
3229   Loop *L = LI.getLoopFor(CLI->getHeader());
3230   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
3231 
3232   TargetTransformInfo::UnrollingPreferences UP =
3233       gatherUnrollingPreferences(L, SE, TTI,
3234                                  /*BlockFrequencyInfo=*/nullptr,
3235                                  /*ProfileSummaryInfo=*/nullptr, ORE, OptLevel,
3236                                  /*UserThreshold=*/std::nullopt,
3237                                  /*UserCount=*/std::nullopt,
3238                                  /*UserAllowPartial=*/true,
3239                                  /*UserAllowRuntime=*/true,
3240                                  /*UserUpperBound=*/std::nullopt,
3241                                  /*UserFullUnrollMaxCount=*/std::nullopt);
3242 
3243   UP.Force = true;
3244 
3245   // Account for additional optimizations taking place before the LoopUnrollPass
3246   // would unroll the loop.
3247   UP.Threshold *= UnrollThresholdFactor;
3248   UP.PartialThreshold *= UnrollThresholdFactor;
3249 
3250   // Use normal unroll factors even if the rest of the code is optimized for
3251   // size.
3252   UP.OptSizeThreshold = UP.Threshold;
3253   UP.PartialOptSizeThreshold = UP.PartialThreshold;
3254 
3255   LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
3256                     << "  Threshold=" << UP.Threshold << "\n"
3257                     << "  PartialThreshold=" << UP.PartialThreshold << "\n"
3258                     << "  OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
3259                     << "  PartialOptSizeThreshold="
3260                     << UP.PartialOptSizeThreshold << "\n");
3261 
3262   // Disable peeling.
3263   TargetTransformInfo::PeelingPreferences PP =
3264       gatherPeelingPreferences(L, SE, TTI,
3265                                /*UserAllowPeeling=*/false,
3266                                /*UserAllowProfileBasedPeeling=*/false,
3267                                /*UnrollingSpecficValues=*/false);
3268 
3269   SmallPtrSet<const Value *, 32> EphValues;
3270   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
3271 
3272   // Assume that reads and writes to stack variables can be eliminated by
3273   // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
3274   // size.
3275   for (BasicBlock *BB : L->blocks()) {
3276     for (Instruction &I : *BB) {
3277       Value *Ptr;
3278       if (auto *Load = dyn_cast<LoadInst>(&I)) {
3279         Ptr = Load->getPointerOperand();
3280       } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
3281         Ptr = Store->getPointerOperand();
3282       } else
3283         continue;
3284 
3285       Ptr = Ptr->stripPointerCasts();
3286 
3287       if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
3288         if (Alloca->getParent() == &F->getEntryBlock())
3289           EphValues.insert(&I);
3290       }
3291     }
3292   }
3293 
3294   unsigned NumInlineCandidates;
3295   bool NotDuplicatable;
3296   bool Convergent;
3297   InstructionCost LoopSizeIC =
3298       ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent,
3299                           TTI, EphValues, UP.BEInsns);
3300   LLVM_DEBUG(dbgs() << "Estimated loop size is " << LoopSizeIC << "\n");
3301 
3302   // Loop is not unrollable if the loop contains certain instructions.
3303   if (NotDuplicatable || Convergent || !LoopSizeIC.isValid()) {
3304     LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
3305     return 1;
3306   }
3307   unsigned LoopSize = *LoopSizeIC.getValue();
3308 
3309   // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
3310   // be able to use it.
3311   int TripCount = 0;
3312   int MaxTripCount = 0;
3313   bool MaxOrZero = false;
3314   unsigned TripMultiple = 0;
3315 
3316   bool UseUpperBound = false;
3317   computeUnrollCount(L, TTI, DT, &LI, &AC, SE, EphValues, &ORE, TripCount,
3318                      MaxTripCount, MaxOrZero, TripMultiple, LoopSize, UP, PP,
3319                      UseUpperBound);
3320   unsigned Factor = UP.Count;
3321   LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
3322 
3323   // This function returns 1 to signal to not unroll a loop.
3324   if (Factor == 0)
3325     return 1;
3326   return Factor;
3327 }
3328 
unrollLoopPartial(DebugLoc DL,CanonicalLoopInfo * Loop,int32_t Factor,CanonicalLoopInfo ** UnrolledCLI)3329 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
3330                                         int32_t Factor,
3331                                         CanonicalLoopInfo **UnrolledCLI) {
3332   assert(Factor >= 0 && "Unroll factor must not be negative");
3333 
3334   Function *F = Loop->getFunction();
3335   LLVMContext &Ctx = F->getContext();
3336 
3337   // If the unrolled loop is not used for another loop-associated directive, it
3338   // is sufficient to add metadata for the LoopUnrollPass.
3339   if (!UnrolledCLI) {
3340     SmallVector<Metadata *, 2> LoopMetadata;
3341     LoopMetadata.push_back(
3342         MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
3343 
3344     if (Factor >= 1) {
3345       ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3346           ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3347       LoopMetadata.push_back(MDNode::get(
3348           Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
3349     }
3350 
3351     addLoopMetadata(Loop, LoopMetadata);
3352     return;
3353   }
3354 
3355   // Heuristically determine the unroll factor.
3356   if (Factor == 0)
3357     Factor = computeHeuristicUnrollFactor(Loop);
3358 
3359   // No change required with unroll factor 1.
3360   if (Factor == 1) {
3361     *UnrolledCLI = Loop;
3362     return;
3363   }
3364 
3365   assert(Factor >= 2 &&
3366          "unrolling only makes sense with a factor of 2 or larger");
3367 
3368   Type *IndVarTy = Loop->getIndVarType();
3369 
3370   // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
3371   // unroll the inner loop.
3372   Value *FactorVal =
3373       ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
3374                                        /*isSigned=*/false));
3375   std::vector<CanonicalLoopInfo *> LoopNest =
3376       tileLoops(DL, {Loop}, {FactorVal});
3377   assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
3378   *UnrolledCLI = LoopNest[0];
3379   CanonicalLoopInfo *InnerLoop = LoopNest[1];
3380 
3381   // LoopUnrollPass can only fully unroll loops with constant trip count.
3382   // Unroll by the unroll factor with a fallback epilog for the remainder
3383   // iterations if necessary.
3384   ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3385       ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3386   addLoopMetadata(
3387       InnerLoop,
3388       {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3389        MDNode::get(
3390            Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
3391 
3392 #ifndef NDEBUG
3393   (*UnrolledCLI)->assertOK();
3394 #endif
3395 }
3396 
3397 OpenMPIRBuilder::InsertPointTy
createCopyPrivate(const LocationDescription & Loc,llvm::Value * BufSize,llvm::Value * CpyBuf,llvm::Value * CpyFn,llvm::Value * DidIt)3398 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
3399                                    llvm::Value *BufSize, llvm::Value *CpyBuf,
3400                                    llvm::Value *CpyFn, llvm::Value *DidIt) {
3401   if (!updateToLocation(Loc))
3402     return Loc.IP;
3403 
3404   uint32_t SrcLocStrSize;
3405   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3406   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3407   Value *ThreadId = getOrCreateThreadID(Ident);
3408 
3409   llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
3410 
3411   Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
3412 
3413   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
3414   Builder.CreateCall(Fn, Args);
3415 
3416   return Builder.saveIP();
3417 }
3418 
createSingle(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsNowait,llvm::Value * DidIt)3419 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
3420     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3421     FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) {
3422 
3423   if (!updateToLocation(Loc))
3424     return Loc.IP;
3425 
3426   // If needed (i.e. not null), initialize `DidIt` with 0
3427   if (DidIt) {
3428     Builder.CreateStore(Builder.getInt32(0), DidIt);
3429   }
3430 
3431   Directive OMPD = Directive::OMPD_single;
3432   uint32_t SrcLocStrSize;
3433   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3434   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3435   Value *ThreadId = getOrCreateThreadID(Ident);
3436   Value *Args[] = {Ident, ThreadId};
3437 
3438   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
3439   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3440 
3441   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
3442   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3443 
3444   // generates the following:
3445   // if (__kmpc_single()) {
3446   //		.... single region ...
3447   // 		__kmpc_end_single
3448   // }
3449   // __kmpc_barrier
3450 
3451   EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3452                        /*Conditional*/ true,
3453                        /*hasFinalize*/ true);
3454   if (!IsNowait)
3455     createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
3456                   omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
3457                   /* CheckCancelFlag */ false);
3458   return Builder.saveIP();
3459 }
3460 
createCritical(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,StringRef CriticalName,Value * HintInst)3461 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
3462     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3463     FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
3464 
3465   if (!updateToLocation(Loc))
3466     return Loc.IP;
3467 
3468   Directive OMPD = Directive::OMPD_critical;
3469   uint32_t SrcLocStrSize;
3470   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3471   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3472   Value *ThreadId = getOrCreateThreadID(Ident);
3473   Value *LockVar = getOMPCriticalRegionLock(CriticalName);
3474   Value *Args[] = {Ident, ThreadId, LockVar};
3475 
3476   SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
3477   Function *RTFn = nullptr;
3478   if (HintInst) {
3479     // Add Hint to entry Args and create call
3480     EnterArgs.push_back(HintInst);
3481     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
3482   } else {
3483     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
3484   }
3485   Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
3486 
3487   Function *ExitRTLFn =
3488       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
3489   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3490 
3491   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3492                               /*Conditional*/ false, /*hasFinalize*/ true);
3493 }
3494 
3495 OpenMPIRBuilder::InsertPointTy
createOrderedDepend(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumLoops,ArrayRef<llvm::Value * > StoreValues,const Twine & Name,bool IsDependSource)3496 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
3497                                      InsertPointTy AllocaIP, unsigned NumLoops,
3498                                      ArrayRef<llvm::Value *> StoreValues,
3499                                      const Twine &Name, bool IsDependSource) {
3500   assert(
3501       llvm::all_of(StoreValues,
3502                    [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
3503       "OpenMP runtime requires depend vec with i64 type");
3504 
3505   if (!updateToLocation(Loc))
3506     return Loc.IP;
3507 
3508   // Allocate space for vector and generate alloc instruction.
3509   auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
3510   Builder.restoreIP(AllocaIP);
3511   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
3512   ArgsBase->setAlignment(Align(8));
3513   Builder.restoreIP(Loc.IP);
3514 
3515   // Store the index value with offset in depend vector.
3516   for (unsigned I = 0; I < NumLoops; ++I) {
3517     Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
3518         ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
3519     StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
3520     STInst->setAlignment(Align(8));
3521   }
3522 
3523   Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
3524       ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
3525 
3526   uint32_t SrcLocStrSize;
3527   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3528   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3529   Value *ThreadId = getOrCreateThreadID(Ident);
3530   Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
3531 
3532   Function *RTLFn = nullptr;
3533   if (IsDependSource)
3534     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
3535   else
3536     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
3537   Builder.CreateCall(RTLFn, Args);
3538 
3539   return Builder.saveIP();
3540 }
3541 
createOrderedThreadsSimd(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsThreads)3542 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
3543     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3544     FinalizeCallbackTy FiniCB, bool IsThreads) {
3545   if (!updateToLocation(Loc))
3546     return Loc.IP;
3547 
3548   Directive OMPD = Directive::OMPD_ordered;
3549   Instruction *EntryCall = nullptr;
3550   Instruction *ExitCall = nullptr;
3551 
3552   if (IsThreads) {
3553     uint32_t SrcLocStrSize;
3554     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3555     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3556     Value *ThreadId = getOrCreateThreadID(Ident);
3557     Value *Args[] = {Ident, ThreadId};
3558 
3559     Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
3560     EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3561 
3562     Function *ExitRTLFn =
3563         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
3564     ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3565   }
3566 
3567   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3568                               /*Conditional*/ false, /*hasFinalize*/ true);
3569 }
3570 
EmitOMPInlinedRegion(Directive OMPD,Instruction * EntryCall,Instruction * ExitCall,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool Conditional,bool HasFinalize,bool IsCancellable)3571 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
3572     Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
3573     BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
3574     bool HasFinalize, bool IsCancellable) {
3575 
3576   if (HasFinalize)
3577     FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
3578 
3579   // Create inlined region's entry and body blocks, in preparation
3580   // for conditional creation
3581   BasicBlock *EntryBB = Builder.GetInsertBlock();
3582   Instruction *SplitPos = EntryBB->getTerminator();
3583   if (!isa_and_nonnull<BranchInst>(SplitPos))
3584     SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
3585   BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
3586   BasicBlock *FiniBB =
3587       EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
3588 
3589   Builder.SetInsertPoint(EntryBB->getTerminator());
3590   emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
3591 
3592   // generate body
3593   BodyGenCB(/* AllocaIP */ InsertPointTy(),
3594             /* CodeGenIP */ Builder.saveIP());
3595 
3596   // emit exit call and do any needed finalization.
3597   auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
3598   assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
3599          FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
3600          "Unexpected control flow graph state!!");
3601   emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
3602   assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
3603          "Unexpected Control Flow State!");
3604   MergeBlockIntoPredecessor(FiniBB);
3605 
3606   // If we are skipping the region of a non conditional, remove the exit
3607   // block, and clear the builder's insertion point.
3608   assert(SplitPos->getParent() == ExitBB &&
3609          "Unexpected Insertion point location!");
3610   auto merged = MergeBlockIntoPredecessor(ExitBB);
3611   BasicBlock *ExitPredBB = SplitPos->getParent();
3612   auto InsertBB = merged ? ExitPredBB : ExitBB;
3613   if (!isa_and_nonnull<BranchInst>(SplitPos))
3614     SplitPos->eraseFromParent();
3615   Builder.SetInsertPoint(InsertBB);
3616 
3617   return Builder.saveIP();
3618 }
3619 
emitCommonDirectiveEntry(Directive OMPD,Value * EntryCall,BasicBlock * ExitBB,bool Conditional)3620 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
3621     Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
3622   // if nothing to do, Return current insertion point.
3623   if (!Conditional || !EntryCall)
3624     return Builder.saveIP();
3625 
3626   BasicBlock *EntryBB = Builder.GetInsertBlock();
3627   Value *CallBool = Builder.CreateIsNotNull(EntryCall);
3628   auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
3629   auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
3630 
3631   // Emit thenBB and set the Builder's insertion point there for
3632   // body generation next. Place the block after the current block.
3633   Function *CurFn = EntryBB->getParent();
3634   CurFn->insert(std::next(EntryBB->getIterator()), ThenBB);
3635 
3636   // Move Entry branch to end of ThenBB, and replace with conditional
3637   // branch (If-stmt)
3638   Instruction *EntryBBTI = EntryBB->getTerminator();
3639   Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
3640   EntryBBTI->removeFromParent();
3641   Builder.SetInsertPoint(UI);
3642   Builder.Insert(EntryBBTI);
3643   UI->eraseFromParent();
3644   Builder.SetInsertPoint(ThenBB->getTerminator());
3645 
3646   // return an insertion point to ExitBB.
3647   return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
3648 }
3649 
emitCommonDirectiveExit(omp::Directive OMPD,InsertPointTy FinIP,Instruction * ExitCall,bool HasFinalize)3650 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
3651     omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
3652     bool HasFinalize) {
3653 
3654   Builder.restoreIP(FinIP);
3655 
3656   // If there is finalization to do, emit it before the exit call
3657   if (HasFinalize) {
3658     assert(!FinalizationStack.empty() &&
3659            "Unexpected finalization stack state!");
3660 
3661     FinalizationInfo Fi = FinalizationStack.pop_back_val();
3662     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
3663 
3664     Fi.FiniCB(FinIP);
3665 
3666     BasicBlock *FiniBB = FinIP.getBlock();
3667     Instruction *FiniBBTI = FiniBB->getTerminator();
3668 
3669     // set Builder IP for call creation
3670     Builder.SetInsertPoint(FiniBBTI);
3671   }
3672 
3673   if (!ExitCall)
3674     return Builder.saveIP();
3675 
3676   // place the Exitcall as last instruction before Finalization block terminator
3677   ExitCall->removeFromParent();
3678   Builder.Insert(ExitCall);
3679 
3680   return IRBuilder<>::InsertPoint(ExitCall->getParent(),
3681                                   ExitCall->getIterator());
3682 }
3683 
createCopyinClauseBlocks(InsertPointTy IP,Value * MasterAddr,Value * PrivateAddr,llvm::IntegerType * IntPtrTy,bool BranchtoEnd)3684 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
3685     InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
3686     llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
3687   if (!IP.isSet())
3688     return IP;
3689 
3690   IRBuilder<>::InsertPointGuard IPG(Builder);
3691 
3692   // creates the following CFG structure
3693   //	   OMP_Entry : (MasterAddr != PrivateAddr)?
3694   //       F     T
3695   //       |      \
3696   //       |     copin.not.master
3697   //       |      /
3698   //       v     /
3699   //   copyin.not.master.end
3700   //		     |
3701   //         v
3702   //   OMP.Entry.Next
3703 
3704   BasicBlock *OMP_Entry = IP.getBlock();
3705   Function *CurFn = OMP_Entry->getParent();
3706   BasicBlock *CopyBegin =
3707       BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
3708   BasicBlock *CopyEnd = nullptr;
3709 
3710   // If entry block is terminated, split to preserve the branch to following
3711   // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
3712   if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
3713     CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
3714                                          "copyin.not.master.end");
3715     OMP_Entry->getTerminator()->eraseFromParent();
3716   } else {
3717     CopyEnd =
3718         BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
3719   }
3720 
3721   Builder.SetInsertPoint(OMP_Entry);
3722   Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
3723   Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
3724   Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
3725   Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
3726 
3727   Builder.SetInsertPoint(CopyBegin);
3728   if (BranchtoEnd)
3729     Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
3730 
3731   return Builder.saveIP();
3732 }
3733 
createOMPAlloc(const LocationDescription & Loc,Value * Size,Value * Allocator,std::string Name)3734 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
3735                                           Value *Size, Value *Allocator,
3736                                           std::string Name) {
3737   IRBuilder<>::InsertPointGuard IPG(Builder);
3738   Builder.restoreIP(Loc.IP);
3739 
3740   uint32_t SrcLocStrSize;
3741   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3742   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3743   Value *ThreadId = getOrCreateThreadID(Ident);
3744   Value *Args[] = {ThreadId, Size, Allocator};
3745 
3746   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
3747 
3748   return Builder.CreateCall(Fn, Args, Name);
3749 }
3750 
createOMPFree(const LocationDescription & Loc,Value * Addr,Value * Allocator,std::string Name)3751 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
3752                                          Value *Addr, Value *Allocator,
3753                                          std::string Name) {
3754   IRBuilder<>::InsertPointGuard IPG(Builder);
3755   Builder.restoreIP(Loc.IP);
3756 
3757   uint32_t SrcLocStrSize;
3758   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3759   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3760   Value *ThreadId = getOrCreateThreadID(Ident);
3761   Value *Args[] = {ThreadId, Addr, Allocator};
3762   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
3763   return Builder.CreateCall(Fn, Args, Name);
3764 }
3765 
createOMPInteropInit(const LocationDescription & Loc,Value * InteropVar,omp::OMPInteropType InteropType,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3766 CallInst *OpenMPIRBuilder::createOMPInteropInit(
3767     const LocationDescription &Loc, Value *InteropVar,
3768     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
3769     Value *DependenceAddress, bool HaveNowaitClause) {
3770   IRBuilder<>::InsertPointGuard IPG(Builder);
3771   Builder.restoreIP(Loc.IP);
3772 
3773   uint32_t SrcLocStrSize;
3774   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3775   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3776   Value *ThreadId = getOrCreateThreadID(Ident);
3777   if (Device == nullptr)
3778     Device = ConstantInt::get(Int32, -1);
3779   Constant *InteropTypeVal = ConstantInt::get(Int32, (int)InteropType);
3780   if (NumDependences == nullptr) {
3781     NumDependences = ConstantInt::get(Int64, 0);
3782     PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3783     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3784   }
3785   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3786   Value *Args[] = {
3787       Ident,  ThreadId,       InteropVar,        InteropTypeVal,
3788       Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
3789 
3790   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
3791 
3792   return Builder.CreateCall(Fn, Args);
3793 }
3794 
createOMPInteropDestroy(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3795 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
3796     const LocationDescription &Loc, Value *InteropVar, Value *Device,
3797     Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
3798   IRBuilder<>::InsertPointGuard IPG(Builder);
3799   Builder.restoreIP(Loc.IP);
3800 
3801   uint32_t SrcLocStrSize;
3802   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3803   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3804   Value *ThreadId = getOrCreateThreadID(Ident);
3805   if (Device == nullptr)
3806     Device = ConstantInt::get(Int32, -1);
3807   if (NumDependences == nullptr) {
3808     NumDependences = ConstantInt::get(Int32, 0);
3809     PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3810     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3811   }
3812   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3813   Value *Args[] = {
3814       Ident,          ThreadId,          InteropVar,         Device,
3815       NumDependences, DependenceAddress, HaveNowaitClauseVal};
3816 
3817   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
3818 
3819   return Builder.CreateCall(Fn, Args);
3820 }
3821 
createOMPInteropUse(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3822 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
3823                                                Value *InteropVar, Value *Device,
3824                                                Value *NumDependences,
3825                                                Value *DependenceAddress,
3826                                                bool HaveNowaitClause) {
3827   IRBuilder<>::InsertPointGuard IPG(Builder);
3828   Builder.restoreIP(Loc.IP);
3829   uint32_t SrcLocStrSize;
3830   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3831   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3832   Value *ThreadId = getOrCreateThreadID(Ident);
3833   if (Device == nullptr)
3834     Device = ConstantInt::get(Int32, -1);
3835   if (NumDependences == nullptr) {
3836     NumDependences = ConstantInt::get(Int32, 0);
3837     PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3838     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3839   }
3840   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3841   Value *Args[] = {
3842       Ident,          ThreadId,          InteropVar,         Device,
3843       NumDependences, DependenceAddress, HaveNowaitClauseVal};
3844 
3845   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
3846 
3847   return Builder.CreateCall(Fn, Args);
3848 }
3849 
createCachedThreadPrivate(const LocationDescription & Loc,llvm::Value * Pointer,llvm::ConstantInt * Size,const llvm::Twine & Name)3850 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
3851     const LocationDescription &Loc, llvm::Value *Pointer,
3852     llvm::ConstantInt *Size, const llvm::Twine &Name) {
3853   IRBuilder<>::InsertPointGuard IPG(Builder);
3854   Builder.restoreIP(Loc.IP);
3855 
3856   uint32_t SrcLocStrSize;
3857   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3858   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3859   Value *ThreadId = getOrCreateThreadID(Ident);
3860   Constant *ThreadPrivateCache =
3861       getOrCreateInternalVariable(Int8PtrPtr, Name.str());
3862   llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
3863 
3864   Function *Fn =
3865       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
3866 
3867   return Builder.CreateCall(Fn, Args);
3868 }
3869 
3870 OpenMPIRBuilder::InsertPointTy
createTargetInit(const LocationDescription & Loc,bool IsSPMD)3871 OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD) {
3872   if (!updateToLocation(Loc))
3873     return Loc.IP;
3874 
3875   uint32_t SrcLocStrSize;
3876   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3877   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3878   ConstantInt *IsSPMDVal = ConstantInt::getSigned(
3879       IntegerType::getInt8Ty(Int8->getContext()),
3880       IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
3881   ConstantInt *UseGenericStateMachine =
3882       ConstantInt::getBool(Int32->getContext(), !IsSPMD);
3883 
3884   Function *Fn = getOrCreateRuntimeFunctionPtr(
3885       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
3886 
3887   CallInst *ThreadKind = Builder.CreateCall(
3888       Fn, {Ident, IsSPMDVal, UseGenericStateMachine});
3889 
3890   Value *ExecUserCode = Builder.CreateICmpEQ(
3891       ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
3892       "exec_user_code");
3893 
3894   // ThreadKind = __kmpc_target_init(...)
3895   // if (ThreadKind == -1)
3896   //   user_code
3897   // else
3898   //   return;
3899 
3900   auto *UI = Builder.CreateUnreachable();
3901   BasicBlock *CheckBB = UI->getParent();
3902   BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
3903 
3904   BasicBlock *WorkerExitBB = BasicBlock::Create(
3905       CheckBB->getContext(), "worker.exit", CheckBB->getParent());
3906   Builder.SetInsertPoint(WorkerExitBB);
3907   Builder.CreateRetVoid();
3908 
3909   auto *CheckBBTI = CheckBB->getTerminator();
3910   Builder.SetInsertPoint(CheckBBTI);
3911   Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
3912 
3913   CheckBBTI->eraseFromParent();
3914   UI->eraseFromParent();
3915 
3916   // Continue in the "user_code" block, see diagram above and in
3917   // openmp/libomptarget/deviceRTLs/common/include/target.h .
3918   return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
3919 }
3920 
createTargetDeinit(const LocationDescription & Loc,bool IsSPMD)3921 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
3922                                          bool IsSPMD) {
3923   if (!updateToLocation(Loc))
3924     return;
3925 
3926   uint32_t SrcLocStrSize;
3927   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3928   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3929   ConstantInt *IsSPMDVal = ConstantInt::getSigned(
3930       IntegerType::getInt8Ty(Int8->getContext()),
3931       IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
3932 
3933   Function *Fn = getOrCreateRuntimeFunctionPtr(
3934       omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
3935 
3936   Builder.CreateCall(Fn, {Ident, IsSPMDVal});
3937 }
3938 
setOutlinedTargetRegionFunctionAttributes(Function * OutlinedFn,int32_t NumTeams,int32_t NumThreads)3939 void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
3940     Function *OutlinedFn, int32_t NumTeams, int32_t NumThreads) {
3941   if (Config.isEmbedded()) {
3942     OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
3943     // TODO: Determine if DSO local can be set to true.
3944     OutlinedFn->setDSOLocal(false);
3945     OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
3946     if (Triple(M.getTargetTriple()).isAMDGCN())
3947       OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
3948   }
3949 
3950   if (NumTeams > 0)
3951     OutlinedFn->addFnAttr("omp_target_num_teams", std::to_string(NumTeams));
3952   if (NumThreads > 0)
3953     OutlinedFn->addFnAttr("omp_target_thread_limit",
3954                           std::to_string(NumThreads));
3955 }
3956 
createOutlinedFunctionID(Function * OutlinedFn,StringRef EntryFnIDName)3957 Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
3958                                                     StringRef EntryFnIDName) {
3959   if (Config.isEmbedded()) {
3960     assert(OutlinedFn && "The outlined function must exist if embedded");
3961     return ConstantExpr::getBitCast(OutlinedFn, Builder.getInt8PtrTy());
3962   }
3963 
3964   return new GlobalVariable(
3965       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
3966       Constant::getNullValue(Builder.getInt8Ty()), EntryFnIDName);
3967 }
3968 
createTargetRegionEntryAddr(Function * OutlinedFn,StringRef EntryFnName)3969 Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
3970                                                        StringRef EntryFnName) {
3971   if (OutlinedFn)
3972     return OutlinedFn;
3973 
3974   assert(!M.getGlobalVariable(EntryFnName, true) &&
3975          "Named kernel already exists?");
3976   return new GlobalVariable(
3977       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
3978       Constant::getNullValue(Builder.getInt8Ty()), EntryFnName);
3979 }
3980 
emitTargetRegionFunction(OffloadEntriesInfoManager & InfoManager,TargetRegionEntryInfo & EntryInfo,FunctionGenCallback & GenerateFunctionCallback,int32_t NumTeams,int32_t NumThreads,bool IsOffloadEntry,Function * & OutlinedFn,Constant * & OutlinedFnID)3981 void OpenMPIRBuilder::emitTargetRegionFunction(
3982     OffloadEntriesInfoManager &InfoManager, TargetRegionEntryInfo &EntryInfo,
3983     FunctionGenCallback &GenerateFunctionCallback, int32_t NumTeams,
3984     int32_t NumThreads, bool IsOffloadEntry, Function *&OutlinedFn,
3985     Constant *&OutlinedFnID) {
3986 
3987   SmallString<64> EntryFnName;
3988   InfoManager.getTargetRegionEntryFnName(EntryFnName, EntryInfo);
3989 
3990   OutlinedFn = Config.isEmbedded() || !Config.openMPOffloadMandatory()
3991                    ? GenerateFunctionCallback(EntryFnName)
3992                    : nullptr;
3993 
3994   // If this target outline function is not an offload entry, we don't need to
3995   // register it. This may be in the case of a false if clause, or if there are
3996   // no OpenMP targets.
3997   if (!IsOffloadEntry)
3998     return;
3999 
4000   std::string EntryFnIDName =
4001       Config.isEmbedded()
4002           ? std::string(EntryFnName)
4003           : createPlatformSpecificName({EntryFnName, "region_id"});
4004 
4005   OutlinedFnID = registerTargetRegionFunction(
4006       InfoManager, EntryInfo, OutlinedFn, EntryFnName, EntryFnIDName, NumTeams,
4007       NumThreads);
4008 }
4009 
registerTargetRegionFunction(OffloadEntriesInfoManager & InfoManager,TargetRegionEntryInfo & EntryInfo,Function * OutlinedFn,StringRef EntryFnName,StringRef EntryFnIDName,int32_t NumTeams,int32_t NumThreads)4010 Constant *OpenMPIRBuilder::registerTargetRegionFunction(
4011     OffloadEntriesInfoManager &InfoManager, TargetRegionEntryInfo &EntryInfo,
4012     Function *OutlinedFn, StringRef EntryFnName, StringRef EntryFnIDName,
4013     int32_t NumTeams, int32_t NumThreads) {
4014   if (OutlinedFn)
4015     setOutlinedTargetRegionFunctionAttributes(OutlinedFn, NumTeams, NumThreads);
4016   auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
4017   auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
4018   InfoManager.registerTargetRegionEntryInfo(
4019       EntryInfo, EntryAddr, OutlinedFnID,
4020       OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
4021   return OutlinedFnID;
4022 }
4023 
getNameWithSeparators(ArrayRef<StringRef> Parts,StringRef FirstSeparator,StringRef Separator)4024 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
4025                                                    StringRef FirstSeparator,
4026                                                    StringRef Separator) {
4027   SmallString<128> Buffer;
4028   llvm::raw_svector_ostream OS(Buffer);
4029   StringRef Sep = FirstSeparator;
4030   for (StringRef Part : Parts) {
4031     OS << Sep << Part;
4032     Sep = Separator;
4033   }
4034   return OS.str().str();
4035 }
4036 
4037 std::string
createPlatformSpecificName(ArrayRef<StringRef> Parts) const4038 OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
4039   return OpenMPIRBuilder::getNameWithSeparators(Parts, Config.firstSeparator(),
4040                                                 Config.separator());
4041 }
4042 
4043 GlobalVariable *
getOrCreateInternalVariable(Type * Ty,const StringRef & Name,unsigned AddressSpace)4044 OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
4045                                              unsigned AddressSpace) {
4046   auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
4047   if (Elem.second) {
4048     assert(cast<PointerType>(Elem.second->getType())
4049                ->isOpaqueOrPointeeTypeMatches(Ty) &&
4050            "OMP internal variable has different type than requested");
4051   } else {
4052     // TODO: investigate the appropriate linkage type used for the global
4053     // variable for possibly changing that to internal or private, or maybe
4054     // create different versions of the function for different OMP internal
4055     // variables.
4056     Elem.second = new GlobalVariable(
4057         M, Ty, /*IsConstant=*/false, GlobalValue::CommonLinkage,
4058         Constant::getNullValue(Ty), Elem.first(),
4059         /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal, AddressSpace);
4060   }
4061 
4062   return cast<GlobalVariable>(&*Elem.second);
4063 }
4064 
getOMPCriticalRegionLock(StringRef CriticalName)4065 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
4066   std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
4067   std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
4068   return getOrCreateInternalVariable(KmpCriticalNameTy, Name);
4069 }
4070 
4071 GlobalVariable *
createOffloadMaptypes(SmallVectorImpl<uint64_t> & Mappings,std::string VarName)4072 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
4073                                        std::string VarName) {
4074   llvm::Constant *MaptypesArrayInit =
4075       llvm::ConstantDataArray::get(M.getContext(), Mappings);
4076   auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
4077       M, MaptypesArrayInit->getType(),
4078       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
4079       VarName);
4080   MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
4081   return MaptypesArrayGlobal;
4082 }
4083 
createMapperAllocas(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumOperands,struct MapperAllocas & MapperAllocas)4084 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
4085                                           InsertPointTy AllocaIP,
4086                                           unsigned NumOperands,
4087                                           struct MapperAllocas &MapperAllocas) {
4088   if (!updateToLocation(Loc))
4089     return;
4090 
4091   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
4092   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
4093   Builder.restoreIP(AllocaIP);
4094   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI8PtrTy);
4095   AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy);
4096   AllocaInst *ArgSizes = Builder.CreateAlloca(ArrI64Ty);
4097   Builder.restoreIP(Loc.IP);
4098   MapperAllocas.ArgsBase = ArgsBase;
4099   MapperAllocas.Args = Args;
4100   MapperAllocas.ArgSizes = ArgSizes;
4101 }
4102 
emitMapperCall(const LocationDescription & Loc,Function * MapperFunc,Value * SrcLocInfo,Value * MaptypesArg,Value * MapnamesArg,struct MapperAllocas & MapperAllocas,int64_t DeviceID,unsigned NumOperands)4103 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
4104                                      Function *MapperFunc, Value *SrcLocInfo,
4105                                      Value *MaptypesArg, Value *MapnamesArg,
4106                                      struct MapperAllocas &MapperAllocas,
4107                                      int64_t DeviceID, unsigned NumOperands) {
4108   if (!updateToLocation(Loc))
4109     return;
4110 
4111   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
4112   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
4113   Value *ArgsBaseGEP =
4114       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
4115                                 {Builder.getInt32(0), Builder.getInt32(0)});
4116   Value *ArgsGEP =
4117       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
4118                                 {Builder.getInt32(0), Builder.getInt32(0)});
4119   Value *ArgSizesGEP =
4120       Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
4121                                 {Builder.getInt32(0), Builder.getInt32(0)});
4122   Value *NullPtr = Constant::getNullValue(Int8Ptr->getPointerTo());
4123   Builder.CreateCall(MapperFunc,
4124                      {SrcLocInfo, Builder.getInt64(DeviceID),
4125                       Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
4126                       ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
4127 }
4128 
emitOffloadingArraysArgument(IRBuilderBase & Builder,TargetDataRTArgs & RTArgs,TargetDataInfo & Info,bool EmitDebug,bool ForEndCall)4129 void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
4130                                                    TargetDataRTArgs &RTArgs,
4131                                                    TargetDataInfo &Info,
4132                                                    bool EmitDebug,
4133                                                    bool ForEndCall) {
4134   assert((!ForEndCall || Info.separateBeginEndCalls()) &&
4135          "expected region end call to runtime only when end call is separate");
4136   auto VoidPtrTy = Type::getInt8PtrTy(M.getContext());
4137   auto VoidPtrPtrTy = VoidPtrTy->getPointerTo(0);
4138   auto Int64Ty = Type::getInt64Ty(M.getContext());
4139   auto Int64PtrTy = Type::getInt64PtrTy(M.getContext());
4140 
4141   if (!Info.NumberOfPtrs) {
4142     RTArgs.BasePointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
4143     RTArgs.PointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
4144     RTArgs.SizesArray = ConstantPointerNull::get(Int64PtrTy);
4145     RTArgs.MapTypesArray = ConstantPointerNull::get(Int64PtrTy);
4146     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
4147     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
4148     return;
4149   }
4150 
4151   RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
4152       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs),
4153       Info.RTArgs.BasePointersArray,
4154       /*Idx0=*/0, /*Idx1=*/0);
4155   RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
4156       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray,
4157       /*Idx0=*/0,
4158       /*Idx1=*/0);
4159   RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
4160       ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
4161       /*Idx0=*/0, /*Idx1=*/0);
4162   RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
4163       ArrayType::get(Int64Ty, Info.NumberOfPtrs),
4164       ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
4165                                                  : Info.RTArgs.MapTypesArray,
4166       /*Idx0=*/0,
4167       /*Idx1=*/0);
4168 
4169   // Only emit the mapper information arrays if debug information is
4170   // requested.
4171   if (!EmitDebug)
4172     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
4173   else
4174     RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
4175         ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.MapNamesArray,
4176         /*Idx0=*/0,
4177         /*Idx1=*/0);
4178   // If there is no user-defined mapper, set the mapper array to nullptr to
4179   // avoid an unnecessary data privatization
4180   if (!Info.HasMapper)
4181     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
4182   else
4183     RTArgs.MappersArray =
4184         Builder.CreatePointerCast(Info.RTArgs.MappersArray, VoidPtrPtrTy);
4185 }
4186 
checkAndEmitFlushAfterAtomic(const LocationDescription & Loc,llvm::AtomicOrdering AO,AtomicKind AK)4187 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
4188     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
4189   assert(!(AO == AtomicOrdering::NotAtomic ||
4190            AO == llvm::AtomicOrdering::Unordered) &&
4191          "Unexpected Atomic Ordering.");
4192 
4193   bool Flush = false;
4194   llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
4195 
4196   switch (AK) {
4197   case Read:
4198     if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
4199         AO == AtomicOrdering::SequentiallyConsistent) {
4200       FlushAO = AtomicOrdering::Acquire;
4201       Flush = true;
4202     }
4203     break;
4204   case Write:
4205   case Compare:
4206   case Update:
4207     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
4208         AO == AtomicOrdering::SequentiallyConsistent) {
4209       FlushAO = AtomicOrdering::Release;
4210       Flush = true;
4211     }
4212     break;
4213   case Capture:
4214     switch (AO) {
4215     case AtomicOrdering::Acquire:
4216       FlushAO = AtomicOrdering::Acquire;
4217       Flush = true;
4218       break;
4219     case AtomicOrdering::Release:
4220       FlushAO = AtomicOrdering::Release;
4221       Flush = true;
4222       break;
4223     case AtomicOrdering::AcquireRelease:
4224     case AtomicOrdering::SequentiallyConsistent:
4225       FlushAO = AtomicOrdering::AcquireRelease;
4226       Flush = true;
4227       break;
4228     default:
4229       // do nothing - leave silently.
4230       break;
4231     }
4232   }
4233 
4234   if (Flush) {
4235     // Currently Flush RT call still doesn't take memory_ordering, so for when
4236     // that happens, this tries to do the resolution of which atomic ordering
4237     // to use with but issue the flush call
4238     // TODO: pass `FlushAO` after memory ordering support is added
4239     (void)FlushAO;
4240     emitFlush(Loc);
4241   }
4242 
4243   // for AO == AtomicOrdering::Monotonic and  all other case combinations
4244   // do nothing
4245   return Flush;
4246 }
4247 
4248 OpenMPIRBuilder::InsertPointTy
createAtomicRead(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOrdering AO)4249 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
4250                                   AtomicOpValue &X, AtomicOpValue &V,
4251                                   AtomicOrdering AO) {
4252   if (!updateToLocation(Loc))
4253     return Loc.IP;
4254 
4255   Type *XTy = X.Var->getType();
4256   assert(XTy->isPointerTy() && "OMP Atomic expects a pointer to target memory");
4257   Type *XElemTy = X.ElemTy;
4258   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
4259           XElemTy->isPointerTy()) &&
4260          "OMP atomic read expected a scalar type");
4261 
4262   Value *XRead = nullptr;
4263 
4264   if (XElemTy->isIntegerTy()) {
4265     LoadInst *XLD =
4266         Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
4267     XLD->setAtomic(AO);
4268     XRead = cast<Value>(XLD);
4269   } else {
4270     // We need to bitcast and perform atomic op as integer
4271     unsigned Addrspace = cast<PointerType>(XTy)->getAddressSpace();
4272     IntegerType *IntCastTy =
4273         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
4274     Value *XBCast = Builder.CreateBitCast(
4275         X.Var, IntCastTy->getPointerTo(Addrspace), "atomic.src.int.cast");
4276     LoadInst *XLoad =
4277         Builder.CreateLoad(IntCastTy, XBCast, X.IsVolatile, "omp.atomic.load");
4278     XLoad->setAtomic(AO);
4279     if (XElemTy->isFloatingPointTy()) {
4280       XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
4281     } else {
4282       XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
4283     }
4284   }
4285   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
4286   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
4287   return Builder.saveIP();
4288 }
4289 
4290 OpenMPIRBuilder::InsertPointTy
createAtomicWrite(const LocationDescription & Loc,AtomicOpValue & X,Value * Expr,AtomicOrdering AO)4291 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
4292                                    AtomicOpValue &X, Value *Expr,
4293                                    AtomicOrdering AO) {
4294   if (!updateToLocation(Loc))
4295     return Loc.IP;
4296 
4297   Type *XTy = X.Var->getType();
4298   assert(XTy->isPointerTy() && "OMP Atomic expects a pointer to target memory");
4299   Type *XElemTy = X.ElemTy;
4300   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
4301           XElemTy->isPointerTy()) &&
4302          "OMP atomic write expected a scalar type");
4303 
4304   if (XElemTy->isIntegerTy()) {
4305     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
4306     XSt->setAtomic(AO);
4307   } else {
4308     // We need to bitcast and perform atomic op as integers
4309     unsigned Addrspace = cast<PointerType>(XTy)->getAddressSpace();
4310     IntegerType *IntCastTy =
4311         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
4312     Value *XBCast = Builder.CreateBitCast(
4313         X.Var, IntCastTy->getPointerTo(Addrspace), "atomic.dst.int.cast");
4314     Value *ExprCast =
4315         Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
4316     StoreInst *XSt = Builder.CreateStore(ExprCast, XBCast, X.IsVolatile);
4317     XSt->setAtomic(AO);
4318   }
4319 
4320   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
4321   return Builder.saveIP();
4322 }
4323 
createAtomicUpdate(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool IsXBinopExpr)4324 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
4325     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
4326     Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
4327     AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
4328   assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
4329   if (!updateToLocation(Loc))
4330     return Loc.IP;
4331 
4332   LLVM_DEBUG({
4333     Type *XTy = X.Var->getType();
4334     assert(XTy->isPointerTy() &&
4335            "OMP Atomic expects a pointer to target memory");
4336     Type *XElemTy = X.ElemTy;
4337     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
4338             XElemTy->isPointerTy()) &&
4339            "OMP atomic update expected a scalar type");
4340     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
4341            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
4342            "OpenMP atomic does not support LT or GT operations");
4343   });
4344 
4345   emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
4346                    X.IsVolatile, IsXBinopExpr);
4347   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
4348   return Builder.saveIP();
4349 }
4350 
4351 // FIXME: Duplicating AtomicExpand
emitRMWOpAsInstruction(Value * Src1,Value * Src2,AtomicRMWInst::BinOp RMWOp)4352 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
4353                                                AtomicRMWInst::BinOp RMWOp) {
4354   switch (RMWOp) {
4355   case AtomicRMWInst::Add:
4356     return Builder.CreateAdd(Src1, Src2);
4357   case AtomicRMWInst::Sub:
4358     return Builder.CreateSub(Src1, Src2);
4359   case AtomicRMWInst::And:
4360     return Builder.CreateAnd(Src1, Src2);
4361   case AtomicRMWInst::Nand:
4362     return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
4363   case AtomicRMWInst::Or:
4364     return Builder.CreateOr(Src1, Src2);
4365   case AtomicRMWInst::Xor:
4366     return Builder.CreateXor(Src1, Src2);
4367   case AtomicRMWInst::Xchg:
4368   case AtomicRMWInst::FAdd:
4369   case AtomicRMWInst::FSub:
4370   case AtomicRMWInst::BAD_BINOP:
4371   case AtomicRMWInst::Max:
4372   case AtomicRMWInst::Min:
4373   case AtomicRMWInst::UMax:
4374   case AtomicRMWInst::UMin:
4375   case AtomicRMWInst::FMax:
4376   case AtomicRMWInst::FMin:
4377   case AtomicRMWInst::UIncWrap:
4378   case AtomicRMWInst::UDecWrap:
4379     llvm_unreachable("Unsupported atomic update operation");
4380   }
4381   llvm_unreachable("Unsupported atomic update operation");
4382 }
4383 
emitAtomicUpdate(InsertPointTy AllocaIP,Value * X,Type * XElemTy,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool VolatileX,bool IsXBinopExpr)4384 std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
4385     InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
4386     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
4387     AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
4388   // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
4389   // or a complex datatype.
4390   bool emitRMWOp = false;
4391   switch (RMWOp) {
4392   case AtomicRMWInst::Add:
4393   case AtomicRMWInst::And:
4394   case AtomicRMWInst::Nand:
4395   case AtomicRMWInst::Or:
4396   case AtomicRMWInst::Xor:
4397   case AtomicRMWInst::Xchg:
4398     emitRMWOp = XElemTy;
4399     break;
4400   case AtomicRMWInst::Sub:
4401     emitRMWOp = (IsXBinopExpr && XElemTy);
4402     break;
4403   default:
4404     emitRMWOp = false;
4405   }
4406   emitRMWOp &= XElemTy->isIntegerTy();
4407 
4408   std::pair<Value *, Value *> Res;
4409   if (emitRMWOp) {
4410     Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
4411     // not needed except in case of postfix captures. Generate anyway for
4412     // consistency with the else part. Will be removed with any DCE pass.
4413     // AtomicRMWInst::Xchg does not have a coressponding instruction.
4414     if (RMWOp == AtomicRMWInst::Xchg)
4415       Res.second = Res.first;
4416     else
4417       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
4418   } else {
4419     unsigned Addrspace = cast<PointerType>(X->getType())->getAddressSpace();
4420     IntegerType *IntCastTy =
4421         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
4422     Value *XBCast =
4423         Builder.CreateBitCast(X, IntCastTy->getPointerTo(Addrspace));
4424     LoadInst *OldVal =
4425         Builder.CreateLoad(IntCastTy, XBCast, X->getName() + ".atomic.load");
4426     OldVal->setAtomic(AO);
4427     // CurBB
4428     // |     /---\
4429 		// ContBB    |
4430     // |     \---/
4431     // ExitBB
4432     BasicBlock *CurBB = Builder.GetInsertBlock();
4433     Instruction *CurBBTI = CurBB->getTerminator();
4434     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
4435     BasicBlock *ExitBB =
4436         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
4437     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
4438                                                 X->getName() + ".atomic.cont");
4439     ContBB->getTerminator()->eraseFromParent();
4440     Builder.restoreIP(AllocaIP);
4441     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
4442     NewAtomicAddr->setName(X->getName() + "x.new.val");
4443     Builder.SetInsertPoint(ContBB);
4444     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
4445     PHI->addIncoming(OldVal, CurBB);
4446     IntegerType *NewAtomicCastTy =
4447         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
4448     bool IsIntTy = XElemTy->isIntegerTy();
4449     Value *NewAtomicIntAddr =
4450         (IsIntTy)
4451             ? NewAtomicAddr
4452             : Builder.CreateBitCast(NewAtomicAddr,
4453                                     NewAtomicCastTy->getPointerTo(Addrspace));
4454     Value *OldExprVal = PHI;
4455     if (!IsIntTy) {
4456       if (XElemTy->isFloatingPointTy()) {
4457         OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
4458                                            X->getName() + ".atomic.fltCast");
4459       } else {
4460         OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
4461                                             X->getName() + ".atomic.ptrCast");
4462       }
4463     }
4464 
4465     Value *Upd = UpdateOp(OldExprVal, Builder);
4466     Builder.CreateStore(Upd, NewAtomicAddr);
4467     LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicIntAddr);
4468     Value *XAddr =
4469         (IsIntTy)
4470             ? X
4471             : Builder.CreateBitCast(X, IntCastTy->getPointerTo(Addrspace));
4472     AtomicOrdering Failure =
4473         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
4474     AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
4475         XAddr, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
4476     Result->setVolatile(VolatileX);
4477     Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
4478     Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4479     PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
4480     Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
4481 
4482     Res.first = OldExprVal;
4483     Res.second = Upd;
4484 
4485     // set Insertion point in exit block
4486     if (UnreachableInst *ExitTI =
4487             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
4488       CurBBTI->eraseFromParent();
4489       Builder.SetInsertPoint(ExitBB);
4490     } else {
4491       Builder.SetInsertPoint(ExitTI);
4492     }
4493   }
4494 
4495   return Res;
4496 }
4497 
createAtomicCapture(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,AtomicOpValue & V,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool UpdateExpr,bool IsPostfixUpdate,bool IsXBinopExpr)4498 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
4499     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
4500     AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
4501     AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
4502     bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
4503   if (!updateToLocation(Loc))
4504     return Loc.IP;
4505 
4506   LLVM_DEBUG({
4507     Type *XTy = X.Var->getType();
4508     assert(XTy->isPointerTy() &&
4509            "OMP Atomic expects a pointer to target memory");
4510     Type *XElemTy = X.ElemTy;
4511     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
4512             XElemTy->isPointerTy()) &&
4513            "OMP atomic capture expected a scalar type");
4514     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
4515            "OpenMP atomic does not support LT or GT operations");
4516   });
4517 
4518   // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
4519   // 'x' is simply atomically rewritten with 'expr'.
4520   AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
4521   std::pair<Value *, Value *> Result =
4522       emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
4523                        X.IsVolatile, IsXBinopExpr);
4524 
4525   Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
4526   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
4527 
4528   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
4529   return Builder.saveIP();
4530 }
4531 
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly)4532 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
4533     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
4534     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
4535     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
4536     bool IsFailOnly) {
4537 
4538   if (!updateToLocation(Loc))
4539     return Loc.IP;
4540 
4541   assert(X.Var->getType()->isPointerTy() &&
4542          "OMP atomic expects a pointer to target memory");
4543   // compare capture
4544   if (V.Var) {
4545     assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
4546     assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
4547   }
4548 
4549   bool IsInteger = E->getType()->isIntegerTy();
4550 
4551   if (Op == OMPAtomicCompareOp::EQ) {
4552     AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
4553     AtomicCmpXchgInst *Result = nullptr;
4554     if (!IsInteger) {
4555       unsigned Addrspace =
4556           cast<PointerType>(X.Var->getType())->getAddressSpace();
4557       IntegerType *IntCastTy =
4558           IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
4559       Value *XBCast =
4560           Builder.CreateBitCast(X.Var, IntCastTy->getPointerTo(Addrspace));
4561       Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
4562       Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
4563       Result = Builder.CreateAtomicCmpXchg(XBCast, EBCast, DBCast, MaybeAlign(),
4564                                            AO, Failure);
4565     } else {
4566       Result =
4567           Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
4568     }
4569 
4570     if (V.Var) {
4571       Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
4572       if (!IsInteger)
4573         OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
4574       assert(OldValue->getType() == V.ElemTy &&
4575              "OldValue and V must be of same type");
4576       if (IsPostfixUpdate) {
4577         Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
4578       } else {
4579         Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4580         if (IsFailOnly) {
4581           // CurBB----
4582           //   |     |
4583           //   v     |
4584           // ContBB  |
4585           //   |     |
4586           //   v     |
4587           // ExitBB <-
4588           //
4589           // where ContBB only contains the store of old value to 'v'.
4590           BasicBlock *CurBB = Builder.GetInsertBlock();
4591           Instruction *CurBBTI = CurBB->getTerminator();
4592           CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
4593           BasicBlock *ExitBB = CurBB->splitBasicBlock(
4594               CurBBTI, X.Var->getName() + ".atomic.exit");
4595           BasicBlock *ContBB = CurBB->splitBasicBlock(
4596               CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
4597           ContBB->getTerminator()->eraseFromParent();
4598           CurBB->getTerminator()->eraseFromParent();
4599 
4600           Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
4601 
4602           Builder.SetInsertPoint(ContBB);
4603           Builder.CreateStore(OldValue, V.Var);
4604           Builder.CreateBr(ExitBB);
4605 
4606           if (UnreachableInst *ExitTI =
4607                   dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
4608             CurBBTI->eraseFromParent();
4609             Builder.SetInsertPoint(ExitBB);
4610           } else {
4611             Builder.SetInsertPoint(ExitTI);
4612           }
4613         } else {
4614           Value *CapturedValue =
4615               Builder.CreateSelect(SuccessOrFail, E, OldValue);
4616           Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
4617         }
4618       }
4619     }
4620     // The comparison result has to be stored.
4621     if (R.Var) {
4622       assert(R.Var->getType()->isPointerTy() &&
4623              "r.var must be of pointer type");
4624       assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
4625 
4626       Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4627       Value *ResultCast = R.IsSigned
4628                               ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
4629                               : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
4630       Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
4631     }
4632   } else {
4633     assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
4634            "Op should be either max or min at this point");
4635     assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
4636 
4637     // Reverse the ordop as the OpenMP forms are different from LLVM forms.
4638     // Let's take max as example.
4639     // OpenMP form:
4640     // x = x > expr ? expr : x;
4641     // LLVM form:
4642     // *ptr = *ptr > val ? *ptr : val;
4643     // We need to transform to LLVM form.
4644     // x = x <= expr ? x : expr;
4645     AtomicRMWInst::BinOp NewOp;
4646     if (IsXBinopExpr) {
4647       if (IsInteger) {
4648         if (X.IsSigned)
4649           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
4650                                                 : AtomicRMWInst::Max;
4651         else
4652           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
4653                                                 : AtomicRMWInst::UMax;
4654       } else {
4655         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
4656                                               : AtomicRMWInst::FMax;
4657       }
4658     } else {
4659       if (IsInteger) {
4660         if (X.IsSigned)
4661           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
4662                                                 : AtomicRMWInst::Min;
4663         else
4664           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
4665                                                 : AtomicRMWInst::UMin;
4666       } else {
4667         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
4668                                               : AtomicRMWInst::FMin;
4669       }
4670     }
4671 
4672     AtomicRMWInst *OldValue =
4673         Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
4674     if (V.Var) {
4675       Value *CapturedValue = nullptr;
4676       if (IsPostfixUpdate) {
4677         CapturedValue = OldValue;
4678       } else {
4679         CmpInst::Predicate Pred;
4680         switch (NewOp) {
4681         case AtomicRMWInst::Max:
4682           Pred = CmpInst::ICMP_SGT;
4683           break;
4684         case AtomicRMWInst::UMax:
4685           Pred = CmpInst::ICMP_UGT;
4686           break;
4687         case AtomicRMWInst::FMax:
4688           Pred = CmpInst::FCMP_OGT;
4689           break;
4690         case AtomicRMWInst::Min:
4691           Pred = CmpInst::ICMP_SLT;
4692           break;
4693         case AtomicRMWInst::UMin:
4694           Pred = CmpInst::ICMP_ULT;
4695           break;
4696         case AtomicRMWInst::FMin:
4697           Pred = CmpInst::FCMP_OLT;
4698           break;
4699         default:
4700           llvm_unreachable("unexpected comparison op");
4701         }
4702         Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
4703         CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
4704       }
4705       Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
4706     }
4707   }
4708 
4709   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
4710 
4711   return Builder.saveIP();
4712 }
4713 
4714 GlobalVariable *
createOffloadMapnames(SmallVectorImpl<llvm::Constant * > & Names,std::string VarName)4715 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
4716                                        std::string VarName) {
4717   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
4718       llvm::ArrayType::get(
4719           llvm::Type::getInt8Ty(M.getContext())->getPointerTo(), Names.size()),
4720       Names);
4721   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
4722       M, MapNamesArrayInit->getType(),
4723       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
4724       VarName);
4725   return MapNamesArrayGlobal;
4726 }
4727 
4728 // Create all simple and struct types exposed by the runtime and remember
4729 // the llvm::PointerTypes of them for easy access later.
initializeTypes(Module & M)4730 void OpenMPIRBuilder::initializeTypes(Module &M) {
4731   LLVMContext &Ctx = M.getContext();
4732   StructType *T;
4733 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
4734 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize)                             \
4735   VarName##Ty = ArrayType::get(ElemTy, ArraySize);                             \
4736   VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
4737 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...)                  \
4738   VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg);            \
4739   VarName##Ptr = PointerType::getUnqual(VarName);
4740 #define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...)                      \
4741   T = StructType::getTypeByName(Ctx, StructName);                              \
4742   if (!T)                                                                      \
4743     T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed);            \
4744   VarName = T;                                                                 \
4745   VarName##Ptr = PointerType::getUnqual(T);
4746 #include "llvm/Frontend/OpenMP/OMPKinds.def"
4747 }
4748 
collectBlocks(SmallPtrSetImpl<BasicBlock * > & BlockSet,SmallVectorImpl<BasicBlock * > & BlockVector)4749 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
4750     SmallPtrSetImpl<BasicBlock *> &BlockSet,
4751     SmallVectorImpl<BasicBlock *> &BlockVector) {
4752   SmallVector<BasicBlock *, 32> Worklist;
4753   BlockSet.insert(EntryBB);
4754   BlockSet.insert(ExitBB);
4755 
4756   Worklist.push_back(EntryBB);
4757   while (!Worklist.empty()) {
4758     BasicBlock *BB = Worklist.pop_back_val();
4759     BlockVector.push_back(BB);
4760     for (BasicBlock *SuccBB : successors(BB))
4761       if (BlockSet.insert(SuccBB).second)
4762         Worklist.push_back(SuccBB);
4763   }
4764 }
4765 
createOffloadEntry(Constant * ID,Constant * Addr,uint64_t Size,int32_t Flags,GlobalValue::LinkageTypes)4766 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
4767                                          uint64_t Size, int32_t Flags,
4768                                          GlobalValue::LinkageTypes) {
4769   if (!Config.isTargetCodegen()) {
4770     emitOffloadingEntry(ID, Addr->getName(), Size, Flags);
4771     return;
4772   }
4773   // TODO: Add support for global variables on the device after declare target
4774   // support.
4775   Function *Fn = dyn_cast<Function>(Addr);
4776   if (!Fn)
4777     return;
4778 
4779   Module &M = *(Fn->getParent());
4780   LLVMContext &Ctx = M.getContext();
4781 
4782   // Get "nvvm.annotations" metadata node.
4783   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4784 
4785   Metadata *MDVals[] = {
4786       ConstantAsMetadata::get(Fn), MDString::get(Ctx, "kernel"),
4787       ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), 1))};
4788   // Append metadata to nvvm.annotations.
4789   MD->addOperand(MDNode::get(Ctx, MDVals));
4790 
4791   // Add a function attribute for the kernel.
4792   Fn->addFnAttr(Attribute::get(Ctx, "kernel"));
4793 }
4794 
4795 // We only generate metadata for function that contain target regions.
createOffloadEntriesAndInfoMetadata(OffloadEntriesInfoManager & OffloadEntriesInfoManager,EmitMetadataErrorReportFunctionTy & ErrorFn)4796 void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
4797     OffloadEntriesInfoManager &OffloadEntriesInfoManager,
4798     EmitMetadataErrorReportFunctionTy &ErrorFn) {
4799 
4800   // If there are no entries, we don't need to do anything.
4801   if (OffloadEntriesInfoManager.empty())
4802     return;
4803 
4804   LLVMContext &C = M.getContext();
4805   SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
4806                         TargetRegionEntryInfo>,
4807               16>
4808       OrderedEntries(OffloadEntriesInfoManager.size());
4809 
4810   // Auxiliary methods to create metadata values and strings.
4811   auto &&GetMDInt = [this](unsigned V) {
4812     return ConstantAsMetadata::get(ConstantInt::get(Builder.getInt32Ty(), V));
4813   };
4814 
4815   auto &&GetMDString = [&C](StringRef V) { return MDString::get(C, V); };
4816 
4817   // Create the offloading info metadata node.
4818   NamedMDNode *MD = M.getOrInsertNamedMetadata("omp_offload.info");
4819   auto &&TargetRegionMetadataEmitter =
4820       [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
4821           const TargetRegionEntryInfo &EntryInfo,
4822           const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
4823         // Generate metadata for target regions. Each entry of this metadata
4824         // contains:
4825         // - Entry 0 -> Kind of this type of metadata (0).
4826         // - Entry 1 -> Device ID of the file where the entry was identified.
4827         // - Entry 2 -> File ID of the file where the entry was identified.
4828         // - Entry 3 -> Mangled name of the function where the entry was
4829         // identified.
4830         // - Entry 4 -> Line in the file where the entry was identified.
4831         // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
4832         // - Entry 6 -> Order the entry was created.
4833         // The first element of the metadata node is the kind.
4834         Metadata *Ops[] = {
4835             GetMDInt(E.getKind()),      GetMDInt(EntryInfo.DeviceID),
4836             GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
4837             GetMDInt(EntryInfo.Line),   GetMDInt(EntryInfo.Count),
4838             GetMDInt(E.getOrder())};
4839 
4840         // Save this entry in the right position of the ordered entries array.
4841         OrderedEntries[E.getOrder()] = std::make_pair(&E, EntryInfo);
4842 
4843         // Add metadata to the named metadata node.
4844         MD->addOperand(MDNode::get(C, Ops));
4845       };
4846 
4847   OffloadEntriesInfoManager.actOnTargetRegionEntriesInfo(
4848       TargetRegionMetadataEmitter);
4849 
4850   // Create function that emits metadata for each device global variable entry;
4851   auto &&DeviceGlobalVarMetadataEmitter =
4852       [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
4853           StringRef MangledName,
4854           const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
4855         // Generate metadata for global variables. Each entry of this metadata
4856         // contains:
4857         // - Entry 0 -> Kind of this type of metadata (1).
4858         // - Entry 1 -> Mangled name of the variable.
4859         // - Entry 2 -> Declare target kind.
4860         // - Entry 3 -> Order the entry was created.
4861         // The first element of the metadata node is the kind.
4862         Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
4863                            GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
4864 
4865         // Save this entry in the right position of the ordered entries array.
4866         TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
4867         OrderedEntries[E.getOrder()] = std::make_pair(&E, varInfo);
4868 
4869         // Add metadata to the named metadata node.
4870         MD->addOperand(MDNode::get(C, Ops));
4871       };
4872 
4873   OffloadEntriesInfoManager.actOnDeviceGlobalVarEntriesInfo(
4874       DeviceGlobalVarMetadataEmitter);
4875 
4876   for (const auto &E : OrderedEntries) {
4877     assert(E.first && "All ordered entries must exist!");
4878     if (const auto *CE =
4879             dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
4880                 E.first)) {
4881       if (!CE->getID() || !CE->getAddress()) {
4882         // Do not blame the entry if the parent funtion is not emitted.
4883         TargetRegionEntryInfo EntryInfo = E.second;
4884         StringRef FnName = EntryInfo.ParentName;
4885         if (!M.getNamedValue(FnName))
4886           continue;
4887         ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
4888         continue;
4889       }
4890       createOffloadEntry(CE->getID(), CE->getAddress(),
4891                          /*Size=*/0, CE->getFlags(),
4892                          GlobalValue::WeakAnyLinkage);
4893     } else if (const auto *CE = dyn_cast<
4894                    OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
4895                    E.first)) {
4896       OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
4897           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
4898               CE->getFlags());
4899       switch (Flags) {
4900       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo: {
4901         if (Config.isEmbedded() && Config.hasRequiresUnifiedSharedMemory())
4902           continue;
4903         if (!CE->getAddress()) {
4904           ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
4905           continue;
4906         }
4907         // The vaiable has no definition - no need to add the entry.
4908         if (CE->getVarSize() == 0)
4909           continue;
4910         break;
4911       }
4912       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
4913         assert(((Config.isEmbedded() && !CE->getAddress()) ||
4914                 (!Config.isEmbedded() && CE->getAddress())) &&
4915                "Declaret target link address is set.");
4916         if (Config.isEmbedded())
4917           continue;
4918         if (!CE->getAddress()) {
4919           ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
4920           continue;
4921         }
4922         break;
4923       }
4924 
4925       // Hidden or internal symbols on the device are not externally visible.
4926       // We should not attempt to register them by creating an offloading
4927       // entry.
4928       if (auto *GV = dyn_cast<GlobalValue>(CE->getAddress()))
4929         if (GV->hasLocalLinkage() || GV->hasHiddenVisibility())
4930           continue;
4931 
4932       createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
4933                          Flags, CE->getLinkage());
4934 
4935     } else {
4936       llvm_unreachable("Unsupported entry kind.");
4937     }
4938   }
4939 }
4940 
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,StringRef ParentName,unsigned DeviceID,unsigned FileID,unsigned Line,unsigned Count)4941 void TargetRegionEntryInfo::getTargetRegionEntryFnName(
4942     SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
4943     unsigned FileID, unsigned Line, unsigned Count) {
4944   raw_svector_ostream OS(Name);
4945   OS << "__omp_offloading" << llvm::format("_%x", DeviceID)
4946      << llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
4947   if (Count)
4948     OS << "_" << Count;
4949 }
4950 
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,const TargetRegionEntryInfo & EntryInfo)4951 void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
4952     SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
4953   unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
4954   TargetRegionEntryInfo::getTargetRegionEntryFnName(
4955       Name, EntryInfo.ParentName, EntryInfo.DeviceID, EntryInfo.FileID,
4956       EntryInfo.Line, NewCount);
4957 }
4958 
4959 /// Loads all the offload entries information from the host IR
4960 /// metadata.
loadOffloadInfoMetadata(Module & M,OffloadEntriesInfoManager & OffloadEntriesInfoManager)4961 void OpenMPIRBuilder::loadOffloadInfoMetadata(
4962     Module &M, OffloadEntriesInfoManager &OffloadEntriesInfoManager) {
4963   // If we are in target mode, load the metadata from the host IR. This code has
4964   // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
4965 
4966   NamedMDNode *MD = M.getNamedMetadata(ompOffloadInfoName);
4967   if (!MD)
4968     return;
4969 
4970   for (MDNode *MN : MD->operands()) {
4971     auto &&GetMDInt = [MN](unsigned Idx) {
4972       auto *V = cast<ConstantAsMetadata>(MN->getOperand(Idx));
4973       return cast<ConstantInt>(V->getValue())->getZExtValue();
4974     };
4975 
4976     auto &&GetMDString = [MN](unsigned Idx) {
4977       auto *V = cast<MDString>(MN->getOperand(Idx));
4978       return V->getString();
4979     };
4980 
4981     switch (GetMDInt(0)) {
4982     default:
4983       llvm_unreachable("Unexpected metadata!");
4984       break;
4985     case OffloadEntriesInfoManager::OffloadEntryInfo::
4986         OffloadingEntryInfoTargetRegion: {
4987       TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
4988                                       /*DeviceID=*/GetMDInt(1),
4989                                       /*FileID=*/GetMDInt(2),
4990                                       /*Line=*/GetMDInt(4),
4991                                       /*Count=*/GetMDInt(5));
4992       OffloadEntriesInfoManager.initializeTargetRegionEntryInfo(
4993           EntryInfo, /*Order=*/GetMDInt(6));
4994       break;
4995     }
4996     case OffloadEntriesInfoManager::OffloadEntryInfo::
4997         OffloadingEntryInfoDeviceGlobalVar:
4998       OffloadEntriesInfoManager.initializeDeviceGlobalVarEntryInfo(
4999           /*MangledName=*/GetMDString(1),
5000           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
5001               /*Flags=*/GetMDInt(2)),
5002           /*Order=*/GetMDInt(3));
5003       break;
5004     }
5005   }
5006 }
5007 
empty() const5008 bool OffloadEntriesInfoManager::empty() const {
5009   return OffloadEntriesTargetRegion.empty() &&
5010          OffloadEntriesDeviceGlobalVar.empty();
5011 }
5012 
getTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo) const5013 unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
5014     const TargetRegionEntryInfo &EntryInfo) const {
5015   auto It = OffloadEntriesTargetRegionCount.find(
5016       getTargetRegionEntryCountKey(EntryInfo));
5017   if (It == OffloadEntriesTargetRegionCount.end())
5018     return 0;
5019   return It->second;
5020 }
5021 
incrementTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo)5022 void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
5023     const TargetRegionEntryInfo &EntryInfo) {
5024   OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
5025       EntryInfo.Count + 1;
5026 }
5027 
5028 /// Initialize target region entry.
initializeTargetRegionEntryInfo(const TargetRegionEntryInfo & EntryInfo,unsigned Order)5029 void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
5030     const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
5031   OffloadEntriesTargetRegion[EntryInfo] =
5032       OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
5033                                    OMPTargetRegionEntryTargetRegion);
5034   ++OffloadingEntriesNum;
5035 }
5036 
registerTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,Constant * Addr,Constant * ID,OMPTargetRegionEntryKind Flags)5037 void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
5038     TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
5039     OMPTargetRegionEntryKind Flags) {
5040   assert(EntryInfo.Count == 0 && "expected default EntryInfo");
5041 
5042   // Update the EntryInfo with the next available count for this location.
5043   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
5044 
5045   // If we are emitting code for a target, the entry is already initialized,
5046   // only has to be registered.
5047   if (Config.isEmbedded()) {
5048     // This could happen if the device compilation is invoked standalone.
5049     if (!hasTargetRegionEntryInfo(EntryInfo)) {
5050       return;
5051     }
5052     auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
5053     Entry.setAddress(Addr);
5054     Entry.setID(ID);
5055     Entry.setFlags(Flags);
5056   } else {
5057     if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
5058         hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
5059       return;
5060     assert(!hasTargetRegionEntryInfo(EntryInfo) &&
5061            "Target region entry already registered!");
5062     OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
5063     OffloadEntriesTargetRegion[EntryInfo] = Entry;
5064     ++OffloadingEntriesNum;
5065   }
5066   incrementTargetRegionEntryInfoCount(EntryInfo);
5067 }
5068 
hasTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,bool IgnoreAddressId) const5069 bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
5070     TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
5071 
5072   // Update the EntryInfo with the next available count for this location.
5073   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
5074 
5075   auto It = OffloadEntriesTargetRegion.find(EntryInfo);
5076   if (It == OffloadEntriesTargetRegion.end()) {
5077     return false;
5078   }
5079   // Fail if this entry is already registered.
5080   if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
5081     return false;
5082   return true;
5083 }
5084 
actOnTargetRegionEntriesInfo(const OffloadTargetRegionEntryInfoActTy & Action)5085 void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
5086     const OffloadTargetRegionEntryInfoActTy &Action) {
5087   // Scan all target region entries and perform the provided action.
5088   for (const auto &It : OffloadEntriesTargetRegion) {
5089     Action(It.first, It.second);
5090   }
5091 }
5092 
initializeDeviceGlobalVarEntryInfo(StringRef Name,OMPTargetGlobalVarEntryKind Flags,unsigned Order)5093 void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
5094     StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
5095   OffloadEntriesDeviceGlobalVar.try_emplace(Name, Order, Flags);
5096   ++OffloadingEntriesNum;
5097 }
5098 
registerDeviceGlobalVarEntryInfo(StringRef VarName,Constant * Addr,int64_t VarSize,OMPTargetGlobalVarEntryKind Flags,GlobalValue::LinkageTypes Linkage)5099 void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
5100     StringRef VarName, Constant *Addr, int64_t VarSize,
5101     OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
5102   if (Config.isEmbedded()) {
5103     // This could happen if the device compilation is invoked standalone.
5104     if (!hasDeviceGlobalVarEntryInfo(VarName))
5105       return;
5106     auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
5107     if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
5108       if (Entry.getVarSize() == 0) {
5109         Entry.setVarSize(VarSize);
5110         Entry.setLinkage(Linkage);
5111       }
5112       return;
5113     }
5114     Entry.setVarSize(VarSize);
5115     Entry.setLinkage(Linkage);
5116     Entry.setAddress(Addr);
5117   } else {
5118     if (hasDeviceGlobalVarEntryInfo(VarName)) {
5119       auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
5120       assert(Entry.isValid() && Entry.getFlags() == Flags &&
5121              "Entry not initialized!");
5122       if (Entry.getVarSize() == 0) {
5123         Entry.setVarSize(VarSize);
5124         Entry.setLinkage(Linkage);
5125       }
5126       return;
5127     }
5128     OffloadEntriesDeviceGlobalVar.try_emplace(VarName, OffloadingEntriesNum,
5129                                               Addr, VarSize, Flags, Linkage);
5130     ++OffloadingEntriesNum;
5131   }
5132 }
5133 
actOnDeviceGlobalVarEntriesInfo(const OffloadDeviceGlobalVarEntryInfoActTy & Action)5134 void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
5135     const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
5136   // Scan all target region entries and perform the provided action.
5137   for (const auto &E : OffloadEntriesDeviceGlobalVar)
5138     Action(E.getKey(), E.getValue());
5139 }
5140 
collectControlBlocks(SmallVectorImpl<BasicBlock * > & BBs)5141 void CanonicalLoopInfo::collectControlBlocks(
5142     SmallVectorImpl<BasicBlock *> &BBs) {
5143   // We only count those BBs as control block for which we do not need to
5144   // reverse the CFG, i.e. not the loop body which can contain arbitrary control
5145   // flow. For consistency, this also means we do not add the Body block, which
5146   // is just the entry to the body code.
5147   BBs.reserve(BBs.size() + 6);
5148   BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
5149 }
5150 
getPreheader() const5151 BasicBlock *CanonicalLoopInfo::getPreheader() const {
5152   assert(isValid() && "Requires a valid canonical loop");
5153   for (BasicBlock *Pred : predecessors(Header)) {
5154     if (Pred != Latch)
5155       return Pred;
5156   }
5157   llvm_unreachable("Missing preheader");
5158 }
5159 
setTripCount(Value * TripCount)5160 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
5161   assert(isValid() && "Requires a valid canonical loop");
5162 
5163   Instruction *CmpI = &getCond()->front();
5164   assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
5165   CmpI->setOperand(1, TripCount);
5166 
5167 #ifndef NDEBUG
5168   assertOK();
5169 #endif
5170 }
5171 
mapIndVar(llvm::function_ref<Value * (Instruction *)> Updater)5172 void CanonicalLoopInfo::mapIndVar(
5173     llvm::function_ref<Value *(Instruction *)> Updater) {
5174   assert(isValid() && "Requires a valid canonical loop");
5175 
5176   Instruction *OldIV = getIndVar();
5177 
5178   // Record all uses excluding those introduced by the updater. Uses by the
5179   // CanonicalLoopInfo itself to keep track of the number of iterations are
5180   // excluded.
5181   SmallVector<Use *> ReplacableUses;
5182   for (Use &U : OldIV->uses()) {
5183     auto *User = dyn_cast<Instruction>(U.getUser());
5184     if (!User)
5185       continue;
5186     if (User->getParent() == getCond())
5187       continue;
5188     if (User->getParent() == getLatch())
5189       continue;
5190     ReplacableUses.push_back(&U);
5191   }
5192 
5193   // Run the updater that may introduce new uses
5194   Value *NewIV = Updater(OldIV);
5195 
5196   // Replace the old uses with the value returned by the updater.
5197   for (Use *U : ReplacableUses)
5198     U->set(NewIV);
5199 
5200 #ifndef NDEBUG
5201   assertOK();
5202 #endif
5203 }
5204 
assertOK() const5205 void CanonicalLoopInfo::assertOK() const {
5206 #ifndef NDEBUG
5207   // No constraints if this object currently does not describe a loop.
5208   if (!isValid())
5209     return;
5210 
5211   BasicBlock *Preheader = getPreheader();
5212   BasicBlock *Body = getBody();
5213   BasicBlock *After = getAfter();
5214 
5215   // Verify standard control-flow we use for OpenMP loops.
5216   assert(Preheader);
5217   assert(isa<BranchInst>(Preheader->getTerminator()) &&
5218          "Preheader must terminate with unconditional branch");
5219   assert(Preheader->getSingleSuccessor() == Header &&
5220          "Preheader must jump to header");
5221 
5222   assert(Header);
5223   assert(isa<BranchInst>(Header->getTerminator()) &&
5224          "Header must terminate with unconditional branch");
5225   assert(Header->getSingleSuccessor() == Cond &&
5226          "Header must jump to exiting block");
5227 
5228   assert(Cond);
5229   assert(Cond->getSinglePredecessor() == Header &&
5230          "Exiting block only reachable from header");
5231 
5232   assert(isa<BranchInst>(Cond->getTerminator()) &&
5233          "Exiting block must terminate with conditional branch");
5234   assert(size(successors(Cond)) == 2 &&
5235          "Exiting block must have two successors");
5236   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
5237          "Exiting block's first successor jump to the body");
5238   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
5239          "Exiting block's second successor must exit the loop");
5240 
5241   assert(Body);
5242   assert(Body->getSinglePredecessor() == Cond &&
5243          "Body only reachable from exiting block");
5244   assert(!isa<PHINode>(Body->front()));
5245 
5246   assert(Latch);
5247   assert(isa<BranchInst>(Latch->getTerminator()) &&
5248          "Latch must terminate with unconditional branch");
5249   assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
5250   // TODO: To support simple redirecting of the end of the body code that has
5251   // multiple; introduce another auxiliary basic block like preheader and after.
5252   assert(Latch->getSinglePredecessor() != nullptr);
5253   assert(!isa<PHINode>(Latch->front()));
5254 
5255   assert(Exit);
5256   assert(isa<BranchInst>(Exit->getTerminator()) &&
5257          "Exit block must terminate with unconditional branch");
5258   assert(Exit->getSingleSuccessor() == After &&
5259          "Exit block must jump to after block");
5260 
5261   assert(After);
5262   assert(After->getSinglePredecessor() == Exit &&
5263          "After block only reachable from exit block");
5264   assert(After->empty() || !isa<PHINode>(After->front()));
5265 
5266   Instruction *IndVar = getIndVar();
5267   assert(IndVar && "Canonical induction variable not found?");
5268   assert(isa<IntegerType>(IndVar->getType()) &&
5269          "Induction variable must be an integer");
5270   assert(cast<PHINode>(IndVar)->getParent() == Header &&
5271          "Induction variable must be a PHI in the loop header");
5272   assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
5273   assert(
5274       cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
5275   assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
5276 
5277   auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
5278   assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
5279   assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
5280   assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
5281   assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
5282              ->isOne());
5283 
5284   Value *TripCount = getTripCount();
5285   assert(TripCount && "Loop trip count not found?");
5286   assert(IndVar->getType() == TripCount->getType() &&
5287          "Trip count and induction variable must have the same type");
5288 
5289   auto *CmpI = cast<CmpInst>(&Cond->front());
5290   assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
5291          "Exit condition must be a signed less-than comparison");
5292   assert(CmpI->getOperand(0) == IndVar &&
5293          "Exit condition must compare the induction variable");
5294   assert(CmpI->getOperand(1) == TripCount &&
5295          "Exit condition must compare with the trip count");
5296 #endif
5297 }
5298 
invalidate()5299 void CanonicalLoopInfo::invalidate() {
5300   Header = nullptr;
5301   Cond = nullptr;
5302   Latch = nullptr;
5303   Exit = nullptr;
5304 }
5305