1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/CodeMetrics.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/TargetLibraryInfo.h"
25 #include "llvm/Bitcode/BitcodeReader.h"
26 #include "llvm/Frontend/Offloading/Utility.h"
27 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
28 #include "llvm/IR/Attributes.h"
29 #include "llvm/IR/BasicBlock.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/CallingConv.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/GlobalVariable.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/LLVMContext.h"
40 #include "llvm/IR/MDBuilder.h"
41 #include "llvm/IR/Metadata.h"
42 #include "llvm/IR/PassManager.h"
43 #include "llvm/IR/Value.h"
44 #include "llvm/MC/TargetRegistry.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/ErrorHandling.h"
47 #include "llvm/Support/FileSystem.h"
48 #include "llvm/Target/TargetMachine.h"
49 #include "llvm/Target/TargetOptions.h"
50 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51 #include "llvm/Transforms/Utils/Cloning.h"
52 #include "llvm/Transforms/Utils/CodeExtractor.h"
53 #include "llvm/Transforms/Utils/LoopPeel.h"
54 #include "llvm/Transforms/Utils/UnrollLoop.h"
55 
56 #include <cstdint>
57 #include <optional>
58 
59 #define DEBUG_TYPE "openmp-ir-builder"
60 
61 using namespace llvm;
62 using namespace omp;
63 
64 static cl::opt<bool>
65     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
66                          cl::desc("Use optimistic attributes describing "
67                                   "'as-if' properties of runtime calls."),
68                          cl::init(false));
69 
70 static cl::opt<double> UnrollThresholdFactor(
71     "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
72     cl::desc("Factor for the unroll threshold to account for code "
73              "simplifications still taking place"),
74     cl::init(1.5));
75 
76 #ifndef NDEBUG
77 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
78 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
79 /// an InsertPoint stores the instruction before something is inserted. For
80 /// instance, if both point to the same instruction, two IRBuilders alternating
81 /// creating instruction will cause the instructions to be interleaved.
isConflictIP(IRBuilder<>::InsertPoint IP1,IRBuilder<>::InsertPoint IP2)82 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
83                          IRBuilder<>::InsertPoint IP2) {
84   if (!IP1.isSet() || !IP2.isSet())
85     return false;
86   return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
87 }
88 
isValidWorkshareLoopScheduleType(OMPScheduleType SchedType)89 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
90   // Valid ordered/unordered and base algorithm combinations.
91   switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
92   case OMPScheduleType::UnorderedStaticChunked:
93   case OMPScheduleType::UnorderedStatic:
94   case OMPScheduleType::UnorderedDynamicChunked:
95   case OMPScheduleType::UnorderedGuidedChunked:
96   case OMPScheduleType::UnorderedRuntime:
97   case OMPScheduleType::UnorderedAuto:
98   case OMPScheduleType::UnorderedTrapezoidal:
99   case OMPScheduleType::UnorderedGreedy:
100   case OMPScheduleType::UnorderedBalanced:
101   case OMPScheduleType::UnorderedGuidedIterativeChunked:
102   case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
103   case OMPScheduleType::UnorderedSteal:
104   case OMPScheduleType::UnorderedStaticBalancedChunked:
105   case OMPScheduleType::UnorderedGuidedSimd:
106   case OMPScheduleType::UnorderedRuntimeSimd:
107   case OMPScheduleType::OrderedStaticChunked:
108   case OMPScheduleType::OrderedStatic:
109   case OMPScheduleType::OrderedDynamicChunked:
110   case OMPScheduleType::OrderedGuidedChunked:
111   case OMPScheduleType::OrderedRuntime:
112   case OMPScheduleType::OrderedAuto:
113   case OMPScheduleType::OrderdTrapezoidal:
114   case OMPScheduleType::NomergeUnorderedStaticChunked:
115   case OMPScheduleType::NomergeUnorderedStatic:
116   case OMPScheduleType::NomergeUnorderedDynamicChunked:
117   case OMPScheduleType::NomergeUnorderedGuidedChunked:
118   case OMPScheduleType::NomergeUnorderedRuntime:
119   case OMPScheduleType::NomergeUnorderedAuto:
120   case OMPScheduleType::NomergeUnorderedTrapezoidal:
121   case OMPScheduleType::NomergeUnorderedGreedy:
122   case OMPScheduleType::NomergeUnorderedBalanced:
123   case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
124   case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
125   case OMPScheduleType::NomergeUnorderedSteal:
126   case OMPScheduleType::NomergeOrderedStaticChunked:
127   case OMPScheduleType::NomergeOrderedStatic:
128   case OMPScheduleType::NomergeOrderedDynamicChunked:
129   case OMPScheduleType::NomergeOrderedGuidedChunked:
130   case OMPScheduleType::NomergeOrderedRuntime:
131   case OMPScheduleType::NomergeOrderedAuto:
132   case OMPScheduleType::NomergeOrderedTrapezoidal:
133     break;
134   default:
135     return false;
136   }
137 
138   // Must not set both monotonicity modifiers at the same time.
139   OMPScheduleType MonotonicityFlags =
140       SchedType & OMPScheduleType::MonotonicityMask;
141   if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
142     return false;
143 
144   return true;
145 }
146 #endif
147 
getGridValue(const Triple & T,Function * Kernel)148 static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
149   if (T.isAMDGPU()) {
150     StringRef Features =
151         Kernel->getFnAttribute("target-features").getValueAsString();
152     if (Features.count("+wavefrontsize64"))
153       return omp::getAMDGPUGridValues<64>();
154     return omp::getAMDGPUGridValues<32>();
155   }
156   if (T.isNVPTX())
157     return omp::NVPTXGridValues;
158   llvm_unreachable("No grid value available for this architecture!");
159 }
160 
161 /// Determine which scheduling algorithm to use, determined from schedule clause
162 /// arguments.
163 static OMPScheduleType
getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier)164 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
165                           bool HasSimdModifier) {
166   // Currently, the default schedule it static.
167   switch (ClauseKind) {
168   case OMP_SCHEDULE_Default:
169   case OMP_SCHEDULE_Static:
170     return HasChunks ? OMPScheduleType::BaseStaticChunked
171                      : OMPScheduleType::BaseStatic;
172   case OMP_SCHEDULE_Dynamic:
173     return OMPScheduleType::BaseDynamicChunked;
174   case OMP_SCHEDULE_Guided:
175     return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
176                            : OMPScheduleType::BaseGuidedChunked;
177   case OMP_SCHEDULE_Auto:
178     return llvm::omp::OMPScheduleType::BaseAuto;
179   case OMP_SCHEDULE_Runtime:
180     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
181                            : OMPScheduleType::BaseRuntime;
182   }
183   llvm_unreachable("unhandled schedule clause argument");
184 }
185 
186 /// Adds ordering modifier flags to schedule type.
187 static OMPScheduleType
getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,bool HasOrderedClause)188 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
189                               bool HasOrderedClause) {
190   assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
191              OMPScheduleType::None &&
192          "Must not have ordering nor monotonicity flags already set");
193 
194   OMPScheduleType OrderingModifier = HasOrderedClause
195                                          ? OMPScheduleType::ModifierOrdered
196                                          : OMPScheduleType::ModifierUnordered;
197   OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
198 
199   // Unsupported combinations
200   if (OrderingScheduleType ==
201       (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
202     return OMPScheduleType::OrderedGuidedChunked;
203   else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
204                                     OMPScheduleType::ModifierOrdered))
205     return OMPScheduleType::OrderedRuntime;
206 
207   return OrderingScheduleType;
208 }
209 
210 /// Adds monotonicity modifier flags to schedule type.
211 static OMPScheduleType
getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,bool HasSimdModifier,bool HasMonotonic,bool HasNonmonotonic,bool HasOrderedClause)212 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
213                                   bool HasSimdModifier, bool HasMonotonic,
214                                   bool HasNonmonotonic, bool HasOrderedClause) {
215   assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
216              OMPScheduleType::None &&
217          "Must not have monotonicity flags already set");
218   assert((!HasMonotonic || !HasNonmonotonic) &&
219          "Monotonic and Nonmonotonic are contradicting each other");
220 
221   if (HasMonotonic) {
222     return ScheduleType | OMPScheduleType::ModifierMonotonic;
223   } else if (HasNonmonotonic) {
224     return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
225   } else {
226     // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
227     // If the static schedule kind is specified or if the ordered clause is
228     // specified, and if the nonmonotonic modifier is not specified, the
229     // effect is as if the monotonic modifier is specified. Otherwise, unless
230     // the monotonic modifier is specified, the effect is as if the
231     // nonmonotonic modifier is specified.
232     OMPScheduleType BaseScheduleType =
233         ScheduleType & ~OMPScheduleType::ModifierMask;
234     if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
235         (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
236         HasOrderedClause) {
237       // The monotonic is used by default in openmp runtime library, so no need
238       // to set it.
239       return ScheduleType;
240     } else {
241       return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
242     }
243   }
244 }
245 
246 /// Determine the schedule type using schedule and ordering clause arguments.
247 static OMPScheduleType
computeOpenMPScheduleType(ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)248 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
249                           bool HasSimdModifier, bool HasMonotonicModifier,
250                           bool HasNonmonotonicModifier, bool HasOrderedClause) {
251   OMPScheduleType BaseSchedule =
252       getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
253   OMPScheduleType OrderedSchedule =
254       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
255   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
256       OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
257       HasNonmonotonicModifier, HasOrderedClause);
258 
259   assert(isValidWorkshareLoopScheduleType(Result));
260   return Result;
261 }
262 
263 /// Make \p Source branch to \p Target.
264 ///
265 /// Handles two situations:
266 /// * \p Source already has an unconditional branch.
267 /// * \p Source is a degenerate block (no terminator because the BB is
268 ///             the current head of the IR construction).
redirectTo(BasicBlock * Source,BasicBlock * Target,DebugLoc DL)269 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
270   if (Instruction *Term = Source->getTerminator()) {
271     auto *Br = cast<BranchInst>(Term);
272     assert(!Br->isConditional() &&
273            "BB's terminator must be an unconditional branch (or degenerate)");
274     BasicBlock *Succ = Br->getSuccessor(0);
275     Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
276     Br->setSuccessor(0, Target);
277     return;
278   }
279 
280   auto *NewBr = BranchInst::Create(Target, Source);
281   NewBr->setDebugLoc(DL);
282 }
283 
spliceBB(IRBuilderBase::InsertPoint IP,BasicBlock * New,bool CreateBranch)284 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
285                     bool CreateBranch) {
286   assert(New->getFirstInsertionPt() == New->begin() &&
287          "Target BB must not have PHI nodes");
288 
289   // Move instructions to new block.
290   BasicBlock *Old = IP.getBlock();
291   New->splice(New->begin(), Old, IP.getPoint(), Old->end());
292 
293   if (CreateBranch)
294     BranchInst::Create(New, Old);
295 }
296 
spliceBB(IRBuilder<> & Builder,BasicBlock * New,bool CreateBranch)297 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
298   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
299   BasicBlock *Old = Builder.GetInsertBlock();
300 
301   spliceBB(Builder.saveIP(), New, CreateBranch);
302   if (CreateBranch)
303     Builder.SetInsertPoint(Old->getTerminator());
304   else
305     Builder.SetInsertPoint(Old);
306 
307   // SetInsertPoint also updates the Builder's debug location, but we want to
308   // keep the one the Builder was configured to use.
309   Builder.SetCurrentDebugLocation(DebugLoc);
310 }
311 
splitBB(IRBuilderBase::InsertPoint IP,bool CreateBranch,llvm::Twine Name)312 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
313                           llvm::Twine Name) {
314   BasicBlock *Old = IP.getBlock();
315   BasicBlock *New = BasicBlock::Create(
316       Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
317       Old->getParent(), Old->getNextNode());
318   spliceBB(IP, New, CreateBranch);
319   New->replaceSuccessorsPhiUsesWith(Old, New);
320   return New;
321 }
322 
splitBB(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Name)323 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
324                           llvm::Twine Name) {
325   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
326   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
327   if (CreateBranch)
328     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
329   else
330     Builder.SetInsertPoint(Builder.GetInsertBlock());
331   // SetInsertPoint also updates the Builder's debug location, but we want to
332   // keep the one the Builder was configured to use.
333   Builder.SetCurrentDebugLocation(DebugLoc);
334   return New;
335 }
336 
splitBB(IRBuilder<> & Builder,bool CreateBranch,llvm::Twine Name)337 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
338                           llvm::Twine Name) {
339   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
340   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
341   if (CreateBranch)
342     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
343   else
344     Builder.SetInsertPoint(Builder.GetInsertBlock());
345   // SetInsertPoint also updates the Builder's debug location, but we want to
346   // keep the one the Builder was configured to use.
347   Builder.SetCurrentDebugLocation(DebugLoc);
348   return New;
349 }
350 
splitBBWithSuffix(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Suffix)351 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
352                                     llvm::Twine Suffix) {
353   BasicBlock *Old = Builder.GetInsertBlock();
354   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
355 }
356 
357 // This function creates a fake integer value and a fake use for the integer
358 // value. It returns the fake value created. This is useful in modeling the
359 // extra arguments to the outlined functions.
createFakeIntVal(IRBuilder<> & Builder,OpenMPIRBuilder::InsertPointTy OuterAllocaIP,std::stack<Instruction * > & ToBeDeleted,OpenMPIRBuilder::InsertPointTy InnerAllocaIP,const Twine & Name="",bool AsPtr=true)360 Value *createFakeIntVal(IRBuilder<> &Builder,
361                         OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
362                         std::stack<Instruction *> &ToBeDeleted,
363                         OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
364                         const Twine &Name = "", bool AsPtr = true) {
365   Builder.restoreIP(OuterAllocaIP);
366   Instruction *FakeVal;
367   AllocaInst *FakeValAddr =
368       Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
369   ToBeDeleted.push(FakeValAddr);
370 
371   if (AsPtr) {
372     FakeVal = FakeValAddr;
373   } else {
374     FakeVal =
375         Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
376     ToBeDeleted.push(FakeVal);
377   }
378 
379   // Generate a fake use of this value
380   Builder.restoreIP(InnerAllocaIP);
381   Instruction *UseFakeVal;
382   if (AsPtr) {
383     UseFakeVal =
384         Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
385   } else {
386     UseFakeVal =
387         cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
388   }
389   ToBeDeleted.push(UseFakeVal);
390   return FakeVal;
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // OpenMPIRBuilderConfig
395 //===----------------------------------------------------------------------===//
396 
397 namespace {
398 LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
399 /// Values for bit flags for marking which requires clauses have been used.
400 enum OpenMPOffloadingRequiresDirFlags {
401   /// flag undefined.
402   OMP_REQ_UNDEFINED = 0x000,
403   /// no requires directive present.
404   OMP_REQ_NONE = 0x001,
405   /// reverse_offload clause.
406   OMP_REQ_REVERSE_OFFLOAD = 0x002,
407   /// unified_address clause.
408   OMP_REQ_UNIFIED_ADDRESS = 0x004,
409   /// unified_shared_memory clause.
410   OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
411   /// dynamic_allocators clause.
412   OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
413   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
414 };
415 
416 } // anonymous namespace
417 
OpenMPIRBuilderConfig()418 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
419     : RequiresFlags(OMP_REQ_UNDEFINED) {}
420 
OpenMPIRBuilderConfig(bool IsTargetDevice,bool IsGPU,bool OpenMPOffloadMandatory,bool HasRequiresReverseOffload,bool HasRequiresUnifiedAddress,bool HasRequiresUnifiedSharedMemory,bool HasRequiresDynamicAllocators)421 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
422     bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
423     bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
424     bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
425     : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
426       OpenMPOffloadMandatory(OpenMPOffloadMandatory),
427       RequiresFlags(OMP_REQ_UNDEFINED) {
428   if (HasRequiresReverseOffload)
429     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
430   if (HasRequiresUnifiedAddress)
431     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
432   if (HasRequiresUnifiedSharedMemory)
433     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
434   if (HasRequiresDynamicAllocators)
435     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
436 }
437 
hasRequiresReverseOffload() const438 bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
439   return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
440 }
441 
hasRequiresUnifiedAddress() const442 bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
443   return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
444 }
445 
hasRequiresUnifiedSharedMemory() const446 bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
447   return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
448 }
449 
hasRequiresDynamicAllocators() const450 bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
451   return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
452 }
453 
getRequiresFlags() const454 int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
455   return hasRequiresFlags() ? RequiresFlags
456                             : static_cast<int64_t>(OMP_REQ_NONE);
457 }
458 
setHasRequiresReverseOffload(bool Value)459 void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
460   if (Value)
461     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
462   else
463     RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
464 }
465 
setHasRequiresUnifiedAddress(bool Value)466 void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
467   if (Value)
468     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
469   else
470     RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
471 }
472 
setHasRequiresUnifiedSharedMemory(bool Value)473 void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
474   if (Value)
475     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
476   else
477     RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
478 }
479 
setHasRequiresDynamicAllocators(bool Value)480 void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
481   if (Value)
482     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
483   else
484     RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // OpenMPIRBuilder
489 //===----------------------------------------------------------------------===//
490 
getKernelArgsVector(TargetKernelArgs & KernelArgs,IRBuilderBase & Builder,SmallVector<Value * > & ArgsVector)491 void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
492                                           IRBuilderBase &Builder,
493                                           SmallVector<Value *> &ArgsVector) {
494   Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
495   Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
496   auto Int32Ty = Type::getInt32Ty(Builder.getContext());
497   Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, 3));
498   Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
499 
500   Value *NumTeams3D =
501       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams, {0});
502   Value *NumThreads3D =
503       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
504 
505   ArgsVector = {Version,
506                 PointerNum,
507                 KernelArgs.RTArgs.BasePointersArray,
508                 KernelArgs.RTArgs.PointersArray,
509                 KernelArgs.RTArgs.SizesArray,
510                 KernelArgs.RTArgs.MapTypesArray,
511                 KernelArgs.RTArgs.MapNamesArray,
512                 KernelArgs.RTArgs.MappersArray,
513                 KernelArgs.NumIterations,
514                 Flags,
515                 NumTeams3D,
516                 NumThreads3D,
517                 KernelArgs.DynCGGroupMem};
518 }
519 
addAttributes(omp::RuntimeFunction FnID,Function & Fn)520 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
521   LLVMContext &Ctx = Fn.getContext();
522 
523   // Get the function's current attributes.
524   auto Attrs = Fn.getAttributes();
525   auto FnAttrs = Attrs.getFnAttrs();
526   auto RetAttrs = Attrs.getRetAttrs();
527   SmallVector<AttributeSet, 4> ArgAttrs;
528   for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
529     ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
530 
531   // Add AS to FnAS while taking special care with integer extensions.
532   auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
533                         bool Param = true) -> void {
534     bool HasSignExt = AS.hasAttribute(Attribute::SExt);
535     bool HasZeroExt = AS.hasAttribute(Attribute::ZExt);
536     if (HasSignExt || HasZeroExt) {
537       assert(AS.getNumAttributes() == 1 &&
538              "Currently not handling extension attr combined with others.");
539       if (Param) {
540         if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
541           FnAS = FnAS.addAttribute(Ctx, AK);
542       } else if (auto AK =
543                      TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
544         FnAS = FnAS.addAttribute(Ctx, AK);
545     } else {
546       FnAS = FnAS.addAttributes(Ctx, AS);
547     }
548   };
549 
550 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
551 #include "llvm/Frontend/OpenMP/OMPKinds.def"
552 
553   // Add attributes to the function declaration.
554   switch (FnID) {
555 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
556   case Enum:                                                                   \
557     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
558     addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false);                         \
559     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
560       addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]);                         \
561     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
562     break;
563 #include "llvm/Frontend/OpenMP/OMPKinds.def"
564   default:
565     // Attributes are optional.
566     break;
567   }
568 }
569 
570 FunctionCallee
getOrCreateRuntimeFunction(Module & M,RuntimeFunction FnID)571 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
572   FunctionType *FnTy = nullptr;
573   Function *Fn = nullptr;
574 
575   // Try to find the declation in the module first.
576   switch (FnID) {
577 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...)                          \
578   case Enum:                                                                   \
579     FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__},        \
580                              IsVarArg);                                        \
581     Fn = M.getFunction(Str);                                                   \
582     break;
583 #include "llvm/Frontend/OpenMP/OMPKinds.def"
584   }
585 
586   if (!Fn) {
587     // Create a new declaration if we need one.
588     switch (FnID) {
589 #define OMP_RTL(Enum, Str, ...)                                                \
590   case Enum:                                                                   \
591     Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M);         \
592     break;
593 #include "llvm/Frontend/OpenMP/OMPKinds.def"
594     }
595 
596     // Add information if the runtime function takes a callback function
597     if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
598       if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
599         LLVMContext &Ctx = Fn->getContext();
600         MDBuilder MDB(Ctx);
601         // Annotate the callback behavior of the runtime function:
602         //  - The callback callee is argument number 2 (microtask).
603         //  - The first two arguments of the callback callee are unknown (-1).
604         //  - All variadic arguments to the runtime function are passed to the
605         //    callback callee.
606         Fn->addMetadata(
607             LLVMContext::MD_callback,
608             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
609                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
610       }
611     }
612 
613     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
614                       << " with type " << *Fn->getFunctionType() << "\n");
615     addAttributes(FnID, *Fn);
616 
617   } else {
618     LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
619                       << " with type " << *Fn->getFunctionType() << "\n");
620   }
621 
622   assert(Fn && "Failed to create OpenMP runtime function");
623 
624   return {FnTy, Fn};
625 }
626 
getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID)627 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
628   FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
629   auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
630   assert(Fn && "Failed to create OpenMP runtime function pointer");
631   return Fn;
632 }
633 
initialize()634 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
635 
finalize(Function * Fn)636 void OpenMPIRBuilder::finalize(Function *Fn) {
637   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
638   SmallVector<BasicBlock *, 32> Blocks;
639   SmallVector<OutlineInfo, 16> DeferredOutlines;
640   for (OutlineInfo &OI : OutlineInfos) {
641     // Skip functions that have not finalized yet; may happen with nested
642     // function generation.
643     if (Fn && OI.getFunction() != Fn) {
644       DeferredOutlines.push_back(OI);
645       continue;
646     }
647 
648     ParallelRegionBlockSet.clear();
649     Blocks.clear();
650     OI.collectBlocks(ParallelRegionBlockSet, Blocks);
651 
652     Function *OuterFn = OI.getFunction();
653     CodeExtractorAnalysisCache CEAC(*OuterFn);
654     // If we generate code for the target device, we need to allocate
655     // struct for aggregate params in the device default alloca address space.
656     // OpenMP runtime requires that the params of the extracted functions are
657     // passed as zero address space pointers. This flag ensures that
658     // CodeExtractor generates correct code for extracted functions
659     // which are used by OpenMP runtime.
660     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
661     CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
662                             /* AggregateArgs */ true,
663                             /* BlockFrequencyInfo */ nullptr,
664                             /* BranchProbabilityInfo */ nullptr,
665                             /* AssumptionCache */ nullptr,
666                             /* AllowVarArgs */ true,
667                             /* AllowAlloca */ true,
668                             /* AllocaBlock*/ OI.OuterAllocaBB,
669                             /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
670 
671     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
672     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
673                       << " Exit: " << OI.ExitBB->getName() << "\n");
674     assert(Extractor.isEligible() &&
675            "Expected OpenMP outlining to be possible!");
676 
677     for (auto *V : OI.ExcludeArgsFromAggregate)
678       Extractor.excludeArgFromAggregate(V);
679 
680     Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
681 
682     LLVM_DEBUG(dbgs() << "After      outlining: " << *OuterFn << "\n");
683     LLVM_DEBUG(dbgs() << "   Outlined function: " << *OutlinedFn << "\n");
684     assert(OutlinedFn->getReturnType()->isVoidTy() &&
685            "OpenMP outlined functions should not return a value!");
686 
687     // For compability with the clang CG we move the outlined function after the
688     // one with the parallel region.
689     OutlinedFn->removeFromParent();
690     M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
691 
692     // Remove the artificial entry introduced by the extractor right away, we
693     // made our own entry block after all.
694     {
695       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
696       assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
697       assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
698       // Move instructions from the to-be-deleted ArtificialEntry to the entry
699       // basic block of the parallel region. CodeExtractor generates
700       // instructions to unwrap the aggregate argument and may sink
701       // allocas/bitcasts for values that are solely used in the outlined region
702       // and do not escape.
703       assert(!ArtificialEntry.empty() &&
704              "Expected instructions to add in the outlined region entry");
705       for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
706                                         End = ArtificialEntry.rend();
707            It != End;) {
708         Instruction &I = *It;
709         It++;
710 
711         if (I.isTerminator())
712           continue;
713 
714         I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
715       }
716 
717       OI.EntryBB->moveBefore(&ArtificialEntry);
718       ArtificialEntry.eraseFromParent();
719     }
720     assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
721     assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
722 
723     // Run a user callback, e.g. to add attributes.
724     if (OI.PostOutlineCB)
725       OI.PostOutlineCB(*OutlinedFn);
726   }
727 
728   // Remove work items that have been completed.
729   OutlineInfos = std::move(DeferredOutlines);
730 
731   EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
732       [](EmitMetadataErrorKind Kind,
733          const TargetRegionEntryInfo &EntryInfo) -> void {
734     errs() << "Error of kind: " << Kind
735            << " when emitting offload entries and metadata during "
736               "OMPIRBuilder finalization \n";
737   };
738 
739   if (!OffloadInfoManager.empty())
740     createOffloadEntriesAndInfoMetadata(ErrorReportFn);
741 }
742 
~OpenMPIRBuilder()743 OpenMPIRBuilder::~OpenMPIRBuilder() {
744   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
745 }
746 
createGlobalFlag(unsigned Value,StringRef Name)747 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
748   IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
749   auto *GV =
750       new GlobalVariable(M, I32Ty,
751                          /* isConstant = */ true, GlobalValue::WeakODRLinkage,
752                          ConstantInt::get(I32Ty, Value), Name);
753   GV->setVisibility(GlobalValue::HiddenVisibility);
754 
755   return GV;
756 }
757 
getOrCreateIdent(Constant * SrcLocStr,uint32_t SrcLocStrSize,IdentFlag LocFlags,unsigned Reserve2Flags)758 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
759                                             uint32_t SrcLocStrSize,
760                                             IdentFlag LocFlags,
761                                             unsigned Reserve2Flags) {
762   // Enable "C-mode".
763   LocFlags |= OMP_IDENT_FLAG_KMPC;
764 
765   Constant *&Ident =
766       IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
767   if (!Ident) {
768     Constant *I32Null = ConstantInt::getNullValue(Int32);
769     Constant *IdentData[] = {I32Null,
770                              ConstantInt::get(Int32, uint32_t(LocFlags)),
771                              ConstantInt::get(Int32, Reserve2Flags),
772                              ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
773     Constant *Initializer =
774         ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
775 
776     // Look for existing encoding of the location + flags, not needed but
777     // minimizes the difference to the existing solution while we transition.
778     for (GlobalVariable &GV : M.globals())
779       if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
780         if (GV.getInitializer() == Initializer)
781           Ident = &GV;
782 
783     if (!Ident) {
784       auto *GV = new GlobalVariable(
785           M, OpenMPIRBuilder::Ident,
786           /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
787           nullptr, GlobalValue::NotThreadLocal,
788           M.getDataLayout().getDefaultGlobalsAddressSpace());
789       GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
790       GV->setAlignment(Align(8));
791       Ident = GV;
792     }
793   }
794 
795   return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
796 }
797 
getOrCreateSrcLocStr(StringRef LocStr,uint32_t & SrcLocStrSize)798 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
799                                                 uint32_t &SrcLocStrSize) {
800   SrcLocStrSize = LocStr.size();
801   Constant *&SrcLocStr = SrcLocStrMap[LocStr];
802   if (!SrcLocStr) {
803     Constant *Initializer =
804         ConstantDataArray::getString(M.getContext(), LocStr);
805 
806     // Look for existing encoding of the location, not needed but minimizes the
807     // difference to the existing solution while we transition.
808     for (GlobalVariable &GV : M.globals())
809       if (GV.isConstant() && GV.hasInitializer() &&
810           GV.getInitializer() == Initializer)
811         return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
812 
813     SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
814                                               /* AddressSpace */ 0, &M);
815   }
816   return SrcLocStr;
817 }
818 
getOrCreateSrcLocStr(StringRef FunctionName,StringRef FileName,unsigned Line,unsigned Column,uint32_t & SrcLocStrSize)819 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
820                                                 StringRef FileName,
821                                                 unsigned Line, unsigned Column,
822                                                 uint32_t &SrcLocStrSize) {
823   SmallString<128> Buffer;
824   Buffer.push_back(';');
825   Buffer.append(FileName);
826   Buffer.push_back(';');
827   Buffer.append(FunctionName);
828   Buffer.push_back(';');
829   Buffer.append(std::to_string(Line));
830   Buffer.push_back(';');
831   Buffer.append(std::to_string(Column));
832   Buffer.push_back(';');
833   Buffer.push_back(';');
834   return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
835 }
836 
837 Constant *
getOrCreateDefaultSrcLocStr(uint32_t & SrcLocStrSize)838 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
839   StringRef UnknownLoc = ";unknown;unknown;0;0;;";
840   return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
841 }
842 
getOrCreateSrcLocStr(DebugLoc DL,uint32_t & SrcLocStrSize,Function * F)843 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
844                                                 uint32_t &SrcLocStrSize,
845                                                 Function *F) {
846   DILocation *DIL = DL.get();
847   if (!DIL)
848     return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
849   StringRef FileName = M.getName();
850   if (DIFile *DIF = DIL->getFile())
851     if (std::optional<StringRef> Source = DIF->getSource())
852       FileName = *Source;
853   StringRef Function = DIL->getScope()->getSubprogram()->getName();
854   if (Function.empty() && F)
855     Function = F->getName();
856   return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
857                               DIL->getColumn(), SrcLocStrSize);
858 }
859 
getOrCreateSrcLocStr(const LocationDescription & Loc,uint32_t & SrcLocStrSize)860 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
861                                                 uint32_t &SrcLocStrSize) {
862   return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
863                               Loc.IP.getBlock()->getParent());
864 }
865 
getOrCreateThreadID(Value * Ident)866 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
867   return Builder.CreateCall(
868       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
869       "omp_global_thread_num");
870 }
871 
872 OpenMPIRBuilder::InsertPointTy
createBarrier(const LocationDescription & Loc,Directive DK,bool ForceSimpleCall,bool CheckCancelFlag)873 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive DK,
874                                bool ForceSimpleCall, bool CheckCancelFlag) {
875   if (!updateToLocation(Loc))
876     return Loc.IP;
877   return emitBarrierImpl(Loc, DK, ForceSimpleCall, CheckCancelFlag);
878 }
879 
880 OpenMPIRBuilder::InsertPointTy
emitBarrierImpl(const LocationDescription & Loc,Directive Kind,bool ForceSimpleCall,bool CheckCancelFlag)881 OpenMPIRBuilder::emitBarrierImpl(const LocationDescription &Loc, Directive Kind,
882                                  bool ForceSimpleCall, bool CheckCancelFlag) {
883   // Build call __kmpc_cancel_barrier(loc, thread_id) or
884   //            __kmpc_barrier(loc, thread_id);
885 
886   IdentFlag BarrierLocFlags;
887   switch (Kind) {
888   case OMPD_for:
889     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
890     break;
891   case OMPD_sections:
892     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
893     break;
894   case OMPD_single:
895     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
896     break;
897   case OMPD_barrier:
898     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
899     break;
900   default:
901     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
902     break;
903   }
904 
905   uint32_t SrcLocStrSize;
906   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
907   Value *Args[] = {
908       getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
909       getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
910 
911   // If we are in a cancellable parallel region, barriers are cancellation
912   // points.
913   // TODO: Check why we would force simple calls or to ignore the cancel flag.
914   bool UseCancelBarrier =
915       !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
916 
917   Value *Result =
918       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
919                              UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
920                                               : OMPRTL___kmpc_barrier),
921                          Args);
922 
923   if (UseCancelBarrier && CheckCancelFlag)
924     emitCancelationCheckImpl(Result, OMPD_parallel);
925 
926   return Builder.saveIP();
927 }
928 
929 OpenMPIRBuilder::InsertPointTy
createCancel(const LocationDescription & Loc,Value * IfCondition,omp::Directive CanceledDirective)930 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
931                               Value *IfCondition,
932                               omp::Directive CanceledDirective) {
933   if (!updateToLocation(Loc))
934     return Loc.IP;
935 
936   // LLVM utilities like blocks with terminators.
937   auto *UI = Builder.CreateUnreachable();
938 
939   Instruction *ThenTI = UI, *ElseTI = nullptr;
940   if (IfCondition)
941     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
942   Builder.SetInsertPoint(ThenTI);
943 
944   Value *CancelKind = nullptr;
945   switch (CanceledDirective) {
946 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
947   case DirectiveEnum:                                                          \
948     CancelKind = Builder.getInt32(Value);                                      \
949     break;
950 #include "llvm/Frontend/OpenMP/OMPKinds.def"
951   default:
952     llvm_unreachable("Unknown cancel kind!");
953   }
954 
955   uint32_t SrcLocStrSize;
956   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
957   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
958   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
959   Value *Result = Builder.CreateCall(
960       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
961   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
962     if (CanceledDirective == OMPD_parallel) {
963       IRBuilder<>::InsertPointGuard IPG(Builder);
964       Builder.restoreIP(IP);
965       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
966                     omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
967                     /* CheckCancelFlag */ false);
968     }
969   };
970 
971   // The actual cancel logic is shared with others, e.g., cancel_barriers.
972   emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
973 
974   // Update the insertion point and remove the terminator we introduced.
975   Builder.SetInsertPoint(UI->getParent());
976   UI->eraseFromParent();
977 
978   return Builder.saveIP();
979 }
980 
emitTargetKernel(const LocationDescription & Loc,InsertPointTy AllocaIP,Value * & Return,Value * Ident,Value * DeviceID,Value * NumTeams,Value * NumThreads,Value * HostPtr,ArrayRef<Value * > KernelArgs)981 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
982     const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
983     Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
984     Value *HostPtr, ArrayRef<Value *> KernelArgs) {
985   if (!updateToLocation(Loc))
986     return Loc.IP;
987 
988   Builder.restoreIP(AllocaIP);
989   auto *KernelArgsPtr =
990       Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
991   Builder.restoreIP(Loc.IP);
992 
993   for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
994     llvm::Value *Arg =
995         Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
996     Builder.CreateAlignedStore(
997         KernelArgs[I], Arg,
998         M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
999   }
1000 
1001   SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
1002                                       NumThreads, HostPtr,  KernelArgsPtr};
1003 
1004   Return = Builder.CreateCall(
1005       getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
1006       OffloadingArgs);
1007 
1008   return Builder.saveIP();
1009 }
1010 
emitKernelLaunch(const LocationDescription & Loc,Function * OutlinedFn,Value * OutlinedFnID,EmitFallbackCallbackTy emitTargetCallFallbackCB,TargetKernelArgs & Args,Value * DeviceID,Value * RTLoc,InsertPointTy AllocaIP)1011 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1012     const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
1013     EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
1014     Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1015 
1016   if (!updateToLocation(Loc))
1017     return Loc.IP;
1018 
1019   Builder.restoreIP(Loc.IP);
1020   // On top of the arrays that were filled up, the target offloading call
1021   // takes as arguments the device id as well as the host pointer. The host
1022   // pointer is used by the runtime library to identify the current target
1023   // region, so it only has to be unique and not necessarily point to
1024   // anything. It could be the pointer to the outlined function that
1025   // implements the target region, but we aren't using that so that the
1026   // compiler doesn't need to keep that, and could therefore inline the host
1027   // function if proven worthwhile during optimization.
1028 
1029   // From this point on, we need to have an ID of the target region defined.
1030   assert(OutlinedFnID && "Invalid outlined function ID!");
1031   (void)OutlinedFnID;
1032 
1033   // Return value of the runtime offloading call.
1034   Value *Return = nullptr;
1035 
1036   // Arguments for the target kernel.
1037   SmallVector<Value *> ArgsVector;
1038   getKernelArgsVector(Args, Builder, ArgsVector);
1039 
1040   // The target region is an outlined function launched by the runtime
1041   // via calls to __tgt_target_kernel().
1042   //
1043   // Note that on the host and CPU targets, the runtime implementation of
1044   // these calls simply call the outlined function without forking threads.
1045   // The outlined functions themselves have runtime calls to
1046   // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1047   // the compiler in emitTeamsCall() and emitParallelCall().
1048   //
1049   // In contrast, on the NVPTX target, the implementation of
1050   // __tgt_target_teams() launches a GPU kernel with the requested number
1051   // of teams and threads so no additional calls to the runtime are required.
1052   // Check the error code and execute the host version if required.
1053   Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
1054                                      Args.NumTeams, Args.NumThreads,
1055                                      OutlinedFnID, ArgsVector));
1056 
1057   BasicBlock *OffloadFailedBlock =
1058       BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
1059   BasicBlock *OffloadContBlock =
1060       BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
1061   Value *Failed = Builder.CreateIsNotNull(Return);
1062   Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
1063 
1064   auto CurFn = Builder.GetInsertBlock()->getParent();
1065   emitBlock(OffloadFailedBlock, CurFn);
1066   Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
1067   emitBranch(OffloadContBlock);
1068   emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
1069   return Builder.saveIP();
1070 }
1071 
emitCancelationCheckImpl(Value * CancelFlag,omp::Directive CanceledDirective,FinalizeCallbackTy ExitCB)1072 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
1073                                                omp::Directive CanceledDirective,
1074                                                FinalizeCallbackTy ExitCB) {
1075   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1076          "Unexpected cancellation!");
1077 
1078   // For a cancel barrier we create two new blocks.
1079   BasicBlock *BB = Builder.GetInsertBlock();
1080   BasicBlock *NonCancellationBlock;
1081   if (Builder.GetInsertPoint() == BB->end()) {
1082     // TODO: This branch will not be needed once we moved to the
1083     // OpenMPIRBuilder codegen completely.
1084     NonCancellationBlock = BasicBlock::Create(
1085         BB->getContext(), BB->getName() + ".cont", BB->getParent());
1086   } else {
1087     NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
1088     BB->getTerminator()->eraseFromParent();
1089     Builder.SetInsertPoint(BB);
1090   }
1091   BasicBlock *CancellationBlock = BasicBlock::Create(
1092       BB->getContext(), BB->getName() + ".cncl", BB->getParent());
1093 
1094   // Jump to them based on the return value.
1095   Value *Cmp = Builder.CreateIsNull(CancelFlag);
1096   Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
1097                        /* TODO weight */ nullptr, nullptr);
1098 
1099   // From the cancellation block we finalize all variables and go to the
1100   // post finalization block that is known to the FiniCB callback.
1101   Builder.SetInsertPoint(CancellationBlock);
1102   if (ExitCB)
1103     ExitCB(Builder.saveIP());
1104   auto &FI = FinalizationStack.back();
1105   FI.FiniCB(Builder.saveIP());
1106 
1107   // The continuation block is where code generation continues.
1108   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
1109 }
1110 
1111 // Callback used to create OpenMP runtime calls to support
1112 // omp parallel clause for the device.
1113 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1114 // by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
targetParallelCallback(OpenMPIRBuilder * OMPIRBuilder,Function & OutlinedFn,Function * OuterFn,BasicBlock * OuterAllocaBB,Value * Ident,Value * IfCondition,Value * NumThreads,Instruction * PrivTID,AllocaInst * PrivTIDAddr,Value * ThreadID,const SmallVector<Instruction *,4> & ToBeDeleted)1115 static void targetParallelCallback(
1116     OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1117     BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1118     Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1119     Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1120   // Add some known attributes.
1121   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1122   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1123   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1124   OutlinedFn.addParamAttr(0, Attribute::NoUndef);
1125   OutlinedFn.addParamAttr(1, Attribute::NoUndef);
1126   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1127 
1128   assert(OutlinedFn.arg_size() >= 2 &&
1129          "Expected at least tid and bounded tid as arguments");
1130   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1131 
1132   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1133   assert(CI && "Expected call instruction to outlined function");
1134   CI->getParent()->setName("omp_parallel");
1135 
1136   Builder.SetInsertPoint(CI);
1137   Type *PtrTy = OMPIRBuilder->VoidPtr;
1138   Value *NullPtrValue = Constant::getNullValue(PtrTy);
1139 
1140   // Add alloca for kernel args
1141   OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1142   Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
1143   AllocaInst *ArgsAlloca =
1144       Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
1145   Value *Args = ArgsAlloca;
1146   // Add address space cast if array for storing arguments is not allocated
1147   // in address space 0
1148   if (ArgsAlloca->getAddressSpace())
1149     Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
1150   Builder.restoreIP(CurrentIP);
1151 
1152   // Store captured vars which are used by kmpc_parallel_51
1153   for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1154     Value *V = *(CI->arg_begin() + 2 + Idx);
1155     Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1156         ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
1157     Builder.CreateStore(V, StoreAddress);
1158   }
1159 
1160   Value *Cond =
1161       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
1162                   : Builder.getInt32(1);
1163 
1164   // Build kmpc_parallel_51 call
1165   Value *Parallel51CallArgs[] = {
1166       /* identifier*/ Ident,
1167       /* global thread num*/ ThreadID,
1168       /* if expression */ Cond,
1169       /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
1170       /* Proc bind */ Builder.getInt32(-1),
1171       /* outlined function */
1172       Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr),
1173       /* wrapper function */ NullPtrValue,
1174       /* arguments of the outlined funciton*/ Args,
1175       /* number of arguments */ Builder.getInt64(NumCapturedVars)};
1176 
1177   FunctionCallee RTLFn =
1178       OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
1179 
1180   Builder.CreateCall(RTLFn, Parallel51CallArgs);
1181 
1182   LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1183                     << *Builder.GetInsertBlock()->getParent() << "\n");
1184 
1185   // Initialize the local TID stack location with the argument value.
1186   Builder.SetInsertPoint(PrivTID);
1187   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1188   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1189                       PrivTIDAddr);
1190 
1191   // Remove redundant call to the outlined function.
1192   CI->eraseFromParent();
1193 
1194   for (Instruction *I : ToBeDeleted) {
1195     I->eraseFromParent();
1196   }
1197 }
1198 
1199 // Callback used to create OpenMP runtime calls to support
1200 // omp parallel clause for the host.
1201 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1202 // by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1203 static void
hostParallelCallback(OpenMPIRBuilder * OMPIRBuilder,Function & OutlinedFn,Function * OuterFn,Value * Ident,Value * IfCondition,Instruction * PrivTID,AllocaInst * PrivTIDAddr,const SmallVector<Instruction *,4> & ToBeDeleted)1204 hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1205                      Function *OuterFn, Value *Ident, Value *IfCondition,
1206                      Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1207                      const SmallVector<Instruction *, 4> &ToBeDeleted) {
1208   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1209   FunctionCallee RTLFn;
1210   if (IfCondition) {
1211     RTLFn =
1212         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1213   } else {
1214     RTLFn =
1215         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1216   }
1217   if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
1218     if (!F->hasMetadata(LLVMContext::MD_callback)) {
1219       LLVMContext &Ctx = F->getContext();
1220       MDBuilder MDB(Ctx);
1221       // Annotate the callback behavior of the __kmpc_fork_call:
1222       //  - The callback callee is argument number 2 (microtask).
1223       //  - The first two arguments of the callback callee are unknown (-1).
1224       //  - All variadic arguments to the __kmpc_fork_call are passed to the
1225       //    callback callee.
1226       F->addMetadata(LLVMContext::MD_callback,
1227                      *MDNode::get(Ctx, {MDB.createCallbackEncoding(
1228                                            2, {-1, -1},
1229                                            /* VarArgsArePassed */ true)}));
1230     }
1231   }
1232   // Add some known attributes.
1233   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1234   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1235   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1236 
1237   assert(OutlinedFn.arg_size() >= 2 &&
1238          "Expected at least tid and bounded tid as arguments");
1239   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1240 
1241   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1242   CI->getParent()->setName("omp_parallel");
1243   Builder.SetInsertPoint(CI);
1244 
1245   // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1246   Value *ForkCallArgs[] = {
1247       Ident, Builder.getInt32(NumCapturedVars),
1248       Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr)};
1249 
1250   SmallVector<Value *, 16> RealArgs;
1251   RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1252   if (IfCondition) {
1253     Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
1254     RealArgs.push_back(Cond);
1255   }
1256   RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1257 
1258   // __kmpc_fork_call_if always expects a void ptr as the last argument
1259   // If there are no arguments, pass a null pointer.
1260   auto PtrTy = OMPIRBuilder->VoidPtr;
1261   if (IfCondition && NumCapturedVars == 0) {
1262     Value *NullPtrValue = Constant::getNullValue(PtrTy);
1263     RealArgs.push_back(NullPtrValue);
1264   }
1265   if (IfCondition && RealArgs.back()->getType() != PtrTy)
1266     RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1267 
1268   Builder.CreateCall(RTLFn, RealArgs);
1269 
1270   LLVM_DEBUG(dbgs() << "With fork_call placed: "
1271                     << *Builder.GetInsertBlock()->getParent() << "\n");
1272 
1273   // Initialize the local TID stack location with the argument value.
1274   Builder.SetInsertPoint(PrivTID);
1275   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1276   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1277                       PrivTIDAddr);
1278 
1279   // Remove redundant call to the outlined function.
1280   CI->eraseFromParent();
1281 
1282   for (Instruction *I : ToBeDeleted) {
1283     I->eraseFromParent();
1284   }
1285 }
1286 
createParallel(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,Value * IfCondition,Value * NumThreads,omp::ProcBindKind ProcBind,bool IsCancellable)1287 IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1288     const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1289     BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1290     FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1291     omp::ProcBindKind ProcBind, bool IsCancellable) {
1292   assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1293 
1294   if (!updateToLocation(Loc))
1295     return Loc.IP;
1296 
1297   uint32_t SrcLocStrSize;
1298   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1299   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1300   Value *ThreadID = getOrCreateThreadID(Ident);
1301   // If we generate code for the target device, we need to allocate
1302   // struct for aggregate params in the device default alloca address space.
1303   // OpenMP runtime requires that the params of the extracted functions are
1304   // passed as zero address space pointers. This flag ensures that extracted
1305   // function arguments are declared in zero address space
1306   bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1307 
1308   // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1309   // only if we compile for host side.
1310   if (NumThreads && !Config.isTargetDevice()) {
1311     Value *Args[] = {
1312         Ident, ThreadID,
1313         Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
1314     Builder.CreateCall(
1315         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
1316   }
1317 
1318   if (ProcBind != OMP_PROC_BIND_default) {
1319     // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1320     Value *Args[] = {
1321         Ident, ThreadID,
1322         ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
1323     Builder.CreateCall(
1324         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
1325   }
1326 
1327   BasicBlock *InsertBB = Builder.GetInsertBlock();
1328   Function *OuterFn = InsertBB->getParent();
1329 
1330   // Save the outer alloca block because the insertion iterator may get
1331   // invalidated and we still need this later.
1332   BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1333 
1334   // Vector to remember instructions we used only during the modeling but which
1335   // we want to delete at the end.
1336   SmallVector<Instruction *, 4> ToBeDeleted;
1337 
1338   // Change the location to the outer alloca insertion point to create and
1339   // initialize the allocas we pass into the parallel region.
1340   Builder.restoreIP(OuterAllocaIP);
1341   AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1342   AllocaInst *ZeroAddrAlloca =
1343       Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1344   Instruction *TIDAddr = TIDAddrAlloca;
1345   Instruction *ZeroAddr = ZeroAddrAlloca;
1346   if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1347     // Add additional casts to enforce pointers in zero address space
1348     TIDAddr = new AddrSpaceCastInst(
1349         TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
1350     TIDAddr->insertAfter(TIDAddrAlloca);
1351     ToBeDeleted.push_back(TIDAddr);
1352     ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1353                                      PointerType ::get(M.getContext(), 0),
1354                                      "zero.addr.ascast");
1355     ZeroAddr->insertAfter(ZeroAddrAlloca);
1356     ToBeDeleted.push_back(ZeroAddr);
1357   }
1358 
1359   // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1360   // associated arguments in the outlined function, so we delete them later.
1361   ToBeDeleted.push_back(TIDAddrAlloca);
1362   ToBeDeleted.push_back(ZeroAddrAlloca);
1363 
1364   // Create an artificial insertion point that will also ensure the blocks we
1365   // are about to split are not degenerated.
1366   auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1367 
1368   BasicBlock *EntryBB = UI->getParent();
1369   BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
1370   BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
1371   BasicBlock *PRegPreFiniBB =
1372       PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
1373   BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
1374 
1375   auto FiniCBWrapper = [&](InsertPointTy IP) {
1376     // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1377     // target to the region exit block.
1378     if (IP.getBlock()->end() == IP.getPoint()) {
1379       IRBuilder<>::InsertPointGuard IPG(Builder);
1380       Builder.restoreIP(IP);
1381       Instruction *I = Builder.CreateBr(PRegExitBB);
1382       IP = InsertPointTy(I->getParent(), I->getIterator());
1383     }
1384     assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1385            IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1386            "Unexpected insertion point for finalization call!");
1387     return FiniCB(IP);
1388   };
1389 
1390   FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
1391 
1392   // Generate the privatization allocas in the block that will become the entry
1393   // of the outlined function.
1394   Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1395   InsertPointTy InnerAllocaIP = Builder.saveIP();
1396 
1397   AllocaInst *PrivTIDAddr =
1398       Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
1399   Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
1400 
1401   // Add some fake uses for OpenMP provided arguments.
1402   ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
1403   Instruction *ZeroAddrUse =
1404       Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
1405   ToBeDeleted.push_back(ZeroAddrUse);
1406 
1407   // EntryBB
1408   //   |
1409   //   V
1410   // PRegionEntryBB         <- Privatization allocas are placed here.
1411   //   |
1412   //   V
1413   // PRegionBodyBB          <- BodeGen is invoked here.
1414   //   |
1415   //   V
1416   // PRegPreFiniBB          <- The block we will start finalization from.
1417   //   |
1418   //   V
1419   // PRegionExitBB          <- A common exit to simplify block collection.
1420   //
1421 
1422   LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1423 
1424   // Let the caller create the body.
1425   assert(BodyGenCB && "Expected body generation callback!");
1426   InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1427   BodyGenCB(InnerAllocaIP, CodeGenIP);
1428 
1429   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
1430 
1431   OutlineInfo OI;
1432   if (Config.isTargetDevice()) {
1433     // Generate OpenMP target specific runtime call
1434     OI.PostOutlineCB = [=, ToBeDeletedVec =
1435                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1436       targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1437                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1438                              ThreadID, ToBeDeletedVec);
1439     };
1440   } else {
1441     // Generate OpenMP host runtime call
1442     OI.PostOutlineCB = [=, ToBeDeletedVec =
1443                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1444       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
1445                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
1446     };
1447   }
1448 
1449   // Adjust the finalization stack, verify the adjustment, and call the
1450   // finalize function a last time to finalize values between the pre-fini
1451   // block and the exit block if we left the parallel "the normal way".
1452   auto FiniInfo = FinalizationStack.pop_back_val();
1453   (void)FiniInfo;
1454   assert(FiniInfo.DK == OMPD_parallel &&
1455          "Unexpected finalization stack state!");
1456 
1457   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1458 
1459   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1460   FiniCB(PreFiniIP);
1461 
1462   OI.OuterAllocaBB = OuterAllocaBlock;
1463   OI.EntryBB = PRegEntryBB;
1464   OI.ExitBB = PRegExitBB;
1465 
1466   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1467   SmallVector<BasicBlock *, 32> Blocks;
1468   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1469 
1470   // Ensure a single exit node for the outlined region by creating one.
1471   // We might have multiple incoming edges to the exit now due to finalizations,
1472   // e.g., cancel calls that cause the control flow to leave the region.
1473   BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1474   PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
1475   PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1476   Blocks.push_back(PRegOutlinedExitBB);
1477 
1478   CodeExtractorAnalysisCache CEAC(*OuterFn);
1479   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1480                           /* AggregateArgs */ false,
1481                           /* BlockFrequencyInfo */ nullptr,
1482                           /* BranchProbabilityInfo */ nullptr,
1483                           /* AssumptionCache */ nullptr,
1484                           /* AllowVarArgs */ true,
1485                           /* AllowAlloca */ true,
1486                           /* AllocationBlock */ OuterAllocaBlock,
1487                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1488 
1489   // Find inputs to, outputs from the code region.
1490   BasicBlock *CommonExit = nullptr;
1491   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1492   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1493   Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
1494 
1495   LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1496 
1497   FunctionCallee TIDRTLFn =
1498       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1499 
1500   auto PrivHelper = [&](Value &V) {
1501     if (&V == TIDAddr || &V == ZeroAddr) {
1502       OI.ExcludeArgsFromAggregate.push_back(&V);
1503       return;
1504     }
1505 
1506     SetVector<Use *> Uses;
1507     for (Use &U : V.uses())
1508       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1509         if (ParallelRegionBlockSet.count(UserI->getParent()))
1510           Uses.insert(&U);
1511 
1512     // __kmpc_fork_call expects extra arguments as pointers. If the input
1513     // already has a pointer type, everything is fine. Otherwise, store the
1514     // value onto stack and load it back inside the to-be-outlined region. This
1515     // will ensure only the pointer will be passed to the function.
1516     // FIXME: if there are more than 15 trailing arguments, they must be
1517     // additionally packed in a struct.
1518     Value *Inner = &V;
1519     if (!V.getType()->isPointerTy()) {
1520       IRBuilder<>::InsertPointGuard Guard(Builder);
1521       LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1522 
1523       Builder.restoreIP(OuterAllocaIP);
1524       Value *Ptr =
1525           Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1526 
1527       // Store to stack at end of the block that currently branches to the entry
1528       // block of the to-be-outlined region.
1529       Builder.SetInsertPoint(InsertBB,
1530                              InsertBB->getTerminator()->getIterator());
1531       Builder.CreateStore(&V, Ptr);
1532 
1533       // Load back next to allocations in the to-be-outlined region.
1534       Builder.restoreIP(InnerAllocaIP);
1535       Inner = Builder.CreateLoad(V.getType(), Ptr);
1536     }
1537 
1538     Value *ReplacementValue = nullptr;
1539     CallInst *CI = dyn_cast<CallInst>(&V);
1540     if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1541       ReplacementValue = PrivTID;
1542     } else {
1543       Builder.restoreIP(
1544           PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1545       assert(ReplacementValue &&
1546              "Expected copy/create callback to set replacement value!");
1547       if (ReplacementValue == &V)
1548         return;
1549     }
1550 
1551     for (Use *UPtr : Uses)
1552       UPtr->set(ReplacementValue);
1553   };
1554 
1555   // Reset the inner alloca insertion as it will be used for loading the values
1556   // wrapped into pointers before passing them into the to-be-outlined region.
1557   // Configure it to insert immediately after the fake use of zero address so
1558   // that they are available in the generated body and so that the
1559   // OpenMP-related values (thread ID and zero address pointers) remain leading
1560   // in the argument list.
1561   InnerAllocaIP = IRBuilder<>::InsertPoint(
1562       ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1563 
1564   // Reset the outer alloca insertion point to the entry of the relevant block
1565   // in case it was invalidated.
1566   OuterAllocaIP = IRBuilder<>::InsertPoint(
1567       OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1568 
1569   for (Value *Input : Inputs) {
1570     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1571     PrivHelper(*Input);
1572   }
1573   LLVM_DEBUG({
1574     for (Value *Output : Outputs)
1575       LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1576   });
1577   assert(Outputs.empty() &&
1578          "OpenMP outlining should not produce live-out values!");
1579 
1580   LLVM_DEBUG(dbgs() << "After  privatization: " << *OuterFn << "\n");
1581   LLVM_DEBUG({
1582     for (auto *BB : Blocks)
1583       dbgs() << " PBR: " << BB->getName() << "\n";
1584   });
1585 
1586   // Register the outlined info.
1587   addOutlineInfo(std::move(OI));
1588 
1589   InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1590   UI->eraseFromParent();
1591 
1592   return AfterIP;
1593 }
1594 
emitFlush(const LocationDescription & Loc)1595 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1596   // Build call void __kmpc_flush(ident_t *loc)
1597   uint32_t SrcLocStrSize;
1598   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1599   Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1600 
1601   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1602 }
1603 
createFlush(const LocationDescription & Loc)1604 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1605   if (!updateToLocation(Loc))
1606     return;
1607   emitFlush(Loc);
1608 }
1609 
emitTaskwaitImpl(const LocationDescription & Loc)1610 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1611   // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1612   // global_tid);
1613   uint32_t SrcLocStrSize;
1614   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1615   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1616   Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1617 
1618   // Ignore return result until untied tasks are supported.
1619   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1620                      Args);
1621 }
1622 
createTaskwait(const LocationDescription & Loc)1623 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1624   if (!updateToLocation(Loc))
1625     return;
1626   emitTaskwaitImpl(Loc);
1627 }
1628 
emitTaskyieldImpl(const LocationDescription & Loc)1629 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1630   // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1631   uint32_t SrcLocStrSize;
1632   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1633   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1634   Constant *I32Null = ConstantInt::getNullValue(Int32);
1635   Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1636 
1637   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1638                      Args);
1639 }
1640 
createTaskyield(const LocationDescription & Loc)1641 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1642   if (!updateToLocation(Loc))
1643     return;
1644   emitTaskyieldImpl(Loc);
1645 }
1646 
1647 OpenMPIRBuilder::InsertPointTy
createTask(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB,bool Tied,Value * Final,Value * IfCondition,SmallVector<DependData> Dependencies)1648 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1649                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1650                             bool Tied, Value *Final, Value *IfCondition,
1651                             SmallVector<DependData> Dependencies) {
1652 
1653   if (!updateToLocation(Loc))
1654     return InsertPointTy();
1655 
1656   uint32_t SrcLocStrSize;
1657   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1658   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1659   // The current basic block is split into four basic blocks. After outlining,
1660   // they will be mapped as follows:
1661   // ```
1662   // def current_fn() {
1663   //   current_basic_block:
1664   //     br label %task.exit
1665   //   task.exit:
1666   //     ; instructions after task
1667   // }
1668   // def outlined_fn() {
1669   //   task.alloca:
1670   //     br label %task.body
1671   //   task.body:
1672   //     ret void
1673   // }
1674   // ```
1675   BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1676   BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1677   BasicBlock *TaskAllocaBB =
1678       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1679 
1680   InsertPointTy TaskAllocaIP =
1681       InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1682   InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1683   BodyGenCB(TaskAllocaIP, TaskBodyIP);
1684 
1685   OutlineInfo OI;
1686   OI.EntryBB = TaskAllocaBB;
1687   OI.OuterAllocaBB = AllocaIP.getBlock();
1688   OI.ExitBB = TaskExitBB;
1689 
1690   // Add the thread ID argument.
1691   std::stack<Instruction *> ToBeDeleted;
1692   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
1693       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
1694 
1695   OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1696                       TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1697     // Replace the Stale CI by appropriate RTL function call.
1698     assert(OutlinedFn.getNumUses() == 1 &&
1699            "there must be a single user for the outlined function");
1700     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1701 
1702     // HasShareds is true if any variables are captured in the outlined region,
1703     // false otherwise.
1704     bool HasShareds = StaleCI->arg_size() > 1;
1705     Builder.SetInsertPoint(StaleCI);
1706 
1707     // Gather the arguments for emitting the runtime call for
1708     // @__kmpc_omp_task_alloc
1709     Function *TaskAllocFn =
1710         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1711 
1712     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1713     // call.
1714     Value *ThreadID = getOrCreateThreadID(Ident);
1715 
1716     // Argument - `flags`
1717     // Task is tied iff (Flags & 1) == 1.
1718     // Task is untied iff (Flags & 1) == 0.
1719     // Task is final iff (Flags & 2) == 2.
1720     // Task is not final iff (Flags & 2) == 0.
1721     // TODO: Handle the other flags.
1722     Value *Flags = Builder.getInt32(Tied);
1723     if (Final) {
1724       Value *FinalFlag =
1725           Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1726       Flags = Builder.CreateOr(FinalFlag, Flags);
1727     }
1728 
1729     // Argument - `sizeof_kmp_task_t` (TaskSize)
1730     // Tasksize refers to the size in bytes of kmp_task_t data structure
1731     // including private vars accessed in task.
1732     // TODO: add kmp_task_t_with_privates (privates)
1733     Value *TaskSize = Builder.getInt64(
1734         divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
1735 
1736     // Argument - `sizeof_shareds` (SharedsSize)
1737     // SharedsSize refers to the shareds array size in the kmp_task_t data
1738     // structure.
1739     Value *SharedsSize = Builder.getInt64(0);
1740     if (HasShareds) {
1741       AllocaInst *ArgStructAlloca =
1742           dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
1743       assert(ArgStructAlloca &&
1744              "Unable to find the alloca instruction corresponding to arguments "
1745              "for extracted function");
1746       StructType *ArgStructType =
1747           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
1748       assert(ArgStructType && "Unable to find struct type corresponding to "
1749                               "arguments for extracted function");
1750       SharedsSize =
1751           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
1752     }
1753     // Emit the @__kmpc_omp_task_alloc runtime call
1754     // The runtime call returns a pointer to an area where the task captured
1755     // variables must be copied before the task is run (TaskData)
1756     CallInst *TaskData = Builder.CreateCall(
1757         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1758                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
1759                       /*task_func=*/&OutlinedFn});
1760 
1761     // Copy the arguments for outlined function
1762     if (HasShareds) {
1763       Value *Shareds = StaleCI->getArgOperand(1);
1764       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1765       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
1766       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
1767                            SharedsSize);
1768     }
1769 
1770     Value *DepArray = nullptr;
1771     if (Dependencies.size()) {
1772       InsertPointTy OldIP = Builder.saveIP();
1773       Builder.SetInsertPoint(
1774           &OldIP.getBlock()->getParent()->getEntryBlock().back());
1775 
1776       Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1777       DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1778 
1779       unsigned P = 0;
1780       for (const DependData &Dep : Dependencies) {
1781         Value *Base =
1782             Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
1783         // Store the pointer to the variable
1784         Value *Addr = Builder.CreateStructGEP(
1785             DependInfo, Base,
1786             static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1787         Value *DepValPtr =
1788             Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1789         Builder.CreateStore(DepValPtr, Addr);
1790         // Store the size of the variable
1791         Value *Size = Builder.CreateStructGEP(
1792             DependInfo, Base,
1793             static_cast<unsigned int>(RTLDependInfoFields::Len));
1794         Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
1795                                 Dep.DepValueType)),
1796                             Size);
1797         // Store the dependency kind
1798         Value *Flags = Builder.CreateStructGEP(
1799             DependInfo, Base,
1800             static_cast<unsigned int>(RTLDependInfoFields::Flags));
1801         Builder.CreateStore(
1802             ConstantInt::get(Builder.getInt8Ty(),
1803                              static_cast<unsigned int>(Dep.DepKind)),
1804             Flags);
1805         ++P;
1806       }
1807 
1808       Builder.restoreIP(OldIP);
1809     }
1810 
1811     // In the presence of the `if` clause, the following IR is generated:
1812     //    ...
1813     //    %data = call @__kmpc_omp_task_alloc(...)
1814     //    br i1 %if_condition, label %then, label %else
1815     //  then:
1816     //    call @__kmpc_omp_task(...)
1817     //    br label %exit
1818     //  else:
1819     //    call @__kmpc_omp_task_begin_if0(...)
1820     //    call @outlined_fn(...)
1821     //    call @__kmpc_omp_task_complete_if0(...)
1822     //    br label %exit
1823     //  exit:
1824     //    ...
1825     if (IfCondition) {
1826       // `SplitBlockAndInsertIfThenElse` requires the block to have a
1827       // terminator.
1828       splitBB(Builder, /*CreateBranch=*/true, "if.end");
1829       Instruction *IfTerminator =
1830           Builder.GetInsertPoint()->getParent()->getTerminator();
1831       Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
1832       Builder.SetInsertPoint(IfTerminator);
1833       SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
1834                                     &ElseTI);
1835       Builder.SetInsertPoint(ElseTI);
1836       Function *TaskBeginFn =
1837           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
1838       Function *TaskCompleteFn =
1839           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
1840       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
1841       CallInst *CI = nullptr;
1842       if (HasShareds)
1843         CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
1844       else
1845         CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
1846       CI->setDebugLoc(StaleCI->getDebugLoc());
1847       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
1848       Builder.SetInsertPoint(ThenTI);
1849     }
1850 
1851     if (Dependencies.size()) {
1852       Function *TaskFn =
1853           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
1854       Builder.CreateCall(
1855           TaskFn,
1856           {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
1857            DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
1858            ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
1859 
1860     } else {
1861       // Emit the @__kmpc_omp_task runtime call to spawn the task
1862       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
1863       Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
1864     }
1865 
1866     StaleCI->eraseFromParent();
1867 
1868     Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
1869     if (HasShareds) {
1870       LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
1871       OutlinedFn.getArg(1)->replaceUsesWithIf(
1872           Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
1873     }
1874 
1875     while (!ToBeDeleted.empty()) {
1876       ToBeDeleted.top()->eraseFromParent();
1877       ToBeDeleted.pop();
1878     }
1879   };
1880 
1881   addOutlineInfo(std::move(OI));
1882   Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
1883 
1884   return Builder.saveIP();
1885 }
1886 
1887 OpenMPIRBuilder::InsertPointTy
createTaskgroup(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB)1888 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
1889                                  InsertPointTy AllocaIP,
1890                                  BodyGenCallbackTy BodyGenCB) {
1891   if (!updateToLocation(Loc))
1892     return InsertPointTy();
1893 
1894   uint32_t SrcLocStrSize;
1895   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1896   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1897   Value *ThreadID = getOrCreateThreadID(Ident);
1898 
1899   // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
1900   Function *TaskgroupFn =
1901       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
1902   Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
1903 
1904   BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
1905   BodyGenCB(AllocaIP, Builder.saveIP());
1906 
1907   Builder.SetInsertPoint(TaskgroupExitBB);
1908   // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
1909   Function *EndTaskgroupFn =
1910       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
1911   Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
1912 
1913   return Builder.saveIP();
1914 }
1915 
createSections(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<StorableBodyGenCallbackTy> SectionCBs,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,bool IsCancellable,bool IsNowait)1916 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
1917     const LocationDescription &Loc, InsertPointTy AllocaIP,
1918     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
1919     FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
1920   assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
1921 
1922   if (!updateToLocation(Loc))
1923     return Loc.IP;
1924 
1925   auto FiniCBWrapper = [&](InsertPointTy IP) {
1926     if (IP.getBlock()->end() != IP.getPoint())
1927       return FiniCB(IP);
1928     // This must be done otherwise any nested constructs using FinalizeOMPRegion
1929     // will fail because that function requires the Finalization Basic Block to
1930     // have a terminator, which is already removed by EmitOMPRegionBody.
1931     // IP is currently at cancelation block.
1932     // We need to backtrack to the condition block to fetch
1933     // the exit block and create a branch from cancelation
1934     // to exit block.
1935     IRBuilder<>::InsertPointGuard IPG(Builder);
1936     Builder.restoreIP(IP);
1937     auto *CaseBB = IP.getBlock()->getSinglePredecessor();
1938     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1939     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1940     Instruction *I = Builder.CreateBr(ExitBB);
1941     IP = InsertPointTy(I->getParent(), I->getIterator());
1942     return FiniCB(IP);
1943   };
1944 
1945   FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
1946 
1947   // Each section is emitted as a switch case
1948   // Each finalization callback is handled from clang.EmitOMPSectionDirective()
1949   // -> OMP.createSection() which generates the IR for each section
1950   // Iterate through all sections and emit a switch construct:
1951   // switch (IV) {
1952   //   case 0:
1953   //     <SectionStmt[0]>;
1954   //     break;
1955   // ...
1956   //   case <NumSection> - 1:
1957   //     <SectionStmt[<NumSection> - 1]>;
1958   //     break;
1959   // }
1960   // ...
1961   // section_loop.after:
1962   // <FiniCB>;
1963   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
1964     Builder.restoreIP(CodeGenIP);
1965     BasicBlock *Continue =
1966         splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
1967     Function *CurFn = Continue->getParent();
1968     SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
1969 
1970     unsigned CaseNumber = 0;
1971     for (auto SectionCB : SectionCBs) {
1972       BasicBlock *CaseBB = BasicBlock::Create(
1973           M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
1974       SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
1975       Builder.SetInsertPoint(CaseBB);
1976       BranchInst *CaseEndBr = Builder.CreateBr(Continue);
1977       SectionCB(InsertPointTy(),
1978                 {CaseEndBr->getParent(), CaseEndBr->getIterator()});
1979       CaseNumber++;
1980     }
1981     // remove the existing terminator from body BB since there can be no
1982     // terminators after switch/case
1983   };
1984   // Loop body ends here
1985   // LowerBound, UpperBound, and STride for createCanonicalLoop
1986   Type *I32Ty = Type::getInt32Ty(M.getContext());
1987   Value *LB = ConstantInt::get(I32Ty, 0);
1988   Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
1989   Value *ST = ConstantInt::get(I32Ty, 1);
1990   llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
1991       Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
1992   InsertPointTy AfterIP =
1993       applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, !IsNowait);
1994 
1995   // Apply the finalization callback in LoopAfterBB
1996   auto FiniInfo = FinalizationStack.pop_back_val();
1997   assert(FiniInfo.DK == OMPD_sections &&
1998          "Unexpected finalization stack state!");
1999   if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2000     Builder.restoreIP(AfterIP);
2001     BasicBlock *FiniBB =
2002         splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
2003     CB(Builder.saveIP());
2004     AfterIP = {FiniBB, FiniBB->begin()};
2005   }
2006 
2007   return AfterIP;
2008 }
2009 
2010 OpenMPIRBuilder::InsertPointTy
createSection(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)2011 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2012                                BodyGenCallbackTy BodyGenCB,
2013                                FinalizeCallbackTy FiniCB) {
2014   if (!updateToLocation(Loc))
2015     return Loc.IP;
2016 
2017   auto FiniCBWrapper = [&](InsertPointTy IP) {
2018     if (IP.getBlock()->end() != IP.getPoint())
2019       return FiniCB(IP);
2020     // This must be done otherwise any nested constructs using FinalizeOMPRegion
2021     // will fail because that function requires the Finalization Basic Block to
2022     // have a terminator, which is already removed by EmitOMPRegionBody.
2023     // IP is currently at cancelation block.
2024     // We need to backtrack to the condition block to fetch
2025     // the exit block and create a branch from cancelation
2026     // to exit block.
2027     IRBuilder<>::InsertPointGuard IPG(Builder);
2028     Builder.restoreIP(IP);
2029     auto *CaseBB = Loc.IP.getBlock();
2030     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2031     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2032     Instruction *I = Builder.CreateBr(ExitBB);
2033     IP = InsertPointTy(I->getParent(), I->getIterator());
2034     return FiniCB(IP);
2035   };
2036 
2037   Directive OMPD = Directive::OMPD_sections;
2038   // Since we are using Finalization Callback here, HasFinalize
2039   // and IsCancellable have to be true
2040   return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
2041                               /*Conditional*/ false, /*hasFinalize*/ true,
2042                               /*IsCancellable*/ true);
2043 }
2044 
2045 /// Create a function with a unique name and a "void (i8*, i8*)" signature in
2046 /// the given module and return it.
getFreshReductionFunc(Module & M)2047 Function *getFreshReductionFunc(Module &M) {
2048   Type *VoidTy = Type::getVoidTy(M.getContext());
2049   Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
2050   auto *FuncTy =
2051       FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
2052   return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2053                           M.getDataLayout().getDefaultGlobalsAddressSpace(),
2054                           ".omp.reduction.func", &M);
2055 }
2056 
createReductions(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<ReductionInfo> ReductionInfos,bool IsNoWait)2057 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
2058     const LocationDescription &Loc, InsertPointTy AllocaIP,
2059     ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
2060   for (const ReductionInfo &RI : ReductionInfos) {
2061     (void)RI;
2062     assert(RI.Variable && "expected non-null variable");
2063     assert(RI.PrivateVariable && "expected non-null private variable");
2064     assert(RI.ReductionGen && "expected non-null reduction generator callback");
2065     assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
2066            "expected variables and their private equivalents to have the same "
2067            "type");
2068     assert(RI.Variable->getType()->isPointerTy() &&
2069            "expected variables to be pointers");
2070   }
2071 
2072   if (!updateToLocation(Loc))
2073     return InsertPointTy();
2074 
2075   BasicBlock *InsertBlock = Loc.IP.getBlock();
2076   BasicBlock *ContinuationBlock =
2077       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
2078   InsertBlock->getTerminator()->eraseFromParent();
2079 
2080   // Create and populate array of type-erased pointers to private reduction
2081   // values.
2082   unsigned NumReductions = ReductionInfos.size();
2083   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
2084   Builder.restoreIP(AllocaIP);
2085   Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
2086 
2087   Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
2088 
2089   for (auto En : enumerate(ReductionInfos)) {
2090     unsigned Index = En.index();
2091     const ReductionInfo &RI = En.value();
2092     Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
2093         RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
2094     Builder.CreateStore(RI.PrivateVariable, RedArrayElemPtr);
2095   }
2096 
2097   // Emit a call to the runtime function that orchestrates the reduction.
2098   // Declare the reduction function in the process.
2099   Function *Func = Builder.GetInsertBlock()->getParent();
2100   Module *Module = Func->getParent();
2101   uint32_t SrcLocStrSize;
2102   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2103   bool CanGenerateAtomic =
2104       llvm::all_of(ReductionInfos, [](const ReductionInfo &RI) {
2105         return RI.AtomicReductionGen;
2106       });
2107   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
2108                                   CanGenerateAtomic
2109                                       ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
2110                                       : IdentFlag(0));
2111   Value *ThreadId = getOrCreateThreadID(Ident);
2112   Constant *NumVariables = Builder.getInt32(NumReductions);
2113   const DataLayout &DL = Module->getDataLayout();
2114   unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
2115   Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
2116   Function *ReductionFunc = getFreshReductionFunc(*Module);
2117   Value *Lock = getOMPCriticalRegionLock(".reduction");
2118   Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
2119       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
2120                : RuntimeFunction::OMPRTL___kmpc_reduce);
2121   CallInst *ReduceCall =
2122       Builder.CreateCall(ReduceFunc,
2123                          {Ident, ThreadId, NumVariables, RedArraySize, RedArray,
2124                           ReductionFunc, Lock},
2125                          "reduce");
2126 
2127   // Create final reduction entry blocks for the atomic and non-atomic case.
2128   // Emit IR that dispatches control flow to one of the blocks based on the
2129   // reduction supporting the atomic mode.
2130   BasicBlock *NonAtomicRedBlock =
2131       BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
2132   BasicBlock *AtomicRedBlock =
2133       BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
2134   SwitchInst *Switch =
2135       Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
2136   Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
2137   Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
2138 
2139   // Populate the non-atomic reduction using the elementwise reduction function.
2140   // This loads the elements from the global and private variables and reduces
2141   // them before storing back the result to the global variable.
2142   Builder.SetInsertPoint(NonAtomicRedBlock);
2143   for (auto En : enumerate(ReductionInfos)) {
2144     const ReductionInfo &RI = En.value();
2145     Type *ValueType = RI.ElementType;
2146     Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
2147                                          "red.value." + Twine(En.index()));
2148     Value *PrivateRedValue =
2149         Builder.CreateLoad(ValueType, RI.PrivateVariable,
2150                            "red.private.value." + Twine(En.index()));
2151     Value *Reduced;
2152     Builder.restoreIP(
2153         RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced));
2154     if (!Builder.GetInsertBlock())
2155       return InsertPointTy();
2156     Builder.CreateStore(Reduced, RI.Variable);
2157   }
2158   Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
2159       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
2160                : RuntimeFunction::OMPRTL___kmpc_end_reduce);
2161   Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
2162   Builder.CreateBr(ContinuationBlock);
2163 
2164   // Populate the atomic reduction using the atomic elementwise reduction
2165   // function. There are no loads/stores here because they will be happening
2166   // inside the atomic elementwise reduction.
2167   Builder.SetInsertPoint(AtomicRedBlock);
2168   if (CanGenerateAtomic) {
2169     for (const ReductionInfo &RI : ReductionInfos) {
2170       Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
2171                                               RI.Variable, RI.PrivateVariable));
2172       if (!Builder.GetInsertBlock())
2173         return InsertPointTy();
2174     }
2175     Builder.CreateBr(ContinuationBlock);
2176   } else {
2177     Builder.CreateUnreachable();
2178   }
2179 
2180   // Populate the outlined reduction function using the elementwise reduction
2181   // function. Partial values are extracted from the type-erased array of
2182   // pointers to private variables.
2183   BasicBlock *ReductionFuncBlock =
2184       BasicBlock::Create(Module->getContext(), "", ReductionFunc);
2185   Builder.SetInsertPoint(ReductionFuncBlock);
2186   Value *LHSArrayPtr = ReductionFunc->getArg(0);
2187   Value *RHSArrayPtr = ReductionFunc->getArg(1);
2188 
2189   for (auto En : enumerate(ReductionInfos)) {
2190     const ReductionInfo &RI = En.value();
2191     Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
2192         RedArrayTy, LHSArrayPtr, 0, En.index());
2193     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
2194     Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
2195     Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
2196     Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
2197         RedArrayTy, RHSArrayPtr, 0, En.index());
2198     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
2199     Value *RHSPtr =
2200         Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
2201     Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
2202     Value *Reduced;
2203     Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
2204     if (!Builder.GetInsertBlock())
2205       return InsertPointTy();
2206     Builder.CreateStore(Reduced, LHSPtr);
2207   }
2208   Builder.CreateRetVoid();
2209 
2210   Builder.SetInsertPoint(ContinuationBlock);
2211   return Builder.saveIP();
2212 }
2213 
2214 OpenMPIRBuilder::InsertPointTy
createMaster(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)2215 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
2216                               BodyGenCallbackTy BodyGenCB,
2217                               FinalizeCallbackTy FiniCB) {
2218 
2219   if (!updateToLocation(Loc))
2220     return Loc.IP;
2221 
2222   Directive OMPD = Directive::OMPD_master;
2223   uint32_t SrcLocStrSize;
2224   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2225   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2226   Value *ThreadId = getOrCreateThreadID(Ident);
2227   Value *Args[] = {Ident, ThreadId};
2228 
2229   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
2230   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
2231 
2232   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
2233   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
2234 
2235   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
2236                               /*Conditional*/ true, /*hasFinalize*/ true);
2237 }
2238 
2239 OpenMPIRBuilder::InsertPointTy
createMasked(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,Value * Filter)2240 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
2241                               BodyGenCallbackTy BodyGenCB,
2242                               FinalizeCallbackTy FiniCB, Value *Filter) {
2243   if (!updateToLocation(Loc))
2244     return Loc.IP;
2245 
2246   Directive OMPD = Directive::OMPD_masked;
2247   uint32_t SrcLocStrSize;
2248   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2249   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2250   Value *ThreadId = getOrCreateThreadID(Ident);
2251   Value *Args[] = {Ident, ThreadId, Filter};
2252   Value *ArgsEnd[] = {Ident, ThreadId};
2253 
2254   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
2255   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
2256 
2257   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
2258   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
2259 
2260   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
2261                               /*Conditional*/ true, /*hasFinalize*/ true);
2262 }
2263 
createLoopSkeleton(DebugLoc DL,Value * TripCount,Function * F,BasicBlock * PreInsertBefore,BasicBlock * PostInsertBefore,const Twine & Name)2264 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
2265     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
2266     BasicBlock *PostInsertBefore, const Twine &Name) {
2267   Module *M = F->getParent();
2268   LLVMContext &Ctx = M->getContext();
2269   Type *IndVarTy = TripCount->getType();
2270 
2271   // Create the basic block structure.
2272   BasicBlock *Preheader =
2273       BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
2274   BasicBlock *Header =
2275       BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
2276   BasicBlock *Cond =
2277       BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
2278   BasicBlock *Body =
2279       BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
2280   BasicBlock *Latch =
2281       BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
2282   BasicBlock *Exit =
2283       BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
2284   BasicBlock *After =
2285       BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
2286 
2287   // Use specified DebugLoc for new instructions.
2288   Builder.SetCurrentDebugLocation(DL);
2289 
2290   Builder.SetInsertPoint(Preheader);
2291   Builder.CreateBr(Header);
2292 
2293   Builder.SetInsertPoint(Header);
2294   PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
2295   IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
2296   Builder.CreateBr(Cond);
2297 
2298   Builder.SetInsertPoint(Cond);
2299   Value *Cmp =
2300       Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
2301   Builder.CreateCondBr(Cmp, Body, Exit);
2302 
2303   Builder.SetInsertPoint(Body);
2304   Builder.CreateBr(Latch);
2305 
2306   Builder.SetInsertPoint(Latch);
2307   Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
2308                                   "omp_" + Name + ".next", /*HasNUW=*/true);
2309   Builder.CreateBr(Header);
2310   IndVarPHI->addIncoming(Next, Latch);
2311 
2312   Builder.SetInsertPoint(Exit);
2313   Builder.CreateBr(After);
2314 
2315   // Remember and return the canonical control flow.
2316   LoopInfos.emplace_front();
2317   CanonicalLoopInfo *CL = &LoopInfos.front();
2318 
2319   CL->Header = Header;
2320   CL->Cond = Cond;
2321   CL->Latch = Latch;
2322   CL->Exit = Exit;
2323 
2324 #ifndef NDEBUG
2325   CL->assertOK();
2326 #endif
2327   return CL;
2328 }
2329 
2330 CanonicalLoopInfo *
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * TripCount,const Twine & Name)2331 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
2332                                      LoopBodyGenCallbackTy BodyGenCB,
2333                                      Value *TripCount, const Twine &Name) {
2334   BasicBlock *BB = Loc.IP.getBlock();
2335   BasicBlock *NextBB = BB->getNextNode();
2336 
2337   CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
2338                                              NextBB, NextBB, Name);
2339   BasicBlock *After = CL->getAfter();
2340 
2341   // If location is not set, don't connect the loop.
2342   if (updateToLocation(Loc)) {
2343     // Split the loop at the insertion point: Branch to the preheader and move
2344     // every following instruction to after the loop (the After BB). Also, the
2345     // new successor is the loop's after block.
2346     spliceBB(Builder, After, /*CreateBranch=*/false);
2347     Builder.CreateBr(CL->getPreheader());
2348   }
2349 
2350   // Emit the body content. We do it after connecting the loop to the CFG to
2351   // avoid that the callback encounters degenerate BBs.
2352   BodyGenCB(CL->getBodyIP(), CL->getIndVar());
2353 
2354 #ifndef NDEBUG
2355   CL->assertOK();
2356 #endif
2357   return CL;
2358 }
2359 
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,InsertPointTy ComputeIP,const Twine & Name)2360 CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
2361     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
2362     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
2363     InsertPointTy ComputeIP, const Twine &Name) {
2364 
2365   // Consider the following difficulties (assuming 8-bit signed integers):
2366   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
2367   //      DO I = 1, 100, 50
2368   ///  * A \p Step of INT_MIN cannot not be normalized to a positive direction:
2369   //      DO I = 100, 0, -128
2370 
2371   // Start, Stop and Step must be of the same integer type.
2372   auto *IndVarTy = cast<IntegerType>(Start->getType());
2373   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
2374   assert(IndVarTy == Step->getType() && "Step type mismatch");
2375 
2376   LocationDescription ComputeLoc =
2377       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
2378   updateToLocation(ComputeLoc);
2379 
2380   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
2381   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
2382 
2383   // Like Step, but always positive.
2384   Value *Incr = Step;
2385 
2386   // Distance between Start and Stop; always positive.
2387   Value *Span;
2388 
2389   // Condition whether there are no iterations are executed at all, e.g. because
2390   // UB < LB.
2391   Value *ZeroCmp;
2392 
2393   if (IsSigned) {
2394     // Ensure that increment is positive. If not, negate and invert LB and UB.
2395     Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
2396     Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
2397     Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
2398     Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
2399     Span = Builder.CreateSub(UB, LB, "", false, true);
2400     ZeroCmp = Builder.CreateICmp(
2401         InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
2402   } else {
2403     Span = Builder.CreateSub(Stop, Start, "", true);
2404     ZeroCmp = Builder.CreateICmp(
2405         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
2406   }
2407 
2408   Value *CountIfLooping;
2409   if (InclusiveStop) {
2410     CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
2411   } else {
2412     // Avoid incrementing past stop since it could overflow.
2413     Value *CountIfTwo = Builder.CreateAdd(
2414         Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
2415     Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
2416     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
2417   }
2418   Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
2419                                           "omp_" + Name + ".tripcount");
2420 
2421   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
2422     Builder.restoreIP(CodeGenIP);
2423     Value *Span = Builder.CreateMul(IV, Step);
2424     Value *IndVar = Builder.CreateAdd(Span, Start);
2425     BodyGenCB(Builder.saveIP(), IndVar);
2426   };
2427   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
2428   return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
2429 }
2430 
2431 // Returns an LLVM function to call for initializing loop bounds using OpenMP
2432 // static scheduling depending on `type`. Only i32 and i64 are supported by the
2433 // runtime. Always interpret integers as unsigned similarly to
2434 // CanonicalLoopInfo.
getKmpcForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2435 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
2436                                                   OpenMPIRBuilder &OMPBuilder) {
2437   unsigned Bitwidth = Ty->getIntegerBitWidth();
2438   if (Bitwidth == 32)
2439     return OMPBuilder.getOrCreateRuntimeFunction(
2440         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
2441   if (Bitwidth == 64)
2442     return OMPBuilder.getOrCreateRuntimeFunction(
2443         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
2444   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2445 }
2446 
2447 OpenMPIRBuilder::InsertPointTy
applyStaticWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier)2448 OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
2449                                           InsertPointTy AllocaIP,
2450                                           bool NeedsBarrier) {
2451   assert(CLI->isValid() && "Requires a valid canonical loop");
2452   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
2453          "Require dedicated allocate IP");
2454 
2455   // Set up the source location value for OpenMP runtime.
2456   Builder.restoreIP(CLI->getPreheaderIP());
2457   Builder.SetCurrentDebugLocation(DL);
2458 
2459   uint32_t SrcLocStrSize;
2460   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2461   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2462 
2463   // Declare useful OpenMP runtime functions.
2464   Value *IV = CLI->getIndVar();
2465   Type *IVTy = IV->getType();
2466   FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
2467   FunctionCallee StaticFini =
2468       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2469 
2470   // Allocate space for computed loop bounds as expected by the "init" function.
2471   Builder.restoreIP(AllocaIP);
2472   Type *I32Type = Type::getInt32Ty(M.getContext());
2473   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2474   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
2475   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
2476   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
2477 
2478   // At the end of the preheader, prepare for calling the "init" function by
2479   // storing the current loop bounds into the allocated space. A canonical loop
2480   // always iterates from 0 to trip-count with step 1. Note that "init" expects
2481   // and produces an inclusive upper bound.
2482   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2483   Constant *Zero = ConstantInt::get(IVTy, 0);
2484   Constant *One = ConstantInt::get(IVTy, 1);
2485   Builder.CreateStore(Zero, PLowerBound);
2486   Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
2487   Builder.CreateStore(UpperBound, PUpperBound);
2488   Builder.CreateStore(One, PStride);
2489 
2490   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2491 
2492   Constant *SchedulingType = ConstantInt::get(
2493       I32Type, static_cast<int>(OMPScheduleType::UnorderedStatic));
2494 
2495   // Call the "init" function and update the trip count of the loop with the
2496   // value it produced.
2497   Builder.CreateCall(StaticInit,
2498                      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
2499                       PUpperBound, PStride, One, Zero});
2500   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
2501   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
2502   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
2503   Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
2504   CLI->setTripCount(TripCount);
2505 
2506   // Update all uses of the induction variable except the one in the condition
2507   // block that compares it with the actual upper bound, and the increment in
2508   // the latch block.
2509 
2510   CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
2511     Builder.SetInsertPoint(CLI->getBody(),
2512                            CLI->getBody()->getFirstInsertionPt());
2513     Builder.SetCurrentDebugLocation(DL);
2514     return Builder.CreateAdd(OldIV, LowerBound);
2515   });
2516 
2517   // In the "exit" block, call the "fini" function.
2518   Builder.SetInsertPoint(CLI->getExit(),
2519                          CLI->getExit()->getTerminator()->getIterator());
2520   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2521 
2522   // Add the barrier if requested.
2523   if (NeedsBarrier)
2524     createBarrier(LocationDescription(Builder.saveIP(), DL),
2525                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
2526                   /* CheckCancelFlag */ false);
2527 
2528   InsertPointTy AfterIP = CLI->getAfterIP();
2529   CLI->invalidate();
2530 
2531   return AfterIP;
2532 }
2533 
applyStaticChunkedWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,Value * ChunkSize)2534 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
2535     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2536     bool NeedsBarrier, Value *ChunkSize) {
2537   assert(CLI->isValid() && "Requires a valid canonical loop");
2538   assert(ChunkSize && "Chunk size is required");
2539 
2540   LLVMContext &Ctx = CLI->getFunction()->getContext();
2541   Value *IV = CLI->getIndVar();
2542   Value *OrigTripCount = CLI->getTripCount();
2543   Type *IVTy = IV->getType();
2544   assert(IVTy->getIntegerBitWidth() <= 64 &&
2545          "Max supported tripcount bitwidth is 64 bits");
2546   Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
2547                                                         : Type::getInt64Ty(Ctx);
2548   Type *I32Type = Type::getInt32Ty(M.getContext());
2549   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
2550   Constant *One = ConstantInt::get(InternalIVTy, 1);
2551 
2552   // Declare useful OpenMP runtime functions.
2553   FunctionCallee StaticInit =
2554       getKmpcForStaticInitForType(InternalIVTy, M, *this);
2555   FunctionCallee StaticFini =
2556       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2557 
2558   // Allocate space for computed loop bounds as expected by the "init" function.
2559   Builder.restoreIP(AllocaIP);
2560   Builder.SetCurrentDebugLocation(DL);
2561   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2562   Value *PLowerBound =
2563       Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
2564   Value *PUpperBound =
2565       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
2566   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
2567 
2568   // Set up the source location value for the OpenMP runtime.
2569   Builder.restoreIP(CLI->getPreheaderIP());
2570   Builder.SetCurrentDebugLocation(DL);
2571 
2572   // TODO: Detect overflow in ubsan or max-out with current tripcount.
2573   Value *CastedChunkSize =
2574       Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
2575   Value *CastedTripCount =
2576       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
2577 
2578   Constant *SchedulingType = ConstantInt::get(
2579       I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
2580   Builder.CreateStore(Zero, PLowerBound);
2581   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
2582   Builder.CreateStore(OrigUpperBound, PUpperBound);
2583   Builder.CreateStore(One, PStride);
2584 
2585   // Call the "init" function and update the trip count of the loop with the
2586   // value it produced.
2587   uint32_t SrcLocStrSize;
2588   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2589   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2590   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2591   Builder.CreateCall(StaticInit,
2592                      {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
2593                       /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
2594                       /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
2595                       /*pstride=*/PStride, /*incr=*/One,
2596                       /*chunk=*/CastedChunkSize});
2597 
2598   // Load values written by the "init" function.
2599   Value *FirstChunkStart =
2600       Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
2601   Value *FirstChunkStop =
2602       Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
2603   Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
2604   Value *ChunkRange =
2605       Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
2606   Value *NextChunkStride =
2607       Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
2608 
2609   // Create outer "dispatch" loop for enumerating the chunks.
2610   BasicBlock *DispatchEnter = splitBB(Builder, true);
2611   Value *DispatchCounter;
2612   CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
2613       {Builder.saveIP(), DL},
2614       [&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
2615       FirstChunkStart, CastedTripCount, NextChunkStride,
2616       /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
2617       "dispatch");
2618 
2619   // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
2620   // not have to preserve the canonical invariant.
2621   BasicBlock *DispatchBody = DispatchCLI->getBody();
2622   BasicBlock *DispatchLatch = DispatchCLI->getLatch();
2623   BasicBlock *DispatchExit = DispatchCLI->getExit();
2624   BasicBlock *DispatchAfter = DispatchCLI->getAfter();
2625   DispatchCLI->invalidate();
2626 
2627   // Rewire the original loop to become the chunk loop inside the dispatch loop.
2628   redirectTo(DispatchAfter, CLI->getAfter(), DL);
2629   redirectTo(CLI->getExit(), DispatchLatch, DL);
2630   redirectTo(DispatchBody, DispatchEnter, DL);
2631 
2632   // Prepare the prolog of the chunk loop.
2633   Builder.restoreIP(CLI->getPreheaderIP());
2634   Builder.SetCurrentDebugLocation(DL);
2635 
2636   // Compute the number of iterations of the chunk loop.
2637   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2638   Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
2639   Value *IsLastChunk =
2640       Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
2641   Value *CountUntilOrigTripCount =
2642       Builder.CreateSub(CastedTripCount, DispatchCounter);
2643   Value *ChunkTripCount = Builder.CreateSelect(
2644       IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
2645   Value *BackcastedChunkTC =
2646       Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
2647   CLI->setTripCount(BackcastedChunkTC);
2648 
2649   // Update all uses of the induction variable except the one in the condition
2650   // block that compares it with the actual upper bound, and the increment in
2651   // the latch block.
2652   Value *BackcastedDispatchCounter =
2653       Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
2654   CLI->mapIndVar([&](Instruction *) -> Value * {
2655     Builder.restoreIP(CLI->getBodyIP());
2656     return Builder.CreateAdd(IV, BackcastedDispatchCounter);
2657   });
2658 
2659   // In the "exit" block, call the "fini" function.
2660   Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
2661   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2662 
2663   // Add the barrier if requested.
2664   if (NeedsBarrier)
2665     createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
2666                   /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
2667 
2668 #ifndef NDEBUG
2669   // Even though we currently do not support applying additional methods to it,
2670   // the chunk loop should remain a canonical loop.
2671   CLI->assertOK();
2672 #endif
2673 
2674   return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
2675 }
2676 
2677 // Returns an LLVM function to call for executing an OpenMP static worksharing
2678 // for loop depending on `type`. Only i32 and i64 are supported by the runtime.
2679 // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
2680 static FunctionCallee
getKmpcForStaticLoopForType(Type * Ty,OpenMPIRBuilder * OMPBuilder,WorksharingLoopType LoopType)2681 getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
2682                             WorksharingLoopType LoopType) {
2683   unsigned Bitwidth = Ty->getIntegerBitWidth();
2684   Module &M = OMPBuilder->M;
2685   switch (LoopType) {
2686   case WorksharingLoopType::ForStaticLoop:
2687     if (Bitwidth == 32)
2688       return OMPBuilder->getOrCreateRuntimeFunction(
2689           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
2690     if (Bitwidth == 64)
2691       return OMPBuilder->getOrCreateRuntimeFunction(
2692           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
2693     break;
2694   case WorksharingLoopType::DistributeStaticLoop:
2695     if (Bitwidth == 32)
2696       return OMPBuilder->getOrCreateRuntimeFunction(
2697           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
2698     if (Bitwidth == 64)
2699       return OMPBuilder->getOrCreateRuntimeFunction(
2700           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
2701     break;
2702   case WorksharingLoopType::DistributeForStaticLoop:
2703     if (Bitwidth == 32)
2704       return OMPBuilder->getOrCreateRuntimeFunction(
2705           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
2706     if (Bitwidth == 64)
2707       return OMPBuilder->getOrCreateRuntimeFunction(
2708           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
2709     break;
2710   }
2711   if (Bitwidth != 32 && Bitwidth != 64) {
2712     llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
2713   }
2714   llvm_unreachable("Unknown type of OpenMP worksharing loop");
2715 }
2716 
2717 // Inserts a call to proper OpenMP Device RTL function which handles
2718 // loop worksharing.
createTargetLoopWorkshareCall(OpenMPIRBuilder * OMPBuilder,WorksharingLoopType LoopType,BasicBlock * InsertBlock,Value * Ident,Value * LoopBodyArg,Type * ParallelTaskPtr,Value * TripCount,Function & LoopBodyFn)2719 static void createTargetLoopWorkshareCall(
2720     OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
2721     BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
2722     Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
2723   Type *TripCountTy = TripCount->getType();
2724   Module &M = OMPBuilder->M;
2725   IRBuilder<> &Builder = OMPBuilder->Builder;
2726   FunctionCallee RTLFn =
2727       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
2728   SmallVector<Value *, 8> RealArgs;
2729   RealArgs.push_back(Ident);
2730   RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
2731   RealArgs.push_back(LoopBodyArg);
2732   RealArgs.push_back(TripCount);
2733   if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
2734     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2735     Builder.CreateCall(RTLFn, RealArgs);
2736     return;
2737   }
2738   FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
2739       M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
2740   Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
2741   Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
2742 
2743   RealArgs.push_back(
2744       Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
2745   RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2746   if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
2747     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2748   }
2749 
2750   Builder.CreateCall(RTLFn, RealArgs);
2751 }
2752 
2753 static void
workshareLoopTargetCallback(OpenMPIRBuilder * OMPIRBuilder,CanonicalLoopInfo * CLI,Value * Ident,Function & OutlinedFn,Type * ParallelTaskPtr,const SmallVector<Instruction *,4> & ToBeDeleted,WorksharingLoopType LoopType)2754 workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
2755                             CanonicalLoopInfo *CLI, Value *Ident,
2756                             Function &OutlinedFn, Type *ParallelTaskPtr,
2757                             const SmallVector<Instruction *, 4> &ToBeDeleted,
2758                             WorksharingLoopType LoopType) {
2759   IRBuilder<> &Builder = OMPIRBuilder->Builder;
2760   BasicBlock *Preheader = CLI->getPreheader();
2761   Value *TripCount = CLI->getTripCount();
2762 
2763   // After loop body outling, the loop body contains only set up
2764   // of loop body argument structure and the call to the outlined
2765   // loop body function. Firstly, we need to move setup of loop body args
2766   // into loop preheader.
2767   Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
2768                     CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
2769 
2770   // The next step is to remove the whole loop. We do not it need anymore.
2771   // That's why make an unconditional branch from loop preheader to loop
2772   // exit block
2773   Builder.restoreIP({Preheader, Preheader->end()});
2774   Preheader->getTerminator()->eraseFromParent();
2775   Builder.CreateBr(CLI->getExit());
2776 
2777   // Delete dead loop blocks
2778   OpenMPIRBuilder::OutlineInfo CleanUpInfo;
2779   SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
2780   SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
2781   CleanUpInfo.EntryBB = CLI->getHeader();
2782   CleanUpInfo.ExitBB = CLI->getExit();
2783   CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
2784   DeleteDeadBlocks(BlocksToBeRemoved);
2785 
2786   // Find the instruction which corresponds to loop body argument structure
2787   // and remove the call to loop body function instruction.
2788   Value *LoopBodyArg;
2789   User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
2790   assert(OutlinedFnUser &&
2791          "Expected unique undroppable user of outlined function");
2792   CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
2793   assert(OutlinedFnCallInstruction && "Expected outlined function call");
2794   assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
2795          "Expected outlined function call to be located in loop preheader");
2796   // Check in case no argument structure has been passed.
2797   if (OutlinedFnCallInstruction->arg_size() > 1)
2798     LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
2799   else
2800     LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
2801   OutlinedFnCallInstruction->eraseFromParent();
2802 
2803   createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
2804                                 LoopBodyArg, ParallelTaskPtr, TripCount,
2805                                 OutlinedFn);
2806 
2807   for (auto &ToBeDeletedItem : ToBeDeleted)
2808     ToBeDeletedItem->eraseFromParent();
2809   CLI->invalidate();
2810 }
2811 
2812 OpenMPIRBuilder::InsertPointTy
applyWorkshareLoopTarget(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,WorksharingLoopType LoopType)2813 OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
2814                                           InsertPointTy AllocaIP,
2815                                           WorksharingLoopType LoopType) {
2816   uint32_t SrcLocStrSize;
2817   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2818   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2819 
2820   OutlineInfo OI;
2821   OI.OuterAllocaBB = CLI->getPreheader();
2822   Function *OuterFn = CLI->getPreheader()->getParent();
2823 
2824   // Instructions which need to be deleted at the end of code generation
2825   SmallVector<Instruction *, 4> ToBeDeleted;
2826 
2827   OI.OuterAllocaBB = AllocaIP.getBlock();
2828 
2829   // Mark the body loop as region which needs to be extracted
2830   OI.EntryBB = CLI->getBody();
2831   OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
2832                                                "omp.prelatch", true);
2833 
2834   // Prepare loop body for extraction
2835   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
2836 
2837   // Insert new loop counter variable which will be used only in loop
2838   // body.
2839   AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
2840   Instruction *NewLoopCntLoad =
2841       Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
2842   // New loop counter instructions are redundant in the loop preheader when
2843   // code generation for workshare loop is finshed. That's why mark them as
2844   // ready for deletion.
2845   ToBeDeleted.push_back(NewLoopCntLoad);
2846   ToBeDeleted.push_back(NewLoopCnt);
2847 
2848   // Analyse loop body region. Find all input variables which are used inside
2849   // loop body region.
2850   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
2851   SmallVector<BasicBlock *, 32> Blocks;
2852   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
2853   SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
2854                                         ParallelRegionBlockSet.end());
2855 
2856   CodeExtractorAnalysisCache CEAC(*OuterFn);
2857   CodeExtractor Extractor(Blocks,
2858                           /* DominatorTree */ nullptr,
2859                           /* AggregateArgs */ true,
2860                           /* BlockFrequencyInfo */ nullptr,
2861                           /* BranchProbabilityInfo */ nullptr,
2862                           /* AssumptionCache */ nullptr,
2863                           /* AllowVarArgs */ true,
2864                           /* AllowAlloca */ true,
2865                           /* AllocationBlock */ CLI->getPreheader(),
2866                           /* Suffix */ ".omp_wsloop",
2867                           /* AggrArgsIn0AddrSpace */ true);
2868 
2869   BasicBlock *CommonExit = nullptr;
2870   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
2871 
2872   // Find allocas outside the loop body region which are used inside loop
2873   // body
2874   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
2875 
2876   // We need to model loop body region as the function f(cnt, loop_arg).
2877   // That's why we replace loop induction variable by the new counter
2878   // which will be one of loop body function argument
2879   SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
2880                             CLI->getIndVar()->user_end());
2881   for (auto Use : Users) {
2882     if (Instruction *Inst = dyn_cast<Instruction>(Use)) {
2883       if (ParallelRegionBlockSet.count(Inst->getParent())) {
2884         Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
2885       }
2886     }
2887   }
2888   // Make sure that loop counter variable is not merged into loop body
2889   // function argument structure and it is passed as separate variable
2890   OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
2891 
2892   // PostOutline CB is invoked when loop body function is outlined and
2893   // loop body is replaced by call to outlined function. We need to add
2894   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
2895   // function will handle loop control logic.
2896   //
2897   OI.PostOutlineCB = [=, ToBeDeletedVec =
2898                              std::move(ToBeDeleted)](Function &OutlinedFn) {
2899     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
2900                                 ToBeDeletedVec, LoopType);
2901   };
2902   addOutlineInfo(std::move(OI));
2903   return CLI->getAfterIP();
2904 }
2905 
applyWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,omp::ScheduleKind SchedKind,Value * ChunkSize,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause,WorksharingLoopType LoopType)2906 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
2907     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2908     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
2909     bool HasSimdModifier, bool HasMonotonicModifier,
2910     bool HasNonmonotonicModifier, bool HasOrderedClause,
2911     WorksharingLoopType LoopType) {
2912   if (Config.isTargetDevice())
2913     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
2914   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
2915       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
2916       HasNonmonotonicModifier, HasOrderedClause);
2917 
2918   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
2919                    OMPScheduleType::ModifierOrdered;
2920   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
2921   case OMPScheduleType::BaseStatic:
2922     assert(!ChunkSize && "No chunk size with static-chunked schedule");
2923     if (IsOrdered)
2924       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2925                                        NeedsBarrier, ChunkSize);
2926     // FIXME: Monotonicity ignored?
2927     return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
2928 
2929   case OMPScheduleType::BaseStaticChunked:
2930     if (IsOrdered)
2931       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2932                                        NeedsBarrier, ChunkSize);
2933     // FIXME: Monotonicity ignored?
2934     return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
2935                                            ChunkSize);
2936 
2937   case OMPScheduleType::BaseRuntime:
2938   case OMPScheduleType::BaseAuto:
2939   case OMPScheduleType::BaseGreedy:
2940   case OMPScheduleType::BaseBalanced:
2941   case OMPScheduleType::BaseSteal:
2942   case OMPScheduleType::BaseGuidedSimd:
2943   case OMPScheduleType::BaseRuntimeSimd:
2944     assert(!ChunkSize &&
2945            "schedule type does not support user-defined chunk sizes");
2946     [[fallthrough]];
2947   case OMPScheduleType::BaseDynamicChunked:
2948   case OMPScheduleType::BaseGuidedChunked:
2949   case OMPScheduleType::BaseGuidedIterativeChunked:
2950   case OMPScheduleType::BaseGuidedAnalyticalChunked:
2951   case OMPScheduleType::BaseStaticBalancedChunked:
2952     return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2953                                      NeedsBarrier, ChunkSize);
2954 
2955   default:
2956     llvm_unreachable("Unknown/unimplemented schedule kind");
2957   }
2958 }
2959 
2960 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
2961 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2962 /// the runtime. Always interpret integers as unsigned similarly to
2963 /// CanonicalLoopInfo.
2964 static FunctionCallee
getKmpcForDynamicInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2965 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2966   unsigned Bitwidth = Ty->getIntegerBitWidth();
2967   if (Bitwidth == 32)
2968     return OMPBuilder.getOrCreateRuntimeFunction(
2969         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
2970   if (Bitwidth == 64)
2971     return OMPBuilder.getOrCreateRuntimeFunction(
2972         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
2973   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2974 }
2975 
2976 /// Returns an LLVM function to call for updating the next loop using OpenMP
2977 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2978 /// the runtime. Always interpret integers as unsigned similarly to
2979 /// CanonicalLoopInfo.
2980 static FunctionCallee
getKmpcForDynamicNextForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2981 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2982   unsigned Bitwidth = Ty->getIntegerBitWidth();
2983   if (Bitwidth == 32)
2984     return OMPBuilder.getOrCreateRuntimeFunction(
2985         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
2986   if (Bitwidth == 64)
2987     return OMPBuilder.getOrCreateRuntimeFunction(
2988         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
2989   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2990 }
2991 
2992 /// Returns an LLVM function to call for finalizing the dynamic loop using
2993 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
2994 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
2995 static FunctionCallee
getKmpcForDynamicFiniForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2996 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2997   unsigned Bitwidth = Ty->getIntegerBitWidth();
2998   if (Bitwidth == 32)
2999     return OMPBuilder.getOrCreateRuntimeFunction(
3000         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
3001   if (Bitwidth == 64)
3002     return OMPBuilder.getOrCreateRuntimeFunction(
3003         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
3004   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
3005 }
3006 
applyDynamicWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,OMPScheduleType SchedType,bool NeedsBarrier,Value * Chunk)3007 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
3008     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
3009     OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
3010   assert(CLI->isValid() && "Requires a valid canonical loop");
3011   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
3012          "Require dedicated allocate IP");
3013   assert(isValidWorkshareLoopScheduleType(SchedType) &&
3014          "Require valid schedule type");
3015 
3016   bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
3017                  OMPScheduleType::ModifierOrdered;
3018 
3019   // Set up the source location value for OpenMP runtime.
3020   Builder.SetCurrentDebugLocation(DL);
3021 
3022   uint32_t SrcLocStrSize;
3023   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
3024   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3025 
3026   // Declare useful OpenMP runtime functions.
3027   Value *IV = CLI->getIndVar();
3028   Type *IVTy = IV->getType();
3029   FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
3030   FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
3031 
3032   // Allocate space for computed loop bounds as expected by the "init" function.
3033   Builder.restoreIP(AllocaIP);
3034   Type *I32Type = Type::getInt32Ty(M.getContext());
3035   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
3036   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
3037   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
3038   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
3039 
3040   // At the end of the preheader, prepare for calling the "init" function by
3041   // storing the current loop bounds into the allocated space. A canonical loop
3042   // always iterates from 0 to trip-count with step 1. Note that "init" expects
3043   // and produces an inclusive upper bound.
3044   BasicBlock *PreHeader = CLI->getPreheader();
3045   Builder.SetInsertPoint(PreHeader->getTerminator());
3046   Constant *One = ConstantInt::get(IVTy, 1);
3047   Builder.CreateStore(One, PLowerBound);
3048   Value *UpperBound = CLI->getTripCount();
3049   Builder.CreateStore(UpperBound, PUpperBound);
3050   Builder.CreateStore(One, PStride);
3051 
3052   BasicBlock *Header = CLI->getHeader();
3053   BasicBlock *Exit = CLI->getExit();
3054   BasicBlock *Cond = CLI->getCond();
3055   BasicBlock *Latch = CLI->getLatch();
3056   InsertPointTy AfterIP = CLI->getAfterIP();
3057 
3058   // The CLI will be "broken" in the code below, as the loop is no longer
3059   // a valid canonical loop.
3060 
3061   if (!Chunk)
3062     Chunk = One;
3063 
3064   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
3065 
3066   Constant *SchedulingType =
3067       ConstantInt::get(I32Type, static_cast<int>(SchedType));
3068 
3069   // Call the "init" function.
3070   Builder.CreateCall(DynamicInit,
3071                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
3072                       UpperBound, /* step */ One, Chunk});
3073 
3074   // An outer loop around the existing one.
3075   BasicBlock *OuterCond = BasicBlock::Create(
3076       PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
3077       PreHeader->getParent());
3078   // This needs to be 32-bit always, so can't use the IVTy Zero above.
3079   Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
3080   Value *Res =
3081       Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
3082                                        PLowerBound, PUpperBound, PStride});
3083   Constant *Zero32 = ConstantInt::get(I32Type, 0);
3084   Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
3085   Value *LowerBound =
3086       Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
3087   Builder.CreateCondBr(MoreWork, Header, Exit);
3088 
3089   // Change PHI-node in loop header to use outer cond rather than preheader,
3090   // and set IV to the LowerBound.
3091   Instruction *Phi = &Header->front();
3092   auto *PI = cast<PHINode>(Phi);
3093   PI->setIncomingBlock(0, OuterCond);
3094   PI->setIncomingValue(0, LowerBound);
3095 
3096   // Then set the pre-header to jump to the OuterCond
3097   Instruction *Term = PreHeader->getTerminator();
3098   auto *Br = cast<BranchInst>(Term);
3099   Br->setSuccessor(0, OuterCond);
3100 
3101   // Modify the inner condition:
3102   // * Use the UpperBound returned from the DynamicNext call.
3103   // * jump to the loop outer loop when done with one of the inner loops.
3104   Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
3105   UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
3106   Instruction *Comp = &*Builder.GetInsertPoint();
3107   auto *CI = cast<CmpInst>(Comp);
3108   CI->setOperand(1, UpperBound);
3109   // Redirect the inner exit to branch to outer condition.
3110   Instruction *Branch = &Cond->back();
3111   auto *BI = cast<BranchInst>(Branch);
3112   assert(BI->getSuccessor(1) == Exit);
3113   BI->setSuccessor(1, OuterCond);
3114 
3115   // Call the "fini" function if "ordered" is present in wsloop directive.
3116   if (Ordered) {
3117     Builder.SetInsertPoint(&Latch->back());
3118     FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
3119     Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
3120   }
3121 
3122   // Add the barrier if requested.
3123   if (NeedsBarrier) {
3124     Builder.SetInsertPoint(&Exit->back());
3125     createBarrier(LocationDescription(Builder.saveIP(), DL),
3126                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
3127                   /* CheckCancelFlag */ false);
3128   }
3129 
3130   CLI->invalidate();
3131   return AfterIP;
3132 }
3133 
3134 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
3135 /// after this \p OldTarget will be orphaned.
redirectAllPredecessorsTo(BasicBlock * OldTarget,BasicBlock * NewTarget,DebugLoc DL)3136 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
3137                                       BasicBlock *NewTarget, DebugLoc DL) {
3138   for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
3139     redirectTo(Pred, NewTarget, DL);
3140 }
3141 
3142 /// Determine which blocks in \p BBs are reachable from outside and remove the
3143 /// ones that are not reachable from the function.
removeUnusedBlocksFromParent(ArrayRef<BasicBlock * > BBs)3144 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
3145   SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
3146   auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
3147     for (Use &U : BB->uses()) {
3148       auto *UseInst = dyn_cast<Instruction>(U.getUser());
3149       if (!UseInst)
3150         continue;
3151       if (BBsToErase.count(UseInst->getParent()))
3152         continue;
3153       return true;
3154     }
3155     return false;
3156   };
3157 
3158   while (true) {
3159     bool Changed = false;
3160     for (BasicBlock *BB : make_early_inc_range(BBsToErase)) {
3161       if (HasRemainingUses(BB)) {
3162         BBsToErase.erase(BB);
3163         Changed = true;
3164       }
3165     }
3166     if (!Changed)
3167       break;
3168   }
3169 
3170   SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
3171   DeleteDeadBlocks(BBVec);
3172 }
3173 
3174 CanonicalLoopInfo *
collapseLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,InsertPointTy ComputeIP)3175 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
3176                                InsertPointTy ComputeIP) {
3177   assert(Loops.size() >= 1 && "At least one loop required");
3178   size_t NumLoops = Loops.size();
3179 
3180   // Nothing to do if there is already just one loop.
3181   if (NumLoops == 1)
3182     return Loops.front();
3183 
3184   CanonicalLoopInfo *Outermost = Loops.front();
3185   CanonicalLoopInfo *Innermost = Loops.back();
3186   BasicBlock *OrigPreheader = Outermost->getPreheader();
3187   BasicBlock *OrigAfter = Outermost->getAfter();
3188   Function *F = OrigPreheader->getParent();
3189 
3190   // Loop control blocks that may become orphaned later.
3191   SmallVector<BasicBlock *, 12> OldControlBBs;
3192   OldControlBBs.reserve(6 * Loops.size());
3193   for (CanonicalLoopInfo *Loop : Loops)
3194     Loop->collectControlBlocks(OldControlBBs);
3195 
3196   // Setup the IRBuilder for inserting the trip count computation.
3197   Builder.SetCurrentDebugLocation(DL);
3198   if (ComputeIP.isSet())
3199     Builder.restoreIP(ComputeIP);
3200   else
3201     Builder.restoreIP(Outermost->getPreheaderIP());
3202 
3203   // Derive the collapsed' loop trip count.
3204   // TODO: Find common/largest indvar type.
3205   Value *CollapsedTripCount = nullptr;
3206   for (CanonicalLoopInfo *L : Loops) {
3207     assert(L->isValid() &&
3208            "All loops to collapse must be valid canonical loops");
3209     Value *OrigTripCount = L->getTripCount();
3210     if (!CollapsedTripCount) {
3211       CollapsedTripCount = OrigTripCount;
3212       continue;
3213     }
3214 
3215     // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
3216     CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
3217                                            {}, /*HasNUW=*/true);
3218   }
3219 
3220   // Create the collapsed loop control flow.
3221   CanonicalLoopInfo *Result =
3222       createLoopSkeleton(DL, CollapsedTripCount, F,
3223                          OrigPreheader->getNextNode(), OrigAfter, "collapsed");
3224 
3225   // Build the collapsed loop body code.
3226   // Start with deriving the input loop induction variables from the collapsed
3227   // one, using a divmod scheme. To preserve the original loops' order, the
3228   // innermost loop use the least significant bits.
3229   Builder.restoreIP(Result->getBodyIP());
3230 
3231   Value *Leftover = Result->getIndVar();
3232   SmallVector<Value *> NewIndVars;
3233   NewIndVars.resize(NumLoops);
3234   for (int i = NumLoops - 1; i >= 1; --i) {
3235     Value *OrigTripCount = Loops[i]->getTripCount();
3236 
3237     Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
3238     NewIndVars[i] = NewIndVar;
3239 
3240     Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
3241   }
3242   // Outermost loop gets all the remaining bits.
3243   NewIndVars[0] = Leftover;
3244 
3245   // Construct the loop body control flow.
3246   // We progressively construct the branch structure following in direction of
3247   // the control flow, from the leading in-between code, the loop nest body, the
3248   // trailing in-between code, and rejoining the collapsed loop's latch.
3249   // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
3250   // the ContinueBlock is set, continue with that block. If ContinuePred, use
3251   // its predecessors as sources.
3252   BasicBlock *ContinueBlock = Result->getBody();
3253   BasicBlock *ContinuePred = nullptr;
3254   auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
3255                                                           BasicBlock *NextSrc) {
3256     if (ContinueBlock)
3257       redirectTo(ContinueBlock, Dest, DL);
3258     else
3259       redirectAllPredecessorsTo(ContinuePred, Dest, DL);
3260 
3261     ContinueBlock = nullptr;
3262     ContinuePred = NextSrc;
3263   };
3264 
3265   // The code before the nested loop of each level.
3266   // Because we are sinking it into the nest, it will be executed more often
3267   // that the original loop. More sophisticated schemes could keep track of what
3268   // the in-between code is and instantiate it only once per thread.
3269   for (size_t i = 0; i < NumLoops - 1; ++i)
3270     ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
3271 
3272   // Connect the loop nest body.
3273   ContinueWith(Innermost->getBody(), Innermost->getLatch());
3274 
3275   // The code after the nested loop at each level.
3276   for (size_t i = NumLoops - 1; i > 0; --i)
3277     ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
3278 
3279   // Connect the finished loop to the collapsed loop latch.
3280   ContinueWith(Result->getLatch(), nullptr);
3281 
3282   // Replace the input loops with the new collapsed loop.
3283   redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
3284   redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
3285 
3286   // Replace the input loop indvars with the derived ones.
3287   for (size_t i = 0; i < NumLoops; ++i)
3288     Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
3289 
3290   // Remove unused parts of the input loops.
3291   removeUnusedBlocksFromParent(OldControlBBs);
3292 
3293   for (CanonicalLoopInfo *L : Loops)
3294     L->invalidate();
3295 
3296 #ifndef NDEBUG
3297   Result->assertOK();
3298 #endif
3299   return Result;
3300 }
3301 
3302 std::vector<CanonicalLoopInfo *>
tileLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,ArrayRef<Value * > TileSizes)3303 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
3304                            ArrayRef<Value *> TileSizes) {
3305   assert(TileSizes.size() == Loops.size() &&
3306          "Must pass as many tile sizes as there are loops");
3307   int NumLoops = Loops.size();
3308   assert(NumLoops >= 1 && "At least one loop to tile required");
3309 
3310   CanonicalLoopInfo *OutermostLoop = Loops.front();
3311   CanonicalLoopInfo *InnermostLoop = Loops.back();
3312   Function *F = OutermostLoop->getBody()->getParent();
3313   BasicBlock *InnerEnter = InnermostLoop->getBody();
3314   BasicBlock *InnerLatch = InnermostLoop->getLatch();
3315 
3316   // Loop control blocks that may become orphaned later.
3317   SmallVector<BasicBlock *, 12> OldControlBBs;
3318   OldControlBBs.reserve(6 * Loops.size());
3319   for (CanonicalLoopInfo *Loop : Loops)
3320     Loop->collectControlBlocks(OldControlBBs);
3321 
3322   // Collect original trip counts and induction variable to be accessible by
3323   // index. Also, the structure of the original loops is not preserved during
3324   // the construction of the tiled loops, so do it before we scavenge the BBs of
3325   // any original CanonicalLoopInfo.
3326   SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
3327   for (CanonicalLoopInfo *L : Loops) {
3328     assert(L->isValid() && "All input loops must be valid canonical loops");
3329     OrigTripCounts.push_back(L->getTripCount());
3330     OrigIndVars.push_back(L->getIndVar());
3331   }
3332 
3333   // Collect the code between loop headers. These may contain SSA definitions
3334   // that are used in the loop nest body. To be usable with in the innermost
3335   // body, these BasicBlocks will be sunk into the loop nest body. That is,
3336   // these instructions may be executed more often than before the tiling.
3337   // TODO: It would be sufficient to only sink them into body of the
3338   // corresponding tile loop.
3339   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
3340   for (int i = 0; i < NumLoops - 1; ++i) {
3341     CanonicalLoopInfo *Surrounding = Loops[i];
3342     CanonicalLoopInfo *Nested = Loops[i + 1];
3343 
3344     BasicBlock *EnterBB = Surrounding->getBody();
3345     BasicBlock *ExitBB = Nested->getHeader();
3346     InbetweenCode.emplace_back(EnterBB, ExitBB);
3347   }
3348 
3349   // Compute the trip counts of the floor loops.
3350   Builder.SetCurrentDebugLocation(DL);
3351   Builder.restoreIP(OutermostLoop->getPreheaderIP());
3352   SmallVector<Value *, 4> FloorCount, FloorRems;
3353   for (int i = 0; i < NumLoops; ++i) {
3354     Value *TileSize = TileSizes[i];
3355     Value *OrigTripCount = OrigTripCounts[i];
3356     Type *IVType = OrigTripCount->getType();
3357 
3358     Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
3359     Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
3360 
3361     // 0 if tripcount divides the tilesize, 1 otherwise.
3362     // 1 means we need an additional iteration for a partial tile.
3363     //
3364     // Unfortunately we cannot just use the roundup-formula
3365     //   (tripcount + tilesize - 1)/tilesize
3366     // because the summation might overflow. We do not want introduce undefined
3367     // behavior when the untiled loop nest did not.
3368     Value *FloorTripOverflow =
3369         Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
3370 
3371     FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
3372     FloorTripCount =
3373         Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
3374                           "omp_floor" + Twine(i) + ".tripcount", true);
3375 
3376     // Remember some values for later use.
3377     FloorCount.push_back(FloorTripCount);
3378     FloorRems.push_back(FloorTripRem);
3379   }
3380 
3381   // Generate the new loop nest, from the outermost to the innermost.
3382   std::vector<CanonicalLoopInfo *> Result;
3383   Result.reserve(NumLoops * 2);
3384 
3385   // The basic block of the surrounding loop that enters the nest generated
3386   // loop.
3387   BasicBlock *Enter = OutermostLoop->getPreheader();
3388 
3389   // The basic block of the surrounding loop where the inner code should
3390   // continue.
3391   BasicBlock *Continue = OutermostLoop->getAfter();
3392 
3393   // Where the next loop basic block should be inserted.
3394   BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
3395 
3396   auto EmbeddNewLoop =
3397       [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
3398           Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
3399     CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
3400         DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
3401     redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
3402     redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
3403 
3404     // Setup the position where the next embedded loop connects to this loop.
3405     Enter = EmbeddedLoop->getBody();
3406     Continue = EmbeddedLoop->getLatch();
3407     OutroInsertBefore = EmbeddedLoop->getLatch();
3408     return EmbeddedLoop;
3409   };
3410 
3411   auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
3412                                                   const Twine &NameBase) {
3413     for (auto P : enumerate(TripCounts)) {
3414       CanonicalLoopInfo *EmbeddedLoop =
3415           EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
3416       Result.push_back(EmbeddedLoop);
3417     }
3418   };
3419 
3420   EmbeddNewLoops(FloorCount, "floor");
3421 
3422   // Within the innermost floor loop, emit the code that computes the tile
3423   // sizes.
3424   Builder.SetInsertPoint(Enter->getTerminator());
3425   SmallVector<Value *, 4> TileCounts;
3426   for (int i = 0; i < NumLoops; ++i) {
3427     CanonicalLoopInfo *FloorLoop = Result[i];
3428     Value *TileSize = TileSizes[i];
3429 
3430     Value *FloorIsEpilogue =
3431         Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
3432     Value *TileTripCount =
3433         Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
3434 
3435     TileCounts.push_back(TileTripCount);
3436   }
3437 
3438   // Create the tile loops.
3439   EmbeddNewLoops(TileCounts, "tile");
3440 
3441   // Insert the inbetween code into the body.
3442   BasicBlock *BodyEnter = Enter;
3443   BasicBlock *BodyEntered = nullptr;
3444   for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
3445     BasicBlock *EnterBB = P.first;
3446     BasicBlock *ExitBB = P.second;
3447 
3448     if (BodyEnter)
3449       redirectTo(BodyEnter, EnterBB, DL);
3450     else
3451       redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
3452 
3453     BodyEnter = nullptr;
3454     BodyEntered = ExitBB;
3455   }
3456 
3457   // Append the original loop nest body into the generated loop nest body.
3458   if (BodyEnter)
3459     redirectTo(BodyEnter, InnerEnter, DL);
3460   else
3461     redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
3462   redirectAllPredecessorsTo(InnerLatch, Continue, DL);
3463 
3464   // Replace the original induction variable with an induction variable computed
3465   // from the tile and floor induction variables.
3466   Builder.restoreIP(Result.back()->getBodyIP());
3467   for (int i = 0; i < NumLoops; ++i) {
3468     CanonicalLoopInfo *FloorLoop = Result[i];
3469     CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
3470     Value *OrigIndVar = OrigIndVars[i];
3471     Value *Size = TileSizes[i];
3472 
3473     Value *Scale =
3474         Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
3475     Value *Shift =
3476         Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
3477     OrigIndVar->replaceAllUsesWith(Shift);
3478   }
3479 
3480   // Remove unused parts of the original loops.
3481   removeUnusedBlocksFromParent(OldControlBBs);
3482 
3483   for (CanonicalLoopInfo *L : Loops)
3484     L->invalidate();
3485 
3486 #ifndef NDEBUG
3487   for (CanonicalLoopInfo *GenL : Result)
3488     GenL->assertOK();
3489 #endif
3490   return Result;
3491 }
3492 
3493 /// Attach metadata \p Properties to the basic block described by \p BB. If the
3494 /// basic block already has metadata, the basic block properties are appended.
addBasicBlockMetadata(BasicBlock * BB,ArrayRef<Metadata * > Properties)3495 static void addBasicBlockMetadata(BasicBlock *BB,
3496                                   ArrayRef<Metadata *> Properties) {
3497   // Nothing to do if no property to attach.
3498   if (Properties.empty())
3499     return;
3500 
3501   LLVMContext &Ctx = BB->getContext();
3502   SmallVector<Metadata *> NewProperties;
3503   NewProperties.push_back(nullptr);
3504 
3505   // If the basic block already has metadata, prepend it to the new metadata.
3506   MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
3507   if (Existing)
3508     append_range(NewProperties, drop_begin(Existing->operands(), 1));
3509 
3510   append_range(NewProperties, Properties);
3511   MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
3512   BasicBlockID->replaceOperandWith(0, BasicBlockID);
3513 
3514   BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
3515 }
3516 
3517 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
3518 /// loop already has metadata, the loop properties are appended.
addLoopMetadata(CanonicalLoopInfo * Loop,ArrayRef<Metadata * > Properties)3519 static void addLoopMetadata(CanonicalLoopInfo *Loop,
3520                             ArrayRef<Metadata *> Properties) {
3521   assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
3522 
3523   // Attach metadata to the loop's latch
3524   BasicBlock *Latch = Loop->getLatch();
3525   assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
3526   addBasicBlockMetadata(Latch, Properties);
3527 }
3528 
3529 /// Attach llvm.access.group metadata to the memref instructions of \p Block
addSimdMetadata(BasicBlock * Block,MDNode * AccessGroup,LoopInfo & LI)3530 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
3531                             LoopInfo &LI) {
3532   for (Instruction &I : *Block) {
3533     if (I.mayReadOrWriteMemory()) {
3534       // TODO: This instruction may already have access group from
3535       // other pragmas e.g. #pragma clang loop vectorize.  Append
3536       // so that the existing metadata is not overwritten.
3537       I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
3538     }
3539   }
3540 }
3541 
unrollLoopFull(DebugLoc,CanonicalLoopInfo * Loop)3542 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
3543   LLVMContext &Ctx = Builder.getContext();
3544   addLoopMetadata(
3545       Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3546              MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
3547 }
3548 
unrollLoopHeuristic(DebugLoc,CanonicalLoopInfo * Loop)3549 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
3550   LLVMContext &Ctx = Builder.getContext();
3551   addLoopMetadata(
3552       Loop, {
3553                 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3554             });
3555 }
3556 
createIfVersion(CanonicalLoopInfo * CanonicalLoop,Value * IfCond,ValueToValueMapTy & VMap,const Twine & NamePrefix)3557 void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
3558                                       Value *IfCond, ValueToValueMapTy &VMap,
3559                                       const Twine &NamePrefix) {
3560   Function *F = CanonicalLoop->getFunction();
3561 
3562   // Define where if branch should be inserted
3563   Instruction *SplitBefore;
3564   if (Instruction::classof(IfCond)) {
3565     SplitBefore = dyn_cast<Instruction>(IfCond);
3566   } else {
3567     SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
3568   }
3569 
3570   // TODO: We should not rely on pass manager. Currently we use pass manager
3571   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
3572   // object. We should have a method  which returns all blocks between
3573   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
3574   FunctionAnalysisManager FAM;
3575   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3576   FAM.registerPass([]() { return LoopAnalysis(); });
3577   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3578 
3579   // Get the loop which needs to be cloned
3580   LoopAnalysis LIA;
3581   LoopInfo &&LI = LIA.run(*F, FAM);
3582   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
3583 
3584   // Create additional blocks for the if statement
3585   BasicBlock *Head = SplitBefore->getParent();
3586   Instruction *HeadOldTerm = Head->getTerminator();
3587   llvm::LLVMContext &C = Head->getContext();
3588   llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
3589       C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
3590   llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
3591       C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
3592 
3593   // Create if condition branch.
3594   Builder.SetInsertPoint(HeadOldTerm);
3595   Instruction *BrInstr =
3596       Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
3597   InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
3598   // Then block contains branch to omp loop which needs to be vectorized
3599   spliceBB(IP, ThenBlock, false);
3600   ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
3601 
3602   Builder.SetInsertPoint(ElseBlock);
3603 
3604   // Clone loop for the else branch
3605   SmallVector<BasicBlock *, 8> NewBlocks;
3606 
3607   VMap[CanonicalLoop->getPreheader()] = ElseBlock;
3608   for (BasicBlock *Block : L->getBlocks()) {
3609     BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
3610     NewBB->moveBefore(CanonicalLoop->getExit());
3611     VMap[Block] = NewBB;
3612     NewBlocks.push_back(NewBB);
3613   }
3614   remapInstructionsInBlocks(NewBlocks, VMap);
3615   Builder.CreateBr(NewBlocks.front());
3616 }
3617 
3618 unsigned
getOpenMPDefaultSimdAlign(const Triple & TargetTriple,const StringMap<bool> & Features)3619 OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
3620                                            const StringMap<bool> &Features) {
3621   if (TargetTriple.isX86()) {
3622     if (Features.lookup("avx512f"))
3623       return 512;
3624     else if (Features.lookup("avx"))
3625       return 256;
3626     return 128;
3627   }
3628   if (TargetTriple.isPPC())
3629     return 128;
3630   if (TargetTriple.isWasm())
3631     return 128;
3632   return 0;
3633 }
3634 
applySimd(CanonicalLoopInfo * CanonicalLoop,MapVector<Value *,Value * > AlignedVars,Value * IfCond,OrderKind Order,ConstantInt * Simdlen,ConstantInt * Safelen)3635 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
3636                                 MapVector<Value *, Value *> AlignedVars,
3637                                 Value *IfCond, OrderKind Order,
3638                                 ConstantInt *Simdlen, ConstantInt *Safelen) {
3639   LLVMContext &Ctx = Builder.getContext();
3640 
3641   Function *F = CanonicalLoop->getFunction();
3642 
3643   // TODO: We should not rely on pass manager. Currently we use pass manager
3644   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
3645   // object. We should have a method  which returns all blocks between
3646   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
3647   FunctionAnalysisManager FAM;
3648   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3649   FAM.registerPass([]() { return LoopAnalysis(); });
3650   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3651 
3652   LoopAnalysis LIA;
3653   LoopInfo &&LI = LIA.run(*F, FAM);
3654 
3655   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
3656   if (AlignedVars.size()) {
3657     InsertPointTy IP = Builder.saveIP();
3658     Builder.SetInsertPoint(CanonicalLoop->getPreheader()->getTerminator());
3659     for (auto &AlignedItem : AlignedVars) {
3660       Value *AlignedPtr = AlignedItem.first;
3661       Value *Alignment = AlignedItem.second;
3662       Builder.CreateAlignmentAssumption(F->getParent()->getDataLayout(),
3663                                         AlignedPtr, Alignment);
3664     }
3665     Builder.restoreIP(IP);
3666   }
3667 
3668   if (IfCond) {
3669     ValueToValueMapTy VMap;
3670     createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
3671     // Add metadata to the cloned loop which disables vectorization
3672     Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
3673     assert(MappedLatch &&
3674            "Cannot find value which corresponds to original loop latch");
3675     assert(isa<BasicBlock>(MappedLatch) &&
3676            "Cannot cast mapped latch block value to BasicBlock");
3677     BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
3678     ConstantAsMetadata *BoolConst =
3679         ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
3680     addBasicBlockMetadata(
3681         NewLatchBlock,
3682         {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
3683                            BoolConst})});
3684   }
3685 
3686   SmallSet<BasicBlock *, 8> Reachable;
3687 
3688   // Get the basic blocks from the loop in which memref instructions
3689   // can be found.
3690   // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
3691   // preferably without running any passes.
3692   for (BasicBlock *Block : L->getBlocks()) {
3693     if (Block == CanonicalLoop->getCond() ||
3694         Block == CanonicalLoop->getHeader())
3695       continue;
3696     Reachable.insert(Block);
3697   }
3698 
3699   SmallVector<Metadata *> LoopMDList;
3700 
3701   // In presence of finite 'safelen', it may be unsafe to mark all
3702   // the memory instructions parallel, because loop-carried
3703   // dependences of 'safelen' iterations are possible.
3704   // If clause order(concurrent) is specified then the memory instructions
3705   // are marked parallel even if 'safelen' is finite.
3706   if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
3707     // Add access group metadata to memory-access instructions.
3708     MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
3709     for (BasicBlock *BB : Reachable)
3710       addSimdMetadata(BB, AccessGroup, LI);
3711     // TODO:  If the loop has existing parallel access metadata, have
3712     // to combine two lists.
3713     LoopMDList.push_back(MDNode::get(
3714         Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
3715   }
3716 
3717   // Use the above access group metadata to create loop level
3718   // metadata, which should be distinct for each loop.
3719   ConstantAsMetadata *BoolConst =
3720       ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
3721   LoopMDList.push_back(MDNode::get(
3722       Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), BoolConst}));
3723 
3724   if (Simdlen || Safelen) {
3725     // If both simdlen and safelen clauses are specified, the value of the
3726     // simdlen parameter must be less than or equal to the value of the safelen
3727     // parameter. Therefore, use safelen only in the absence of simdlen.
3728     ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
3729     LoopMDList.push_back(
3730         MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
3731                           ConstantAsMetadata::get(VectorizeWidth)}));
3732   }
3733 
3734   addLoopMetadata(CanonicalLoop, LoopMDList);
3735 }
3736 
3737 /// Create the TargetMachine object to query the backend for optimization
3738 /// preferences.
3739 ///
3740 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
3741 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
3742 /// needed for the LLVM pass pipline. We use some default options to avoid
3743 /// having to pass too many settings from the frontend that probably do not
3744 /// matter.
3745 ///
3746 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
3747 /// method. If we are going to use TargetMachine for more purposes, especially
3748 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
3749 /// might become be worth requiring front-ends to pass on their TargetMachine,
3750 /// or at least cache it between methods. Note that while fontends such as Clang
3751 /// have just a single main TargetMachine per translation unit, "target-cpu" and
3752 /// "target-features" that determine the TargetMachine are per-function and can
3753 /// be overrided using __attribute__((target("OPTIONS"))).
3754 static std::unique_ptr<TargetMachine>
createTargetMachine(Function * F,CodeGenOptLevel OptLevel)3755 createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
3756   Module *M = F->getParent();
3757 
3758   StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
3759   StringRef Features = F->getFnAttribute("target-features").getValueAsString();
3760   const std::string &Triple = M->getTargetTriple();
3761 
3762   std::string Error;
3763   const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
3764   if (!TheTarget)
3765     return {};
3766 
3767   llvm::TargetOptions Options;
3768   return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
3769       Triple, CPU, Features, Options, /*RelocModel=*/std::nullopt,
3770       /*CodeModel=*/std::nullopt, OptLevel));
3771 }
3772 
3773 /// Heuristically determine the best-performant unroll factor for \p CLI. This
3774 /// depends on the target processor. We are re-using the same heuristics as the
3775 /// LoopUnrollPass.
computeHeuristicUnrollFactor(CanonicalLoopInfo * CLI)3776 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
3777   Function *F = CLI->getFunction();
3778 
3779   // Assume the user requests the most aggressive unrolling, even if the rest of
3780   // the code is optimized using a lower setting.
3781   CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
3782   std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
3783 
3784   FunctionAnalysisManager FAM;
3785   FAM.registerPass([]() { return TargetLibraryAnalysis(); });
3786   FAM.registerPass([]() { return AssumptionAnalysis(); });
3787   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3788   FAM.registerPass([]() { return LoopAnalysis(); });
3789   FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
3790   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3791   TargetIRAnalysis TIRA;
3792   if (TM)
3793     TIRA = TargetIRAnalysis(
3794         [&](const Function &F) { return TM->getTargetTransformInfo(F); });
3795   FAM.registerPass([&]() { return TIRA; });
3796 
3797   TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
3798   ScalarEvolutionAnalysis SEA;
3799   ScalarEvolution &&SE = SEA.run(*F, FAM);
3800   DominatorTreeAnalysis DTA;
3801   DominatorTree &&DT = DTA.run(*F, FAM);
3802   LoopAnalysis LIA;
3803   LoopInfo &&LI = LIA.run(*F, FAM);
3804   AssumptionAnalysis ACT;
3805   AssumptionCache &&AC = ACT.run(*F, FAM);
3806   OptimizationRemarkEmitter ORE{F};
3807 
3808   Loop *L = LI.getLoopFor(CLI->getHeader());
3809   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
3810 
3811   TargetTransformInfo::UnrollingPreferences UP =
3812       gatherUnrollingPreferences(L, SE, TTI,
3813                                  /*BlockFrequencyInfo=*/nullptr,
3814                                  /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
3815                                  /*UserThreshold=*/std::nullopt,
3816                                  /*UserCount=*/std::nullopt,
3817                                  /*UserAllowPartial=*/true,
3818                                  /*UserAllowRuntime=*/true,
3819                                  /*UserUpperBound=*/std::nullopt,
3820                                  /*UserFullUnrollMaxCount=*/std::nullopt);
3821 
3822   UP.Force = true;
3823 
3824   // Account for additional optimizations taking place before the LoopUnrollPass
3825   // would unroll the loop.
3826   UP.Threshold *= UnrollThresholdFactor;
3827   UP.PartialThreshold *= UnrollThresholdFactor;
3828 
3829   // Use normal unroll factors even if the rest of the code is optimized for
3830   // size.
3831   UP.OptSizeThreshold = UP.Threshold;
3832   UP.PartialOptSizeThreshold = UP.PartialThreshold;
3833 
3834   LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
3835                     << "  Threshold=" << UP.Threshold << "\n"
3836                     << "  PartialThreshold=" << UP.PartialThreshold << "\n"
3837                     << "  OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
3838                     << "  PartialOptSizeThreshold="
3839                     << UP.PartialOptSizeThreshold << "\n");
3840 
3841   // Disable peeling.
3842   TargetTransformInfo::PeelingPreferences PP =
3843       gatherPeelingPreferences(L, SE, TTI,
3844                                /*UserAllowPeeling=*/false,
3845                                /*UserAllowProfileBasedPeeling=*/false,
3846                                /*UnrollingSpecficValues=*/false);
3847 
3848   SmallPtrSet<const Value *, 32> EphValues;
3849   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
3850 
3851   // Assume that reads and writes to stack variables can be eliminated by
3852   // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
3853   // size.
3854   for (BasicBlock *BB : L->blocks()) {
3855     for (Instruction &I : *BB) {
3856       Value *Ptr;
3857       if (auto *Load = dyn_cast<LoadInst>(&I)) {
3858         Ptr = Load->getPointerOperand();
3859       } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
3860         Ptr = Store->getPointerOperand();
3861       } else
3862         continue;
3863 
3864       Ptr = Ptr->stripPointerCasts();
3865 
3866       if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
3867         if (Alloca->getParent() == &F->getEntryBlock())
3868           EphValues.insert(&I);
3869       }
3870     }
3871   }
3872 
3873   UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
3874 
3875   // Loop is not unrollable if the loop contains certain instructions.
3876   if (!UCE.canUnroll() || UCE.Convergent) {
3877     LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
3878     return 1;
3879   }
3880 
3881   LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
3882                     << "\n");
3883 
3884   // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
3885   // be able to use it.
3886   int TripCount = 0;
3887   int MaxTripCount = 0;
3888   bool MaxOrZero = false;
3889   unsigned TripMultiple = 0;
3890 
3891   bool UseUpperBound = false;
3892   computeUnrollCount(L, TTI, DT, &LI, &AC, SE, EphValues, &ORE, TripCount,
3893                      MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
3894                      UseUpperBound);
3895   unsigned Factor = UP.Count;
3896   LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
3897 
3898   // This function returns 1 to signal to not unroll a loop.
3899   if (Factor == 0)
3900     return 1;
3901   return Factor;
3902 }
3903 
unrollLoopPartial(DebugLoc DL,CanonicalLoopInfo * Loop,int32_t Factor,CanonicalLoopInfo ** UnrolledCLI)3904 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
3905                                         int32_t Factor,
3906                                         CanonicalLoopInfo **UnrolledCLI) {
3907   assert(Factor >= 0 && "Unroll factor must not be negative");
3908 
3909   Function *F = Loop->getFunction();
3910   LLVMContext &Ctx = F->getContext();
3911 
3912   // If the unrolled loop is not used for another loop-associated directive, it
3913   // is sufficient to add metadata for the LoopUnrollPass.
3914   if (!UnrolledCLI) {
3915     SmallVector<Metadata *, 2> LoopMetadata;
3916     LoopMetadata.push_back(
3917         MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
3918 
3919     if (Factor >= 1) {
3920       ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3921           ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3922       LoopMetadata.push_back(MDNode::get(
3923           Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
3924     }
3925 
3926     addLoopMetadata(Loop, LoopMetadata);
3927     return;
3928   }
3929 
3930   // Heuristically determine the unroll factor.
3931   if (Factor == 0)
3932     Factor = computeHeuristicUnrollFactor(Loop);
3933 
3934   // No change required with unroll factor 1.
3935   if (Factor == 1) {
3936     *UnrolledCLI = Loop;
3937     return;
3938   }
3939 
3940   assert(Factor >= 2 &&
3941          "unrolling only makes sense with a factor of 2 or larger");
3942 
3943   Type *IndVarTy = Loop->getIndVarType();
3944 
3945   // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
3946   // unroll the inner loop.
3947   Value *FactorVal =
3948       ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
3949                                        /*isSigned=*/false));
3950   std::vector<CanonicalLoopInfo *> LoopNest =
3951       tileLoops(DL, {Loop}, {FactorVal});
3952   assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
3953   *UnrolledCLI = LoopNest[0];
3954   CanonicalLoopInfo *InnerLoop = LoopNest[1];
3955 
3956   // LoopUnrollPass can only fully unroll loops with constant trip count.
3957   // Unroll by the unroll factor with a fallback epilog for the remainder
3958   // iterations if necessary.
3959   ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3960       ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3961   addLoopMetadata(
3962       InnerLoop,
3963       {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3964        MDNode::get(
3965            Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
3966 
3967 #ifndef NDEBUG
3968   (*UnrolledCLI)->assertOK();
3969 #endif
3970 }
3971 
3972 OpenMPIRBuilder::InsertPointTy
createCopyPrivate(const LocationDescription & Loc,llvm::Value * BufSize,llvm::Value * CpyBuf,llvm::Value * CpyFn,llvm::Value * DidIt)3973 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
3974                                    llvm::Value *BufSize, llvm::Value *CpyBuf,
3975                                    llvm::Value *CpyFn, llvm::Value *DidIt) {
3976   if (!updateToLocation(Loc))
3977     return Loc.IP;
3978 
3979   uint32_t SrcLocStrSize;
3980   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3981   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3982   Value *ThreadId = getOrCreateThreadID(Ident);
3983 
3984   llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
3985 
3986   Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
3987 
3988   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
3989   Builder.CreateCall(Fn, Args);
3990 
3991   return Builder.saveIP();
3992 }
3993 
createSingle(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsNowait,llvm::Value * DidIt)3994 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
3995     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3996     FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) {
3997 
3998   if (!updateToLocation(Loc))
3999     return Loc.IP;
4000 
4001   // If needed (i.e. not null), initialize `DidIt` with 0
4002   if (DidIt) {
4003     Builder.CreateStore(Builder.getInt32(0), DidIt);
4004   }
4005 
4006   Directive OMPD = Directive::OMPD_single;
4007   uint32_t SrcLocStrSize;
4008   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4009   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4010   Value *ThreadId = getOrCreateThreadID(Ident);
4011   Value *Args[] = {Ident, ThreadId};
4012 
4013   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
4014   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
4015 
4016   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
4017   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
4018 
4019   // generates the following:
4020   // if (__kmpc_single()) {
4021   //		.... single region ...
4022   // 		__kmpc_end_single
4023   // }
4024   // __kmpc_barrier
4025 
4026   EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4027                        /*Conditional*/ true,
4028                        /*hasFinalize*/ true);
4029   if (!IsNowait)
4030     createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
4031                   omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
4032                   /* CheckCancelFlag */ false);
4033   return Builder.saveIP();
4034 }
4035 
createCritical(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,StringRef CriticalName,Value * HintInst)4036 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
4037     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
4038     FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
4039 
4040   if (!updateToLocation(Loc))
4041     return Loc.IP;
4042 
4043   Directive OMPD = Directive::OMPD_critical;
4044   uint32_t SrcLocStrSize;
4045   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4046   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4047   Value *ThreadId = getOrCreateThreadID(Ident);
4048   Value *LockVar = getOMPCriticalRegionLock(CriticalName);
4049   Value *Args[] = {Ident, ThreadId, LockVar};
4050 
4051   SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
4052   Function *RTFn = nullptr;
4053   if (HintInst) {
4054     // Add Hint to entry Args and create call
4055     EnterArgs.push_back(HintInst);
4056     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
4057   } else {
4058     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
4059   }
4060   Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
4061 
4062   Function *ExitRTLFn =
4063       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
4064   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
4065 
4066   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4067                               /*Conditional*/ false, /*hasFinalize*/ true);
4068 }
4069 
4070 OpenMPIRBuilder::InsertPointTy
createOrderedDepend(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumLoops,ArrayRef<llvm::Value * > StoreValues,const Twine & Name,bool IsDependSource)4071 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
4072                                      InsertPointTy AllocaIP, unsigned NumLoops,
4073                                      ArrayRef<llvm::Value *> StoreValues,
4074                                      const Twine &Name, bool IsDependSource) {
4075   assert(
4076       llvm::all_of(StoreValues,
4077                    [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
4078       "OpenMP runtime requires depend vec with i64 type");
4079 
4080   if (!updateToLocation(Loc))
4081     return Loc.IP;
4082 
4083   // Allocate space for vector and generate alloc instruction.
4084   auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
4085   Builder.restoreIP(AllocaIP);
4086   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
4087   ArgsBase->setAlignment(Align(8));
4088   Builder.restoreIP(Loc.IP);
4089 
4090   // Store the index value with offset in depend vector.
4091   for (unsigned I = 0; I < NumLoops; ++I) {
4092     Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
4093         ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
4094     StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
4095     STInst->setAlignment(Align(8));
4096   }
4097 
4098   Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
4099       ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
4100 
4101   uint32_t SrcLocStrSize;
4102   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4103   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4104   Value *ThreadId = getOrCreateThreadID(Ident);
4105   Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
4106 
4107   Function *RTLFn = nullptr;
4108   if (IsDependSource)
4109     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
4110   else
4111     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
4112   Builder.CreateCall(RTLFn, Args);
4113 
4114   return Builder.saveIP();
4115 }
4116 
createOrderedThreadsSimd(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsThreads)4117 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
4118     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
4119     FinalizeCallbackTy FiniCB, bool IsThreads) {
4120   if (!updateToLocation(Loc))
4121     return Loc.IP;
4122 
4123   Directive OMPD = Directive::OMPD_ordered;
4124   Instruction *EntryCall = nullptr;
4125   Instruction *ExitCall = nullptr;
4126 
4127   if (IsThreads) {
4128     uint32_t SrcLocStrSize;
4129     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4130     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4131     Value *ThreadId = getOrCreateThreadID(Ident);
4132     Value *Args[] = {Ident, ThreadId};
4133 
4134     Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
4135     EntryCall = Builder.CreateCall(EntryRTLFn, Args);
4136 
4137     Function *ExitRTLFn =
4138         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
4139     ExitCall = Builder.CreateCall(ExitRTLFn, Args);
4140   }
4141 
4142   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4143                               /*Conditional*/ false, /*hasFinalize*/ true);
4144 }
4145 
EmitOMPInlinedRegion(Directive OMPD,Instruction * EntryCall,Instruction * ExitCall,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool Conditional,bool HasFinalize,bool IsCancellable)4146 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
4147     Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
4148     BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
4149     bool HasFinalize, bool IsCancellable) {
4150 
4151   if (HasFinalize)
4152     FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
4153 
4154   // Create inlined region's entry and body blocks, in preparation
4155   // for conditional creation
4156   BasicBlock *EntryBB = Builder.GetInsertBlock();
4157   Instruction *SplitPos = EntryBB->getTerminator();
4158   if (!isa_and_nonnull<BranchInst>(SplitPos))
4159     SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
4160   BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
4161   BasicBlock *FiniBB =
4162       EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
4163 
4164   Builder.SetInsertPoint(EntryBB->getTerminator());
4165   emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
4166 
4167   // generate body
4168   BodyGenCB(/* AllocaIP */ InsertPointTy(),
4169             /* CodeGenIP */ Builder.saveIP());
4170 
4171   // emit exit call and do any needed finalization.
4172   auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
4173   assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
4174          FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
4175          "Unexpected control flow graph state!!");
4176   emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
4177   assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
4178          "Unexpected Control Flow State!");
4179   MergeBlockIntoPredecessor(FiniBB);
4180 
4181   // If we are skipping the region of a non conditional, remove the exit
4182   // block, and clear the builder's insertion point.
4183   assert(SplitPos->getParent() == ExitBB &&
4184          "Unexpected Insertion point location!");
4185   auto merged = MergeBlockIntoPredecessor(ExitBB);
4186   BasicBlock *ExitPredBB = SplitPos->getParent();
4187   auto InsertBB = merged ? ExitPredBB : ExitBB;
4188   if (!isa_and_nonnull<BranchInst>(SplitPos))
4189     SplitPos->eraseFromParent();
4190   Builder.SetInsertPoint(InsertBB);
4191 
4192   return Builder.saveIP();
4193 }
4194 
emitCommonDirectiveEntry(Directive OMPD,Value * EntryCall,BasicBlock * ExitBB,bool Conditional)4195 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
4196     Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
4197   // if nothing to do, Return current insertion point.
4198   if (!Conditional || !EntryCall)
4199     return Builder.saveIP();
4200 
4201   BasicBlock *EntryBB = Builder.GetInsertBlock();
4202   Value *CallBool = Builder.CreateIsNotNull(EntryCall);
4203   auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
4204   auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
4205 
4206   // Emit thenBB and set the Builder's insertion point there for
4207   // body generation next. Place the block after the current block.
4208   Function *CurFn = EntryBB->getParent();
4209   CurFn->insert(std::next(EntryBB->getIterator()), ThenBB);
4210 
4211   // Move Entry branch to end of ThenBB, and replace with conditional
4212   // branch (If-stmt)
4213   Instruction *EntryBBTI = EntryBB->getTerminator();
4214   Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
4215   EntryBBTI->removeFromParent();
4216   Builder.SetInsertPoint(UI);
4217   Builder.Insert(EntryBBTI);
4218   UI->eraseFromParent();
4219   Builder.SetInsertPoint(ThenBB->getTerminator());
4220 
4221   // return an insertion point to ExitBB.
4222   return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
4223 }
4224 
emitCommonDirectiveExit(omp::Directive OMPD,InsertPointTy FinIP,Instruction * ExitCall,bool HasFinalize)4225 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
4226     omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
4227     bool HasFinalize) {
4228 
4229   Builder.restoreIP(FinIP);
4230 
4231   // If there is finalization to do, emit it before the exit call
4232   if (HasFinalize) {
4233     assert(!FinalizationStack.empty() &&
4234            "Unexpected finalization stack state!");
4235 
4236     FinalizationInfo Fi = FinalizationStack.pop_back_val();
4237     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
4238 
4239     Fi.FiniCB(FinIP);
4240 
4241     BasicBlock *FiniBB = FinIP.getBlock();
4242     Instruction *FiniBBTI = FiniBB->getTerminator();
4243 
4244     // set Builder IP for call creation
4245     Builder.SetInsertPoint(FiniBBTI);
4246   }
4247 
4248   if (!ExitCall)
4249     return Builder.saveIP();
4250 
4251   // place the Exitcall as last instruction before Finalization block terminator
4252   ExitCall->removeFromParent();
4253   Builder.Insert(ExitCall);
4254 
4255   return IRBuilder<>::InsertPoint(ExitCall->getParent(),
4256                                   ExitCall->getIterator());
4257 }
4258 
createCopyinClauseBlocks(InsertPointTy IP,Value * MasterAddr,Value * PrivateAddr,llvm::IntegerType * IntPtrTy,bool BranchtoEnd)4259 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
4260     InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
4261     llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
4262   if (!IP.isSet())
4263     return IP;
4264 
4265   IRBuilder<>::InsertPointGuard IPG(Builder);
4266 
4267   // creates the following CFG structure
4268   //	   OMP_Entry : (MasterAddr != PrivateAddr)?
4269   //       F     T
4270   //       |      \
4271   //       |     copin.not.master
4272   //       |      /
4273   //       v     /
4274   //   copyin.not.master.end
4275   //		     |
4276   //         v
4277   //   OMP.Entry.Next
4278 
4279   BasicBlock *OMP_Entry = IP.getBlock();
4280   Function *CurFn = OMP_Entry->getParent();
4281   BasicBlock *CopyBegin =
4282       BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
4283   BasicBlock *CopyEnd = nullptr;
4284 
4285   // If entry block is terminated, split to preserve the branch to following
4286   // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
4287   if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
4288     CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
4289                                          "copyin.not.master.end");
4290     OMP_Entry->getTerminator()->eraseFromParent();
4291   } else {
4292     CopyEnd =
4293         BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
4294   }
4295 
4296   Builder.SetInsertPoint(OMP_Entry);
4297   Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
4298   Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
4299   Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
4300   Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
4301 
4302   Builder.SetInsertPoint(CopyBegin);
4303   if (BranchtoEnd)
4304     Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
4305 
4306   return Builder.saveIP();
4307 }
4308 
createOMPAlloc(const LocationDescription & Loc,Value * Size,Value * Allocator,std::string Name)4309 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
4310                                           Value *Size, Value *Allocator,
4311                                           std::string Name) {
4312   IRBuilder<>::InsertPointGuard IPG(Builder);
4313   Builder.restoreIP(Loc.IP);
4314 
4315   uint32_t SrcLocStrSize;
4316   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4317   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4318   Value *ThreadId = getOrCreateThreadID(Ident);
4319   Value *Args[] = {ThreadId, Size, Allocator};
4320 
4321   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
4322 
4323   return Builder.CreateCall(Fn, Args, Name);
4324 }
4325 
createOMPFree(const LocationDescription & Loc,Value * Addr,Value * Allocator,std::string Name)4326 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
4327                                          Value *Addr, Value *Allocator,
4328                                          std::string Name) {
4329   IRBuilder<>::InsertPointGuard IPG(Builder);
4330   Builder.restoreIP(Loc.IP);
4331 
4332   uint32_t SrcLocStrSize;
4333   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4334   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4335   Value *ThreadId = getOrCreateThreadID(Ident);
4336   Value *Args[] = {ThreadId, Addr, Allocator};
4337   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
4338   return Builder.CreateCall(Fn, Args, Name);
4339 }
4340 
createOMPInteropInit(const LocationDescription & Loc,Value * InteropVar,omp::OMPInteropType InteropType,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)4341 CallInst *OpenMPIRBuilder::createOMPInteropInit(
4342     const LocationDescription &Loc, Value *InteropVar,
4343     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
4344     Value *DependenceAddress, bool HaveNowaitClause) {
4345   IRBuilder<>::InsertPointGuard IPG(Builder);
4346   Builder.restoreIP(Loc.IP);
4347 
4348   uint32_t SrcLocStrSize;
4349   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4350   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4351   Value *ThreadId = getOrCreateThreadID(Ident);
4352   if (Device == nullptr)
4353     Device = ConstantInt::get(Int32, -1);
4354   Constant *InteropTypeVal = ConstantInt::get(Int32, (int)InteropType);
4355   if (NumDependences == nullptr) {
4356     NumDependences = ConstantInt::get(Int32, 0);
4357     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
4358     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
4359   }
4360   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
4361   Value *Args[] = {
4362       Ident,  ThreadId,       InteropVar,        InteropTypeVal,
4363       Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
4364 
4365   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
4366 
4367   return Builder.CreateCall(Fn, Args);
4368 }
4369 
createOMPInteropDestroy(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)4370 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
4371     const LocationDescription &Loc, Value *InteropVar, Value *Device,
4372     Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
4373   IRBuilder<>::InsertPointGuard IPG(Builder);
4374   Builder.restoreIP(Loc.IP);
4375 
4376   uint32_t SrcLocStrSize;
4377   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4378   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4379   Value *ThreadId = getOrCreateThreadID(Ident);
4380   if (Device == nullptr)
4381     Device = ConstantInt::get(Int32, -1);
4382   if (NumDependences == nullptr) {
4383     NumDependences = ConstantInt::get(Int32, 0);
4384     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
4385     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
4386   }
4387   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
4388   Value *Args[] = {
4389       Ident,          ThreadId,          InteropVar,         Device,
4390       NumDependences, DependenceAddress, HaveNowaitClauseVal};
4391 
4392   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
4393 
4394   return Builder.CreateCall(Fn, Args);
4395 }
4396 
createOMPInteropUse(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)4397 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
4398                                                Value *InteropVar, Value *Device,
4399                                                Value *NumDependences,
4400                                                Value *DependenceAddress,
4401                                                bool HaveNowaitClause) {
4402   IRBuilder<>::InsertPointGuard IPG(Builder);
4403   Builder.restoreIP(Loc.IP);
4404   uint32_t SrcLocStrSize;
4405   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4406   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4407   Value *ThreadId = getOrCreateThreadID(Ident);
4408   if (Device == nullptr)
4409     Device = ConstantInt::get(Int32, -1);
4410   if (NumDependences == nullptr) {
4411     NumDependences = ConstantInt::get(Int32, 0);
4412     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
4413     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
4414   }
4415   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
4416   Value *Args[] = {
4417       Ident,          ThreadId,          InteropVar,         Device,
4418       NumDependences, DependenceAddress, HaveNowaitClauseVal};
4419 
4420   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
4421 
4422   return Builder.CreateCall(Fn, Args);
4423 }
4424 
createCachedThreadPrivate(const LocationDescription & Loc,llvm::Value * Pointer,llvm::ConstantInt * Size,const llvm::Twine & Name)4425 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
4426     const LocationDescription &Loc, llvm::Value *Pointer,
4427     llvm::ConstantInt *Size, const llvm::Twine &Name) {
4428   IRBuilder<>::InsertPointGuard IPG(Builder);
4429   Builder.restoreIP(Loc.IP);
4430 
4431   uint32_t SrcLocStrSize;
4432   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4433   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4434   Value *ThreadId = getOrCreateThreadID(Ident);
4435   Constant *ThreadPrivateCache =
4436       getOrCreateInternalVariable(Int8PtrPtr, Name.str());
4437   llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
4438 
4439   Function *Fn =
4440       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
4441 
4442   return Builder.CreateCall(Fn, Args);
4443 }
4444 
4445 OpenMPIRBuilder::InsertPointTy
createTargetInit(const LocationDescription & Loc,bool IsSPMD,int32_t MinThreadsVal,int32_t MaxThreadsVal,int32_t MinTeamsVal,int32_t MaxTeamsVal)4446 OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
4447                                   int32_t MinThreadsVal, int32_t MaxThreadsVal,
4448                                   int32_t MinTeamsVal, int32_t MaxTeamsVal) {
4449   if (!updateToLocation(Loc))
4450     return Loc.IP;
4451 
4452   uint32_t SrcLocStrSize;
4453   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4454   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4455   Constant *IsSPMDVal = ConstantInt::getSigned(
4456       Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
4457   Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
4458   Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
4459   Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
4460 
4461   Function *Kernel = Builder.GetInsertBlock()->getParent();
4462 
4463   // Manifest the launch configuration in the metadata matching the kernel
4464   // environment.
4465   if (MinTeamsVal > 1 || MaxTeamsVal > 0)
4466     writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
4467 
4468   // For max values, < 0 means unset, == 0 means set but unknown.
4469   if (MaxThreadsVal < 0)
4470     MaxThreadsVal = std::max(
4471         int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
4472 
4473   if (MaxThreadsVal > 0)
4474     writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
4475 
4476   Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
4477   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
4478   Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
4479   Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
4480   Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
4481   Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
4482 
4483   // We need to strip the debug prefix to get the correct kernel name.
4484   StringRef KernelName = Kernel->getName();
4485   const std::string DebugPrefix = "_debug__";
4486   if (KernelName.ends_with(DebugPrefix))
4487     KernelName = KernelName.drop_back(DebugPrefix.length());
4488 
4489   Function *Fn = getOrCreateRuntimeFunctionPtr(
4490       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
4491   const DataLayout &DL = Fn->getParent()->getDataLayout();
4492 
4493   Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
4494   Constant *DynamicEnvironmentInitializer =
4495       ConstantStruct::get(DynamicEnvironment, {DebugIndentionLevelVal});
4496   GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
4497       M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
4498       DynamicEnvironmentInitializer, DynamicEnvironmentName,
4499       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
4500       DL.getDefaultGlobalsAddressSpace());
4501   DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
4502 
4503   Constant *DynamicEnvironment =
4504       DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
4505           ? DynamicEnvironmentGV
4506           : ConstantExpr::getAddrSpaceCast(DynamicEnvironmentGV,
4507                                            DynamicEnvironmentPtr);
4508 
4509   Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
4510       ConfigurationEnvironment, {
4511                                     UseGenericStateMachineVal,
4512                                     MayUseNestedParallelismVal,
4513                                     IsSPMDVal,
4514                                     MinThreads,
4515                                     MaxThreads,
4516                                     MinTeams,
4517                                     MaxTeams,
4518                                     ReductionDataSize,
4519                                     ReductionBufferLength,
4520                                 });
4521   Constant *KernelEnvironmentInitializer = ConstantStruct::get(
4522       KernelEnvironment, {
4523                              ConfigurationEnvironmentInitializer,
4524                              Ident,
4525                              DynamicEnvironment,
4526                          });
4527   Twine KernelEnvironmentName = KernelName + "_kernel_environment";
4528   GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
4529       M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
4530       KernelEnvironmentInitializer, KernelEnvironmentName,
4531       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
4532       DL.getDefaultGlobalsAddressSpace());
4533   KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
4534 
4535   Constant *KernelEnvironment =
4536       KernelEnvironmentGV->getType() == KernelEnvironmentPtr
4537           ? KernelEnvironmentGV
4538           : ConstantExpr::getAddrSpaceCast(KernelEnvironmentGV,
4539                                            KernelEnvironmentPtr);
4540   Value *KernelLaunchEnvironment = Kernel->getArg(0);
4541   CallInst *ThreadKind =
4542       Builder.CreateCall(Fn, {KernelEnvironment, KernelLaunchEnvironment});
4543 
4544   Value *ExecUserCode = Builder.CreateICmpEQ(
4545       ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
4546       "exec_user_code");
4547 
4548   // ThreadKind = __kmpc_target_init(...)
4549   // if (ThreadKind == -1)
4550   //   user_code
4551   // else
4552   //   return;
4553 
4554   auto *UI = Builder.CreateUnreachable();
4555   BasicBlock *CheckBB = UI->getParent();
4556   BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
4557 
4558   BasicBlock *WorkerExitBB = BasicBlock::Create(
4559       CheckBB->getContext(), "worker.exit", CheckBB->getParent());
4560   Builder.SetInsertPoint(WorkerExitBB);
4561   Builder.CreateRetVoid();
4562 
4563   auto *CheckBBTI = CheckBB->getTerminator();
4564   Builder.SetInsertPoint(CheckBBTI);
4565   Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
4566 
4567   CheckBBTI->eraseFromParent();
4568   UI->eraseFromParent();
4569 
4570   // Continue in the "user_code" block, see diagram above and in
4571   // openmp/libomptarget/deviceRTLs/common/include/target.h .
4572   return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
4573 }
4574 
createTargetDeinit(const LocationDescription & Loc,int32_t TeamsReductionDataSize,int32_t TeamsReductionBufferLength)4575 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
4576                                          int32_t TeamsReductionDataSize,
4577                                          int32_t TeamsReductionBufferLength) {
4578   if (!updateToLocation(Loc))
4579     return;
4580 
4581   Function *Fn = getOrCreateRuntimeFunctionPtr(
4582       omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
4583 
4584   Builder.CreateCall(Fn, {});
4585 
4586   if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
4587     return;
4588 
4589   Function *Kernel = Builder.GetInsertBlock()->getParent();
4590   // We need to strip the debug prefix to get the correct kernel name.
4591   StringRef KernelName = Kernel->getName();
4592   const std::string DebugPrefix = "_debug__";
4593   if (KernelName.ends_with(DebugPrefix))
4594     KernelName = KernelName.drop_back(DebugPrefix.length());
4595   auto *KernelEnvironmentGV =
4596       M.getNamedGlobal((KernelName + "_kernel_environment").str());
4597   assert(KernelEnvironmentGV && "Expected kernel environment global\n");
4598   auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
4599   auto *NewInitializer = ConstantFoldInsertValueInstruction(
4600       KernelEnvironmentInitializer,
4601       ConstantInt::get(Int32, TeamsReductionDataSize), {0, 7});
4602   NewInitializer = ConstantFoldInsertValueInstruction(
4603       NewInitializer, ConstantInt::get(Int32, TeamsReductionBufferLength),
4604       {0, 8});
4605   KernelEnvironmentGV->setInitializer(NewInitializer);
4606 }
4607 
getNVPTXMDNode(Function & Kernel,StringRef Name)4608 static MDNode *getNVPTXMDNode(Function &Kernel, StringRef Name) {
4609   Module &M = *Kernel.getParent();
4610   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4611   for (auto *Op : MD->operands()) {
4612     if (Op->getNumOperands() != 3)
4613       continue;
4614     auto *KernelOp = dyn_cast<ConstantAsMetadata>(Op->getOperand(0));
4615     if (!KernelOp || KernelOp->getValue() != &Kernel)
4616       continue;
4617     auto *Prop = dyn_cast<MDString>(Op->getOperand(1));
4618     if (!Prop || Prop->getString() != Name)
4619       continue;
4620     return Op;
4621   }
4622   return nullptr;
4623 }
4624 
updateNVPTXMetadata(Function & Kernel,StringRef Name,int32_t Value,bool Min)4625 static void updateNVPTXMetadata(Function &Kernel, StringRef Name, int32_t Value,
4626                                 bool Min) {
4627   // Update the "maxntidx" metadata for NVIDIA, or add it.
4628   MDNode *ExistingOp = getNVPTXMDNode(Kernel, Name);
4629   if (ExistingOp) {
4630     auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
4631     int32_t OldLimit = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
4632     ExistingOp->replaceOperandWith(
4633         2, ConstantAsMetadata::get(ConstantInt::get(
4634                OldVal->getValue()->getType(),
4635                Min ? std::min(OldLimit, Value) : std::max(OldLimit, Value))));
4636   } else {
4637     LLVMContext &Ctx = Kernel.getContext();
4638     Metadata *MDVals[] = {ConstantAsMetadata::get(&Kernel),
4639                           MDString::get(Ctx, Name),
4640                           ConstantAsMetadata::get(
4641                               ConstantInt::get(Type::getInt32Ty(Ctx), Value))};
4642     // Append metadata to nvvm.annotations
4643     Module &M = *Kernel.getParent();
4644     NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4645     MD->addOperand(MDNode::get(Ctx, MDVals));
4646   }
4647 }
4648 
4649 std::pair<int32_t, int32_t>
readThreadBoundsForKernel(const Triple & T,Function & Kernel)4650 OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
4651   int32_t ThreadLimit =
4652       Kernel.getFnAttributeAsParsedInteger("omp_target_thread_limit");
4653 
4654   if (T.isAMDGPU()) {
4655     const auto &Attr = Kernel.getFnAttribute("amdgpu-flat-work-group-size");
4656     if (!Attr.isValid() || !Attr.isStringAttribute())
4657       return {0, ThreadLimit};
4658     auto [LBStr, UBStr] = Attr.getValueAsString().split(',');
4659     int32_t LB, UB;
4660     if (!llvm::to_integer(UBStr, UB, 10))
4661       return {0, ThreadLimit};
4662     UB = ThreadLimit ? std::min(ThreadLimit, UB) : UB;
4663     if (!llvm::to_integer(LBStr, LB, 10))
4664       return {0, UB};
4665     return {LB, UB};
4666   }
4667 
4668   if (MDNode *ExistingOp = getNVPTXMDNode(Kernel, "maxntidx")) {
4669     auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
4670     int32_t UB = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
4671     return {0, ThreadLimit ? std::min(ThreadLimit, UB) : UB};
4672   }
4673   return {0, ThreadLimit};
4674 }
4675 
writeThreadBoundsForKernel(const Triple & T,Function & Kernel,int32_t LB,int32_t UB)4676 void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
4677                                                  Function &Kernel, int32_t LB,
4678                                                  int32_t UB) {
4679   Kernel.addFnAttr("omp_target_thread_limit", std::to_string(UB));
4680 
4681   if (T.isAMDGPU()) {
4682     Kernel.addFnAttr("amdgpu-flat-work-group-size",
4683                      llvm::utostr(LB) + "," + llvm::utostr(UB));
4684     return;
4685   }
4686 
4687   updateNVPTXMetadata(Kernel, "maxntidx", UB, true);
4688 }
4689 
4690 std::pair<int32_t, int32_t>
readTeamBoundsForKernel(const Triple &,Function & Kernel)4691 OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
4692   // TODO: Read from backend annotations if available.
4693   return {0, Kernel.getFnAttributeAsParsedInteger("omp_target_num_teams")};
4694 }
4695 
writeTeamsForKernel(const Triple & T,Function & Kernel,int32_t LB,int32_t UB)4696 void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
4697                                           int32_t LB, int32_t UB) {
4698   if (T.isNVPTX()) {
4699     if (UB > 0)
4700       updateNVPTXMetadata(Kernel, "maxclusterrank", UB, true);
4701     updateNVPTXMetadata(Kernel, "minctasm", LB, false);
4702   }
4703   Kernel.addFnAttr("omp_target_num_teams", std::to_string(LB));
4704 }
4705 
setOutlinedTargetRegionFunctionAttributes(Function * OutlinedFn)4706 void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
4707     Function *OutlinedFn) {
4708   if (Config.isTargetDevice()) {
4709     OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
4710     // TODO: Determine if DSO local can be set to true.
4711     OutlinedFn->setDSOLocal(false);
4712     OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
4713     if (T.isAMDGCN())
4714       OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
4715   }
4716 }
4717 
createOutlinedFunctionID(Function * OutlinedFn,StringRef EntryFnIDName)4718 Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
4719                                                     StringRef EntryFnIDName) {
4720   if (Config.isTargetDevice()) {
4721     assert(OutlinedFn && "The outlined function must exist if embedded");
4722     return OutlinedFn;
4723   }
4724 
4725   return new GlobalVariable(
4726       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
4727       Constant::getNullValue(Builder.getInt8Ty()), EntryFnIDName);
4728 }
4729 
createTargetRegionEntryAddr(Function * OutlinedFn,StringRef EntryFnName)4730 Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
4731                                                        StringRef EntryFnName) {
4732   if (OutlinedFn)
4733     return OutlinedFn;
4734 
4735   assert(!M.getGlobalVariable(EntryFnName, true) &&
4736          "Named kernel already exists?");
4737   return new GlobalVariable(
4738       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
4739       Constant::getNullValue(Builder.getInt8Ty()), EntryFnName);
4740 }
4741 
emitTargetRegionFunction(TargetRegionEntryInfo & EntryInfo,FunctionGenCallback & GenerateFunctionCallback,bool IsOffloadEntry,Function * & OutlinedFn,Constant * & OutlinedFnID)4742 void OpenMPIRBuilder::emitTargetRegionFunction(
4743     TargetRegionEntryInfo &EntryInfo,
4744     FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
4745     Function *&OutlinedFn, Constant *&OutlinedFnID) {
4746 
4747   SmallString<64> EntryFnName;
4748   OffloadInfoManager.getTargetRegionEntryFnName(EntryFnName, EntryInfo);
4749 
4750   OutlinedFn = Config.isTargetDevice() || !Config.openMPOffloadMandatory()
4751                    ? GenerateFunctionCallback(EntryFnName)
4752                    : nullptr;
4753 
4754   // If this target outline function is not an offload entry, we don't need to
4755   // register it. This may be in the case of a false if clause, or if there are
4756   // no OpenMP targets.
4757   if (!IsOffloadEntry)
4758     return;
4759 
4760   std::string EntryFnIDName =
4761       Config.isTargetDevice()
4762           ? std::string(EntryFnName)
4763           : createPlatformSpecificName({EntryFnName, "region_id"});
4764 
4765   OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFn,
4766                                               EntryFnName, EntryFnIDName);
4767 }
4768 
registerTargetRegionFunction(TargetRegionEntryInfo & EntryInfo,Function * OutlinedFn,StringRef EntryFnName,StringRef EntryFnIDName)4769 Constant *OpenMPIRBuilder::registerTargetRegionFunction(
4770     TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
4771     StringRef EntryFnName, StringRef EntryFnIDName) {
4772   if (OutlinedFn)
4773     setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
4774   auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
4775   auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
4776   OffloadInfoManager.registerTargetRegionEntryInfo(
4777       EntryInfo, EntryAddr, OutlinedFnID,
4778       OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
4779   return OutlinedFnID;
4780 }
4781 
createTargetData(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,Value * DeviceID,Value * IfCond,TargetDataInfo & Info,GenMapInfoCallbackTy GenMapInfoCB,omp::RuntimeFunction * MapperFunc,function_ref<InsertPointTy (InsertPointTy CodeGenIP,BodyGenTy BodyGenType)> BodyGenCB,function_ref<void (unsigned int,Value *)> DeviceAddrCB,function_ref<Value * (unsigned int)> CustomMapperCB,Value * SrcLocInfo)4782 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
4783     const LocationDescription &Loc, InsertPointTy AllocaIP,
4784     InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
4785     TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
4786     omp::RuntimeFunction *MapperFunc,
4787     function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
4788         BodyGenCB,
4789     function_ref<void(unsigned int, Value *)> DeviceAddrCB,
4790     function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
4791   if (!updateToLocation(Loc))
4792     return InsertPointTy();
4793 
4794   // Disable TargetData CodeGen on Device pass.
4795   if (Config.IsTargetDevice.value_or(false))
4796     return Builder.saveIP();
4797 
4798   Builder.restoreIP(CodeGenIP);
4799   bool IsStandAlone = !BodyGenCB;
4800   MapInfosTy *MapInfo;
4801   // Generate the code for the opening of the data environment. Capture all the
4802   // arguments of the runtime call by reference because they are used in the
4803   // closing of the region.
4804   auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4805     MapInfo = &GenMapInfoCB(Builder.saveIP());
4806     emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
4807                          /*IsNonContiguous=*/true, DeviceAddrCB,
4808                          CustomMapperCB);
4809 
4810     TargetDataRTArgs RTArgs;
4811     emitOffloadingArraysArgument(Builder, RTArgs, Info,
4812                                  !MapInfo->Names.empty());
4813 
4814     // Emit the number of elements in the offloading arrays.
4815     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
4816 
4817     // Source location for the ident struct
4818     if (!SrcLocInfo) {
4819       uint32_t SrcLocStrSize;
4820       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4821       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4822     }
4823 
4824     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
4825                                PointerNum,           RTArgs.BasePointersArray,
4826                                RTArgs.PointersArray, RTArgs.SizesArray,
4827                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
4828                                RTArgs.MappersArray};
4829 
4830     if (IsStandAlone) {
4831       assert(MapperFunc && "MapperFunc missing for standalone target data");
4832       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
4833                          OffloadingArgs);
4834     } else {
4835       Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
4836           omp::OMPRTL___tgt_target_data_begin_mapper);
4837 
4838       Builder.CreateCall(BeginMapperFunc, OffloadingArgs);
4839 
4840       for (auto DeviceMap : Info.DevicePtrInfoMap) {
4841         if (isa<AllocaInst>(DeviceMap.second.second)) {
4842           auto *LI =
4843               Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
4844           Builder.CreateStore(LI, DeviceMap.second.second);
4845         }
4846       }
4847 
4848       // If device pointer privatization is required, emit the body of the
4849       // region here. It will have to be duplicated: with and without
4850       // privatization.
4851       Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::Priv));
4852     }
4853   };
4854 
4855   // If we need device pointer privatization, we need to emit the body of the
4856   // region with no privatization in the 'else' branch of the conditional.
4857   // Otherwise, we don't have to do anything.
4858   auto BeginElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4859     Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv));
4860   };
4861 
4862   // Generate code for the closing of the data region.
4863   auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4864     TargetDataRTArgs RTArgs;
4865     emitOffloadingArraysArgument(Builder, RTArgs, Info, !MapInfo->Names.empty(),
4866                                  /*ForEndCall=*/true);
4867 
4868     // Emit the number of elements in the offloading arrays.
4869     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
4870 
4871     // Source location for the ident struct
4872     if (!SrcLocInfo) {
4873       uint32_t SrcLocStrSize;
4874       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4875       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4876     }
4877 
4878     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
4879                                PointerNum,           RTArgs.BasePointersArray,
4880                                RTArgs.PointersArray, RTArgs.SizesArray,
4881                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
4882                                RTArgs.MappersArray};
4883     Function *EndMapperFunc =
4884         getOrCreateRuntimeFunctionPtr(omp::OMPRTL___tgt_target_data_end_mapper);
4885 
4886     Builder.CreateCall(EndMapperFunc, OffloadingArgs);
4887   };
4888 
4889   // We don't have to do anything to close the region if the if clause evaluates
4890   // to false.
4891   auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
4892 
4893   if (BodyGenCB) {
4894     if (IfCond) {
4895       emitIfClause(IfCond, BeginThenGen, BeginElseGen, AllocaIP);
4896     } else {
4897       BeginThenGen(AllocaIP, Builder.saveIP());
4898     }
4899 
4900     // If we don't require privatization of device pointers, we emit the body in
4901     // between the runtime calls. This avoids duplicating the body code.
4902     Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
4903 
4904     if (IfCond) {
4905       emitIfClause(IfCond, EndThenGen, EndElseGen, AllocaIP);
4906     } else {
4907       EndThenGen(AllocaIP, Builder.saveIP());
4908     }
4909   } else {
4910     if (IfCond) {
4911       emitIfClause(IfCond, BeginThenGen, EndElseGen, AllocaIP);
4912     } else {
4913       BeginThenGen(AllocaIP, Builder.saveIP());
4914     }
4915   }
4916 
4917   return Builder.saveIP();
4918 }
4919 
4920 FunctionCallee
createForStaticInitFunction(unsigned IVSize,bool IVSigned,bool IsGPUDistribute)4921 OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
4922                                              bool IsGPUDistribute) {
4923   assert((IVSize == 32 || IVSize == 64) &&
4924          "IV size is not compatible with the omp runtime");
4925   RuntimeFunction Name;
4926   if (IsGPUDistribute)
4927     Name = IVSize == 32
4928                ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
4929                            : omp::OMPRTL___kmpc_distribute_static_init_4u)
4930                : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
4931                            : omp::OMPRTL___kmpc_distribute_static_init_8u);
4932   else
4933     Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
4934                                     : omp::OMPRTL___kmpc_for_static_init_4u)
4935                         : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
4936                                     : omp::OMPRTL___kmpc_for_static_init_8u);
4937 
4938   return getOrCreateRuntimeFunction(M, Name);
4939 }
4940 
createDispatchInitFunction(unsigned IVSize,bool IVSigned)4941 FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
4942                                                            bool IVSigned) {
4943   assert((IVSize == 32 || IVSize == 64) &&
4944          "IV size is not compatible with the omp runtime");
4945   RuntimeFunction Name = IVSize == 32
4946                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
4947                                          : omp::OMPRTL___kmpc_dispatch_init_4u)
4948                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
4949                                          : omp::OMPRTL___kmpc_dispatch_init_8u);
4950 
4951   return getOrCreateRuntimeFunction(M, Name);
4952 }
4953 
createDispatchNextFunction(unsigned IVSize,bool IVSigned)4954 FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
4955                                                            bool IVSigned) {
4956   assert((IVSize == 32 || IVSize == 64) &&
4957          "IV size is not compatible with the omp runtime");
4958   RuntimeFunction Name = IVSize == 32
4959                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
4960                                          : omp::OMPRTL___kmpc_dispatch_next_4u)
4961                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
4962                                          : omp::OMPRTL___kmpc_dispatch_next_8u);
4963 
4964   return getOrCreateRuntimeFunction(M, Name);
4965 }
4966 
createDispatchFiniFunction(unsigned IVSize,bool IVSigned)4967 FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
4968                                                            bool IVSigned) {
4969   assert((IVSize == 32 || IVSize == 64) &&
4970          "IV size is not compatible with the omp runtime");
4971   RuntimeFunction Name = IVSize == 32
4972                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
4973                                          : omp::OMPRTL___kmpc_dispatch_fini_4u)
4974                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
4975                                          : omp::OMPRTL___kmpc_dispatch_fini_8u);
4976 
4977   return getOrCreateRuntimeFunction(M, Name);
4978 }
4979 
replaceConstatExprUsesInFuncWithInstr(ConstantExpr * ConstExpr,Function * Func)4980 static void replaceConstatExprUsesInFuncWithInstr(ConstantExpr *ConstExpr,
4981                                                   Function *Func) {
4982   for (User *User : make_early_inc_range(ConstExpr->users()))
4983     if (auto *Instr = dyn_cast<Instruction>(User))
4984       if (Instr->getFunction() == Func)
4985         Instr->replaceUsesOfWith(ConstExpr, ConstExpr->getAsInstruction(Instr));
4986 }
4987 
replaceConstantValueUsesInFuncWithInstr(llvm::Value * Input,Function * Func)4988 static void replaceConstantValueUsesInFuncWithInstr(llvm::Value *Input,
4989                                                     Function *Func) {
4990   for (User *User : make_early_inc_range(Input->users()))
4991     if (auto *Const = dyn_cast<Constant>(User))
4992       if (auto *ConstExpr = dyn_cast<ConstantExpr>(Const))
4993         replaceConstatExprUsesInFuncWithInstr(ConstExpr, Func);
4994 }
4995 
createOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,StringRef FuncName,SmallVectorImpl<Value * > & Inputs,OpenMPIRBuilder::TargetBodyGenCallbackTy & CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy & ArgAccessorFuncCB)4996 static Function *createOutlinedFunction(
4997     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
4998     SmallVectorImpl<Value *> &Inputs,
4999     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
5000     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
5001   SmallVector<Type *> ParameterTypes;
5002   if (OMPBuilder.Config.isTargetDevice()) {
5003     // Add the "implicit" runtime argument we use to provide launch specific
5004     // information for target devices.
5005     auto *Int8PtrTy = PointerType::getUnqual(Builder.getContext());
5006     ParameterTypes.push_back(Int8PtrTy);
5007 
5008     // All parameters to target devices are passed as pointers
5009     // or i64. This assumes 64-bit address spaces/pointers.
5010     for (auto &Arg : Inputs)
5011       ParameterTypes.push_back(Arg->getType()->isPointerTy()
5012                                    ? Arg->getType()
5013                                    : Type::getInt64Ty(Builder.getContext()));
5014   } else {
5015     for (auto &Arg : Inputs)
5016       ParameterTypes.push_back(Arg->getType());
5017   }
5018 
5019   auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
5020                                     /*isVarArg*/ false);
5021   auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
5022                                Builder.GetInsertBlock()->getModule());
5023 
5024   // Save insert point.
5025   auto OldInsertPoint = Builder.saveIP();
5026 
5027   // Generate the region into the function.
5028   BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
5029   Builder.SetInsertPoint(EntryBB);
5030 
5031   // Insert target init call in the device compilation pass.
5032   if (OMPBuilder.Config.isTargetDevice())
5033     Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
5034 
5035   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
5036 
5037   // Insert target deinit call in the device compilation pass.
5038   Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
5039   if (OMPBuilder.Config.isTargetDevice())
5040     OMPBuilder.createTargetDeinit(Builder);
5041 
5042   // Insert return instruction.
5043   Builder.CreateRetVoid();
5044 
5045   // New Alloca IP at entry point of created device function.
5046   Builder.SetInsertPoint(EntryBB->getFirstNonPHI());
5047   auto AllocaIP = Builder.saveIP();
5048 
5049   Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
5050 
5051   // Skip the artificial dyn_ptr on the device.
5052   const auto &ArgRange =
5053       OMPBuilder.Config.isTargetDevice()
5054           ? make_range(Func->arg_begin() + 1, Func->arg_end())
5055           : Func->args();
5056 
5057   // Rewrite uses of input valus to parameters.
5058   for (auto InArg : zip(Inputs, ArgRange)) {
5059     Value *Input = std::get<0>(InArg);
5060     Argument &Arg = std::get<1>(InArg);
5061     Value *InputCopy = nullptr;
5062 
5063     Builder.restoreIP(
5064         ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
5065 
5066     // Things like GEP's can come in the form of Constants. Constants and
5067     // ConstantExpr's do not have access to the knowledge of what they're
5068     // contained in, so we must dig a little to find an instruction so we can
5069     // tell if they're used inside of the function we're outlining. We also
5070     // replace the original constant expression with a new instruction
5071     // equivalent; an instruction as it allows easy modification in the
5072     // following loop, as we can now know the constant (instruction) is owned by
5073     // our target function and replaceUsesOfWith can now be invoked on it
5074     // (cannot do this with constants it seems). A brand new one also allows us
5075     // to be cautious as it is perhaps possible the old expression was used
5076     // inside of the function but exists and is used externally (unlikely by the
5077     // nature of a Constant, but still).
5078     replaceConstantValueUsesInFuncWithInstr(Input, Func);
5079 
5080     // Collect all the instructions
5081     for (User *User : make_early_inc_range(Input->users()))
5082       if (auto *Instr = dyn_cast<Instruction>(User))
5083         if (Instr->getFunction() == Func)
5084           Instr->replaceUsesOfWith(Input, InputCopy);
5085   }
5086 
5087   // Restore insert point.
5088   Builder.restoreIP(OldInsertPoint);
5089 
5090   return Func;
5091 }
5092 
emitTargetOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,TargetRegionEntryInfo & EntryInfo,Function * & OutlinedFn,Constant * & OutlinedFnID,SmallVectorImpl<Value * > & Inputs,OpenMPIRBuilder::TargetBodyGenCallbackTy & CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy & ArgAccessorFuncCB)5093 static void emitTargetOutlinedFunction(
5094     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
5095     TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
5096     Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
5097     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
5098     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
5099 
5100   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
5101       [&OMPBuilder, &Builder, &Inputs, &CBFunc,
5102        &ArgAccessorFuncCB](StringRef EntryFnName) {
5103         return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
5104                                       CBFunc, ArgAccessorFuncCB);
5105       };
5106 
5107   OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, true,
5108                                       OutlinedFn, OutlinedFnID);
5109 }
5110 
emitTargetCall(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,OpenMPIRBuilder::InsertPointTy AllocaIP,Function * OutlinedFn,Constant * OutlinedFnID,int32_t NumTeams,int32_t NumThreads,SmallVectorImpl<Value * > & Args,OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB)5111 static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
5112                            OpenMPIRBuilder::InsertPointTy AllocaIP,
5113                            Function *OutlinedFn, Constant *OutlinedFnID,
5114                            int32_t NumTeams, int32_t NumThreads,
5115                            SmallVectorImpl<Value *> &Args,
5116                            OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB) {
5117 
5118   OpenMPIRBuilder::TargetDataInfo Info(
5119       /*RequiresDevicePointerInfo=*/false,
5120       /*SeparateBeginEndCalls=*/true);
5121 
5122   OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
5123   OMPBuilder.emitOffloadingArrays(AllocaIP, Builder.saveIP(), MapInfo, Info,
5124                                   /*IsNonContiguous=*/true);
5125 
5126   OpenMPIRBuilder::TargetDataRTArgs RTArgs;
5127   OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info,
5128                                           !MapInfo.Names.empty());
5129 
5130   //  emitKernelLaunch
5131   auto &&EmitTargetCallFallbackCB =
5132       [&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
5133     Builder.restoreIP(IP);
5134     Builder.CreateCall(OutlinedFn, Args);
5135     return Builder.saveIP();
5136   };
5137 
5138   unsigned NumTargetItems = MapInfo.BasePointers.size();
5139   // TODO: Use correct device ID
5140   Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
5141   Value *NumTeamsVal = Builder.getInt32(NumTeams);
5142   Value *NumThreadsVal = Builder.getInt32(NumThreads);
5143   uint32_t SrcLocStrSize;
5144   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
5145   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
5146                                              llvm::omp::IdentFlag(0), 0);
5147   // TODO: Use correct NumIterations
5148   Value *NumIterations = Builder.getInt64(0);
5149   // TODO: Use correct DynCGGroupMem
5150   Value *DynCGGroupMem = Builder.getInt32(0);
5151 
5152   bool HasNoWait = false;
5153 
5154   OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
5155                                           NumTeamsVal, NumThreadsVal,
5156                                           DynCGGroupMem, HasNoWait);
5157 
5158   Builder.restoreIP(OMPBuilder.emitKernelLaunch(
5159       Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
5160       DeviceID, RTLoc, AllocaIP));
5161 }
5162 
createTarget(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,TargetRegionEntryInfo & EntryInfo,int32_t NumTeams,int32_t NumThreads,SmallVectorImpl<Value * > & Args,GenMapInfoCallbackTy GenMapInfoCB,OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB)5163 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
5164     const LocationDescription &Loc, InsertPointTy AllocaIP,
5165     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
5166     int32_t NumThreads, SmallVectorImpl<Value *> &Args,
5167     GenMapInfoCallbackTy GenMapInfoCB,
5168     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
5169     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB) {
5170   if (!updateToLocation(Loc))
5171     return InsertPointTy();
5172 
5173   Builder.restoreIP(CodeGenIP);
5174 
5175   Function *OutlinedFn;
5176   Constant *OutlinedFnID;
5177   emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
5178                              OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB);
5179   if (!Config.isTargetDevice())
5180     emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
5181                    NumThreads, Args, GenMapInfoCB);
5182 
5183   return Builder.saveIP();
5184 }
5185 
getNameWithSeparators(ArrayRef<StringRef> Parts,StringRef FirstSeparator,StringRef Separator)5186 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
5187                                                    StringRef FirstSeparator,
5188                                                    StringRef Separator) {
5189   SmallString<128> Buffer;
5190   llvm::raw_svector_ostream OS(Buffer);
5191   StringRef Sep = FirstSeparator;
5192   for (StringRef Part : Parts) {
5193     OS << Sep << Part;
5194     Sep = Separator;
5195   }
5196   return OS.str().str();
5197 }
5198 
5199 std::string
createPlatformSpecificName(ArrayRef<StringRef> Parts) const5200 OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
5201   return OpenMPIRBuilder::getNameWithSeparators(Parts, Config.firstSeparator(),
5202                                                 Config.separator());
5203 }
5204 
5205 GlobalVariable *
getOrCreateInternalVariable(Type * Ty,const StringRef & Name,unsigned AddressSpace)5206 OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
5207                                              unsigned AddressSpace) {
5208   auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
5209   if (Elem.second) {
5210     assert(Elem.second->getValueType() == Ty &&
5211            "OMP internal variable has different type than requested");
5212   } else {
5213     // TODO: investigate the appropriate linkage type used for the global
5214     // variable for possibly changing that to internal or private, or maybe
5215     // create different versions of the function for different OMP internal
5216     // variables.
5217     auto Linkage = this->M.getTargetTriple().rfind("wasm32") == 0
5218                        ? GlobalValue::ExternalLinkage
5219                        : GlobalValue::CommonLinkage;
5220     auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
5221                                   Constant::getNullValue(Ty), Elem.first(),
5222                                   /*InsertBefore=*/nullptr,
5223                                   GlobalValue::NotThreadLocal, AddressSpace);
5224     const DataLayout &DL = M.getDataLayout();
5225     const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
5226     const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace);
5227     GV->setAlignment(std::max(TypeAlign, PtrAlign));
5228     Elem.second = GV;
5229   }
5230 
5231   return Elem.second;
5232 }
5233 
getOMPCriticalRegionLock(StringRef CriticalName)5234 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
5235   std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
5236   std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
5237   return getOrCreateInternalVariable(KmpCriticalNameTy, Name);
5238 }
5239 
getSizeInBytes(Value * BasePtr)5240 Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
5241   LLVMContext &Ctx = Builder.getContext();
5242   Value *Null =
5243       Constant::getNullValue(PointerType::getUnqual(BasePtr->getContext()));
5244   Value *SizeGep =
5245       Builder.CreateGEP(BasePtr->getType(), Null, Builder.getInt32(1));
5246   Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, Type::getInt64Ty(Ctx));
5247   return SizePtrToInt;
5248 }
5249 
5250 GlobalVariable *
createOffloadMaptypes(SmallVectorImpl<uint64_t> & Mappings,std::string VarName)5251 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
5252                                        std::string VarName) {
5253   llvm::Constant *MaptypesArrayInit =
5254       llvm::ConstantDataArray::get(M.getContext(), Mappings);
5255   auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
5256       M, MaptypesArrayInit->getType(),
5257       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
5258       VarName);
5259   MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
5260   return MaptypesArrayGlobal;
5261 }
5262 
createMapperAllocas(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumOperands,struct MapperAllocas & MapperAllocas)5263 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
5264                                           InsertPointTy AllocaIP,
5265                                           unsigned NumOperands,
5266                                           struct MapperAllocas &MapperAllocas) {
5267   if (!updateToLocation(Loc))
5268     return;
5269 
5270   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
5271   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
5272   Builder.restoreIP(AllocaIP);
5273   AllocaInst *ArgsBase = Builder.CreateAlloca(
5274       ArrI8PtrTy, /* ArraySize = */ nullptr, ".offload_baseptrs");
5275   AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy, /* ArraySize = */ nullptr,
5276                                           ".offload_ptrs");
5277   AllocaInst *ArgSizes = Builder.CreateAlloca(
5278       ArrI64Ty, /* ArraySize = */ nullptr, ".offload_sizes");
5279   Builder.restoreIP(Loc.IP);
5280   MapperAllocas.ArgsBase = ArgsBase;
5281   MapperAllocas.Args = Args;
5282   MapperAllocas.ArgSizes = ArgSizes;
5283 }
5284 
emitMapperCall(const LocationDescription & Loc,Function * MapperFunc,Value * SrcLocInfo,Value * MaptypesArg,Value * MapnamesArg,struct MapperAllocas & MapperAllocas,int64_t DeviceID,unsigned NumOperands)5285 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
5286                                      Function *MapperFunc, Value *SrcLocInfo,
5287                                      Value *MaptypesArg, Value *MapnamesArg,
5288                                      struct MapperAllocas &MapperAllocas,
5289                                      int64_t DeviceID, unsigned NumOperands) {
5290   if (!updateToLocation(Loc))
5291     return;
5292 
5293   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
5294   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
5295   Value *ArgsBaseGEP =
5296       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
5297                                 {Builder.getInt32(0), Builder.getInt32(0)});
5298   Value *ArgsGEP =
5299       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
5300                                 {Builder.getInt32(0), Builder.getInt32(0)});
5301   Value *ArgSizesGEP =
5302       Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
5303                                 {Builder.getInt32(0), Builder.getInt32(0)});
5304   Value *NullPtr =
5305       Constant::getNullValue(PointerType::getUnqual(Int8Ptr->getContext()));
5306   Builder.CreateCall(MapperFunc,
5307                      {SrcLocInfo, Builder.getInt64(DeviceID),
5308                       Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
5309                       ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
5310 }
5311 
emitOffloadingArraysArgument(IRBuilderBase & Builder,TargetDataRTArgs & RTArgs,TargetDataInfo & Info,bool EmitDebug,bool ForEndCall)5312 void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
5313                                                    TargetDataRTArgs &RTArgs,
5314                                                    TargetDataInfo &Info,
5315                                                    bool EmitDebug,
5316                                                    bool ForEndCall) {
5317   assert((!ForEndCall || Info.separateBeginEndCalls()) &&
5318          "expected region end call to runtime only when end call is separate");
5319   auto UnqualPtrTy = PointerType::getUnqual(M.getContext());
5320   auto VoidPtrTy = UnqualPtrTy;
5321   auto VoidPtrPtrTy = UnqualPtrTy;
5322   auto Int64Ty = Type::getInt64Ty(M.getContext());
5323   auto Int64PtrTy = UnqualPtrTy;
5324 
5325   if (!Info.NumberOfPtrs) {
5326     RTArgs.BasePointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
5327     RTArgs.PointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
5328     RTArgs.SizesArray = ConstantPointerNull::get(Int64PtrTy);
5329     RTArgs.MapTypesArray = ConstantPointerNull::get(Int64PtrTy);
5330     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
5331     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
5332     return;
5333   }
5334 
5335   RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
5336       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs),
5337       Info.RTArgs.BasePointersArray,
5338       /*Idx0=*/0, /*Idx1=*/0);
5339   RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
5340       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray,
5341       /*Idx0=*/0,
5342       /*Idx1=*/0);
5343   RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
5344       ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
5345       /*Idx0=*/0, /*Idx1=*/0);
5346   RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
5347       ArrayType::get(Int64Ty, Info.NumberOfPtrs),
5348       ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
5349                                                  : Info.RTArgs.MapTypesArray,
5350       /*Idx0=*/0,
5351       /*Idx1=*/0);
5352 
5353   // Only emit the mapper information arrays if debug information is
5354   // requested.
5355   if (!EmitDebug)
5356     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
5357   else
5358     RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
5359         ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.MapNamesArray,
5360         /*Idx0=*/0,
5361         /*Idx1=*/0);
5362   // If there is no user-defined mapper, set the mapper array to nullptr to
5363   // avoid an unnecessary data privatization
5364   if (!Info.HasMapper)
5365     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
5366   else
5367     RTArgs.MappersArray =
5368         Builder.CreatePointerCast(Info.RTArgs.MappersArray, VoidPtrPtrTy);
5369 }
5370 
emitNonContiguousDescriptor(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,MapInfosTy & CombinedInfo,TargetDataInfo & Info)5371 void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
5372                                                   InsertPointTy CodeGenIP,
5373                                                   MapInfosTy &CombinedInfo,
5374                                                   TargetDataInfo &Info) {
5375   MapInfosTy::StructNonContiguousInfo &NonContigInfo =
5376       CombinedInfo.NonContigInfo;
5377 
5378   // Build an array of struct descriptor_dim and then assign it to
5379   // offload_args.
5380   //
5381   // struct descriptor_dim {
5382   //  uint64_t offset;
5383   //  uint64_t count;
5384   //  uint64_t stride
5385   // };
5386   Type *Int64Ty = Builder.getInt64Ty();
5387   StructType *DimTy = StructType::create(
5388       M.getContext(), ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
5389       "struct.descriptor_dim");
5390 
5391   enum { OffsetFD = 0, CountFD, StrideFD };
5392   // We need two index variable here since the size of "Dims" is the same as
5393   // the size of Components, however, the size of offset, count, and stride is
5394   // equal to the size of base declaration that is non-contiguous.
5395   for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
5396     // Skip emitting ir if dimension size is 1 since it cannot be
5397     // non-contiguous.
5398     if (NonContigInfo.Dims[I] == 1)
5399       continue;
5400     Builder.restoreIP(AllocaIP);
5401     ArrayType *ArrayTy = ArrayType::get(DimTy, NonContigInfo.Dims[I]);
5402     AllocaInst *DimsAddr =
5403         Builder.CreateAlloca(ArrayTy, /* ArraySize = */ nullptr, "dims");
5404     Builder.restoreIP(CodeGenIP);
5405     for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
5406       unsigned RevIdx = EE - II - 1;
5407       Value *DimsLVal = Builder.CreateInBoundsGEP(
5408           DimsAddr->getAllocatedType(), DimsAddr,
5409           {Builder.getInt64(0), Builder.getInt64(II)});
5410       // Offset
5411       Value *OffsetLVal = Builder.CreateStructGEP(DimTy, DimsLVal, OffsetFD);
5412       Builder.CreateAlignedStore(
5413           NonContigInfo.Offsets[L][RevIdx], OffsetLVal,
5414           M.getDataLayout().getPrefTypeAlign(OffsetLVal->getType()));
5415       // Count
5416       Value *CountLVal = Builder.CreateStructGEP(DimTy, DimsLVal, CountFD);
5417       Builder.CreateAlignedStore(
5418           NonContigInfo.Counts[L][RevIdx], CountLVal,
5419           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
5420       // Stride
5421       Value *StrideLVal = Builder.CreateStructGEP(DimTy, DimsLVal, StrideFD);
5422       Builder.CreateAlignedStore(
5423           NonContigInfo.Strides[L][RevIdx], StrideLVal,
5424           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
5425     }
5426     // args[I] = &dims
5427     Builder.restoreIP(CodeGenIP);
5428     Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
5429         DimsAddr, Builder.getPtrTy());
5430     Value *P = Builder.CreateConstInBoundsGEP2_32(
5431         ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs),
5432         Info.RTArgs.PointersArray, 0, I);
5433     Builder.CreateAlignedStore(
5434         DAddr, P, M.getDataLayout().getPrefTypeAlign(Builder.getPtrTy()));
5435     ++L;
5436   }
5437 }
5438 
emitOffloadingArrays(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,MapInfosTy & CombinedInfo,TargetDataInfo & Info,bool IsNonContiguous,function_ref<void (unsigned int,Value *)> DeviceAddrCB,function_ref<Value * (unsigned int)> CustomMapperCB)5439 void OpenMPIRBuilder::emitOffloadingArrays(
5440     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
5441     TargetDataInfo &Info, bool IsNonContiguous,
5442     function_ref<void(unsigned int, Value *)> DeviceAddrCB,
5443     function_ref<Value *(unsigned int)> CustomMapperCB) {
5444 
5445   // Reset the array information.
5446   Info.clearArrayInfo();
5447   Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
5448 
5449   if (Info.NumberOfPtrs == 0)
5450     return;
5451 
5452   Builder.restoreIP(AllocaIP);
5453   // Detect if we have any capture size requiring runtime evaluation of the
5454   // size so that a constant array could be eventually used.
5455   ArrayType *PointerArrayType =
5456       ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs);
5457 
5458   Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
5459       PointerArrayType, /* ArraySize = */ nullptr, ".offload_baseptrs");
5460 
5461   Info.RTArgs.PointersArray = Builder.CreateAlloca(
5462       PointerArrayType, /* ArraySize = */ nullptr, ".offload_ptrs");
5463   AllocaInst *MappersArray = Builder.CreateAlloca(
5464       PointerArrayType, /* ArraySize = */ nullptr, ".offload_mappers");
5465   Info.RTArgs.MappersArray = MappersArray;
5466 
5467   // If we don't have any VLA types or other types that require runtime
5468   // evaluation, we can use a constant array for the map sizes, otherwise we
5469   // need to fill up the arrays as we do for the pointers.
5470   Type *Int64Ty = Builder.getInt64Ty();
5471   SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
5472                                      ConstantInt::get(Int64Ty, 0));
5473   SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
5474   for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
5475     if (auto *CI = dyn_cast<Constant>(CombinedInfo.Sizes[I])) {
5476       if (!isa<ConstantExpr>(CI) && !isa<GlobalValue>(CI)) {
5477         if (IsNonContiguous &&
5478             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
5479                 CombinedInfo.Types[I] &
5480                 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
5481           ConstSizes[I] =
5482               ConstantInt::get(Int64Ty, CombinedInfo.NonContigInfo.Dims[I]);
5483         else
5484           ConstSizes[I] = CI;
5485         continue;
5486       }
5487     }
5488     RuntimeSizes.set(I);
5489   }
5490 
5491   if (RuntimeSizes.all()) {
5492     ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
5493     Info.RTArgs.SizesArray = Builder.CreateAlloca(
5494         SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
5495     Builder.restoreIP(CodeGenIP);
5496   } else {
5497     auto *SizesArrayInit = ConstantArray::get(
5498         ArrayType::get(Int64Ty, ConstSizes.size()), ConstSizes);
5499     std::string Name = createPlatformSpecificName({"offload_sizes"});
5500     auto *SizesArrayGbl =
5501         new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
5502                            GlobalValue::PrivateLinkage, SizesArrayInit, Name);
5503     SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
5504 
5505     if (!RuntimeSizes.any()) {
5506       Info.RTArgs.SizesArray = SizesArrayGbl;
5507     } else {
5508       unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
5509       Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(64);
5510       ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
5511       AllocaInst *Buffer = Builder.CreateAlloca(
5512           SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
5513       Buffer->setAlignment(OffloadSizeAlign);
5514       Builder.restoreIP(CodeGenIP);
5515       Builder.CreateMemCpy(
5516           Buffer, M.getDataLayout().getPrefTypeAlign(Buffer->getType()),
5517           SizesArrayGbl, OffloadSizeAlign,
5518           Builder.getIntN(
5519               IndexSize,
5520               Buffer->getAllocationSize(M.getDataLayout())->getFixedValue()));
5521 
5522       Info.RTArgs.SizesArray = Buffer;
5523     }
5524     Builder.restoreIP(CodeGenIP);
5525   }
5526 
5527   // The map types are always constant so we don't need to generate code to
5528   // fill arrays. Instead, we create an array constant.
5529   SmallVector<uint64_t, 4> Mapping;
5530   for (auto mapFlag : CombinedInfo.Types)
5531     Mapping.push_back(
5532         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
5533             mapFlag));
5534   std::string MaptypesName = createPlatformSpecificName({"offload_maptypes"});
5535   auto *MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
5536   Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
5537 
5538   // The information types are only built if provided.
5539   if (!CombinedInfo.Names.empty()) {
5540     std::string MapnamesName = createPlatformSpecificName({"offload_mapnames"});
5541     auto *MapNamesArrayGbl =
5542         createOffloadMapnames(CombinedInfo.Names, MapnamesName);
5543     Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
5544   } else {
5545     Info.RTArgs.MapNamesArray =
5546         Constant::getNullValue(PointerType::getUnqual(Builder.getContext()));
5547   }
5548 
5549   // If there's a present map type modifier, it must not be applied to the end
5550   // of a region, so generate a separate map type array in that case.
5551   if (Info.separateBeginEndCalls()) {
5552     bool EndMapTypesDiffer = false;
5553     for (uint64_t &Type : Mapping) {
5554       if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
5555                      OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
5556         Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
5557             OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
5558         EndMapTypesDiffer = true;
5559       }
5560     }
5561     if (EndMapTypesDiffer) {
5562       MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
5563       Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
5564     }
5565   }
5566 
5567   PointerType *PtrTy = Builder.getPtrTy();
5568   for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
5569     Value *BPVal = CombinedInfo.BasePointers[I];
5570     Value *BP = Builder.CreateConstInBoundsGEP2_32(
5571         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
5572         0, I);
5573     Builder.CreateAlignedStore(BPVal, BP,
5574                                M.getDataLayout().getPrefTypeAlign(PtrTy));
5575 
5576     if (Info.requiresDevicePointerInfo()) {
5577       if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
5578         CodeGenIP = Builder.saveIP();
5579         Builder.restoreIP(AllocaIP);
5580         Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(PtrTy)};
5581         Builder.restoreIP(CodeGenIP);
5582         if (DeviceAddrCB)
5583           DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
5584       } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
5585         Info.DevicePtrInfoMap[BPVal] = {BP, BP};
5586         if (DeviceAddrCB)
5587           DeviceAddrCB(I, BP);
5588       }
5589     }
5590 
5591     Value *PVal = CombinedInfo.Pointers[I];
5592     Value *P = Builder.CreateConstInBoundsGEP2_32(
5593         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray, 0,
5594         I);
5595     // TODO: Check alignment correct.
5596     Builder.CreateAlignedStore(PVal, P,
5597                                M.getDataLayout().getPrefTypeAlign(PtrTy));
5598 
5599     if (RuntimeSizes.test(I)) {
5600       Value *S = Builder.CreateConstInBoundsGEP2_32(
5601           ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
5602           /*Idx0=*/0,
5603           /*Idx1=*/I);
5604       Builder.CreateAlignedStore(Builder.CreateIntCast(CombinedInfo.Sizes[I],
5605                                                        Int64Ty,
5606                                                        /*isSigned=*/true),
5607                                  S, M.getDataLayout().getPrefTypeAlign(PtrTy));
5608     }
5609     // Fill up the mapper array.
5610     unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
5611     Value *MFunc = ConstantPointerNull::get(PtrTy);
5612     if (CustomMapperCB)
5613       if (Value *CustomMFunc = CustomMapperCB(I))
5614         MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy);
5615     Value *MAddr = Builder.CreateInBoundsGEP(
5616         MappersArray->getAllocatedType(), MappersArray,
5617         {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
5618     Builder.CreateAlignedStore(
5619         MFunc, MAddr, M.getDataLayout().getPrefTypeAlign(MAddr->getType()));
5620   }
5621 
5622   if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
5623       Info.NumberOfPtrs == 0)
5624     return;
5625   emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
5626 }
5627 
emitBranch(BasicBlock * Target)5628 void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
5629   BasicBlock *CurBB = Builder.GetInsertBlock();
5630 
5631   if (!CurBB || CurBB->getTerminator()) {
5632     // If there is no insert point or the previous block is already
5633     // terminated, don't touch it.
5634   } else {
5635     // Otherwise, create a fall-through branch.
5636     Builder.CreateBr(Target);
5637   }
5638 
5639   Builder.ClearInsertionPoint();
5640 }
5641 
emitBlock(BasicBlock * BB,Function * CurFn,bool IsFinished)5642 void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
5643                                 bool IsFinished) {
5644   BasicBlock *CurBB = Builder.GetInsertBlock();
5645 
5646   // Fall out of the current block (if necessary).
5647   emitBranch(BB);
5648 
5649   if (IsFinished && BB->use_empty()) {
5650     BB->eraseFromParent();
5651     return;
5652   }
5653 
5654   // Place the block after the current block, if possible, or else at
5655   // the end of the function.
5656   if (CurBB && CurBB->getParent())
5657     CurFn->insert(std::next(CurBB->getIterator()), BB);
5658   else
5659     CurFn->insert(CurFn->end(), BB);
5660   Builder.SetInsertPoint(BB);
5661 }
5662 
emitIfClause(Value * Cond,BodyGenCallbackTy ThenGen,BodyGenCallbackTy ElseGen,InsertPointTy AllocaIP)5663 void OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
5664                                    BodyGenCallbackTy ElseGen,
5665                                    InsertPointTy AllocaIP) {
5666   // If the condition constant folds and can be elided, try to avoid emitting
5667   // the condition and the dead arm of the if/else.
5668   if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
5669     auto CondConstant = CI->getSExtValue();
5670     if (CondConstant)
5671       ThenGen(AllocaIP, Builder.saveIP());
5672     else
5673       ElseGen(AllocaIP, Builder.saveIP());
5674     return;
5675   }
5676 
5677   Function *CurFn = Builder.GetInsertBlock()->getParent();
5678 
5679   // Otherwise, the condition did not fold, or we couldn't elide it.  Just
5680   // emit the conditional branch.
5681   BasicBlock *ThenBlock = BasicBlock::Create(M.getContext(), "omp_if.then");
5682   BasicBlock *ElseBlock = BasicBlock::Create(M.getContext(), "omp_if.else");
5683   BasicBlock *ContBlock = BasicBlock::Create(M.getContext(), "omp_if.end");
5684   Builder.CreateCondBr(Cond, ThenBlock, ElseBlock);
5685   // Emit the 'then' code.
5686   emitBlock(ThenBlock, CurFn);
5687   ThenGen(AllocaIP, Builder.saveIP());
5688   emitBranch(ContBlock);
5689   // Emit the 'else' code if present.
5690   // There is no need to emit line number for unconditional branch.
5691   emitBlock(ElseBlock, CurFn);
5692   ElseGen(AllocaIP, Builder.saveIP());
5693   // There is no need to emit line number for unconditional branch.
5694   emitBranch(ContBlock);
5695   // Emit the continuation block for code after the if.
5696   emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
5697 }
5698 
checkAndEmitFlushAfterAtomic(const LocationDescription & Loc,llvm::AtomicOrdering AO,AtomicKind AK)5699 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
5700     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
5701   assert(!(AO == AtomicOrdering::NotAtomic ||
5702            AO == llvm::AtomicOrdering::Unordered) &&
5703          "Unexpected Atomic Ordering.");
5704 
5705   bool Flush = false;
5706   llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
5707 
5708   switch (AK) {
5709   case Read:
5710     if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
5711         AO == AtomicOrdering::SequentiallyConsistent) {
5712       FlushAO = AtomicOrdering::Acquire;
5713       Flush = true;
5714     }
5715     break;
5716   case Write:
5717   case Compare:
5718   case Update:
5719     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
5720         AO == AtomicOrdering::SequentiallyConsistent) {
5721       FlushAO = AtomicOrdering::Release;
5722       Flush = true;
5723     }
5724     break;
5725   case Capture:
5726     switch (AO) {
5727     case AtomicOrdering::Acquire:
5728       FlushAO = AtomicOrdering::Acquire;
5729       Flush = true;
5730       break;
5731     case AtomicOrdering::Release:
5732       FlushAO = AtomicOrdering::Release;
5733       Flush = true;
5734       break;
5735     case AtomicOrdering::AcquireRelease:
5736     case AtomicOrdering::SequentiallyConsistent:
5737       FlushAO = AtomicOrdering::AcquireRelease;
5738       Flush = true;
5739       break;
5740     default:
5741       // do nothing - leave silently.
5742       break;
5743     }
5744   }
5745 
5746   if (Flush) {
5747     // Currently Flush RT call still doesn't take memory_ordering, so for when
5748     // that happens, this tries to do the resolution of which atomic ordering
5749     // to use with but issue the flush call
5750     // TODO: pass `FlushAO` after memory ordering support is added
5751     (void)FlushAO;
5752     emitFlush(Loc);
5753   }
5754 
5755   // for AO == AtomicOrdering::Monotonic and  all other case combinations
5756   // do nothing
5757   return Flush;
5758 }
5759 
5760 OpenMPIRBuilder::InsertPointTy
createAtomicRead(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOrdering AO)5761 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
5762                                   AtomicOpValue &X, AtomicOpValue &V,
5763                                   AtomicOrdering AO) {
5764   if (!updateToLocation(Loc))
5765     return Loc.IP;
5766 
5767   assert(X.Var->getType()->isPointerTy() &&
5768          "OMP Atomic expects a pointer to target memory");
5769   Type *XElemTy = X.ElemTy;
5770   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
5771           XElemTy->isPointerTy()) &&
5772          "OMP atomic read expected a scalar type");
5773 
5774   Value *XRead = nullptr;
5775 
5776   if (XElemTy->isIntegerTy()) {
5777     LoadInst *XLD =
5778         Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
5779     XLD->setAtomic(AO);
5780     XRead = cast<Value>(XLD);
5781   } else {
5782     // We need to perform atomic op as integer
5783     IntegerType *IntCastTy =
5784         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
5785     LoadInst *XLoad =
5786         Builder.CreateLoad(IntCastTy, X.Var, X.IsVolatile, "omp.atomic.load");
5787     XLoad->setAtomic(AO);
5788     if (XElemTy->isFloatingPointTy()) {
5789       XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
5790     } else {
5791       XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
5792     }
5793   }
5794   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
5795   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
5796   return Builder.saveIP();
5797 }
5798 
5799 OpenMPIRBuilder::InsertPointTy
createAtomicWrite(const LocationDescription & Loc,AtomicOpValue & X,Value * Expr,AtomicOrdering AO)5800 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
5801                                    AtomicOpValue &X, Value *Expr,
5802                                    AtomicOrdering AO) {
5803   if (!updateToLocation(Loc))
5804     return Loc.IP;
5805 
5806   assert(X.Var->getType()->isPointerTy() &&
5807          "OMP Atomic expects a pointer to target memory");
5808   Type *XElemTy = X.ElemTy;
5809   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
5810           XElemTy->isPointerTy()) &&
5811          "OMP atomic write expected a scalar type");
5812 
5813   if (XElemTy->isIntegerTy()) {
5814     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
5815     XSt->setAtomic(AO);
5816   } else {
5817     // We need to bitcast and perform atomic op as integers
5818     IntegerType *IntCastTy =
5819         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
5820     Value *ExprCast =
5821         Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
5822     StoreInst *XSt = Builder.CreateStore(ExprCast, X.Var, X.IsVolatile);
5823     XSt->setAtomic(AO);
5824   }
5825 
5826   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
5827   return Builder.saveIP();
5828 }
5829 
createAtomicUpdate(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool IsXBinopExpr)5830 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
5831     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
5832     Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
5833     AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
5834   assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
5835   if (!updateToLocation(Loc))
5836     return Loc.IP;
5837 
5838   LLVM_DEBUG({
5839     Type *XTy = X.Var->getType();
5840     assert(XTy->isPointerTy() &&
5841            "OMP Atomic expects a pointer to target memory");
5842     Type *XElemTy = X.ElemTy;
5843     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
5844             XElemTy->isPointerTy()) &&
5845            "OMP atomic update expected a scalar type");
5846     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
5847            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
5848            "OpenMP atomic does not support LT or GT operations");
5849   });
5850 
5851   emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
5852                    X.IsVolatile, IsXBinopExpr);
5853   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
5854   return Builder.saveIP();
5855 }
5856 
5857 // FIXME: Duplicating AtomicExpand
emitRMWOpAsInstruction(Value * Src1,Value * Src2,AtomicRMWInst::BinOp RMWOp)5858 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
5859                                                AtomicRMWInst::BinOp RMWOp) {
5860   switch (RMWOp) {
5861   case AtomicRMWInst::Add:
5862     return Builder.CreateAdd(Src1, Src2);
5863   case AtomicRMWInst::Sub:
5864     return Builder.CreateSub(Src1, Src2);
5865   case AtomicRMWInst::And:
5866     return Builder.CreateAnd(Src1, Src2);
5867   case AtomicRMWInst::Nand:
5868     return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
5869   case AtomicRMWInst::Or:
5870     return Builder.CreateOr(Src1, Src2);
5871   case AtomicRMWInst::Xor:
5872     return Builder.CreateXor(Src1, Src2);
5873   case AtomicRMWInst::Xchg:
5874   case AtomicRMWInst::FAdd:
5875   case AtomicRMWInst::FSub:
5876   case AtomicRMWInst::BAD_BINOP:
5877   case AtomicRMWInst::Max:
5878   case AtomicRMWInst::Min:
5879   case AtomicRMWInst::UMax:
5880   case AtomicRMWInst::UMin:
5881   case AtomicRMWInst::FMax:
5882   case AtomicRMWInst::FMin:
5883   case AtomicRMWInst::UIncWrap:
5884   case AtomicRMWInst::UDecWrap:
5885     llvm_unreachable("Unsupported atomic update operation");
5886   }
5887   llvm_unreachable("Unsupported atomic update operation");
5888 }
5889 
emitAtomicUpdate(InsertPointTy AllocaIP,Value * X,Type * XElemTy,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool VolatileX,bool IsXBinopExpr)5890 std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
5891     InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
5892     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
5893     AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
5894   // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
5895   // or a complex datatype.
5896   bool emitRMWOp = false;
5897   switch (RMWOp) {
5898   case AtomicRMWInst::Add:
5899   case AtomicRMWInst::And:
5900   case AtomicRMWInst::Nand:
5901   case AtomicRMWInst::Or:
5902   case AtomicRMWInst::Xor:
5903   case AtomicRMWInst::Xchg:
5904     emitRMWOp = XElemTy;
5905     break;
5906   case AtomicRMWInst::Sub:
5907     emitRMWOp = (IsXBinopExpr && XElemTy);
5908     break;
5909   default:
5910     emitRMWOp = false;
5911   }
5912   emitRMWOp &= XElemTy->isIntegerTy();
5913 
5914   std::pair<Value *, Value *> Res;
5915   if (emitRMWOp) {
5916     Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
5917     // not needed except in case of postfix captures. Generate anyway for
5918     // consistency with the else part. Will be removed with any DCE pass.
5919     // AtomicRMWInst::Xchg does not have a coressponding instruction.
5920     if (RMWOp == AtomicRMWInst::Xchg)
5921       Res.second = Res.first;
5922     else
5923       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
5924   } else {
5925     IntegerType *IntCastTy =
5926         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
5927     LoadInst *OldVal =
5928         Builder.CreateLoad(IntCastTy, X, X->getName() + ".atomic.load");
5929     OldVal->setAtomic(AO);
5930     // CurBB
5931     // |     /---\
5932 		// ContBB    |
5933     // |     \---/
5934     // ExitBB
5935     BasicBlock *CurBB = Builder.GetInsertBlock();
5936     Instruction *CurBBTI = CurBB->getTerminator();
5937     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
5938     BasicBlock *ExitBB =
5939         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
5940     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
5941                                                 X->getName() + ".atomic.cont");
5942     ContBB->getTerminator()->eraseFromParent();
5943     Builder.restoreIP(AllocaIP);
5944     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
5945     NewAtomicAddr->setName(X->getName() + "x.new.val");
5946     Builder.SetInsertPoint(ContBB);
5947     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
5948     PHI->addIncoming(OldVal, CurBB);
5949     bool IsIntTy = XElemTy->isIntegerTy();
5950     Value *OldExprVal = PHI;
5951     if (!IsIntTy) {
5952       if (XElemTy->isFloatingPointTy()) {
5953         OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
5954                                            X->getName() + ".atomic.fltCast");
5955       } else {
5956         OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
5957                                             X->getName() + ".atomic.ptrCast");
5958       }
5959     }
5960 
5961     Value *Upd = UpdateOp(OldExprVal, Builder);
5962     Builder.CreateStore(Upd, NewAtomicAddr);
5963     LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicAddr);
5964     AtomicOrdering Failure =
5965         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
5966     AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
5967         X, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
5968     Result->setVolatile(VolatileX);
5969     Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
5970     Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
5971     PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
5972     Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
5973 
5974     Res.first = OldExprVal;
5975     Res.second = Upd;
5976 
5977     // set Insertion point in exit block
5978     if (UnreachableInst *ExitTI =
5979             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
5980       CurBBTI->eraseFromParent();
5981       Builder.SetInsertPoint(ExitBB);
5982     } else {
5983       Builder.SetInsertPoint(ExitTI);
5984     }
5985   }
5986 
5987   return Res;
5988 }
5989 
createAtomicCapture(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,AtomicOpValue & V,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool UpdateExpr,bool IsPostfixUpdate,bool IsXBinopExpr)5990 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
5991     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
5992     AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
5993     AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
5994     bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
5995   if (!updateToLocation(Loc))
5996     return Loc.IP;
5997 
5998   LLVM_DEBUG({
5999     Type *XTy = X.Var->getType();
6000     assert(XTy->isPointerTy() &&
6001            "OMP Atomic expects a pointer to target memory");
6002     Type *XElemTy = X.ElemTy;
6003     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
6004             XElemTy->isPointerTy()) &&
6005            "OMP atomic capture expected a scalar type");
6006     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
6007            "OpenMP atomic does not support LT or GT operations");
6008   });
6009 
6010   // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
6011   // 'x' is simply atomically rewritten with 'expr'.
6012   AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
6013   std::pair<Value *, Value *> Result =
6014       emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
6015                        X.IsVolatile, IsXBinopExpr);
6016 
6017   Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
6018   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
6019 
6020   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
6021   return Builder.saveIP();
6022 }
6023 
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly)6024 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
6025     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
6026     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
6027     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
6028     bool IsFailOnly) {
6029 
6030   AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
6031   return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
6032                              IsPostfixUpdate, IsFailOnly, Failure);
6033 }
6034 
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly,AtomicOrdering Failure)6035 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
6036     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
6037     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
6038     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
6039     bool IsFailOnly, AtomicOrdering Failure) {
6040 
6041   if (!updateToLocation(Loc))
6042     return Loc.IP;
6043 
6044   assert(X.Var->getType()->isPointerTy() &&
6045          "OMP atomic expects a pointer to target memory");
6046   // compare capture
6047   if (V.Var) {
6048     assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
6049     assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
6050   }
6051 
6052   bool IsInteger = E->getType()->isIntegerTy();
6053 
6054   if (Op == OMPAtomicCompareOp::EQ) {
6055     AtomicCmpXchgInst *Result = nullptr;
6056     if (!IsInteger) {
6057       IntegerType *IntCastTy =
6058           IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
6059       Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
6060       Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
6061       Result = Builder.CreateAtomicCmpXchg(X.Var, EBCast, DBCast, MaybeAlign(),
6062                                            AO, Failure);
6063     } else {
6064       Result =
6065           Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
6066     }
6067 
6068     if (V.Var) {
6069       Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
6070       if (!IsInteger)
6071         OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
6072       assert(OldValue->getType() == V.ElemTy &&
6073              "OldValue and V must be of same type");
6074       if (IsPostfixUpdate) {
6075         Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
6076       } else {
6077         Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
6078         if (IsFailOnly) {
6079           // CurBB----
6080           //   |     |
6081           //   v     |
6082           // ContBB  |
6083           //   |     |
6084           //   v     |
6085           // ExitBB <-
6086           //
6087           // where ContBB only contains the store of old value to 'v'.
6088           BasicBlock *CurBB = Builder.GetInsertBlock();
6089           Instruction *CurBBTI = CurBB->getTerminator();
6090           CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
6091           BasicBlock *ExitBB = CurBB->splitBasicBlock(
6092               CurBBTI, X.Var->getName() + ".atomic.exit");
6093           BasicBlock *ContBB = CurBB->splitBasicBlock(
6094               CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
6095           ContBB->getTerminator()->eraseFromParent();
6096           CurBB->getTerminator()->eraseFromParent();
6097 
6098           Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
6099 
6100           Builder.SetInsertPoint(ContBB);
6101           Builder.CreateStore(OldValue, V.Var);
6102           Builder.CreateBr(ExitBB);
6103 
6104           if (UnreachableInst *ExitTI =
6105                   dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
6106             CurBBTI->eraseFromParent();
6107             Builder.SetInsertPoint(ExitBB);
6108           } else {
6109             Builder.SetInsertPoint(ExitTI);
6110           }
6111         } else {
6112           Value *CapturedValue =
6113               Builder.CreateSelect(SuccessOrFail, E, OldValue);
6114           Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
6115         }
6116       }
6117     }
6118     // The comparison result has to be stored.
6119     if (R.Var) {
6120       assert(R.Var->getType()->isPointerTy() &&
6121              "r.var must be of pointer type");
6122       assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
6123 
6124       Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
6125       Value *ResultCast = R.IsSigned
6126                               ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
6127                               : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
6128       Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
6129     }
6130   } else {
6131     assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
6132            "Op should be either max or min at this point");
6133     assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
6134 
6135     // Reverse the ordop as the OpenMP forms are different from LLVM forms.
6136     // Let's take max as example.
6137     // OpenMP form:
6138     // x = x > expr ? expr : x;
6139     // LLVM form:
6140     // *ptr = *ptr > val ? *ptr : val;
6141     // We need to transform to LLVM form.
6142     // x = x <= expr ? x : expr;
6143     AtomicRMWInst::BinOp NewOp;
6144     if (IsXBinopExpr) {
6145       if (IsInteger) {
6146         if (X.IsSigned)
6147           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
6148                                                 : AtomicRMWInst::Max;
6149         else
6150           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
6151                                                 : AtomicRMWInst::UMax;
6152       } else {
6153         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
6154                                               : AtomicRMWInst::FMax;
6155       }
6156     } else {
6157       if (IsInteger) {
6158         if (X.IsSigned)
6159           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
6160                                                 : AtomicRMWInst::Min;
6161         else
6162           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
6163                                                 : AtomicRMWInst::UMin;
6164       } else {
6165         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
6166                                               : AtomicRMWInst::FMin;
6167       }
6168     }
6169 
6170     AtomicRMWInst *OldValue =
6171         Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
6172     if (V.Var) {
6173       Value *CapturedValue = nullptr;
6174       if (IsPostfixUpdate) {
6175         CapturedValue = OldValue;
6176       } else {
6177         CmpInst::Predicate Pred;
6178         switch (NewOp) {
6179         case AtomicRMWInst::Max:
6180           Pred = CmpInst::ICMP_SGT;
6181           break;
6182         case AtomicRMWInst::UMax:
6183           Pred = CmpInst::ICMP_UGT;
6184           break;
6185         case AtomicRMWInst::FMax:
6186           Pred = CmpInst::FCMP_OGT;
6187           break;
6188         case AtomicRMWInst::Min:
6189           Pred = CmpInst::ICMP_SLT;
6190           break;
6191         case AtomicRMWInst::UMin:
6192           Pred = CmpInst::ICMP_ULT;
6193           break;
6194         case AtomicRMWInst::FMin:
6195           Pred = CmpInst::FCMP_OLT;
6196           break;
6197         default:
6198           llvm_unreachable("unexpected comparison op");
6199         }
6200         Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
6201         CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
6202       }
6203       Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
6204     }
6205   }
6206 
6207   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
6208 
6209   return Builder.saveIP();
6210 }
6211 
6212 OpenMPIRBuilder::InsertPointTy
createTeams(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,Value * NumTeamsLower,Value * NumTeamsUpper,Value * ThreadLimit,Value * IfExpr)6213 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
6214                              BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
6215                              Value *NumTeamsUpper, Value *ThreadLimit,
6216                              Value *IfExpr) {
6217   if (!updateToLocation(Loc))
6218     return InsertPointTy();
6219 
6220   uint32_t SrcLocStrSize;
6221   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6222   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6223   Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
6224 
6225   // Outer allocation basicblock is the entry block of the current function.
6226   BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
6227   if (&OuterAllocaBB == Builder.GetInsertBlock()) {
6228     BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
6229     Builder.SetInsertPoint(BodyBB, BodyBB->begin());
6230   }
6231 
6232   // The current basic block is split into four basic blocks. After outlining,
6233   // they will be mapped as follows:
6234   // ```
6235   // def current_fn() {
6236   //   current_basic_block:
6237   //     br label %teams.exit
6238   //   teams.exit:
6239   //     ; instructions after teams
6240   // }
6241   //
6242   // def outlined_fn() {
6243   //   teams.alloca:
6244   //     br label %teams.body
6245   //   teams.body:
6246   //     ; instructions within teams body
6247   // }
6248   // ```
6249   BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
6250   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
6251   BasicBlock *AllocaBB =
6252       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
6253 
6254   // Push num_teams
6255   if (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr) {
6256     assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
6257            "if lowerbound is non-null, then upperbound must also be non-null "
6258            "for bounds on num_teams");
6259 
6260     if (NumTeamsUpper == nullptr)
6261       NumTeamsUpper = Builder.getInt32(0);
6262 
6263     if (NumTeamsLower == nullptr)
6264       NumTeamsLower = NumTeamsUpper;
6265 
6266     if (IfExpr) {
6267       assert(IfExpr->getType()->isIntegerTy() &&
6268              "argument to if clause must be an integer value");
6269 
6270       // upper = ifexpr ? upper : 1
6271       if (IfExpr->getType() != Int1)
6272         IfExpr = Builder.CreateICmpNE(IfExpr,
6273                                       ConstantInt::get(IfExpr->getType(), 0));
6274       NumTeamsUpper = Builder.CreateSelect(
6275           IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
6276 
6277       // lower = ifexpr ? lower : 1
6278       NumTeamsLower = Builder.CreateSelect(
6279           IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
6280     }
6281 
6282     if (ThreadLimit == nullptr)
6283       ThreadLimit = Builder.getInt32(0);
6284 
6285     Value *ThreadNum = getOrCreateThreadID(Ident);
6286     Builder.CreateCall(
6287         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
6288         {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
6289   }
6290   // Generate the body of teams.
6291   InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
6292   InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
6293   BodyGenCB(AllocaIP, CodeGenIP);
6294 
6295   OutlineInfo OI;
6296   OI.EntryBB = AllocaBB;
6297   OI.ExitBB = ExitBB;
6298   OI.OuterAllocaBB = &OuterAllocaBB;
6299 
6300   // Insert fake values for global tid and bound tid.
6301   std::stack<Instruction *> ToBeDeleted;
6302   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
6303   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
6304       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
6305   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
6306       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
6307 
6308   OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable {
6309     // The stale call instruction will be replaced with a new call instruction
6310     // for runtime call with the outlined function.
6311 
6312     assert(OutlinedFn.getNumUses() == 1 &&
6313            "there must be a single user for the outlined function");
6314     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
6315     ToBeDeleted.push(StaleCI);
6316 
6317     assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
6318            "Outlined function must have two or three arguments only");
6319 
6320     bool HasShared = OutlinedFn.arg_size() == 3;
6321 
6322     OutlinedFn.getArg(0)->setName("global.tid.ptr");
6323     OutlinedFn.getArg(1)->setName("bound.tid.ptr");
6324     if (HasShared)
6325       OutlinedFn.getArg(2)->setName("data");
6326 
6327     // Call to the runtime function for teams in the current function.
6328     assert(StaleCI && "Error while outlining - no CallInst user found for the "
6329                       "outlined function.");
6330     Builder.SetInsertPoint(StaleCI);
6331     SmallVector<Value *> Args = {
6332         Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
6333     if (HasShared)
6334       Args.push_back(StaleCI->getArgOperand(2));
6335     Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
6336                            omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
6337                        Args);
6338 
6339     while (!ToBeDeleted.empty()) {
6340       ToBeDeleted.top()->eraseFromParent();
6341       ToBeDeleted.pop();
6342     }
6343   };
6344 
6345   addOutlineInfo(std::move(OI));
6346 
6347   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
6348 
6349   return Builder.saveIP();
6350 }
6351 
6352 GlobalVariable *
createOffloadMapnames(SmallVectorImpl<llvm::Constant * > & Names,std::string VarName)6353 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
6354                                        std::string VarName) {
6355   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
6356       llvm::ArrayType::get(llvm::PointerType::getUnqual(M.getContext()),
6357                            Names.size()),
6358       Names);
6359   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
6360       M, MapNamesArrayInit->getType(),
6361       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
6362       VarName);
6363   return MapNamesArrayGlobal;
6364 }
6365 
6366 // Create all simple and struct types exposed by the runtime and remember
6367 // the llvm::PointerTypes of them for easy access later.
initializeTypes(Module & M)6368 void OpenMPIRBuilder::initializeTypes(Module &M) {
6369   LLVMContext &Ctx = M.getContext();
6370   StructType *T;
6371 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
6372 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize)                             \
6373   VarName##Ty = ArrayType::get(ElemTy, ArraySize);                             \
6374   VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
6375 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...)                  \
6376   VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg);            \
6377   VarName##Ptr = PointerType::getUnqual(VarName);
6378 #define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...)                      \
6379   T = StructType::getTypeByName(Ctx, StructName);                              \
6380   if (!T)                                                                      \
6381     T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed);            \
6382   VarName = T;                                                                 \
6383   VarName##Ptr = PointerType::getUnqual(T);
6384 #include "llvm/Frontend/OpenMP/OMPKinds.def"
6385 }
6386 
collectBlocks(SmallPtrSetImpl<BasicBlock * > & BlockSet,SmallVectorImpl<BasicBlock * > & BlockVector)6387 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
6388     SmallPtrSetImpl<BasicBlock *> &BlockSet,
6389     SmallVectorImpl<BasicBlock *> &BlockVector) {
6390   SmallVector<BasicBlock *, 32> Worklist;
6391   BlockSet.insert(EntryBB);
6392   BlockSet.insert(ExitBB);
6393 
6394   Worklist.push_back(EntryBB);
6395   while (!Worklist.empty()) {
6396     BasicBlock *BB = Worklist.pop_back_val();
6397     BlockVector.push_back(BB);
6398     for (BasicBlock *SuccBB : successors(BB))
6399       if (BlockSet.insert(SuccBB).second)
6400         Worklist.push_back(SuccBB);
6401   }
6402 }
6403 
createOffloadEntry(Constant * ID,Constant * Addr,uint64_t Size,int32_t Flags,GlobalValue::LinkageTypes,StringRef Name)6404 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
6405                                          uint64_t Size, int32_t Flags,
6406                                          GlobalValue::LinkageTypes,
6407                                          StringRef Name) {
6408   if (!Config.isGPU()) {
6409     llvm::offloading::emitOffloadingEntry(
6410         M, ID, Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0,
6411         "omp_offloading_entries");
6412     return;
6413   }
6414   // TODO: Add support for global variables on the device after declare target
6415   // support.
6416   Function *Fn = dyn_cast<Function>(Addr);
6417   if (!Fn)
6418     return;
6419 
6420   Module &M = *(Fn->getParent());
6421   LLVMContext &Ctx = M.getContext();
6422 
6423   // Get "nvvm.annotations" metadata node.
6424   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
6425 
6426   Metadata *MDVals[] = {
6427       ConstantAsMetadata::get(Fn), MDString::get(Ctx, "kernel"),
6428       ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), 1))};
6429   // Append metadata to nvvm.annotations.
6430   MD->addOperand(MDNode::get(Ctx, MDVals));
6431 
6432   // Add a function attribute for the kernel.
6433   Fn->addFnAttr(Attribute::get(Ctx, "kernel"));
6434   if (T.isAMDGCN())
6435     Fn->addFnAttr("uniform-work-group-size", "true");
6436   Fn->addFnAttr(Attribute::MustProgress);
6437 }
6438 
6439 // We only generate metadata for function that contain target regions.
createOffloadEntriesAndInfoMetadata(EmitMetadataErrorReportFunctionTy & ErrorFn)6440 void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
6441     EmitMetadataErrorReportFunctionTy &ErrorFn) {
6442 
6443   // If there are no entries, we don't need to do anything.
6444   if (OffloadInfoManager.empty())
6445     return;
6446 
6447   LLVMContext &C = M.getContext();
6448   SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
6449                         TargetRegionEntryInfo>,
6450               16>
6451       OrderedEntries(OffloadInfoManager.size());
6452 
6453   // Auxiliary methods to create metadata values and strings.
6454   auto &&GetMDInt = [this](unsigned V) {
6455     return ConstantAsMetadata::get(ConstantInt::get(Builder.getInt32Ty(), V));
6456   };
6457 
6458   auto &&GetMDString = [&C](StringRef V) { return MDString::get(C, V); };
6459 
6460   // Create the offloading info metadata node.
6461   NamedMDNode *MD = M.getOrInsertNamedMetadata("omp_offload.info");
6462   auto &&TargetRegionMetadataEmitter =
6463       [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
6464           const TargetRegionEntryInfo &EntryInfo,
6465           const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
6466         // Generate metadata for target regions. Each entry of this metadata
6467         // contains:
6468         // - Entry 0 -> Kind of this type of metadata (0).
6469         // - Entry 1 -> Device ID of the file where the entry was identified.
6470         // - Entry 2 -> File ID of the file where the entry was identified.
6471         // - Entry 3 -> Mangled name of the function where the entry was
6472         // identified.
6473         // - Entry 4 -> Line in the file where the entry was identified.
6474         // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
6475         // - Entry 6 -> Order the entry was created.
6476         // The first element of the metadata node is the kind.
6477         Metadata *Ops[] = {
6478             GetMDInt(E.getKind()),      GetMDInt(EntryInfo.DeviceID),
6479             GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
6480             GetMDInt(EntryInfo.Line),   GetMDInt(EntryInfo.Count),
6481             GetMDInt(E.getOrder())};
6482 
6483         // Save this entry in the right position of the ordered entries array.
6484         OrderedEntries[E.getOrder()] = std::make_pair(&E, EntryInfo);
6485 
6486         // Add metadata to the named metadata node.
6487         MD->addOperand(MDNode::get(C, Ops));
6488       };
6489 
6490   OffloadInfoManager.actOnTargetRegionEntriesInfo(TargetRegionMetadataEmitter);
6491 
6492   // Create function that emits metadata for each device global variable entry;
6493   auto &&DeviceGlobalVarMetadataEmitter =
6494       [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
6495           StringRef MangledName,
6496           const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
6497         // Generate metadata for global variables. Each entry of this metadata
6498         // contains:
6499         // - Entry 0 -> Kind of this type of metadata (1).
6500         // - Entry 1 -> Mangled name of the variable.
6501         // - Entry 2 -> Declare target kind.
6502         // - Entry 3 -> Order the entry was created.
6503         // The first element of the metadata node is the kind.
6504         Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
6505                            GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
6506 
6507         // Save this entry in the right position of the ordered entries array.
6508         TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
6509         OrderedEntries[E.getOrder()] = std::make_pair(&E, varInfo);
6510 
6511         // Add metadata to the named metadata node.
6512         MD->addOperand(MDNode::get(C, Ops));
6513       };
6514 
6515   OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
6516       DeviceGlobalVarMetadataEmitter);
6517 
6518   for (const auto &E : OrderedEntries) {
6519     assert(E.first && "All ordered entries must exist!");
6520     if (const auto *CE =
6521             dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
6522                 E.first)) {
6523       if (!CE->getID() || !CE->getAddress()) {
6524         // Do not blame the entry if the parent funtion is not emitted.
6525         TargetRegionEntryInfo EntryInfo = E.second;
6526         StringRef FnName = EntryInfo.ParentName;
6527         if (!M.getNamedValue(FnName))
6528           continue;
6529         ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
6530         continue;
6531       }
6532       createOffloadEntry(CE->getID(), CE->getAddress(),
6533                          /*Size=*/0, CE->getFlags(),
6534                          GlobalValue::WeakAnyLinkage);
6535     } else if (const auto *CE = dyn_cast<
6536                    OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
6537                    E.first)) {
6538       OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
6539           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
6540               CE->getFlags());
6541       switch (Flags) {
6542       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
6543       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
6544         if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
6545           continue;
6546         if (!CE->getAddress()) {
6547           ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
6548           continue;
6549         }
6550         // The vaiable has no definition - no need to add the entry.
6551         if (CE->getVarSize() == 0)
6552           continue;
6553         break;
6554       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
6555         assert(((Config.isTargetDevice() && !CE->getAddress()) ||
6556                 (!Config.isTargetDevice() && CE->getAddress())) &&
6557                "Declaret target link address is set.");
6558         if (Config.isTargetDevice())
6559           continue;
6560         if (!CE->getAddress()) {
6561           ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
6562           continue;
6563         }
6564         break;
6565       default:
6566         break;
6567       }
6568 
6569       // Hidden or internal symbols on the device are not externally visible.
6570       // We should not attempt to register them by creating an offloading
6571       // entry. Indirect variables are handled separately on the device.
6572       if (auto *GV = dyn_cast<GlobalValue>(CE->getAddress()))
6573         if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
6574             Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
6575           continue;
6576 
6577       // Indirect globals need to use a special name that doesn't match the name
6578       // of the associated host global.
6579       if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
6580         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
6581                            Flags, CE->getLinkage(), CE->getVarName());
6582       else
6583         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
6584                            Flags, CE->getLinkage());
6585 
6586     } else {
6587       llvm_unreachable("Unsupported entry kind.");
6588     }
6589   }
6590 }
6591 
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,StringRef ParentName,unsigned DeviceID,unsigned FileID,unsigned Line,unsigned Count)6592 void TargetRegionEntryInfo::getTargetRegionEntryFnName(
6593     SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
6594     unsigned FileID, unsigned Line, unsigned Count) {
6595   raw_svector_ostream OS(Name);
6596   OS << "__omp_offloading" << llvm::format("_%x", DeviceID)
6597      << llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
6598   if (Count)
6599     OS << "_" << Count;
6600 }
6601 
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,const TargetRegionEntryInfo & EntryInfo)6602 void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
6603     SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
6604   unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
6605   TargetRegionEntryInfo::getTargetRegionEntryFnName(
6606       Name, EntryInfo.ParentName, EntryInfo.DeviceID, EntryInfo.FileID,
6607       EntryInfo.Line, NewCount);
6608 }
6609 
6610 TargetRegionEntryInfo
getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,StringRef ParentName)6611 OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
6612                                           StringRef ParentName) {
6613   sys::fs::UniqueID ID;
6614   auto FileIDInfo = CallBack();
6615   if (auto EC = sys::fs::getUniqueID(std::get<0>(FileIDInfo), ID)) {
6616     report_fatal_error(("Unable to get unique ID for file, during "
6617                         "getTargetEntryUniqueInfo, error message: " +
6618                         EC.message())
6619                            .c_str());
6620   }
6621 
6622   return TargetRegionEntryInfo(ParentName, ID.getDevice(), ID.getFile(),
6623                                std::get<1>(FileIDInfo));
6624 }
6625 
getFlagMemberOffset()6626 unsigned OpenMPIRBuilder::getFlagMemberOffset() {
6627   unsigned Offset = 0;
6628   for (uint64_t Remain =
6629            static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
6630                omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
6631        !(Remain & 1); Remain = Remain >> 1)
6632     Offset++;
6633   return Offset;
6634 }
6635 
6636 omp::OpenMPOffloadMappingFlags
getMemberOfFlag(unsigned Position)6637 OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
6638   // Rotate by getFlagMemberOffset() bits.
6639   return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
6640                                                      << getFlagMemberOffset());
6641 }
6642 
setCorrectMemberOfFlag(omp::OpenMPOffloadMappingFlags & Flags,omp::OpenMPOffloadMappingFlags MemberOfFlag)6643 void OpenMPIRBuilder::setCorrectMemberOfFlag(
6644     omp::OpenMPOffloadMappingFlags &Flags,
6645     omp::OpenMPOffloadMappingFlags MemberOfFlag) {
6646   // If the entry is PTR_AND_OBJ but has not been marked with the special
6647   // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
6648   // marked as MEMBER_OF.
6649   if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
6650           Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
6651       static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
6652           (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
6653           omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
6654     return;
6655 
6656   // Reset the placeholder value to prepare the flag for the assignment of the
6657   // proper MEMBER_OF value.
6658   Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
6659   Flags |= MemberOfFlag;
6660 }
6661 
getAddrOfDeclareTargetVar(OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,bool IsDeclaration,bool IsExternallyVisible,TargetRegionEntryInfo EntryInfo,StringRef MangledName,std::vector<GlobalVariable * > & GeneratedRefs,bool OpenMPSIMD,std::vector<Triple> TargetTriple,Type * LlvmPtrTy,std::function<Constant * ()> GlobalInitializer,std::function<GlobalValue::LinkageTypes ()> VariableLinkage)6662 Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
6663     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
6664     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
6665     bool IsDeclaration, bool IsExternallyVisible,
6666     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
6667     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
6668     std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
6669     std::function<Constant *()> GlobalInitializer,
6670     std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
6671   // TODO: convert this to utilise the IRBuilder Config rather than
6672   // a passed down argument.
6673   if (OpenMPSIMD)
6674     return nullptr;
6675 
6676   if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
6677       ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
6678         CaptureClause ==
6679             OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
6680        Config.hasRequiresUnifiedSharedMemory())) {
6681     SmallString<64> PtrName;
6682     {
6683       raw_svector_ostream OS(PtrName);
6684       OS << MangledName;
6685       if (!IsExternallyVisible)
6686         OS << format("_%x", EntryInfo.FileID);
6687       OS << "_decl_tgt_ref_ptr";
6688     }
6689 
6690     Value *Ptr = M.getNamedValue(PtrName);
6691 
6692     if (!Ptr) {
6693       GlobalValue *GlobalValue = M.getNamedValue(MangledName);
6694       Ptr = getOrCreateInternalVariable(LlvmPtrTy, PtrName);
6695 
6696       auto *GV = cast<GlobalVariable>(Ptr);
6697       GV->setLinkage(GlobalValue::WeakAnyLinkage);
6698 
6699       if (!Config.isTargetDevice()) {
6700         if (GlobalInitializer)
6701           GV->setInitializer(GlobalInitializer());
6702         else
6703           GV->setInitializer(GlobalValue);
6704       }
6705 
6706       registerTargetGlobalVariable(
6707           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
6708           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
6709           GlobalInitializer, VariableLinkage, LlvmPtrTy, cast<Constant>(Ptr));
6710     }
6711 
6712     return cast<Constant>(Ptr);
6713   }
6714 
6715   return nullptr;
6716 }
6717 
registerTargetGlobalVariable(OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,bool IsDeclaration,bool IsExternallyVisible,TargetRegionEntryInfo EntryInfo,StringRef MangledName,std::vector<GlobalVariable * > & GeneratedRefs,bool OpenMPSIMD,std::vector<Triple> TargetTriple,std::function<Constant * ()> GlobalInitializer,std::function<GlobalValue::LinkageTypes ()> VariableLinkage,Type * LlvmPtrTy,Constant * Addr)6718 void OpenMPIRBuilder::registerTargetGlobalVariable(
6719     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
6720     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
6721     bool IsDeclaration, bool IsExternallyVisible,
6722     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
6723     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
6724     std::vector<Triple> TargetTriple,
6725     std::function<Constant *()> GlobalInitializer,
6726     std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
6727     Constant *Addr) {
6728   if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
6729       (TargetTriple.empty() && !Config.isTargetDevice()))
6730     return;
6731 
6732   OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
6733   StringRef VarName;
6734   int64_t VarSize;
6735   GlobalValue::LinkageTypes Linkage;
6736 
6737   if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
6738        CaptureClause ==
6739            OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
6740       !Config.hasRequiresUnifiedSharedMemory()) {
6741     Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
6742     VarName = MangledName;
6743     GlobalValue *LlvmVal = M.getNamedValue(VarName);
6744 
6745     if (!IsDeclaration)
6746       VarSize = divideCeil(
6747           M.getDataLayout().getTypeSizeInBits(LlvmVal->getValueType()), 8);
6748     else
6749       VarSize = 0;
6750     Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
6751 
6752     // This is a workaround carried over from Clang which prevents undesired
6753     // optimisation of internal variables.
6754     if (Config.isTargetDevice() &&
6755         (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
6756       // Do not create a "ref-variable" if the original is not also available
6757       // on the host.
6758       if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
6759         return;
6760 
6761       std::string RefName = createPlatformSpecificName({VarName, "ref"});
6762 
6763       if (!M.getNamedValue(RefName)) {
6764         Constant *AddrRef =
6765             getOrCreateInternalVariable(Addr->getType(), RefName);
6766         auto *GvAddrRef = cast<GlobalVariable>(AddrRef);
6767         GvAddrRef->setConstant(true);
6768         GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
6769         GvAddrRef->setInitializer(Addr);
6770         GeneratedRefs.push_back(GvAddrRef);
6771       }
6772     }
6773   } else {
6774     if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
6775       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
6776     else
6777       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
6778 
6779     if (Config.isTargetDevice()) {
6780       VarName = (Addr) ? Addr->getName() : "";
6781       Addr = nullptr;
6782     } else {
6783       Addr = getAddrOfDeclareTargetVar(
6784           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
6785           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
6786           LlvmPtrTy, GlobalInitializer, VariableLinkage);
6787       VarName = (Addr) ? Addr->getName() : "";
6788     }
6789     VarSize = M.getDataLayout().getPointerSize();
6790     Linkage = GlobalValue::WeakAnyLinkage;
6791   }
6792 
6793   OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
6794                                                       Flags, Linkage);
6795 }
6796 
6797 /// Loads all the offload entries information from the host IR
6798 /// metadata.
loadOffloadInfoMetadata(Module & M)6799 void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
6800   // If we are in target mode, load the metadata from the host IR. This code has
6801   // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
6802 
6803   NamedMDNode *MD = M.getNamedMetadata(ompOffloadInfoName);
6804   if (!MD)
6805     return;
6806 
6807   for (MDNode *MN : MD->operands()) {
6808     auto &&GetMDInt = [MN](unsigned Idx) {
6809       auto *V = cast<ConstantAsMetadata>(MN->getOperand(Idx));
6810       return cast<ConstantInt>(V->getValue())->getZExtValue();
6811     };
6812 
6813     auto &&GetMDString = [MN](unsigned Idx) {
6814       auto *V = cast<MDString>(MN->getOperand(Idx));
6815       return V->getString();
6816     };
6817 
6818     switch (GetMDInt(0)) {
6819     default:
6820       llvm_unreachable("Unexpected metadata!");
6821       break;
6822     case OffloadEntriesInfoManager::OffloadEntryInfo::
6823         OffloadingEntryInfoTargetRegion: {
6824       TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
6825                                       /*DeviceID=*/GetMDInt(1),
6826                                       /*FileID=*/GetMDInt(2),
6827                                       /*Line=*/GetMDInt(4),
6828                                       /*Count=*/GetMDInt(5));
6829       OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
6830                                                          /*Order=*/GetMDInt(6));
6831       break;
6832     }
6833     case OffloadEntriesInfoManager::OffloadEntryInfo::
6834         OffloadingEntryInfoDeviceGlobalVar:
6835       OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
6836           /*MangledName=*/GetMDString(1),
6837           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
6838               /*Flags=*/GetMDInt(2)),
6839           /*Order=*/GetMDInt(3));
6840       break;
6841     }
6842   }
6843 }
6844 
loadOffloadInfoMetadata(StringRef HostFilePath)6845 void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
6846   if (HostFilePath.empty())
6847     return;
6848 
6849   auto Buf = MemoryBuffer::getFile(HostFilePath);
6850   if (std::error_code Err = Buf.getError()) {
6851     report_fatal_error(("error opening host file from host file path inside of "
6852                         "OpenMPIRBuilder: " +
6853                         Err.message())
6854                            .c_str());
6855   }
6856 
6857   LLVMContext Ctx;
6858   auto M = expectedToErrorOrAndEmitErrors(
6859       Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
6860   if (std::error_code Err = M.getError()) {
6861     report_fatal_error(
6862         ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
6863             .c_str());
6864   }
6865 
6866   loadOffloadInfoMetadata(*M.get());
6867 }
6868 
createRegisterRequires(StringRef Name)6869 Function *OpenMPIRBuilder::createRegisterRequires(StringRef Name) {
6870   // Skip the creation of the registration function if this is device codegen
6871   if (Config.isTargetDevice())
6872     return nullptr;
6873 
6874   Builder.ClearInsertionPoint();
6875 
6876   // Create registration function prototype
6877   auto *RegFnTy = FunctionType::get(Builder.getVoidTy(), {});
6878   auto *RegFn = Function::Create(
6879       RegFnTy, GlobalVariable::LinkageTypes::InternalLinkage, Name, M);
6880   RegFn->setSection(".text.startup");
6881   RegFn->addFnAttr(Attribute::NoInline);
6882   RegFn->addFnAttr(Attribute::NoUnwind);
6883 
6884   // Create registration function body
6885   auto *BB = BasicBlock::Create(M.getContext(), "entry", RegFn);
6886   ConstantInt *FlagsVal =
6887       ConstantInt::getSigned(Builder.getInt64Ty(), Config.getRequiresFlags());
6888   Function *RTLRegFn = getOrCreateRuntimeFunctionPtr(
6889       omp::RuntimeFunction::OMPRTL___tgt_register_requires);
6890 
6891   Builder.SetInsertPoint(BB);
6892   Builder.CreateCall(RTLRegFn, {FlagsVal});
6893   Builder.CreateRetVoid();
6894 
6895   return RegFn;
6896 }
6897 
6898 //===----------------------------------------------------------------------===//
6899 // OffloadEntriesInfoManager
6900 //===----------------------------------------------------------------------===//
6901 
empty() const6902 bool OffloadEntriesInfoManager::empty() const {
6903   return OffloadEntriesTargetRegion.empty() &&
6904          OffloadEntriesDeviceGlobalVar.empty();
6905 }
6906 
getTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo) const6907 unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
6908     const TargetRegionEntryInfo &EntryInfo) const {
6909   auto It = OffloadEntriesTargetRegionCount.find(
6910       getTargetRegionEntryCountKey(EntryInfo));
6911   if (It == OffloadEntriesTargetRegionCount.end())
6912     return 0;
6913   return It->second;
6914 }
6915 
incrementTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo)6916 void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
6917     const TargetRegionEntryInfo &EntryInfo) {
6918   OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
6919       EntryInfo.Count + 1;
6920 }
6921 
6922 /// Initialize target region entry.
initializeTargetRegionEntryInfo(const TargetRegionEntryInfo & EntryInfo,unsigned Order)6923 void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
6924     const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
6925   OffloadEntriesTargetRegion[EntryInfo] =
6926       OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
6927                                    OMPTargetRegionEntryTargetRegion);
6928   ++OffloadingEntriesNum;
6929 }
6930 
registerTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,Constant * Addr,Constant * ID,OMPTargetRegionEntryKind Flags)6931 void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
6932     TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
6933     OMPTargetRegionEntryKind Flags) {
6934   assert(EntryInfo.Count == 0 && "expected default EntryInfo");
6935 
6936   // Update the EntryInfo with the next available count for this location.
6937   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
6938 
6939   // If we are emitting code for a target, the entry is already initialized,
6940   // only has to be registered.
6941   if (OMPBuilder->Config.isTargetDevice()) {
6942     // This could happen if the device compilation is invoked standalone.
6943     if (!hasTargetRegionEntryInfo(EntryInfo)) {
6944       return;
6945     }
6946     auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
6947     Entry.setAddress(Addr);
6948     Entry.setID(ID);
6949     Entry.setFlags(Flags);
6950   } else {
6951     if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
6952         hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
6953       return;
6954     assert(!hasTargetRegionEntryInfo(EntryInfo) &&
6955            "Target region entry already registered!");
6956     OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
6957     OffloadEntriesTargetRegion[EntryInfo] = Entry;
6958     ++OffloadingEntriesNum;
6959   }
6960   incrementTargetRegionEntryInfoCount(EntryInfo);
6961 }
6962 
hasTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,bool IgnoreAddressId) const6963 bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
6964     TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
6965 
6966   // Update the EntryInfo with the next available count for this location.
6967   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
6968 
6969   auto It = OffloadEntriesTargetRegion.find(EntryInfo);
6970   if (It == OffloadEntriesTargetRegion.end()) {
6971     return false;
6972   }
6973   // Fail if this entry is already registered.
6974   if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
6975     return false;
6976   return true;
6977 }
6978 
actOnTargetRegionEntriesInfo(const OffloadTargetRegionEntryInfoActTy & Action)6979 void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
6980     const OffloadTargetRegionEntryInfoActTy &Action) {
6981   // Scan all target region entries and perform the provided action.
6982   for (const auto &It : OffloadEntriesTargetRegion) {
6983     Action(It.first, It.second);
6984   }
6985 }
6986 
initializeDeviceGlobalVarEntryInfo(StringRef Name,OMPTargetGlobalVarEntryKind Flags,unsigned Order)6987 void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
6988     StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
6989   OffloadEntriesDeviceGlobalVar.try_emplace(Name, Order, Flags);
6990   ++OffloadingEntriesNum;
6991 }
6992 
registerDeviceGlobalVarEntryInfo(StringRef VarName,Constant * Addr,int64_t VarSize,OMPTargetGlobalVarEntryKind Flags,GlobalValue::LinkageTypes Linkage)6993 void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
6994     StringRef VarName, Constant *Addr, int64_t VarSize,
6995     OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
6996   if (OMPBuilder->Config.isTargetDevice()) {
6997     // This could happen if the device compilation is invoked standalone.
6998     if (!hasDeviceGlobalVarEntryInfo(VarName))
6999       return;
7000     auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
7001     if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
7002       if (Entry.getVarSize() == 0) {
7003         Entry.setVarSize(VarSize);
7004         Entry.setLinkage(Linkage);
7005       }
7006       return;
7007     }
7008     Entry.setVarSize(VarSize);
7009     Entry.setLinkage(Linkage);
7010     Entry.setAddress(Addr);
7011   } else {
7012     if (hasDeviceGlobalVarEntryInfo(VarName)) {
7013       auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
7014       assert(Entry.isValid() && Entry.getFlags() == Flags &&
7015              "Entry not initialized!");
7016       if (Entry.getVarSize() == 0) {
7017         Entry.setVarSize(VarSize);
7018         Entry.setLinkage(Linkage);
7019       }
7020       return;
7021     }
7022     if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
7023       OffloadEntriesDeviceGlobalVar.try_emplace(VarName, OffloadingEntriesNum,
7024                                                 Addr, VarSize, Flags, Linkage,
7025                                                 VarName.str());
7026     else
7027       OffloadEntriesDeviceGlobalVar.try_emplace(
7028           VarName, OffloadingEntriesNum, Addr, VarSize, Flags, Linkage, "");
7029     ++OffloadingEntriesNum;
7030   }
7031 }
7032 
actOnDeviceGlobalVarEntriesInfo(const OffloadDeviceGlobalVarEntryInfoActTy & Action)7033 void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
7034     const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
7035   // Scan all target region entries and perform the provided action.
7036   for (const auto &E : OffloadEntriesDeviceGlobalVar)
7037     Action(E.getKey(), E.getValue());
7038 }
7039 
7040 //===----------------------------------------------------------------------===//
7041 // CanonicalLoopInfo
7042 //===----------------------------------------------------------------------===//
7043 
collectControlBlocks(SmallVectorImpl<BasicBlock * > & BBs)7044 void CanonicalLoopInfo::collectControlBlocks(
7045     SmallVectorImpl<BasicBlock *> &BBs) {
7046   // We only count those BBs as control block for which we do not need to
7047   // reverse the CFG, i.e. not the loop body which can contain arbitrary control
7048   // flow. For consistency, this also means we do not add the Body block, which
7049   // is just the entry to the body code.
7050   BBs.reserve(BBs.size() + 6);
7051   BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
7052 }
7053 
getPreheader() const7054 BasicBlock *CanonicalLoopInfo::getPreheader() const {
7055   assert(isValid() && "Requires a valid canonical loop");
7056   for (BasicBlock *Pred : predecessors(Header)) {
7057     if (Pred != Latch)
7058       return Pred;
7059   }
7060   llvm_unreachable("Missing preheader");
7061 }
7062 
setTripCount(Value * TripCount)7063 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
7064   assert(isValid() && "Requires a valid canonical loop");
7065 
7066   Instruction *CmpI = &getCond()->front();
7067   assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
7068   CmpI->setOperand(1, TripCount);
7069 
7070 #ifndef NDEBUG
7071   assertOK();
7072 #endif
7073 }
7074 
mapIndVar(llvm::function_ref<Value * (Instruction *)> Updater)7075 void CanonicalLoopInfo::mapIndVar(
7076     llvm::function_ref<Value *(Instruction *)> Updater) {
7077   assert(isValid() && "Requires a valid canonical loop");
7078 
7079   Instruction *OldIV = getIndVar();
7080 
7081   // Record all uses excluding those introduced by the updater. Uses by the
7082   // CanonicalLoopInfo itself to keep track of the number of iterations are
7083   // excluded.
7084   SmallVector<Use *> ReplacableUses;
7085   for (Use &U : OldIV->uses()) {
7086     auto *User = dyn_cast<Instruction>(U.getUser());
7087     if (!User)
7088       continue;
7089     if (User->getParent() == getCond())
7090       continue;
7091     if (User->getParent() == getLatch())
7092       continue;
7093     ReplacableUses.push_back(&U);
7094   }
7095 
7096   // Run the updater that may introduce new uses
7097   Value *NewIV = Updater(OldIV);
7098 
7099   // Replace the old uses with the value returned by the updater.
7100   for (Use *U : ReplacableUses)
7101     U->set(NewIV);
7102 
7103 #ifndef NDEBUG
7104   assertOK();
7105 #endif
7106 }
7107 
assertOK() const7108 void CanonicalLoopInfo::assertOK() const {
7109 #ifndef NDEBUG
7110   // No constraints if this object currently does not describe a loop.
7111   if (!isValid())
7112     return;
7113 
7114   BasicBlock *Preheader = getPreheader();
7115   BasicBlock *Body = getBody();
7116   BasicBlock *After = getAfter();
7117 
7118   // Verify standard control-flow we use for OpenMP loops.
7119   assert(Preheader);
7120   assert(isa<BranchInst>(Preheader->getTerminator()) &&
7121          "Preheader must terminate with unconditional branch");
7122   assert(Preheader->getSingleSuccessor() == Header &&
7123          "Preheader must jump to header");
7124 
7125   assert(Header);
7126   assert(isa<BranchInst>(Header->getTerminator()) &&
7127          "Header must terminate with unconditional branch");
7128   assert(Header->getSingleSuccessor() == Cond &&
7129          "Header must jump to exiting block");
7130 
7131   assert(Cond);
7132   assert(Cond->getSinglePredecessor() == Header &&
7133          "Exiting block only reachable from header");
7134 
7135   assert(isa<BranchInst>(Cond->getTerminator()) &&
7136          "Exiting block must terminate with conditional branch");
7137   assert(size(successors(Cond)) == 2 &&
7138          "Exiting block must have two successors");
7139   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
7140          "Exiting block's first successor jump to the body");
7141   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
7142          "Exiting block's second successor must exit the loop");
7143 
7144   assert(Body);
7145   assert(Body->getSinglePredecessor() == Cond &&
7146          "Body only reachable from exiting block");
7147   assert(!isa<PHINode>(Body->front()));
7148 
7149   assert(Latch);
7150   assert(isa<BranchInst>(Latch->getTerminator()) &&
7151          "Latch must terminate with unconditional branch");
7152   assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
7153   // TODO: To support simple redirecting of the end of the body code that has
7154   // multiple; introduce another auxiliary basic block like preheader and after.
7155   assert(Latch->getSinglePredecessor() != nullptr);
7156   assert(!isa<PHINode>(Latch->front()));
7157 
7158   assert(Exit);
7159   assert(isa<BranchInst>(Exit->getTerminator()) &&
7160          "Exit block must terminate with unconditional branch");
7161   assert(Exit->getSingleSuccessor() == After &&
7162          "Exit block must jump to after block");
7163 
7164   assert(After);
7165   assert(After->getSinglePredecessor() == Exit &&
7166          "After block only reachable from exit block");
7167   assert(After->empty() || !isa<PHINode>(After->front()));
7168 
7169   Instruction *IndVar = getIndVar();
7170   assert(IndVar && "Canonical induction variable not found?");
7171   assert(isa<IntegerType>(IndVar->getType()) &&
7172          "Induction variable must be an integer");
7173   assert(cast<PHINode>(IndVar)->getParent() == Header &&
7174          "Induction variable must be a PHI in the loop header");
7175   assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
7176   assert(
7177       cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
7178   assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
7179 
7180   auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
7181   assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
7182   assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
7183   assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
7184   assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
7185              ->isOne());
7186 
7187   Value *TripCount = getTripCount();
7188   assert(TripCount && "Loop trip count not found?");
7189   assert(IndVar->getType() == TripCount->getType() &&
7190          "Trip count and induction variable must have the same type");
7191 
7192   auto *CmpI = cast<CmpInst>(&Cond->front());
7193   assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
7194          "Exit condition must be a signed less-than comparison");
7195   assert(CmpI->getOperand(0) == IndVar &&
7196          "Exit condition must compare the induction variable");
7197   assert(CmpI->getOperand(1) == TripCount &&
7198          "Exit condition must compare with the trip count");
7199 #endif
7200 }
7201 
invalidate()7202 void CanonicalLoopInfo::invalidate() {
7203   Header = nullptr;
7204   Cond = nullptr;
7205   Latch = nullptr;
7206   Exit = nullptr;
7207 }
7208