1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/ADT/Triple.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/Analysis/ScalarEvolution.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/IR/CFG.h"
24 #include "llvm/IR/DebugInfo.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/MDBuilder.h"
27 #include "llvm/IR/PassManager.h"
28 #include "llvm/IR/Value.h"
29 #include "llvm/MC/TargetRegistry.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Error.h"
32 #include "llvm/Target/TargetMachine.h"
33 #include "llvm/Target/TargetOptions.h"
34 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
35 #include "llvm/Transforms/Utils/CodeExtractor.h"
36 #include "llvm/Transforms/Utils/LoopPeel.h"
37 #include "llvm/Transforms/Utils/ModuleUtils.h"
38 #include "llvm/Transforms/Utils/UnrollLoop.h"
39 
40 #include <sstream>
41 
42 #define DEBUG_TYPE "openmp-ir-builder"
43 
44 using namespace llvm;
45 using namespace omp;
46 
47 static cl::opt<bool>
48     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
49                          cl::desc("Use optimistic attributes describing "
50                                   "'as-if' properties of runtime calls."),
51                          cl::init(false));
52 
53 static cl::opt<double> UnrollThresholdFactor(
54     "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
55     cl::desc("Factor for the unroll threshold to account for code "
56              "simplifications still taking place"),
57     cl::init(1.5));
58 
addAttributes(omp::RuntimeFunction FnID,Function & Fn)59 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
60   LLVMContext &Ctx = Fn.getContext();
61 
62   // Get the function's current attributes.
63   auto Attrs = Fn.getAttributes();
64   auto FnAttrs = Attrs.getFnAttrs();
65   auto RetAttrs = Attrs.getRetAttrs();
66   SmallVector<AttributeSet, 4> ArgAttrs;
67   for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
68     ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
69 
70 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
71 #include "llvm/Frontend/OpenMP/OMPKinds.def"
72 
73   // Add attributes to the function declaration.
74   switch (FnID) {
75 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
76   case Enum:                                                                   \
77     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
78     RetAttrs = RetAttrs.addAttributes(Ctx, RetAttrSet);                        \
79     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
80       ArgAttrs[ArgNo] =                                                        \
81           ArgAttrs[ArgNo].addAttributes(Ctx, ArgAttrSets[ArgNo]);              \
82     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
83     break;
84 #include "llvm/Frontend/OpenMP/OMPKinds.def"
85   default:
86     // Attributes are optional.
87     break;
88   }
89 }
90 
91 FunctionCallee
getOrCreateRuntimeFunction(Module & M,RuntimeFunction FnID)92 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
93   FunctionType *FnTy = nullptr;
94   Function *Fn = nullptr;
95 
96   // Try to find the declation in the module first.
97   switch (FnID) {
98 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...)                          \
99   case Enum:                                                                   \
100     FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__},        \
101                              IsVarArg);                                        \
102     Fn = M.getFunction(Str);                                                   \
103     break;
104 #include "llvm/Frontend/OpenMP/OMPKinds.def"
105   }
106 
107   if (!Fn) {
108     // Create a new declaration if we need one.
109     switch (FnID) {
110 #define OMP_RTL(Enum, Str, ...)                                                \
111   case Enum:                                                                   \
112     Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M);         \
113     break;
114 #include "llvm/Frontend/OpenMP/OMPKinds.def"
115     }
116 
117     // Add information if the runtime function takes a callback function
118     if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
119       if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
120         LLVMContext &Ctx = Fn->getContext();
121         MDBuilder MDB(Ctx);
122         // Annotate the callback behavior of the runtime function:
123         //  - The callback callee is argument number 2 (microtask).
124         //  - The first two arguments of the callback callee are unknown (-1).
125         //  - All variadic arguments to the runtime function are passed to the
126         //    callback callee.
127         Fn->addMetadata(
128             LLVMContext::MD_callback,
129             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
130                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
131       }
132     }
133 
134     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
135                       << " with type " << *Fn->getFunctionType() << "\n");
136     addAttributes(FnID, *Fn);
137 
138   } else {
139     LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
140                       << " with type " << *Fn->getFunctionType() << "\n");
141   }
142 
143   assert(Fn && "Failed to create OpenMP runtime function");
144 
145   // Cast the function to the expected type if necessary
146   Constant *C = ConstantExpr::getBitCast(Fn, FnTy->getPointerTo());
147   return {FnTy, C};
148 }
149 
getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID)150 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
151   FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
152   auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
153   assert(Fn && "Failed to create OpenMP runtime function pointer");
154   return Fn;
155 }
156 
initialize()157 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
158 
finalize(Function * Fn,bool AllowExtractorSinking)159 void OpenMPIRBuilder::finalize(Function *Fn, bool AllowExtractorSinking) {
160   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
161   SmallVector<BasicBlock *, 32> Blocks;
162   SmallVector<OutlineInfo, 16> DeferredOutlines;
163   for (OutlineInfo &OI : OutlineInfos) {
164     // Skip functions that have not finalized yet; may happen with nested
165     // function generation.
166     if (Fn && OI.getFunction() != Fn) {
167       DeferredOutlines.push_back(OI);
168       continue;
169     }
170 
171     ParallelRegionBlockSet.clear();
172     Blocks.clear();
173     OI.collectBlocks(ParallelRegionBlockSet, Blocks);
174 
175     Function *OuterFn = OI.getFunction();
176     CodeExtractorAnalysisCache CEAC(*OuterFn);
177     CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
178                             /* AggregateArgs */ false,
179                             /* BlockFrequencyInfo */ nullptr,
180                             /* BranchProbabilityInfo */ nullptr,
181                             /* AssumptionCache */ nullptr,
182                             /* AllowVarArgs */ true,
183                             /* AllowAlloca */ true,
184                             /* Suffix */ ".omp_par");
185 
186     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
187     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
188                       << " Exit: " << OI.ExitBB->getName() << "\n");
189     assert(Extractor.isEligible() &&
190            "Expected OpenMP outlining to be possible!");
191 
192     Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
193 
194     LLVM_DEBUG(dbgs() << "After      outlining: " << *OuterFn << "\n");
195     LLVM_DEBUG(dbgs() << "   Outlined function: " << *OutlinedFn << "\n");
196     assert(OutlinedFn->getReturnType()->isVoidTy() &&
197            "OpenMP outlined functions should not return a value!");
198 
199     // For compability with the clang CG we move the outlined function after the
200     // one with the parallel region.
201     OutlinedFn->removeFromParent();
202     M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
203 
204     // Remove the artificial entry introduced by the extractor right away, we
205     // made our own entry block after all.
206     {
207       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
208       assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
209       assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
210       if (AllowExtractorSinking) {
211         // Move instructions from the to-be-deleted ArtificialEntry to the entry
212         // basic block of the parallel region. CodeExtractor may have sunk
213         // allocas/bitcasts for values that are solely used in the outlined
214         // region and do not escape.
215         assert(!ArtificialEntry.empty() &&
216                "Expected instructions to sink in the outlined region");
217         for (BasicBlock::iterator It = ArtificialEntry.begin(),
218                                   End = ArtificialEntry.end();
219              It != End;) {
220           Instruction &I = *It;
221           It++;
222 
223           if (I.isTerminator())
224             continue;
225 
226           I.moveBefore(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
227         }
228       }
229       OI.EntryBB->moveBefore(&ArtificialEntry);
230       ArtificialEntry.eraseFromParent();
231     }
232     assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
233     assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
234 
235     // Run a user callback, e.g. to add attributes.
236     if (OI.PostOutlineCB)
237       OI.PostOutlineCB(*OutlinedFn);
238   }
239 
240   // Remove work items that have been completed.
241   OutlineInfos = std::move(DeferredOutlines);
242 }
243 
~OpenMPIRBuilder()244 OpenMPIRBuilder::~OpenMPIRBuilder() {
245   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
246 }
247 
createGlobalFlag(unsigned Value,StringRef Name)248 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
249   IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
250   auto *GV =
251       new GlobalVariable(M, I32Ty,
252                          /* isConstant = */ true, GlobalValue::WeakODRLinkage,
253                          ConstantInt::get(I32Ty, Value), Name);
254 
255   return GV;
256 }
257 
getOrCreateIdent(Constant * SrcLocStr,IdentFlag LocFlags,unsigned Reserve2Flags)258 Value *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
259                                          IdentFlag LocFlags,
260                                          unsigned Reserve2Flags) {
261   // Enable "C-mode".
262   LocFlags |= OMP_IDENT_FLAG_KMPC;
263 
264   Value *&Ident =
265       IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
266   if (!Ident) {
267     Constant *I32Null = ConstantInt::getNullValue(Int32);
268     Constant *IdentData[] = {
269         I32Null, ConstantInt::get(Int32, uint32_t(LocFlags)),
270         ConstantInt::get(Int32, Reserve2Flags), I32Null, SrcLocStr};
271     Constant *Initializer =
272         ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
273 
274     // Look for existing encoding of the location + flags, not needed but
275     // minimizes the difference to the existing solution while we transition.
276     for (GlobalVariable &GV : M.getGlobalList())
277       if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
278         if (GV.getInitializer() == Initializer)
279           Ident = &GV;
280 
281     if (!Ident) {
282       auto *GV = new GlobalVariable(
283           M, OpenMPIRBuilder::Ident,
284           /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
285           nullptr, GlobalValue::NotThreadLocal,
286           M.getDataLayout().getDefaultGlobalsAddressSpace());
287       GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
288       GV->setAlignment(Align(8));
289       Ident = GV;
290     }
291   }
292 
293   return Builder.CreatePointerCast(Ident, IdentPtr);
294 }
295 
getOrCreateSrcLocStr(StringRef LocStr)296 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr) {
297   Constant *&SrcLocStr = SrcLocStrMap[LocStr];
298   if (!SrcLocStr) {
299     Constant *Initializer =
300         ConstantDataArray::getString(M.getContext(), LocStr);
301 
302     // Look for existing encoding of the location, not needed but minimizes the
303     // difference to the existing solution while we transition.
304     for (GlobalVariable &GV : M.getGlobalList())
305       if (GV.isConstant() && GV.hasInitializer() &&
306           GV.getInitializer() == Initializer)
307         return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
308 
309     SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
310                                               /* AddressSpace */ 0, &M);
311   }
312   return SrcLocStr;
313 }
314 
getOrCreateSrcLocStr(StringRef FunctionName,StringRef FileName,unsigned Line,unsigned Column)315 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
316                                                 StringRef FileName,
317                                                 unsigned Line,
318                                                 unsigned Column) {
319   SmallString<128> Buffer;
320   Buffer.push_back(';');
321   Buffer.append(FileName);
322   Buffer.push_back(';');
323   Buffer.append(FunctionName);
324   Buffer.push_back(';');
325   Buffer.append(std::to_string(Line));
326   Buffer.push_back(';');
327   Buffer.append(std::to_string(Column));
328   Buffer.push_back(';');
329   Buffer.push_back(';');
330   return getOrCreateSrcLocStr(Buffer.str());
331 }
332 
getOrCreateDefaultSrcLocStr()333 Constant *OpenMPIRBuilder::getOrCreateDefaultSrcLocStr() {
334   return getOrCreateSrcLocStr(";unknown;unknown;0;0;;");
335 }
336 
getOrCreateSrcLocStr(DebugLoc DL,Function * F)337 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL, Function *F) {
338   DILocation *DIL = DL.get();
339   if (!DIL)
340     return getOrCreateDefaultSrcLocStr();
341   StringRef FileName = M.getName();
342   if (DIFile *DIF = DIL->getFile())
343     if (Optional<StringRef> Source = DIF->getSource())
344       FileName = *Source;
345   StringRef Function = DIL->getScope()->getSubprogram()->getName();
346   if (Function.empty() && F)
347     Function = F->getName();
348   return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
349                               DIL->getColumn());
350 }
351 
352 Constant *
getOrCreateSrcLocStr(const LocationDescription & Loc)353 OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc) {
354   return getOrCreateSrcLocStr(Loc.DL, Loc.IP.getBlock()->getParent());
355 }
356 
getOrCreateThreadID(Value * Ident)357 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
358   return Builder.CreateCall(
359       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
360       "omp_global_thread_num");
361 }
362 
363 OpenMPIRBuilder::InsertPointTy
createBarrier(const LocationDescription & Loc,Directive DK,bool ForceSimpleCall,bool CheckCancelFlag)364 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive DK,
365                                bool ForceSimpleCall, bool CheckCancelFlag) {
366   if (!updateToLocation(Loc))
367     return Loc.IP;
368   return emitBarrierImpl(Loc, DK, ForceSimpleCall, CheckCancelFlag);
369 }
370 
371 OpenMPIRBuilder::InsertPointTy
emitBarrierImpl(const LocationDescription & Loc,Directive Kind,bool ForceSimpleCall,bool CheckCancelFlag)372 OpenMPIRBuilder::emitBarrierImpl(const LocationDescription &Loc, Directive Kind,
373                                  bool ForceSimpleCall, bool CheckCancelFlag) {
374   // Build call __kmpc_cancel_barrier(loc, thread_id) or
375   //            __kmpc_barrier(loc, thread_id);
376 
377   IdentFlag BarrierLocFlags;
378   switch (Kind) {
379   case OMPD_for:
380     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
381     break;
382   case OMPD_sections:
383     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
384     break;
385   case OMPD_single:
386     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
387     break;
388   case OMPD_barrier:
389     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
390     break;
391   default:
392     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
393     break;
394   }
395 
396   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
397   Value *Args[] = {getOrCreateIdent(SrcLocStr, BarrierLocFlags),
398                    getOrCreateThreadID(getOrCreateIdent(SrcLocStr))};
399 
400   // If we are in a cancellable parallel region, barriers are cancellation
401   // points.
402   // TODO: Check why we would force simple calls or to ignore the cancel flag.
403   bool UseCancelBarrier =
404       !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
405 
406   Value *Result =
407       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
408                              UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
409                                               : OMPRTL___kmpc_barrier),
410                          Args);
411 
412   if (UseCancelBarrier && CheckCancelFlag)
413     emitCancelationCheckImpl(Result, OMPD_parallel);
414 
415   return Builder.saveIP();
416 }
417 
418 OpenMPIRBuilder::InsertPointTy
createCancel(const LocationDescription & Loc,Value * IfCondition,omp::Directive CanceledDirective)419 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
420                               Value *IfCondition,
421                               omp::Directive CanceledDirective) {
422   if (!updateToLocation(Loc))
423     return Loc.IP;
424 
425   // LLVM utilities like blocks with terminators.
426   auto *UI = Builder.CreateUnreachable();
427 
428   Instruction *ThenTI = UI, *ElseTI = nullptr;
429   if (IfCondition)
430     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
431   Builder.SetInsertPoint(ThenTI);
432 
433   Value *CancelKind = nullptr;
434   switch (CanceledDirective) {
435 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
436   case DirectiveEnum:                                                          \
437     CancelKind = Builder.getInt32(Value);                                      \
438     break;
439 #include "llvm/Frontend/OpenMP/OMPKinds.def"
440   default:
441     llvm_unreachable("Unknown cancel kind!");
442   }
443 
444   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
445   Value *Ident = getOrCreateIdent(SrcLocStr);
446   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
447   Value *Result = Builder.CreateCall(
448       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
449   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
450     if (CanceledDirective == OMPD_parallel) {
451       IRBuilder<>::InsertPointGuard IPG(Builder);
452       Builder.restoreIP(IP);
453       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
454                     omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
455                     /* CheckCancelFlag */ false);
456     }
457   };
458 
459   // The actual cancel logic is shared with others, e.g., cancel_barriers.
460   emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
461 
462   // Update the insertion point and remove the terminator we introduced.
463   Builder.SetInsertPoint(UI->getParent());
464   UI->eraseFromParent();
465 
466   return Builder.saveIP();
467 }
468 
emitCancelationCheckImpl(Value * CancelFlag,omp::Directive CanceledDirective,FinalizeCallbackTy ExitCB)469 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
470                                                omp::Directive CanceledDirective,
471                                                FinalizeCallbackTy ExitCB) {
472   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
473          "Unexpected cancellation!");
474 
475   // For a cancel barrier we create two new blocks.
476   BasicBlock *BB = Builder.GetInsertBlock();
477   BasicBlock *NonCancellationBlock;
478   if (Builder.GetInsertPoint() == BB->end()) {
479     // TODO: This branch will not be needed once we moved to the
480     // OpenMPIRBuilder codegen completely.
481     NonCancellationBlock = BasicBlock::Create(
482         BB->getContext(), BB->getName() + ".cont", BB->getParent());
483   } else {
484     NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
485     BB->getTerminator()->eraseFromParent();
486     Builder.SetInsertPoint(BB);
487   }
488   BasicBlock *CancellationBlock = BasicBlock::Create(
489       BB->getContext(), BB->getName() + ".cncl", BB->getParent());
490 
491   // Jump to them based on the return value.
492   Value *Cmp = Builder.CreateIsNull(CancelFlag);
493   Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
494                        /* TODO weight */ nullptr, nullptr);
495 
496   // From the cancellation block we finalize all variables and go to the
497   // post finalization block that is known to the FiniCB callback.
498   Builder.SetInsertPoint(CancellationBlock);
499   if (ExitCB)
500     ExitCB(Builder.saveIP());
501   auto &FI = FinalizationStack.back();
502   FI.FiniCB(Builder.saveIP());
503 
504   // The continuation block is where code generation continues.
505   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
506 }
507 
createParallel(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,Value * IfCondition,Value * NumThreads,omp::ProcBindKind ProcBind,bool IsCancellable)508 IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
509     const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
510     BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
511     FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
512     omp::ProcBindKind ProcBind, bool IsCancellable) {
513   if (!updateToLocation(Loc))
514     return Loc.IP;
515 
516   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
517   Value *Ident = getOrCreateIdent(SrcLocStr);
518   Value *ThreadID = getOrCreateThreadID(Ident);
519 
520   if (NumThreads) {
521     // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
522     Value *Args[] = {
523         Ident, ThreadID,
524         Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
525     Builder.CreateCall(
526         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
527   }
528 
529   if (ProcBind != OMP_PROC_BIND_default) {
530     // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
531     Value *Args[] = {
532         Ident, ThreadID,
533         ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
534     Builder.CreateCall(
535         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
536   }
537 
538   BasicBlock *InsertBB = Builder.GetInsertBlock();
539   Function *OuterFn = InsertBB->getParent();
540 
541   // Save the outer alloca block because the insertion iterator may get
542   // invalidated and we still need this later.
543   BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
544 
545   // Vector to remember instructions we used only during the modeling but which
546   // we want to delete at the end.
547   SmallVector<Instruction *, 4> ToBeDeleted;
548 
549   // Change the location to the outer alloca insertion point to create and
550   // initialize the allocas we pass into the parallel region.
551   Builder.restoreIP(OuterAllocaIP);
552   AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
553   AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr");
554 
555   // If there is an if condition we actually use the TIDAddr and ZeroAddr in the
556   // program, otherwise we only need them for modeling purposes to get the
557   // associated arguments in the outlined function. In the former case,
558   // initialize the allocas properly, in the latter case, delete them later.
559   if (IfCondition) {
560     Builder.CreateStore(Constant::getNullValue(Int32), TIDAddr);
561     Builder.CreateStore(Constant::getNullValue(Int32), ZeroAddr);
562   } else {
563     ToBeDeleted.push_back(TIDAddr);
564     ToBeDeleted.push_back(ZeroAddr);
565   }
566 
567   // Create an artificial insertion point that will also ensure the blocks we
568   // are about to split are not degenerated.
569   auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
570 
571   Instruction *ThenTI = UI, *ElseTI = nullptr;
572   if (IfCondition)
573     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
574 
575   BasicBlock *ThenBB = ThenTI->getParent();
576   BasicBlock *PRegEntryBB = ThenBB->splitBasicBlock(ThenTI, "omp.par.entry");
577   BasicBlock *PRegBodyBB =
578       PRegEntryBB->splitBasicBlock(ThenTI, "omp.par.region");
579   BasicBlock *PRegPreFiniBB =
580       PRegBodyBB->splitBasicBlock(ThenTI, "omp.par.pre_finalize");
581   BasicBlock *PRegExitBB =
582       PRegPreFiniBB->splitBasicBlock(ThenTI, "omp.par.exit");
583 
584   auto FiniCBWrapper = [&](InsertPointTy IP) {
585     // Hide "open-ended" blocks from the given FiniCB by setting the right jump
586     // target to the region exit block.
587     if (IP.getBlock()->end() == IP.getPoint()) {
588       IRBuilder<>::InsertPointGuard IPG(Builder);
589       Builder.restoreIP(IP);
590       Instruction *I = Builder.CreateBr(PRegExitBB);
591       IP = InsertPointTy(I->getParent(), I->getIterator());
592     }
593     assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
594            IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
595            "Unexpected insertion point for finalization call!");
596     return FiniCB(IP);
597   };
598 
599   FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
600 
601   // Generate the privatization allocas in the block that will become the entry
602   // of the outlined function.
603   Builder.SetInsertPoint(PRegEntryBB->getTerminator());
604   InsertPointTy InnerAllocaIP = Builder.saveIP();
605 
606   AllocaInst *PrivTIDAddr =
607       Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
608   Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
609 
610   // Add some fake uses for OpenMP provided arguments.
611   ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
612   Instruction *ZeroAddrUse =
613       Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
614   ToBeDeleted.push_back(ZeroAddrUse);
615 
616   // ThenBB
617   //   |
618   //   V
619   // PRegionEntryBB         <- Privatization allocas are placed here.
620   //   |
621   //   V
622   // PRegionBodyBB          <- BodeGen is invoked here.
623   //   |
624   //   V
625   // PRegPreFiniBB          <- The block we will start finalization from.
626   //   |
627   //   V
628   // PRegionExitBB          <- A common exit to simplify block collection.
629   //
630 
631   LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
632 
633   // Let the caller create the body.
634   assert(BodyGenCB && "Expected body generation callback!");
635   InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
636   BodyGenCB(InnerAllocaIP, CodeGenIP, *PRegPreFiniBB);
637 
638   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
639 
640   FunctionCallee RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
641   if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
642     if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
643       llvm::LLVMContext &Ctx = F->getContext();
644       MDBuilder MDB(Ctx);
645       // Annotate the callback behavior of the __kmpc_fork_call:
646       //  - The callback callee is argument number 2 (microtask).
647       //  - The first two arguments of the callback callee are unknown (-1).
648       //  - All variadic arguments to the __kmpc_fork_call are passed to the
649       //    callback callee.
650       F->addMetadata(
651           llvm::LLVMContext::MD_callback,
652           *llvm::MDNode::get(
653               Ctx, {MDB.createCallbackEncoding(2, {-1, -1},
654                                                /* VarArgsArePassed */ true)}));
655     }
656   }
657 
658   OutlineInfo OI;
659   OI.PostOutlineCB = [=](Function &OutlinedFn) {
660     // Add some known attributes.
661     OutlinedFn.addParamAttr(0, Attribute::NoAlias);
662     OutlinedFn.addParamAttr(1, Attribute::NoAlias);
663     OutlinedFn.addFnAttr(Attribute::NoUnwind);
664     OutlinedFn.addFnAttr(Attribute::NoRecurse);
665 
666     assert(OutlinedFn.arg_size() >= 2 &&
667            "Expected at least tid and bounded tid as arguments");
668     unsigned NumCapturedVars =
669         OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
670 
671     CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
672     CI->getParent()->setName("omp_parallel");
673     Builder.SetInsertPoint(CI);
674 
675     // Build call __kmpc_fork_call(Ident, n, microtask, var1, .., varn);
676     Value *ForkCallArgs[] = {
677         Ident, Builder.getInt32(NumCapturedVars),
678         Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};
679 
680     SmallVector<Value *, 16> RealArgs;
681     RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
682     RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
683 
684     Builder.CreateCall(RTLFn, RealArgs);
685 
686     LLVM_DEBUG(dbgs() << "With fork_call placed: "
687                       << *Builder.GetInsertBlock()->getParent() << "\n");
688 
689     InsertPointTy ExitIP(PRegExitBB, PRegExitBB->end());
690 
691     // Initialize the local TID stack location with the argument value.
692     Builder.SetInsertPoint(PrivTID);
693     Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
694     Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);
695 
696     // If no "if" clause was present we do not need the call created during
697     // outlining, otherwise we reuse it in the serialized parallel region.
698     if (!ElseTI) {
699       CI->eraseFromParent();
700     } else {
701 
702       // If an "if" clause was present we are now generating the serialized
703       // version into the "else" branch.
704       Builder.SetInsertPoint(ElseTI);
705 
706       // Build calls __kmpc_serialized_parallel(&Ident, GTid);
707       Value *SerializedParallelCallArgs[] = {Ident, ThreadID};
708       Builder.CreateCall(
709           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_serialized_parallel),
710           SerializedParallelCallArgs);
711 
712       // OutlinedFn(&GTid, &zero, CapturedStruct);
713       CI->removeFromParent();
714       Builder.Insert(CI);
715 
716       // __kmpc_end_serialized_parallel(&Ident, GTid);
717       Value *EndArgs[] = {Ident, ThreadID};
718       Builder.CreateCall(
719           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_serialized_parallel),
720           EndArgs);
721 
722       LLVM_DEBUG(dbgs() << "With serialized parallel region: "
723                         << *Builder.GetInsertBlock()->getParent() << "\n");
724     }
725 
726     for (Instruction *I : ToBeDeleted)
727       I->eraseFromParent();
728   };
729 
730   // Adjust the finalization stack, verify the adjustment, and call the
731   // finalize function a last time to finalize values between the pre-fini
732   // block and the exit block if we left the parallel "the normal way".
733   auto FiniInfo = FinalizationStack.pop_back_val();
734   (void)FiniInfo;
735   assert(FiniInfo.DK == OMPD_parallel &&
736          "Unexpected finalization stack state!");
737 
738   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
739 
740   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
741   FiniCB(PreFiniIP);
742 
743   OI.EntryBB = PRegEntryBB;
744   OI.ExitBB = PRegExitBB;
745 
746   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
747   SmallVector<BasicBlock *, 32> Blocks;
748   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
749 
750   // Ensure a single exit node for the outlined region by creating one.
751   // We might have multiple incoming edges to the exit now due to finalizations,
752   // e.g., cancel calls that cause the control flow to leave the region.
753   BasicBlock *PRegOutlinedExitBB = PRegExitBB;
754   PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
755   PRegOutlinedExitBB->setName("omp.par.outlined.exit");
756   Blocks.push_back(PRegOutlinedExitBB);
757 
758   CodeExtractorAnalysisCache CEAC(*OuterFn);
759   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
760                           /* AggregateArgs */ false,
761                           /* BlockFrequencyInfo */ nullptr,
762                           /* BranchProbabilityInfo */ nullptr,
763                           /* AssumptionCache */ nullptr,
764                           /* AllowVarArgs */ true,
765                           /* AllowAlloca */ true,
766                           /* Suffix */ ".omp_par");
767 
768   // Find inputs to, outputs from the code region.
769   BasicBlock *CommonExit = nullptr;
770   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
771   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
772   Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
773 
774   LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
775 
776   FunctionCallee TIDRTLFn =
777       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
778 
779   auto PrivHelper = [&](Value &V) {
780     if (&V == TIDAddr || &V == ZeroAddr)
781       return;
782 
783     SetVector<Use *> Uses;
784     for (Use &U : V.uses())
785       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
786         if (ParallelRegionBlockSet.count(UserI->getParent()))
787           Uses.insert(&U);
788 
789     // __kmpc_fork_call expects extra arguments as pointers. If the input
790     // already has a pointer type, everything is fine. Otherwise, store the
791     // value onto stack and load it back inside the to-be-outlined region. This
792     // will ensure only the pointer will be passed to the function.
793     // FIXME: if there are more than 15 trailing arguments, they must be
794     // additionally packed in a struct.
795     Value *Inner = &V;
796     if (!V.getType()->isPointerTy()) {
797       IRBuilder<>::InsertPointGuard Guard(Builder);
798       LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
799 
800       Builder.restoreIP(OuterAllocaIP);
801       Value *Ptr =
802           Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
803 
804       // Store to stack at end of the block that currently branches to the entry
805       // block of the to-be-outlined region.
806       Builder.SetInsertPoint(InsertBB,
807                              InsertBB->getTerminator()->getIterator());
808       Builder.CreateStore(&V, Ptr);
809 
810       // Load back next to allocations in the to-be-outlined region.
811       Builder.restoreIP(InnerAllocaIP);
812       Inner = Builder.CreateLoad(V.getType(), Ptr);
813     }
814 
815     Value *ReplacementValue = nullptr;
816     CallInst *CI = dyn_cast<CallInst>(&V);
817     if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
818       ReplacementValue = PrivTID;
819     } else {
820       Builder.restoreIP(
821           PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
822       assert(ReplacementValue &&
823              "Expected copy/create callback to set replacement value!");
824       if (ReplacementValue == &V)
825         return;
826     }
827 
828     for (Use *UPtr : Uses)
829       UPtr->set(ReplacementValue);
830   };
831 
832   // Reset the inner alloca insertion as it will be used for loading the values
833   // wrapped into pointers before passing them into the to-be-outlined region.
834   // Configure it to insert immediately after the fake use of zero address so
835   // that they are available in the generated body and so that the
836   // OpenMP-related values (thread ID and zero address pointers) remain leading
837   // in the argument list.
838   InnerAllocaIP = IRBuilder<>::InsertPoint(
839       ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
840 
841   // Reset the outer alloca insertion point to the entry of the relevant block
842   // in case it was invalidated.
843   OuterAllocaIP = IRBuilder<>::InsertPoint(
844       OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
845 
846   for (Value *Input : Inputs) {
847     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
848     PrivHelper(*Input);
849   }
850   LLVM_DEBUG({
851     for (Value *Output : Outputs)
852       LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
853   });
854   assert(Outputs.empty() &&
855          "OpenMP outlining should not produce live-out values!");
856 
857   LLVM_DEBUG(dbgs() << "After  privatization: " << *OuterFn << "\n");
858   LLVM_DEBUG({
859     for (auto *BB : Blocks)
860       dbgs() << " PBR: " << BB->getName() << "\n";
861   });
862 
863   // Register the outlined info.
864   addOutlineInfo(std::move(OI));
865 
866   InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
867   UI->eraseFromParent();
868 
869   return AfterIP;
870 }
871 
emitFlush(const LocationDescription & Loc)872 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
873   // Build call void __kmpc_flush(ident_t *loc)
874   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
875   Value *Args[] = {getOrCreateIdent(SrcLocStr)};
876 
877   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
878 }
879 
createFlush(const LocationDescription & Loc)880 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
881   if (!updateToLocation(Loc))
882     return;
883   emitFlush(Loc);
884 }
885 
emitTaskwaitImpl(const LocationDescription & Loc)886 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
887   // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
888   // global_tid);
889   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
890   Value *Ident = getOrCreateIdent(SrcLocStr);
891   Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
892 
893   // Ignore return result until untied tasks are supported.
894   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
895                      Args);
896 }
897 
createTaskwait(const LocationDescription & Loc)898 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
899   if (!updateToLocation(Loc))
900     return;
901   emitTaskwaitImpl(Loc);
902 }
903 
emitTaskyieldImpl(const LocationDescription & Loc)904 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
905   // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
906   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
907   Value *Ident = getOrCreateIdent(SrcLocStr);
908   Constant *I32Null = ConstantInt::getNullValue(Int32);
909   Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
910 
911   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
912                      Args);
913 }
914 
createTaskyield(const LocationDescription & Loc)915 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
916   if (!updateToLocation(Loc))
917     return;
918   emitTaskyieldImpl(Loc);
919 }
920 
createSections(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<StorableBodyGenCallbackTy> SectionCBs,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,bool IsCancellable,bool IsNowait)921 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
922     const LocationDescription &Loc, InsertPointTy AllocaIP,
923     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
924     FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
925   if (!updateToLocation(Loc))
926     return Loc.IP;
927 
928   auto FiniCBWrapper = [&](InsertPointTy IP) {
929     if (IP.getBlock()->end() != IP.getPoint())
930       return FiniCB(IP);
931     // This must be done otherwise any nested constructs using FinalizeOMPRegion
932     // will fail because that function requires the Finalization Basic Block to
933     // have a terminator, which is already removed by EmitOMPRegionBody.
934     // IP is currently at cancelation block.
935     // We need to backtrack to the condition block to fetch
936     // the exit block and create a branch from cancelation
937     // to exit block.
938     IRBuilder<>::InsertPointGuard IPG(Builder);
939     Builder.restoreIP(IP);
940     auto *CaseBB = IP.getBlock()->getSinglePredecessor();
941     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
942     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
943     Instruction *I = Builder.CreateBr(ExitBB);
944     IP = InsertPointTy(I->getParent(), I->getIterator());
945     return FiniCB(IP);
946   };
947 
948   FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
949 
950   // Each section is emitted as a switch case
951   // Each finalization callback is handled from clang.EmitOMPSectionDirective()
952   // -> OMP.createSection() which generates the IR for each section
953   // Iterate through all sections and emit a switch construct:
954   // switch (IV) {
955   //   case 0:
956   //     <SectionStmt[0]>;
957   //     break;
958   // ...
959   //   case <NumSection> - 1:
960   //     <SectionStmt[<NumSection> - 1]>;
961   //     break;
962   // }
963   // ...
964   // section_loop.after:
965   // <FiniCB>;
966   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
967     auto *CurFn = CodeGenIP.getBlock()->getParent();
968     auto *ForIncBB = CodeGenIP.getBlock()->getSingleSuccessor();
969     auto *ForExitBB = CodeGenIP.getBlock()
970                           ->getSinglePredecessor()
971                           ->getTerminator()
972                           ->getSuccessor(1);
973     SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, ForIncBB);
974     Builder.restoreIP(CodeGenIP);
975     unsigned CaseNumber = 0;
976     for (auto SectionCB : SectionCBs) {
977       auto *CaseBB = BasicBlock::Create(M.getContext(),
978                                         "omp_section_loop.body.case", CurFn);
979       SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
980       Builder.SetInsertPoint(CaseBB);
981       SectionCB(InsertPointTy(), Builder.saveIP(), *ForExitBB);
982       CaseNumber++;
983     }
984     // remove the existing terminator from body BB since there can be no
985     // terminators after switch/case
986     CodeGenIP.getBlock()->getTerminator()->eraseFromParent();
987   };
988   // Loop body ends here
989   // LowerBound, UpperBound, and STride for createCanonicalLoop
990   Type *I32Ty = Type::getInt32Ty(M.getContext());
991   Value *LB = ConstantInt::get(I32Ty, 0);
992   Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
993   Value *ST = ConstantInt::get(I32Ty, 1);
994   llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
995       Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
996   InsertPointTy AfterIP =
997       applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, true);
998   BasicBlock *LoopAfterBB = AfterIP.getBlock();
999   Instruction *SplitPos = LoopAfterBB->getTerminator();
1000   if (!isa_and_nonnull<BranchInst>(SplitPos))
1001     SplitPos = new UnreachableInst(Builder.getContext(), LoopAfterBB);
1002   // ExitBB after LoopAfterBB because LoopAfterBB is used for FinalizationCB,
1003   // which requires a BB with branch
1004   BasicBlock *ExitBB =
1005       LoopAfterBB->splitBasicBlock(SplitPos, "omp_sections.end");
1006   SplitPos->eraseFromParent();
1007 
1008   // Apply the finalization callback in LoopAfterBB
1009   auto FiniInfo = FinalizationStack.pop_back_val();
1010   assert(FiniInfo.DK == OMPD_sections &&
1011          "Unexpected finalization stack state!");
1012   Builder.SetInsertPoint(LoopAfterBB->getTerminator());
1013   FiniInfo.FiniCB(Builder.saveIP());
1014   Builder.SetInsertPoint(ExitBB);
1015 
1016   return Builder.saveIP();
1017 }
1018 
1019 OpenMPIRBuilder::InsertPointTy
createSection(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)1020 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
1021                                BodyGenCallbackTy BodyGenCB,
1022                                FinalizeCallbackTy FiniCB) {
1023   if (!updateToLocation(Loc))
1024     return Loc.IP;
1025 
1026   auto FiniCBWrapper = [&](InsertPointTy IP) {
1027     if (IP.getBlock()->end() != IP.getPoint())
1028       return FiniCB(IP);
1029     // This must be done otherwise any nested constructs using FinalizeOMPRegion
1030     // will fail because that function requires the Finalization Basic Block to
1031     // have a terminator, which is already removed by EmitOMPRegionBody.
1032     // IP is currently at cancelation block.
1033     // We need to backtrack to the condition block to fetch
1034     // the exit block and create a branch from cancelation
1035     // to exit block.
1036     IRBuilder<>::InsertPointGuard IPG(Builder);
1037     Builder.restoreIP(IP);
1038     auto *CaseBB = Loc.IP.getBlock();
1039     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1040     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1041     Instruction *I = Builder.CreateBr(ExitBB);
1042     IP = InsertPointTy(I->getParent(), I->getIterator());
1043     return FiniCB(IP);
1044   };
1045 
1046   Directive OMPD = Directive::OMPD_sections;
1047   // Since we are using Finalization Callback here, HasFinalize
1048   // and IsCancellable have to be true
1049   return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
1050                               /*Conditional*/ false, /*hasFinalize*/ true,
1051                               /*IsCancellable*/ true);
1052 }
1053 
1054 /// Create a function with a unique name and a "void (i8*, i8*)" signature in
1055 /// the given module and return it.
getFreshReductionFunc(Module & M)1056 Function *getFreshReductionFunc(Module &M) {
1057   Type *VoidTy = Type::getVoidTy(M.getContext());
1058   Type *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
1059   auto *FuncTy =
1060       FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
1061   return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
1062                           M.getDataLayout().getDefaultGlobalsAddressSpace(),
1063                           ".omp.reduction.func", &M);
1064 }
1065 
createReductions(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<ReductionInfo> ReductionInfos,bool IsNoWait)1066 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
1067     const LocationDescription &Loc, InsertPointTy AllocaIP,
1068     ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
1069   for (const ReductionInfo &RI : ReductionInfos) {
1070     (void)RI;
1071     assert(RI.Variable && "expected non-null variable");
1072     assert(RI.PrivateVariable && "expected non-null private variable");
1073     assert(RI.ReductionGen && "expected non-null reduction generator callback");
1074     assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
1075            "expected variables and their private equivalents to have the same "
1076            "type");
1077     assert(RI.Variable->getType()->isPointerTy() &&
1078            "expected variables to be pointers");
1079   }
1080 
1081   if (!updateToLocation(Loc))
1082     return InsertPointTy();
1083 
1084   BasicBlock *InsertBlock = Loc.IP.getBlock();
1085   BasicBlock *ContinuationBlock =
1086       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
1087   InsertBlock->getTerminator()->eraseFromParent();
1088 
1089   // Create and populate array of type-erased pointers to private reduction
1090   // values.
1091   unsigned NumReductions = ReductionInfos.size();
1092   Type *RedArrayTy = ArrayType::get(Builder.getInt8PtrTy(), NumReductions);
1093   Builder.restoreIP(AllocaIP);
1094   Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
1095 
1096   Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
1097 
1098   for (auto En : enumerate(ReductionInfos)) {
1099     unsigned Index = En.index();
1100     const ReductionInfo &RI = En.value();
1101     Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
1102         RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
1103     Value *Casted =
1104         Builder.CreateBitCast(RI.PrivateVariable, Builder.getInt8PtrTy(),
1105                               "private.red.var." + Twine(Index) + ".casted");
1106     Builder.CreateStore(Casted, RedArrayElemPtr);
1107   }
1108 
1109   // Emit a call to the runtime function that orchestrates the reduction.
1110   // Declare the reduction function in the process.
1111   Function *Func = Builder.GetInsertBlock()->getParent();
1112   Module *Module = Func->getParent();
1113   Value *RedArrayPtr =
1114       Builder.CreateBitCast(RedArray, Builder.getInt8PtrTy(), "red.array.ptr");
1115   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
1116   bool CanGenerateAtomic =
1117       llvm::all_of(ReductionInfos, [](const ReductionInfo &RI) {
1118         return RI.AtomicReductionGen;
1119       });
1120   Value *Ident = getOrCreateIdent(
1121       SrcLocStr, CanGenerateAtomic ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
1122                                    : IdentFlag(0));
1123   Value *ThreadId = getOrCreateThreadID(Ident);
1124   Constant *NumVariables = Builder.getInt32(NumReductions);
1125   const DataLayout &DL = Module->getDataLayout();
1126   unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
1127   Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
1128   Function *ReductionFunc = getFreshReductionFunc(*Module);
1129   Value *Lock = getOMPCriticalRegionLock(".reduction");
1130   Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
1131       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
1132                : RuntimeFunction::OMPRTL___kmpc_reduce);
1133   CallInst *ReduceCall =
1134       Builder.CreateCall(ReduceFunc,
1135                          {Ident, ThreadId, NumVariables, RedArraySize,
1136                           RedArrayPtr, ReductionFunc, Lock},
1137                          "reduce");
1138 
1139   // Create final reduction entry blocks for the atomic and non-atomic case.
1140   // Emit IR that dispatches control flow to one of the blocks based on the
1141   // reduction supporting the atomic mode.
1142   BasicBlock *NonAtomicRedBlock =
1143       BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
1144   BasicBlock *AtomicRedBlock =
1145       BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
1146   SwitchInst *Switch =
1147       Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
1148   Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
1149   Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
1150 
1151   // Populate the non-atomic reduction using the elementwise reduction function.
1152   // This loads the elements from the global and private variables and reduces
1153   // them before storing back the result to the global variable.
1154   Builder.SetInsertPoint(NonAtomicRedBlock);
1155   for (auto En : enumerate(ReductionInfos)) {
1156     const ReductionInfo &RI = En.value();
1157     Type *ValueType = RI.getElementType();
1158     Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
1159                                          "red.value." + Twine(En.index()));
1160     Value *PrivateRedValue =
1161         Builder.CreateLoad(ValueType, RI.PrivateVariable,
1162                            "red.private.value." + Twine(En.index()));
1163     Value *Reduced;
1164     Builder.restoreIP(
1165         RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced));
1166     if (!Builder.GetInsertBlock())
1167       return InsertPointTy();
1168     Builder.CreateStore(Reduced, RI.Variable);
1169   }
1170   Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
1171       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
1172                : RuntimeFunction::OMPRTL___kmpc_end_reduce);
1173   Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
1174   Builder.CreateBr(ContinuationBlock);
1175 
1176   // Populate the atomic reduction using the atomic elementwise reduction
1177   // function. There are no loads/stores here because they will be happening
1178   // inside the atomic elementwise reduction.
1179   Builder.SetInsertPoint(AtomicRedBlock);
1180   if (CanGenerateAtomic) {
1181     for (const ReductionInfo &RI : ReductionInfos) {
1182       Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.Variable,
1183                                               RI.PrivateVariable));
1184       if (!Builder.GetInsertBlock())
1185         return InsertPointTy();
1186     }
1187     Builder.CreateBr(ContinuationBlock);
1188   } else {
1189     Builder.CreateUnreachable();
1190   }
1191 
1192   // Populate the outlined reduction function using the elementwise reduction
1193   // function. Partial values are extracted from the type-erased array of
1194   // pointers to private variables.
1195   BasicBlock *ReductionFuncBlock =
1196       BasicBlock::Create(Module->getContext(), "", ReductionFunc);
1197   Builder.SetInsertPoint(ReductionFuncBlock);
1198   Value *LHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(0),
1199                                              RedArrayTy->getPointerTo());
1200   Value *RHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(1),
1201                                              RedArrayTy->getPointerTo());
1202   for (auto En : enumerate(ReductionInfos)) {
1203     const ReductionInfo &RI = En.value();
1204     Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1205         RedArrayTy, LHSArrayPtr, 0, En.index());
1206     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr);
1207     Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
1208     Value *LHS = Builder.CreateLoad(RI.getElementType(), LHSPtr);
1209     Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1210         RedArrayTy, RHSArrayPtr, 0, En.index());
1211     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr);
1212     Value *RHSPtr =
1213         Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
1214     Value *RHS = Builder.CreateLoad(RI.getElementType(), RHSPtr);
1215     Value *Reduced;
1216     Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
1217     if (!Builder.GetInsertBlock())
1218       return InsertPointTy();
1219     Builder.CreateStore(Reduced, LHSPtr);
1220   }
1221   Builder.CreateRetVoid();
1222 
1223   Builder.SetInsertPoint(ContinuationBlock);
1224   return Builder.saveIP();
1225 }
1226 
1227 OpenMPIRBuilder::InsertPointTy
createMaster(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)1228 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
1229                               BodyGenCallbackTy BodyGenCB,
1230                               FinalizeCallbackTy FiniCB) {
1231 
1232   if (!updateToLocation(Loc))
1233     return Loc.IP;
1234 
1235   Directive OMPD = Directive::OMPD_master;
1236   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
1237   Value *Ident = getOrCreateIdent(SrcLocStr);
1238   Value *ThreadId = getOrCreateThreadID(Ident);
1239   Value *Args[] = {Ident, ThreadId};
1240 
1241   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
1242   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
1243 
1244   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
1245   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
1246 
1247   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
1248                               /*Conditional*/ true, /*hasFinalize*/ true);
1249 }
1250 
1251 OpenMPIRBuilder::InsertPointTy
createMasked(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,Value * Filter)1252 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
1253                               BodyGenCallbackTy BodyGenCB,
1254                               FinalizeCallbackTy FiniCB, Value *Filter) {
1255   if (!updateToLocation(Loc))
1256     return Loc.IP;
1257 
1258   Directive OMPD = Directive::OMPD_masked;
1259   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
1260   Value *Ident = getOrCreateIdent(SrcLocStr);
1261   Value *ThreadId = getOrCreateThreadID(Ident);
1262   Value *Args[] = {Ident, ThreadId, Filter};
1263   Value *ArgsEnd[] = {Ident, ThreadId};
1264 
1265   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
1266   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
1267 
1268   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
1269   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
1270 
1271   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
1272                               /*Conditional*/ true, /*hasFinalize*/ true);
1273 }
1274 
createLoopSkeleton(DebugLoc DL,Value * TripCount,Function * F,BasicBlock * PreInsertBefore,BasicBlock * PostInsertBefore,const Twine & Name)1275 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
1276     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
1277     BasicBlock *PostInsertBefore, const Twine &Name) {
1278   Module *M = F->getParent();
1279   LLVMContext &Ctx = M->getContext();
1280   Type *IndVarTy = TripCount->getType();
1281 
1282   // Create the basic block structure.
1283   BasicBlock *Preheader =
1284       BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
1285   BasicBlock *Header =
1286       BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
1287   BasicBlock *Cond =
1288       BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
1289   BasicBlock *Body =
1290       BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
1291   BasicBlock *Latch =
1292       BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
1293   BasicBlock *Exit =
1294       BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
1295   BasicBlock *After =
1296       BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
1297 
1298   // Use specified DebugLoc for new instructions.
1299   Builder.SetCurrentDebugLocation(DL);
1300 
1301   Builder.SetInsertPoint(Preheader);
1302   Builder.CreateBr(Header);
1303 
1304   Builder.SetInsertPoint(Header);
1305   PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
1306   IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
1307   Builder.CreateBr(Cond);
1308 
1309   Builder.SetInsertPoint(Cond);
1310   Value *Cmp =
1311       Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
1312   Builder.CreateCondBr(Cmp, Body, Exit);
1313 
1314   Builder.SetInsertPoint(Body);
1315   Builder.CreateBr(Latch);
1316 
1317   Builder.SetInsertPoint(Latch);
1318   Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
1319                                   "omp_" + Name + ".next", /*HasNUW=*/true);
1320   Builder.CreateBr(Header);
1321   IndVarPHI->addIncoming(Next, Latch);
1322 
1323   Builder.SetInsertPoint(Exit);
1324   Builder.CreateBr(After);
1325 
1326   // Remember and return the canonical control flow.
1327   LoopInfos.emplace_front();
1328   CanonicalLoopInfo *CL = &LoopInfos.front();
1329 
1330   CL->Preheader = Preheader;
1331   CL->Header = Header;
1332   CL->Cond = Cond;
1333   CL->Body = Body;
1334   CL->Latch = Latch;
1335   CL->Exit = Exit;
1336   CL->After = After;
1337 
1338 #ifndef NDEBUG
1339   CL->assertOK();
1340 #endif
1341   return CL;
1342 }
1343 
1344 CanonicalLoopInfo *
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * TripCount,const Twine & Name)1345 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
1346                                      LoopBodyGenCallbackTy BodyGenCB,
1347                                      Value *TripCount, const Twine &Name) {
1348   BasicBlock *BB = Loc.IP.getBlock();
1349   BasicBlock *NextBB = BB->getNextNode();
1350 
1351   CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
1352                                              NextBB, NextBB, Name);
1353   BasicBlock *After = CL->getAfter();
1354 
1355   // If location is not set, don't connect the loop.
1356   if (updateToLocation(Loc)) {
1357     // Split the loop at the insertion point: Branch to the preheader and move
1358     // every following instruction to after the loop (the After BB). Also, the
1359     // new successor is the loop's after block.
1360     Builder.CreateBr(CL->Preheader);
1361     After->getInstList().splice(After->begin(), BB->getInstList(),
1362                                 Builder.GetInsertPoint(), BB->end());
1363     After->replaceSuccessorsPhiUsesWith(BB, After);
1364   }
1365 
1366   // Emit the body content. We do it after connecting the loop to the CFG to
1367   // avoid that the callback encounters degenerate BBs.
1368   BodyGenCB(CL->getBodyIP(), CL->getIndVar());
1369 
1370 #ifndef NDEBUG
1371   CL->assertOK();
1372 #endif
1373   return CL;
1374 }
1375 
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,InsertPointTy ComputeIP,const Twine & Name)1376 CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
1377     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
1378     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
1379     InsertPointTy ComputeIP, const Twine &Name) {
1380 
1381   // Consider the following difficulties (assuming 8-bit signed integers):
1382   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
1383   //      DO I = 1, 100, 50
1384   ///  * A \p Step of INT_MIN cannot not be normalized to a positive direction:
1385   //      DO I = 100, 0, -128
1386 
1387   // Start, Stop and Step must be of the same integer type.
1388   auto *IndVarTy = cast<IntegerType>(Start->getType());
1389   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
1390   assert(IndVarTy == Step->getType() && "Step type mismatch");
1391 
1392   LocationDescription ComputeLoc =
1393       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
1394   updateToLocation(ComputeLoc);
1395 
1396   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
1397   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
1398 
1399   // Like Step, but always positive.
1400   Value *Incr = Step;
1401 
1402   // Distance between Start and Stop; always positive.
1403   Value *Span;
1404 
1405   // Condition whether there are no iterations are executed at all, e.g. because
1406   // UB < LB.
1407   Value *ZeroCmp;
1408 
1409   if (IsSigned) {
1410     // Ensure that increment is positive. If not, negate and invert LB and UB.
1411     Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
1412     Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
1413     Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
1414     Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
1415     Span = Builder.CreateSub(UB, LB, "", false, true);
1416     ZeroCmp = Builder.CreateICmp(
1417         InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
1418   } else {
1419     Span = Builder.CreateSub(Stop, Start, "", true);
1420     ZeroCmp = Builder.CreateICmp(
1421         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
1422   }
1423 
1424   Value *CountIfLooping;
1425   if (InclusiveStop) {
1426     CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
1427   } else {
1428     // Avoid incrementing past stop since it could overflow.
1429     Value *CountIfTwo = Builder.CreateAdd(
1430         Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
1431     Value *OneCmp = Builder.CreateICmp(
1432         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Span, Incr);
1433     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
1434   }
1435   Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
1436                                           "omp_" + Name + ".tripcount");
1437 
1438   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
1439     Builder.restoreIP(CodeGenIP);
1440     Value *Span = Builder.CreateMul(IV, Step);
1441     Value *IndVar = Builder.CreateAdd(Span, Start);
1442     BodyGenCB(Builder.saveIP(), IndVar);
1443   };
1444   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
1445   return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
1446 }
1447 
1448 // Returns an LLVM function to call for initializing loop bounds using OpenMP
1449 // static scheduling depending on `type`. Only i32 and i64 are supported by the
1450 // runtime. Always interpret integers as unsigned similarly to
1451 // CanonicalLoopInfo.
getKmpcForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)1452 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
1453                                                   OpenMPIRBuilder &OMPBuilder) {
1454   unsigned Bitwidth = Ty->getIntegerBitWidth();
1455   if (Bitwidth == 32)
1456     return OMPBuilder.getOrCreateRuntimeFunction(
1457         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
1458   if (Bitwidth == 64)
1459     return OMPBuilder.getOrCreateRuntimeFunction(
1460         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
1461   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
1462 }
1463 
1464 // Sets the number of loop iterations to the given value. This value must be
1465 // valid in the condition block (i.e., defined in the preheader) and is
1466 // interpreted as an unsigned integer.
setCanonicalLoopTripCount(CanonicalLoopInfo * CLI,Value * TripCount)1467 void setCanonicalLoopTripCount(CanonicalLoopInfo *CLI, Value *TripCount) {
1468   Instruction *CmpI = &CLI->getCond()->front();
1469   assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
1470   CmpI->setOperand(1, TripCount);
1471   CLI->assertOK();
1472 }
1473 
1474 OpenMPIRBuilder::InsertPointTy
applyStaticWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,Value * Chunk)1475 OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
1476                                           InsertPointTy AllocaIP,
1477                                           bool NeedsBarrier, Value *Chunk) {
1478   assert(CLI->isValid() && "Requires a valid canonical loop");
1479 
1480   // Set up the source location value for OpenMP runtime.
1481   Builder.restoreIP(CLI->getPreheaderIP());
1482   Builder.SetCurrentDebugLocation(DL);
1483 
1484   Constant *SrcLocStr = getOrCreateSrcLocStr(DL);
1485   Value *SrcLoc = getOrCreateIdent(SrcLocStr);
1486 
1487   // Declare useful OpenMP runtime functions.
1488   Value *IV = CLI->getIndVar();
1489   Type *IVTy = IV->getType();
1490   FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
1491   FunctionCallee StaticFini =
1492       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
1493 
1494   // Allocate space for computed loop bounds as expected by the "init" function.
1495   Builder.restoreIP(AllocaIP);
1496   Type *I32Type = Type::getInt32Ty(M.getContext());
1497   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
1498   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
1499   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
1500   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
1501 
1502   // At the end of the preheader, prepare for calling the "init" function by
1503   // storing the current loop bounds into the allocated space. A canonical loop
1504   // always iterates from 0 to trip-count with step 1. Note that "init" expects
1505   // and produces an inclusive upper bound.
1506   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
1507   Constant *Zero = ConstantInt::get(IVTy, 0);
1508   Constant *One = ConstantInt::get(IVTy, 1);
1509   Builder.CreateStore(Zero, PLowerBound);
1510   Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
1511   Builder.CreateStore(UpperBound, PUpperBound);
1512   Builder.CreateStore(One, PStride);
1513 
1514   // FIXME: schedule(static) is NOT the same as schedule(static,1)
1515   if (!Chunk)
1516     Chunk = One;
1517 
1518   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
1519 
1520   Constant *SchedulingType =
1521       ConstantInt::get(I32Type, static_cast<int>(OMPScheduleType::Static));
1522 
1523   // Call the "init" function and update the trip count of the loop with the
1524   // value it produced.
1525   Builder.CreateCall(StaticInit,
1526                      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
1527                       PUpperBound, PStride, One, Chunk});
1528   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
1529   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
1530   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
1531   Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
1532   setCanonicalLoopTripCount(CLI, TripCount);
1533 
1534   // Update all uses of the induction variable except the one in the condition
1535   // block that compares it with the actual upper bound, and the increment in
1536   // the latch block.
1537   // TODO: this can eventually move to CanonicalLoopInfo or to a new
1538   // CanonicalLoopInfoUpdater interface.
1539   Builder.SetInsertPoint(CLI->getBody(), CLI->getBody()->getFirstInsertionPt());
1540   Value *UpdatedIV = Builder.CreateAdd(IV, LowerBound);
1541   IV->replaceUsesWithIf(UpdatedIV, [&](Use &U) {
1542     auto *Instr = dyn_cast<Instruction>(U.getUser());
1543     return !Instr ||
1544            (Instr->getParent() != CLI->getCond() &&
1545             Instr->getParent() != CLI->getLatch() && Instr != UpdatedIV);
1546   });
1547 
1548   // In the "exit" block, call the "fini" function.
1549   Builder.SetInsertPoint(CLI->getExit(),
1550                          CLI->getExit()->getTerminator()->getIterator());
1551   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
1552 
1553   // Add the barrier if requested.
1554   if (NeedsBarrier)
1555     createBarrier(LocationDescription(Builder.saveIP(), DL),
1556                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
1557                   /* CheckCancelFlag */ false);
1558 
1559   InsertPointTy AfterIP = CLI->getAfterIP();
1560   CLI->invalidate();
1561 
1562   return AfterIP;
1563 }
1564 
1565 OpenMPIRBuilder::InsertPointTy
applyWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier)1566 OpenMPIRBuilder::applyWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
1567                                     InsertPointTy AllocaIP, bool NeedsBarrier) {
1568   // Currently only supports static schedules.
1569   return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
1570 }
1571 
1572 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
1573 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
1574 /// the runtime. Always interpret integers as unsigned similarly to
1575 /// CanonicalLoopInfo.
1576 static FunctionCallee
getKmpcForDynamicInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)1577 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
1578   unsigned Bitwidth = Ty->getIntegerBitWidth();
1579   if (Bitwidth == 32)
1580     return OMPBuilder.getOrCreateRuntimeFunction(
1581         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
1582   if (Bitwidth == 64)
1583     return OMPBuilder.getOrCreateRuntimeFunction(
1584         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
1585   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
1586 }
1587 
1588 /// Returns an LLVM function to call for updating the next loop using OpenMP
1589 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
1590 /// the runtime. Always interpret integers as unsigned similarly to
1591 /// CanonicalLoopInfo.
1592 static FunctionCallee
getKmpcForDynamicNextForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)1593 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
1594   unsigned Bitwidth = Ty->getIntegerBitWidth();
1595   if (Bitwidth == 32)
1596     return OMPBuilder.getOrCreateRuntimeFunction(
1597         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
1598   if (Bitwidth == 64)
1599     return OMPBuilder.getOrCreateRuntimeFunction(
1600         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
1601   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
1602 }
1603 
applyDynamicWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,OMPScheduleType SchedType,bool NeedsBarrier,Value * Chunk)1604 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
1605     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
1606     OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
1607   assert(CLI->isValid() && "Requires a valid canonical loop");
1608 
1609   // Set up the source location value for OpenMP runtime.
1610   Builder.SetCurrentDebugLocation(DL);
1611 
1612   Constant *SrcLocStr = getOrCreateSrcLocStr(DL);
1613   Value *SrcLoc = getOrCreateIdent(SrcLocStr);
1614 
1615   // Declare useful OpenMP runtime functions.
1616   Value *IV = CLI->getIndVar();
1617   Type *IVTy = IV->getType();
1618   FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
1619   FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
1620 
1621   // Allocate space for computed loop bounds as expected by the "init" function.
1622   Builder.restoreIP(AllocaIP);
1623   Type *I32Type = Type::getInt32Ty(M.getContext());
1624   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
1625   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
1626   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
1627   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
1628 
1629   // At the end of the preheader, prepare for calling the "init" function by
1630   // storing the current loop bounds into the allocated space. A canonical loop
1631   // always iterates from 0 to trip-count with step 1. Note that "init" expects
1632   // and produces an inclusive upper bound.
1633   BasicBlock *PreHeader = CLI->getPreheader();
1634   Builder.SetInsertPoint(PreHeader->getTerminator());
1635   Constant *One = ConstantInt::get(IVTy, 1);
1636   Builder.CreateStore(One, PLowerBound);
1637   Value *UpperBound = CLI->getTripCount();
1638   Builder.CreateStore(UpperBound, PUpperBound);
1639   Builder.CreateStore(One, PStride);
1640 
1641   BasicBlock *Header = CLI->getHeader();
1642   BasicBlock *Exit = CLI->getExit();
1643   BasicBlock *Cond = CLI->getCond();
1644   InsertPointTy AfterIP = CLI->getAfterIP();
1645 
1646   // The CLI will be "broken" in the code below, as the loop is no longer
1647   // a valid canonical loop.
1648 
1649   if (!Chunk)
1650     Chunk = One;
1651 
1652   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
1653 
1654   Constant *SchedulingType =
1655       ConstantInt::get(I32Type, static_cast<int>(SchedType));
1656 
1657   // Call the "init" function.
1658   Builder.CreateCall(DynamicInit,
1659                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
1660                       UpperBound, /* step */ One, Chunk});
1661 
1662   // An outer loop around the existing one.
1663   BasicBlock *OuterCond = BasicBlock::Create(
1664       PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
1665       PreHeader->getParent());
1666   // This needs to be 32-bit always, so can't use the IVTy Zero above.
1667   Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
1668   Value *Res =
1669       Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
1670                                        PLowerBound, PUpperBound, PStride});
1671   Constant *Zero32 = ConstantInt::get(I32Type, 0);
1672   Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
1673   Value *LowerBound =
1674       Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
1675   Builder.CreateCondBr(MoreWork, Header, Exit);
1676 
1677   // Change PHI-node in loop header to use outer cond rather than preheader,
1678   // and set IV to the LowerBound.
1679   Instruction *Phi = &Header->front();
1680   auto *PI = cast<PHINode>(Phi);
1681   PI->setIncomingBlock(0, OuterCond);
1682   PI->setIncomingValue(0, LowerBound);
1683 
1684   // Then set the pre-header to jump to the OuterCond
1685   Instruction *Term = PreHeader->getTerminator();
1686   auto *Br = cast<BranchInst>(Term);
1687   Br->setSuccessor(0, OuterCond);
1688 
1689   // Modify the inner condition:
1690   // * Use the UpperBound returned from the DynamicNext call.
1691   // * jump to the loop outer loop when done with one of the inner loops.
1692   Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
1693   UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
1694   Instruction *Comp = &*Builder.GetInsertPoint();
1695   auto *CI = cast<CmpInst>(Comp);
1696   CI->setOperand(1, UpperBound);
1697   // Redirect the inner exit to branch to outer condition.
1698   Instruction *Branch = &Cond->back();
1699   auto *BI = cast<BranchInst>(Branch);
1700   assert(BI->getSuccessor(1) == Exit);
1701   BI->setSuccessor(1, OuterCond);
1702 
1703   // Add the barrier if requested.
1704   if (NeedsBarrier) {
1705     Builder.SetInsertPoint(&Exit->back());
1706     createBarrier(LocationDescription(Builder.saveIP(), DL),
1707                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
1708                   /* CheckCancelFlag */ false);
1709   }
1710 
1711   CLI->invalidate();
1712   return AfterIP;
1713 }
1714 
1715 /// Make \p Source branch to \p Target.
1716 ///
1717 /// Handles two situations:
1718 /// * \p Source already has an unconditional branch.
1719 /// * \p Source is a degenerate block (no terminator because the BB is
1720 ///             the current head of the IR construction).
redirectTo(BasicBlock * Source,BasicBlock * Target,DebugLoc DL)1721 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
1722   if (Instruction *Term = Source->getTerminator()) {
1723     auto *Br = cast<BranchInst>(Term);
1724     assert(!Br->isConditional() &&
1725            "BB's terminator must be an unconditional branch (or degenerate)");
1726     BasicBlock *Succ = Br->getSuccessor(0);
1727     Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
1728     Br->setSuccessor(0, Target);
1729     return;
1730   }
1731 
1732   auto *NewBr = BranchInst::Create(Target, Source);
1733   NewBr->setDebugLoc(DL);
1734 }
1735 
1736 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
1737 /// after this \p OldTarget will be orphaned.
redirectAllPredecessorsTo(BasicBlock * OldTarget,BasicBlock * NewTarget,DebugLoc DL)1738 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
1739                                       BasicBlock *NewTarget, DebugLoc DL) {
1740   for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
1741     redirectTo(Pred, NewTarget, DL);
1742 }
1743 
1744 /// Determine which blocks in \p BBs are reachable from outside and remove the
1745 /// ones that are not reachable from the function.
removeUnusedBlocksFromParent(ArrayRef<BasicBlock * > BBs)1746 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
1747   SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
1748   auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
1749     for (Use &U : BB->uses()) {
1750       auto *UseInst = dyn_cast<Instruction>(U.getUser());
1751       if (!UseInst)
1752         continue;
1753       if (BBsToErase.count(UseInst->getParent()))
1754         continue;
1755       return true;
1756     }
1757     return false;
1758   };
1759 
1760   while (true) {
1761     bool Changed = false;
1762     for (BasicBlock *BB : make_early_inc_range(BBsToErase)) {
1763       if (HasRemainingUses(BB)) {
1764         BBsToErase.erase(BB);
1765         Changed = true;
1766       }
1767     }
1768     if (!Changed)
1769       break;
1770   }
1771 
1772   SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
1773   DeleteDeadBlocks(BBVec);
1774 }
1775 
1776 CanonicalLoopInfo *
collapseLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,InsertPointTy ComputeIP)1777 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
1778                                InsertPointTy ComputeIP) {
1779   assert(Loops.size() >= 1 && "At least one loop required");
1780   size_t NumLoops = Loops.size();
1781 
1782   // Nothing to do if there is already just one loop.
1783   if (NumLoops == 1)
1784     return Loops.front();
1785 
1786   CanonicalLoopInfo *Outermost = Loops.front();
1787   CanonicalLoopInfo *Innermost = Loops.back();
1788   BasicBlock *OrigPreheader = Outermost->getPreheader();
1789   BasicBlock *OrigAfter = Outermost->getAfter();
1790   Function *F = OrigPreheader->getParent();
1791 
1792   // Setup the IRBuilder for inserting the trip count computation.
1793   Builder.SetCurrentDebugLocation(DL);
1794   if (ComputeIP.isSet())
1795     Builder.restoreIP(ComputeIP);
1796   else
1797     Builder.restoreIP(Outermost->getPreheaderIP());
1798 
1799   // Derive the collapsed' loop trip count.
1800   // TODO: Find common/largest indvar type.
1801   Value *CollapsedTripCount = nullptr;
1802   for (CanonicalLoopInfo *L : Loops) {
1803     assert(L->isValid() &&
1804            "All loops to collapse must be valid canonical loops");
1805     Value *OrigTripCount = L->getTripCount();
1806     if (!CollapsedTripCount) {
1807       CollapsedTripCount = OrigTripCount;
1808       continue;
1809     }
1810 
1811     // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
1812     CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
1813                                            {}, /*HasNUW=*/true);
1814   }
1815 
1816   // Create the collapsed loop control flow.
1817   CanonicalLoopInfo *Result =
1818       createLoopSkeleton(DL, CollapsedTripCount, F,
1819                          OrigPreheader->getNextNode(), OrigAfter, "collapsed");
1820 
1821   // Build the collapsed loop body code.
1822   // Start with deriving the input loop induction variables from the collapsed
1823   // one, using a divmod scheme. To preserve the original loops' order, the
1824   // innermost loop use the least significant bits.
1825   Builder.restoreIP(Result->getBodyIP());
1826 
1827   Value *Leftover = Result->getIndVar();
1828   SmallVector<Value *> NewIndVars;
1829   NewIndVars.set_size(NumLoops);
1830   for (int i = NumLoops - 1; i >= 1; --i) {
1831     Value *OrigTripCount = Loops[i]->getTripCount();
1832 
1833     Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
1834     NewIndVars[i] = NewIndVar;
1835 
1836     Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
1837   }
1838   // Outermost loop gets all the remaining bits.
1839   NewIndVars[0] = Leftover;
1840 
1841   // Construct the loop body control flow.
1842   // We progressively construct the branch structure following in direction of
1843   // the control flow, from the leading in-between code, the loop nest body, the
1844   // trailing in-between code, and rejoining the collapsed loop's latch.
1845   // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
1846   // the ContinueBlock is set, continue with that block. If ContinuePred, use
1847   // its predecessors as sources.
1848   BasicBlock *ContinueBlock = Result->getBody();
1849   BasicBlock *ContinuePred = nullptr;
1850   auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
1851                                                           BasicBlock *NextSrc) {
1852     if (ContinueBlock)
1853       redirectTo(ContinueBlock, Dest, DL);
1854     else
1855       redirectAllPredecessorsTo(ContinuePred, Dest, DL);
1856 
1857     ContinueBlock = nullptr;
1858     ContinuePred = NextSrc;
1859   };
1860 
1861   // The code before the nested loop of each level.
1862   // Because we are sinking it into the nest, it will be executed more often
1863   // that the original loop. More sophisticated schemes could keep track of what
1864   // the in-between code is and instantiate it only once per thread.
1865   for (size_t i = 0; i < NumLoops - 1; ++i)
1866     ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
1867 
1868   // Connect the loop nest body.
1869   ContinueWith(Innermost->getBody(), Innermost->getLatch());
1870 
1871   // The code after the nested loop at each level.
1872   for (size_t i = NumLoops - 1; i > 0; --i)
1873     ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
1874 
1875   // Connect the finished loop to the collapsed loop latch.
1876   ContinueWith(Result->getLatch(), nullptr);
1877 
1878   // Replace the input loops with the new collapsed loop.
1879   redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
1880   redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
1881 
1882   // Replace the input loop indvars with the derived ones.
1883   for (size_t i = 0; i < NumLoops; ++i)
1884     Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
1885 
1886   // Remove unused parts of the input loops.
1887   SmallVector<BasicBlock *, 12> OldControlBBs;
1888   OldControlBBs.reserve(6 * Loops.size());
1889   for (CanonicalLoopInfo *Loop : Loops)
1890     Loop->collectControlBlocks(OldControlBBs);
1891   removeUnusedBlocksFromParent(OldControlBBs);
1892 
1893   for (CanonicalLoopInfo *L : Loops)
1894     L->invalidate();
1895 
1896 #ifndef NDEBUG
1897   Result->assertOK();
1898 #endif
1899   return Result;
1900 }
1901 
1902 std::vector<CanonicalLoopInfo *>
tileLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,ArrayRef<Value * > TileSizes)1903 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
1904                            ArrayRef<Value *> TileSizes) {
1905   assert(TileSizes.size() == Loops.size() &&
1906          "Must pass as many tile sizes as there are loops");
1907   int NumLoops = Loops.size();
1908   assert(NumLoops >= 1 && "At least one loop to tile required");
1909 
1910   CanonicalLoopInfo *OutermostLoop = Loops.front();
1911   CanonicalLoopInfo *InnermostLoop = Loops.back();
1912   Function *F = OutermostLoop->getBody()->getParent();
1913   BasicBlock *InnerEnter = InnermostLoop->getBody();
1914   BasicBlock *InnerLatch = InnermostLoop->getLatch();
1915 
1916   // Collect original trip counts and induction variable to be accessible by
1917   // index. Also, the structure of the original loops is not preserved during
1918   // the construction of the tiled loops, so do it before we scavenge the BBs of
1919   // any original CanonicalLoopInfo.
1920   SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
1921   for (CanonicalLoopInfo *L : Loops) {
1922     assert(L->isValid() && "All input loops must be valid canonical loops");
1923     OrigTripCounts.push_back(L->getTripCount());
1924     OrigIndVars.push_back(L->getIndVar());
1925   }
1926 
1927   // Collect the code between loop headers. These may contain SSA definitions
1928   // that are used in the loop nest body. To be usable with in the innermost
1929   // body, these BasicBlocks will be sunk into the loop nest body. That is,
1930   // these instructions may be executed more often than before the tiling.
1931   // TODO: It would be sufficient to only sink them into body of the
1932   // corresponding tile loop.
1933   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
1934   for (int i = 0; i < NumLoops - 1; ++i) {
1935     CanonicalLoopInfo *Surrounding = Loops[i];
1936     CanonicalLoopInfo *Nested = Loops[i + 1];
1937 
1938     BasicBlock *EnterBB = Surrounding->getBody();
1939     BasicBlock *ExitBB = Nested->getHeader();
1940     InbetweenCode.emplace_back(EnterBB, ExitBB);
1941   }
1942 
1943   // Compute the trip counts of the floor loops.
1944   Builder.SetCurrentDebugLocation(DL);
1945   Builder.restoreIP(OutermostLoop->getPreheaderIP());
1946   SmallVector<Value *, 4> FloorCount, FloorRems;
1947   for (int i = 0; i < NumLoops; ++i) {
1948     Value *TileSize = TileSizes[i];
1949     Value *OrigTripCount = OrigTripCounts[i];
1950     Type *IVType = OrigTripCount->getType();
1951 
1952     Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
1953     Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
1954 
1955     // 0 if tripcount divides the tilesize, 1 otherwise.
1956     // 1 means we need an additional iteration for a partial tile.
1957     //
1958     // Unfortunately we cannot just use the roundup-formula
1959     //   (tripcount + tilesize - 1)/tilesize
1960     // because the summation might overflow. We do not want introduce undefined
1961     // behavior when the untiled loop nest did not.
1962     Value *FloorTripOverflow =
1963         Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
1964 
1965     FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
1966     FloorTripCount =
1967         Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
1968                           "omp_floor" + Twine(i) + ".tripcount", true);
1969 
1970     // Remember some values for later use.
1971     FloorCount.push_back(FloorTripCount);
1972     FloorRems.push_back(FloorTripRem);
1973   }
1974 
1975   // Generate the new loop nest, from the outermost to the innermost.
1976   std::vector<CanonicalLoopInfo *> Result;
1977   Result.reserve(NumLoops * 2);
1978 
1979   // The basic block of the surrounding loop that enters the nest generated
1980   // loop.
1981   BasicBlock *Enter = OutermostLoop->getPreheader();
1982 
1983   // The basic block of the surrounding loop where the inner code should
1984   // continue.
1985   BasicBlock *Continue = OutermostLoop->getAfter();
1986 
1987   // Where the next loop basic block should be inserted.
1988   BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
1989 
1990   auto EmbeddNewLoop =
1991       [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
1992           Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
1993     CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
1994         DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
1995     redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
1996     redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
1997 
1998     // Setup the position where the next embedded loop connects to this loop.
1999     Enter = EmbeddedLoop->getBody();
2000     Continue = EmbeddedLoop->getLatch();
2001     OutroInsertBefore = EmbeddedLoop->getLatch();
2002     return EmbeddedLoop;
2003   };
2004 
2005   auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
2006                                                   const Twine &NameBase) {
2007     for (auto P : enumerate(TripCounts)) {
2008       CanonicalLoopInfo *EmbeddedLoop =
2009           EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
2010       Result.push_back(EmbeddedLoop);
2011     }
2012   };
2013 
2014   EmbeddNewLoops(FloorCount, "floor");
2015 
2016   // Within the innermost floor loop, emit the code that computes the tile
2017   // sizes.
2018   Builder.SetInsertPoint(Enter->getTerminator());
2019   SmallVector<Value *, 4> TileCounts;
2020   for (int i = 0; i < NumLoops; ++i) {
2021     CanonicalLoopInfo *FloorLoop = Result[i];
2022     Value *TileSize = TileSizes[i];
2023 
2024     Value *FloorIsEpilogue =
2025         Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
2026     Value *TileTripCount =
2027         Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
2028 
2029     TileCounts.push_back(TileTripCount);
2030   }
2031 
2032   // Create the tile loops.
2033   EmbeddNewLoops(TileCounts, "tile");
2034 
2035   // Insert the inbetween code into the body.
2036   BasicBlock *BodyEnter = Enter;
2037   BasicBlock *BodyEntered = nullptr;
2038   for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
2039     BasicBlock *EnterBB = P.first;
2040     BasicBlock *ExitBB = P.second;
2041 
2042     if (BodyEnter)
2043       redirectTo(BodyEnter, EnterBB, DL);
2044     else
2045       redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
2046 
2047     BodyEnter = nullptr;
2048     BodyEntered = ExitBB;
2049   }
2050 
2051   // Append the original loop nest body into the generated loop nest body.
2052   if (BodyEnter)
2053     redirectTo(BodyEnter, InnerEnter, DL);
2054   else
2055     redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
2056   redirectAllPredecessorsTo(InnerLatch, Continue, DL);
2057 
2058   // Replace the original induction variable with an induction variable computed
2059   // from the tile and floor induction variables.
2060   Builder.restoreIP(Result.back()->getBodyIP());
2061   for (int i = 0; i < NumLoops; ++i) {
2062     CanonicalLoopInfo *FloorLoop = Result[i];
2063     CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
2064     Value *OrigIndVar = OrigIndVars[i];
2065     Value *Size = TileSizes[i];
2066 
2067     Value *Scale =
2068         Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
2069     Value *Shift =
2070         Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
2071     OrigIndVar->replaceAllUsesWith(Shift);
2072   }
2073 
2074   // Remove unused parts of the original loops.
2075   SmallVector<BasicBlock *, 12> OldControlBBs;
2076   OldControlBBs.reserve(6 * Loops.size());
2077   for (CanonicalLoopInfo *Loop : Loops)
2078     Loop->collectControlBlocks(OldControlBBs);
2079   removeUnusedBlocksFromParent(OldControlBBs);
2080 
2081   for (CanonicalLoopInfo *L : Loops)
2082     L->invalidate();
2083 
2084 #ifndef NDEBUG
2085   for (CanonicalLoopInfo *GenL : Result)
2086     GenL->assertOK();
2087 #endif
2088   return Result;
2089 }
2090 
2091 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
2092 /// loop already has metadata, the loop properties are appended.
addLoopMetadata(CanonicalLoopInfo * Loop,ArrayRef<Metadata * > Properties)2093 static void addLoopMetadata(CanonicalLoopInfo *Loop,
2094                             ArrayRef<Metadata *> Properties) {
2095   assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
2096 
2097   // Nothing to do if no property to attach.
2098   if (Properties.empty())
2099     return;
2100 
2101   LLVMContext &Ctx = Loop->getFunction()->getContext();
2102   SmallVector<Metadata *> NewLoopProperties;
2103   NewLoopProperties.push_back(nullptr);
2104 
2105   // If the loop already has metadata, prepend it to the new metadata.
2106   BasicBlock *Latch = Loop->getLatch();
2107   assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
2108   MDNode *Existing = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop);
2109   if (Existing)
2110     append_range(NewLoopProperties, drop_begin(Existing->operands(), 1));
2111 
2112   append_range(NewLoopProperties, Properties);
2113   MDNode *LoopID = MDNode::getDistinct(Ctx, NewLoopProperties);
2114   LoopID->replaceOperandWith(0, LoopID);
2115 
2116   Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID);
2117 }
2118 
unrollLoopFull(DebugLoc,CanonicalLoopInfo * Loop)2119 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
2120   LLVMContext &Ctx = Builder.getContext();
2121   addLoopMetadata(
2122       Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2123              MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
2124 }
2125 
unrollLoopHeuristic(DebugLoc,CanonicalLoopInfo * Loop)2126 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
2127   LLVMContext &Ctx = Builder.getContext();
2128   addLoopMetadata(
2129       Loop, {
2130                 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2131             });
2132 }
2133 
2134 /// Create the TargetMachine object to query the backend for optimization
2135 /// preferences.
2136 ///
2137 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
2138 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
2139 /// needed for the LLVM pass pipline. We use some default options to avoid
2140 /// having to pass too many settings from the frontend that probably do not
2141 /// matter.
2142 ///
2143 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
2144 /// method. If we are going to use TargetMachine for more purposes, especially
2145 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
2146 /// might become be worth requiring front-ends to pass on their TargetMachine,
2147 /// or at least cache it between methods. Note that while fontends such as Clang
2148 /// have just a single main TargetMachine per translation unit, "target-cpu" and
2149 /// "target-features" that determine the TargetMachine are per-function and can
2150 /// be overrided using __attribute__((target("OPTIONS"))).
2151 static std::unique_ptr<TargetMachine>
createTargetMachine(Function * F,CodeGenOpt::Level OptLevel)2152 createTargetMachine(Function *F, CodeGenOpt::Level OptLevel) {
2153   Module *M = F->getParent();
2154 
2155   StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
2156   StringRef Features = F->getFnAttribute("target-features").getValueAsString();
2157   const std::string &Triple = M->getTargetTriple();
2158 
2159   std::string Error;
2160   const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
2161   if (!TheTarget)
2162     return {};
2163 
2164   llvm::TargetOptions Options;
2165   return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
2166       Triple, CPU, Features, Options, /*RelocModel=*/None, /*CodeModel=*/None,
2167       OptLevel));
2168 }
2169 
2170 /// Heuristically determine the best-performant unroll factor for \p CLI. This
2171 /// depends on the target processor. We are re-using the same heuristics as the
2172 /// LoopUnrollPass.
computeHeuristicUnrollFactor(CanonicalLoopInfo * CLI)2173 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
2174   Function *F = CLI->getFunction();
2175 
2176   // Assume the user requests the most aggressive unrolling, even if the rest of
2177   // the code is optimized using a lower setting.
2178   CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive;
2179   std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
2180 
2181   FunctionAnalysisManager FAM;
2182   FAM.registerPass([]() { return TargetLibraryAnalysis(); });
2183   FAM.registerPass([]() { return AssumptionAnalysis(); });
2184   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
2185   FAM.registerPass([]() { return LoopAnalysis(); });
2186   FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
2187   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
2188   TargetIRAnalysis TIRA;
2189   if (TM)
2190     TIRA = TargetIRAnalysis(
2191         [&](const Function &F) { return TM->getTargetTransformInfo(F); });
2192   FAM.registerPass([&]() { return TIRA; });
2193 
2194   TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
2195   ScalarEvolutionAnalysis SEA;
2196   ScalarEvolution &&SE = SEA.run(*F, FAM);
2197   DominatorTreeAnalysis DTA;
2198   DominatorTree &&DT = DTA.run(*F, FAM);
2199   LoopAnalysis LIA;
2200   LoopInfo &&LI = LIA.run(*F, FAM);
2201   AssumptionAnalysis ACT;
2202   AssumptionCache &&AC = ACT.run(*F, FAM);
2203   OptimizationRemarkEmitter ORE{F};
2204 
2205   Loop *L = LI.getLoopFor(CLI->getHeader());
2206   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
2207 
2208   TargetTransformInfo::UnrollingPreferences UP =
2209       gatherUnrollingPreferences(L, SE, TTI,
2210                                  /*BlockFrequencyInfo=*/nullptr,
2211                                  /*ProfileSummaryInfo=*/nullptr, ORE, OptLevel,
2212                                  /*UserThreshold=*/None,
2213                                  /*UserCount=*/None,
2214                                  /*UserAllowPartial=*/true,
2215                                  /*UserAllowRuntime=*/true,
2216                                  /*UserUpperBound=*/None,
2217                                  /*UserFullUnrollMaxCount=*/None);
2218 
2219   UP.Force = true;
2220 
2221   // Account for additional optimizations taking place before the LoopUnrollPass
2222   // would unroll the loop.
2223   UP.Threshold *= UnrollThresholdFactor;
2224   UP.PartialThreshold *= UnrollThresholdFactor;
2225 
2226   // Use normal unroll factors even if the rest of the code is optimized for
2227   // size.
2228   UP.OptSizeThreshold = UP.Threshold;
2229   UP.PartialOptSizeThreshold = UP.PartialThreshold;
2230 
2231   LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
2232                     << "  Threshold=" << UP.Threshold << "\n"
2233                     << "  PartialThreshold=" << UP.PartialThreshold << "\n"
2234                     << "  OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
2235                     << "  PartialOptSizeThreshold="
2236                     << UP.PartialOptSizeThreshold << "\n");
2237 
2238   // Disable peeling.
2239   TargetTransformInfo::PeelingPreferences PP =
2240       gatherPeelingPreferences(L, SE, TTI,
2241                                /*UserAllowPeeling=*/false,
2242                                /*UserAllowProfileBasedPeeling=*/false,
2243                                /*UserUnrollingSpecficValues=*/false);
2244 
2245   SmallPtrSet<const Value *, 32> EphValues;
2246   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
2247 
2248   // Assume that reads and writes to stack variables can be eliminated by
2249   // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
2250   // size.
2251   for (BasicBlock *BB : L->blocks()) {
2252     for (Instruction &I : *BB) {
2253       Value *Ptr;
2254       if (auto *Load = dyn_cast<LoadInst>(&I)) {
2255         Ptr = Load->getPointerOperand();
2256       } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
2257         Ptr = Store->getPointerOperand();
2258       } else
2259         continue;
2260 
2261       Ptr = Ptr->stripPointerCasts();
2262 
2263       if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
2264         if (Alloca->getParent() == &F->getEntryBlock())
2265           EphValues.insert(&I);
2266       }
2267     }
2268   }
2269 
2270   unsigned NumInlineCandidates;
2271   bool NotDuplicatable;
2272   bool Convergent;
2273   unsigned LoopSize =
2274       ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent,
2275                           TTI, EphValues, UP.BEInsns);
2276   LLVM_DEBUG(dbgs() << "Estimated loop size is " << LoopSize << "\n");
2277 
2278   // Loop is not unrollable if the loop contains certain instructions.
2279   if (NotDuplicatable || Convergent) {
2280     LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
2281     return 1;
2282   }
2283 
2284   // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
2285   // be able to use it.
2286   int TripCount = 0;
2287   int MaxTripCount = 0;
2288   bool MaxOrZero = false;
2289   unsigned TripMultiple = 0;
2290 
2291   bool UseUpperBound = false;
2292   computeUnrollCount(L, TTI, DT, &LI, SE, EphValues, &ORE, TripCount,
2293                      MaxTripCount, MaxOrZero, TripMultiple, LoopSize, UP, PP,
2294                      UseUpperBound);
2295   unsigned Factor = UP.Count;
2296   LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
2297 
2298   // This function returns 1 to signal to not unroll a loop.
2299   if (Factor == 0)
2300     return 1;
2301   return Factor;
2302 }
2303 
unrollLoopPartial(DebugLoc DL,CanonicalLoopInfo * Loop,int32_t Factor,CanonicalLoopInfo ** UnrolledCLI)2304 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
2305                                         int32_t Factor,
2306                                         CanonicalLoopInfo **UnrolledCLI) {
2307   assert(Factor >= 0 && "Unroll factor must not be negative");
2308 
2309   Function *F = Loop->getFunction();
2310   LLVMContext &Ctx = F->getContext();
2311 
2312   // If the unrolled loop is not used for another loop-associated directive, it
2313   // is sufficient to add metadata for the LoopUnrollPass.
2314   if (!UnrolledCLI) {
2315     SmallVector<Metadata *, 2> LoopMetadata;
2316     LoopMetadata.push_back(
2317         MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
2318 
2319     if (Factor >= 1) {
2320       ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
2321           ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
2322       LoopMetadata.push_back(MDNode::get(
2323           Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
2324     }
2325 
2326     addLoopMetadata(Loop, LoopMetadata);
2327     return;
2328   }
2329 
2330   // Heuristically determine the unroll factor.
2331   if (Factor == 0)
2332     Factor = computeHeuristicUnrollFactor(Loop);
2333 
2334   // No change required with unroll factor 1.
2335   if (Factor == 1) {
2336     *UnrolledCLI = Loop;
2337     return;
2338   }
2339 
2340   assert(Factor >= 2 &&
2341          "unrolling only makes sense with a factor of 2 or larger");
2342 
2343   Type *IndVarTy = Loop->getIndVarType();
2344 
2345   // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
2346   // unroll the inner loop.
2347   Value *FactorVal =
2348       ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
2349                                        /*isSigned=*/false));
2350   std::vector<CanonicalLoopInfo *> LoopNest =
2351       tileLoops(DL, {Loop}, {FactorVal});
2352   assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
2353   *UnrolledCLI = LoopNest[0];
2354   CanonicalLoopInfo *InnerLoop = LoopNest[1];
2355 
2356   // LoopUnrollPass can only fully unroll loops with constant trip count.
2357   // Unroll by the unroll factor with a fallback epilog for the remainder
2358   // iterations if necessary.
2359   ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
2360       ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
2361   addLoopMetadata(
2362       InnerLoop,
2363       {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2364        MDNode::get(
2365            Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
2366 
2367 #ifndef NDEBUG
2368   (*UnrolledCLI)->assertOK();
2369 #endif
2370 }
2371 
2372 OpenMPIRBuilder::InsertPointTy
createCopyPrivate(const LocationDescription & Loc,llvm::Value * BufSize,llvm::Value * CpyBuf,llvm::Value * CpyFn,llvm::Value * DidIt)2373 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
2374                                    llvm::Value *BufSize, llvm::Value *CpyBuf,
2375                                    llvm::Value *CpyFn, llvm::Value *DidIt) {
2376   if (!updateToLocation(Loc))
2377     return Loc.IP;
2378 
2379   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2380   Value *Ident = getOrCreateIdent(SrcLocStr);
2381   Value *ThreadId = getOrCreateThreadID(Ident);
2382 
2383   llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
2384 
2385   Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
2386 
2387   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
2388   Builder.CreateCall(Fn, Args);
2389 
2390   return Builder.saveIP();
2391 }
2392 
2393 OpenMPIRBuilder::InsertPointTy
createSingle(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,llvm::Value * DidIt)2394 OpenMPIRBuilder::createSingle(const LocationDescription &Loc,
2395                               BodyGenCallbackTy BodyGenCB,
2396                               FinalizeCallbackTy FiniCB, llvm::Value *DidIt) {
2397 
2398   if (!updateToLocation(Loc))
2399     return Loc.IP;
2400 
2401   // If needed (i.e. not null), initialize `DidIt` with 0
2402   if (DidIt) {
2403     Builder.CreateStore(Builder.getInt32(0), DidIt);
2404   }
2405 
2406   Directive OMPD = Directive::OMPD_single;
2407   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2408   Value *Ident = getOrCreateIdent(SrcLocStr);
2409   Value *ThreadId = getOrCreateThreadID(Ident);
2410   Value *Args[] = {Ident, ThreadId};
2411 
2412   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
2413   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
2414 
2415   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
2416   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
2417 
2418   // generates the following:
2419   // if (__kmpc_single()) {
2420   //		.... single region ...
2421   // 		__kmpc_end_single
2422   // }
2423 
2424   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
2425                               /*Conditional*/ true, /*hasFinalize*/ true);
2426 }
2427 
createCritical(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,StringRef CriticalName,Value * HintInst)2428 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
2429     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
2430     FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
2431 
2432   if (!updateToLocation(Loc))
2433     return Loc.IP;
2434 
2435   Directive OMPD = Directive::OMPD_critical;
2436   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2437   Value *Ident = getOrCreateIdent(SrcLocStr);
2438   Value *ThreadId = getOrCreateThreadID(Ident);
2439   Value *LockVar = getOMPCriticalRegionLock(CriticalName);
2440   Value *Args[] = {Ident, ThreadId, LockVar};
2441 
2442   SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
2443   Function *RTFn = nullptr;
2444   if (HintInst) {
2445     // Add Hint to entry Args and create call
2446     EnterArgs.push_back(HintInst);
2447     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
2448   } else {
2449     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
2450   }
2451   Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
2452 
2453   Function *ExitRTLFn =
2454       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
2455   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
2456 
2457   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
2458                               /*Conditional*/ false, /*hasFinalize*/ true);
2459 }
2460 
2461 OpenMPIRBuilder::InsertPointTy
createOrderedDepend(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumLoops,ArrayRef<llvm::Value * > StoreValues,const Twine & Name,bool IsDependSource)2462 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
2463                                      InsertPointTy AllocaIP, unsigned NumLoops,
2464                                      ArrayRef<llvm::Value *> StoreValues,
2465                                      const Twine &Name, bool IsDependSource) {
2466   if (!updateToLocation(Loc))
2467     return Loc.IP;
2468 
2469   // Allocate space for vector and generate alloc instruction.
2470   auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
2471   Builder.restoreIP(AllocaIP);
2472   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
2473   ArgsBase->setAlignment(Align(8));
2474   Builder.restoreIP(Loc.IP);
2475 
2476   // Store the index value with offset in depend vector.
2477   for (unsigned I = 0; I < NumLoops; ++I) {
2478     Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
2479         ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
2480     Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
2481   }
2482 
2483   Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
2484       ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
2485 
2486   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2487   Value *Ident = getOrCreateIdent(SrcLocStr);
2488   Value *ThreadId = getOrCreateThreadID(Ident);
2489   Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
2490 
2491   Function *RTLFn = nullptr;
2492   if (IsDependSource)
2493     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
2494   else
2495     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
2496   Builder.CreateCall(RTLFn, Args);
2497 
2498   return Builder.saveIP();
2499 }
2500 
createOrderedThreadsSimd(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsThreads)2501 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
2502     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
2503     FinalizeCallbackTy FiniCB, bool IsThreads) {
2504   if (!updateToLocation(Loc))
2505     return Loc.IP;
2506 
2507   Directive OMPD = Directive::OMPD_ordered;
2508   Instruction *EntryCall = nullptr;
2509   Instruction *ExitCall = nullptr;
2510 
2511   if (IsThreads) {
2512     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2513     Value *Ident = getOrCreateIdent(SrcLocStr);
2514     Value *ThreadId = getOrCreateThreadID(Ident);
2515     Value *Args[] = {Ident, ThreadId};
2516 
2517     Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
2518     EntryCall = Builder.CreateCall(EntryRTLFn, Args);
2519 
2520     Function *ExitRTLFn =
2521         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
2522     ExitCall = Builder.CreateCall(ExitRTLFn, Args);
2523   }
2524 
2525   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
2526                               /*Conditional*/ false, /*hasFinalize*/ true);
2527 }
2528 
EmitOMPInlinedRegion(Directive OMPD,Instruction * EntryCall,Instruction * ExitCall,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool Conditional,bool HasFinalize,bool IsCancellable)2529 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
2530     Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
2531     BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
2532     bool HasFinalize, bool IsCancellable) {
2533 
2534   if (HasFinalize)
2535     FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
2536 
2537   // Create inlined region's entry and body blocks, in preparation
2538   // for conditional creation
2539   BasicBlock *EntryBB = Builder.GetInsertBlock();
2540   Instruction *SplitPos = EntryBB->getTerminator();
2541   if (!isa_and_nonnull<BranchInst>(SplitPos))
2542     SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
2543   BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
2544   BasicBlock *FiniBB =
2545       EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
2546 
2547   Builder.SetInsertPoint(EntryBB->getTerminator());
2548   emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
2549 
2550   // generate body
2551   BodyGenCB(/* AllocaIP */ InsertPointTy(),
2552             /* CodeGenIP */ Builder.saveIP(), *FiniBB);
2553 
2554   // If we didn't emit a branch to FiniBB during body generation, it means
2555   // FiniBB is unreachable (e.g. while(1);). stop generating all the
2556   // unreachable blocks, and remove anything we are not going to use.
2557   auto SkipEmittingRegion = FiniBB->hasNPredecessors(0);
2558   if (SkipEmittingRegion) {
2559     FiniBB->eraseFromParent();
2560     ExitCall->eraseFromParent();
2561     // Discard finalization if we have it.
2562     if (HasFinalize) {
2563       assert(!FinalizationStack.empty() &&
2564              "Unexpected finalization stack state!");
2565       FinalizationStack.pop_back();
2566     }
2567   } else {
2568     // emit exit call and do any needed finalization.
2569     auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
2570     assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
2571            FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
2572            "Unexpected control flow graph state!!");
2573     emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
2574     assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
2575            "Unexpected Control Flow State!");
2576     MergeBlockIntoPredecessor(FiniBB);
2577   }
2578 
2579   // If we are skipping the region of a non conditional, remove the exit
2580   // block, and clear the builder's insertion point.
2581   assert(SplitPos->getParent() == ExitBB &&
2582          "Unexpected Insertion point location!");
2583   if (!Conditional && SkipEmittingRegion) {
2584     ExitBB->eraseFromParent();
2585     Builder.ClearInsertionPoint();
2586   } else {
2587     auto merged = MergeBlockIntoPredecessor(ExitBB);
2588     BasicBlock *ExitPredBB = SplitPos->getParent();
2589     auto InsertBB = merged ? ExitPredBB : ExitBB;
2590     if (!isa_and_nonnull<BranchInst>(SplitPos))
2591       SplitPos->eraseFromParent();
2592     Builder.SetInsertPoint(InsertBB);
2593   }
2594 
2595   return Builder.saveIP();
2596 }
2597 
emitCommonDirectiveEntry(Directive OMPD,Value * EntryCall,BasicBlock * ExitBB,bool Conditional)2598 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
2599     Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
2600   // if nothing to do, Return current insertion point.
2601   if (!Conditional || !EntryCall)
2602     return Builder.saveIP();
2603 
2604   BasicBlock *EntryBB = Builder.GetInsertBlock();
2605   Value *CallBool = Builder.CreateIsNotNull(EntryCall);
2606   auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
2607   auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
2608 
2609   // Emit thenBB and set the Builder's insertion point there for
2610   // body generation next. Place the block after the current block.
2611   Function *CurFn = EntryBB->getParent();
2612   CurFn->getBasicBlockList().insertAfter(EntryBB->getIterator(), ThenBB);
2613 
2614   // Move Entry branch to end of ThenBB, and replace with conditional
2615   // branch (If-stmt)
2616   Instruction *EntryBBTI = EntryBB->getTerminator();
2617   Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
2618   EntryBBTI->removeFromParent();
2619   Builder.SetInsertPoint(UI);
2620   Builder.Insert(EntryBBTI);
2621   UI->eraseFromParent();
2622   Builder.SetInsertPoint(ThenBB->getTerminator());
2623 
2624   // return an insertion point to ExitBB.
2625   return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
2626 }
2627 
emitCommonDirectiveExit(omp::Directive OMPD,InsertPointTy FinIP,Instruction * ExitCall,bool HasFinalize)2628 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
2629     omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
2630     bool HasFinalize) {
2631 
2632   Builder.restoreIP(FinIP);
2633 
2634   // If there is finalization to do, emit it before the exit call
2635   if (HasFinalize) {
2636     assert(!FinalizationStack.empty() &&
2637            "Unexpected finalization stack state!");
2638 
2639     FinalizationInfo Fi = FinalizationStack.pop_back_val();
2640     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
2641 
2642     Fi.FiniCB(FinIP);
2643 
2644     BasicBlock *FiniBB = FinIP.getBlock();
2645     Instruction *FiniBBTI = FiniBB->getTerminator();
2646 
2647     // set Builder IP for call creation
2648     Builder.SetInsertPoint(FiniBBTI);
2649   }
2650 
2651   if (!ExitCall)
2652     return Builder.saveIP();
2653 
2654   // place the Exitcall as last instruction before Finalization block terminator
2655   ExitCall->removeFromParent();
2656   Builder.Insert(ExitCall);
2657 
2658   return IRBuilder<>::InsertPoint(ExitCall->getParent(),
2659                                   ExitCall->getIterator());
2660 }
2661 
createCopyinClauseBlocks(InsertPointTy IP,Value * MasterAddr,Value * PrivateAddr,llvm::IntegerType * IntPtrTy,bool BranchtoEnd)2662 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
2663     InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
2664     llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
2665   if (!IP.isSet())
2666     return IP;
2667 
2668   IRBuilder<>::InsertPointGuard IPG(Builder);
2669 
2670   // creates the following CFG structure
2671   //	   OMP_Entry : (MasterAddr != PrivateAddr)?
2672   //       F     T
2673   //       |      \
2674   //       |     copin.not.master
2675   //       |      /
2676   //       v     /
2677   //   copyin.not.master.end
2678   //		     |
2679   //         v
2680   //   OMP.Entry.Next
2681 
2682   BasicBlock *OMP_Entry = IP.getBlock();
2683   Function *CurFn = OMP_Entry->getParent();
2684   BasicBlock *CopyBegin =
2685       BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
2686   BasicBlock *CopyEnd = nullptr;
2687 
2688   // If entry block is terminated, split to preserve the branch to following
2689   // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
2690   if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
2691     CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
2692                                          "copyin.not.master.end");
2693     OMP_Entry->getTerminator()->eraseFromParent();
2694   } else {
2695     CopyEnd =
2696         BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
2697   }
2698 
2699   Builder.SetInsertPoint(OMP_Entry);
2700   Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
2701   Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
2702   Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
2703   Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
2704 
2705   Builder.SetInsertPoint(CopyBegin);
2706   if (BranchtoEnd)
2707     Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
2708 
2709   return Builder.saveIP();
2710 }
2711 
createOMPAlloc(const LocationDescription & Loc,Value * Size,Value * Allocator,std::string Name)2712 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
2713                                           Value *Size, Value *Allocator,
2714                                           std::string Name) {
2715   IRBuilder<>::InsertPointGuard IPG(Builder);
2716   Builder.restoreIP(Loc.IP);
2717 
2718   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2719   Value *Ident = getOrCreateIdent(SrcLocStr);
2720   Value *ThreadId = getOrCreateThreadID(Ident);
2721   Value *Args[] = {ThreadId, Size, Allocator};
2722 
2723   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
2724 
2725   return Builder.CreateCall(Fn, Args, Name);
2726 }
2727 
createOMPFree(const LocationDescription & Loc,Value * Addr,Value * Allocator,std::string Name)2728 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
2729                                          Value *Addr, Value *Allocator,
2730                                          std::string Name) {
2731   IRBuilder<>::InsertPointGuard IPG(Builder);
2732   Builder.restoreIP(Loc.IP);
2733 
2734   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2735   Value *Ident = getOrCreateIdent(SrcLocStr);
2736   Value *ThreadId = getOrCreateThreadID(Ident);
2737   Value *Args[] = {ThreadId, Addr, Allocator};
2738   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
2739   return Builder.CreateCall(Fn, Args, Name);
2740 }
2741 
createCachedThreadPrivate(const LocationDescription & Loc,llvm::Value * Pointer,llvm::ConstantInt * Size,const llvm::Twine & Name)2742 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
2743     const LocationDescription &Loc, llvm::Value *Pointer,
2744     llvm::ConstantInt *Size, const llvm::Twine &Name) {
2745   IRBuilder<>::InsertPointGuard IPG(Builder);
2746   Builder.restoreIP(Loc.IP);
2747 
2748   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2749   Value *Ident = getOrCreateIdent(SrcLocStr);
2750   Value *ThreadId = getOrCreateThreadID(Ident);
2751   Constant *ThreadPrivateCache =
2752       getOrCreateOMPInternalVariable(Int8PtrPtr, Name);
2753   llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
2754 
2755   Function *Fn =
2756       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
2757 
2758   return Builder.CreateCall(Fn, Args);
2759 }
2760 
2761 OpenMPIRBuilder::InsertPointTy
createTargetInit(const LocationDescription & Loc,bool IsSPMD,bool RequiresFullRuntime)2762 OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
2763                                   bool RequiresFullRuntime) {
2764   if (!updateToLocation(Loc))
2765     return Loc.IP;
2766 
2767   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2768   Value *Ident = getOrCreateIdent(SrcLocStr);
2769   ConstantInt *IsSPMDVal = ConstantInt::getSigned(
2770       IntegerType::getInt8Ty(Int8->getContext()),
2771       IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
2772   ConstantInt *UseGenericStateMachine =
2773       ConstantInt::getBool(Int32->getContext(), !IsSPMD);
2774   ConstantInt *RequiresFullRuntimeVal =
2775       ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
2776 
2777   Function *Fn = getOrCreateRuntimeFunctionPtr(
2778       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
2779 
2780   CallInst *ThreadKind = Builder.CreateCall(
2781       Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal});
2782 
2783   Value *ExecUserCode = Builder.CreateICmpEQ(
2784       ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
2785       "exec_user_code");
2786 
2787   // ThreadKind = __kmpc_target_init(...)
2788   // if (ThreadKind == -1)
2789   //   user_code
2790   // else
2791   //   return;
2792 
2793   auto *UI = Builder.CreateUnreachable();
2794   BasicBlock *CheckBB = UI->getParent();
2795   BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
2796 
2797   BasicBlock *WorkerExitBB = BasicBlock::Create(
2798       CheckBB->getContext(), "worker.exit", CheckBB->getParent());
2799   Builder.SetInsertPoint(WorkerExitBB);
2800   Builder.CreateRetVoid();
2801 
2802   auto *CheckBBTI = CheckBB->getTerminator();
2803   Builder.SetInsertPoint(CheckBBTI);
2804   Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
2805 
2806   CheckBBTI->eraseFromParent();
2807   UI->eraseFromParent();
2808 
2809   // Continue in the "user_code" block, see diagram above and in
2810   // openmp/libomptarget/deviceRTLs/common/include/target.h .
2811   return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
2812 }
2813 
createTargetDeinit(const LocationDescription & Loc,bool IsSPMD,bool RequiresFullRuntime)2814 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
2815                                          bool IsSPMD,
2816                                          bool RequiresFullRuntime) {
2817   if (!updateToLocation(Loc))
2818     return;
2819 
2820   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
2821   Value *Ident = getOrCreateIdent(SrcLocStr);
2822   ConstantInt *IsSPMDVal = ConstantInt::getSigned(
2823       IntegerType::getInt8Ty(Int8->getContext()),
2824       IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
2825   ConstantInt *RequiresFullRuntimeVal =
2826       ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
2827 
2828   Function *Fn = getOrCreateRuntimeFunctionPtr(
2829       omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
2830 
2831   Builder.CreateCall(Fn, {Ident, IsSPMDVal, RequiresFullRuntimeVal});
2832 }
2833 
getNameWithSeparators(ArrayRef<StringRef> Parts,StringRef FirstSeparator,StringRef Separator)2834 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
2835                                                    StringRef FirstSeparator,
2836                                                    StringRef Separator) {
2837   SmallString<128> Buffer;
2838   llvm::raw_svector_ostream OS(Buffer);
2839   StringRef Sep = FirstSeparator;
2840   for (StringRef Part : Parts) {
2841     OS << Sep << Part;
2842     Sep = Separator;
2843   }
2844   return OS.str().str();
2845 }
2846 
getOrCreateOMPInternalVariable(llvm::Type * Ty,const llvm::Twine & Name,unsigned AddressSpace)2847 Constant *OpenMPIRBuilder::getOrCreateOMPInternalVariable(
2848     llvm::Type *Ty, const llvm::Twine &Name, unsigned AddressSpace) {
2849   // TODO: Replace the twine arg with stringref to get rid of the conversion
2850   // logic. However This is taken from current implementation in clang as is.
2851   // Since this method is used in many places exclusively for OMP internal use
2852   // we will keep it as is for temporarily until we move all users to the
2853   // builder and then, if possible, fix it everywhere in one go.
2854   SmallString<256> Buffer;
2855   llvm::raw_svector_ostream Out(Buffer);
2856   Out << Name;
2857   StringRef RuntimeName = Out.str();
2858   auto &Elem = *InternalVars.try_emplace(RuntimeName, nullptr).first;
2859   if (Elem.second) {
2860     assert(Elem.second->getType()->getPointerElementType() == Ty &&
2861            "OMP internal variable has different type than requested");
2862   } else {
2863     // TODO: investigate the appropriate linkage type used for the global
2864     // variable for possibly changing that to internal or private, or maybe
2865     // create different versions of the function for different OMP internal
2866     // variables.
2867     Elem.second = new llvm::GlobalVariable(
2868         M, Ty, /*IsConstant*/ false, llvm::GlobalValue::CommonLinkage,
2869         llvm::Constant::getNullValue(Ty), Elem.first(),
2870         /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal,
2871         AddressSpace);
2872   }
2873 
2874   return Elem.second;
2875 }
2876 
getOMPCriticalRegionLock(StringRef CriticalName)2877 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
2878   std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
2879   std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
2880   return getOrCreateOMPInternalVariable(KmpCriticalNameTy, Name);
2881 }
2882 
2883 GlobalVariable *
createOffloadMaptypes(SmallVectorImpl<uint64_t> & Mappings,std::string VarName)2884 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
2885                                        std::string VarName) {
2886   llvm::Constant *MaptypesArrayInit =
2887       llvm::ConstantDataArray::get(M.getContext(), Mappings);
2888   auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
2889       M, MaptypesArrayInit->getType(),
2890       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
2891       VarName);
2892   MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
2893   return MaptypesArrayGlobal;
2894 }
2895 
createMapperAllocas(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumOperands,struct MapperAllocas & MapperAllocas)2896 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
2897                                           InsertPointTy AllocaIP,
2898                                           unsigned NumOperands,
2899                                           struct MapperAllocas &MapperAllocas) {
2900   if (!updateToLocation(Loc))
2901     return;
2902 
2903   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
2904   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
2905   Builder.restoreIP(AllocaIP);
2906   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI8PtrTy);
2907   AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy);
2908   AllocaInst *ArgSizes = Builder.CreateAlloca(ArrI64Ty);
2909   Builder.restoreIP(Loc.IP);
2910   MapperAllocas.ArgsBase = ArgsBase;
2911   MapperAllocas.Args = Args;
2912   MapperAllocas.ArgSizes = ArgSizes;
2913 }
2914 
emitMapperCall(const LocationDescription & Loc,Function * MapperFunc,Value * SrcLocInfo,Value * MaptypesArg,Value * MapnamesArg,struct MapperAllocas & MapperAllocas,int64_t DeviceID,unsigned NumOperands)2915 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
2916                                      Function *MapperFunc, Value *SrcLocInfo,
2917                                      Value *MaptypesArg, Value *MapnamesArg,
2918                                      struct MapperAllocas &MapperAllocas,
2919                                      int64_t DeviceID, unsigned NumOperands) {
2920   if (!updateToLocation(Loc))
2921     return;
2922 
2923   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
2924   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
2925   Value *ArgsBaseGEP =
2926       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
2927                                 {Builder.getInt32(0), Builder.getInt32(0)});
2928   Value *ArgsGEP =
2929       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
2930                                 {Builder.getInt32(0), Builder.getInt32(0)});
2931   Value *ArgSizesGEP =
2932       Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
2933                                 {Builder.getInt32(0), Builder.getInt32(0)});
2934   Value *NullPtr = Constant::getNullValue(Int8Ptr->getPointerTo());
2935   Builder.CreateCall(MapperFunc,
2936                      {SrcLocInfo, Builder.getInt64(DeviceID),
2937                       Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
2938                       ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
2939 }
2940 
checkAndEmitFlushAfterAtomic(const LocationDescription & Loc,llvm::AtomicOrdering AO,AtomicKind AK)2941 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
2942     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
2943   assert(!(AO == AtomicOrdering::NotAtomic ||
2944            AO == llvm::AtomicOrdering::Unordered) &&
2945          "Unexpected Atomic Ordering.");
2946 
2947   bool Flush = false;
2948   llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
2949 
2950   switch (AK) {
2951   case Read:
2952     if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
2953         AO == AtomicOrdering::SequentiallyConsistent) {
2954       FlushAO = AtomicOrdering::Acquire;
2955       Flush = true;
2956     }
2957     break;
2958   case Write:
2959   case Update:
2960     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
2961         AO == AtomicOrdering::SequentiallyConsistent) {
2962       FlushAO = AtomicOrdering::Release;
2963       Flush = true;
2964     }
2965     break;
2966   case Capture:
2967     switch (AO) {
2968     case AtomicOrdering::Acquire:
2969       FlushAO = AtomicOrdering::Acquire;
2970       Flush = true;
2971       break;
2972     case AtomicOrdering::Release:
2973       FlushAO = AtomicOrdering::Release;
2974       Flush = true;
2975       break;
2976     case AtomicOrdering::AcquireRelease:
2977     case AtomicOrdering::SequentiallyConsistent:
2978       FlushAO = AtomicOrdering::AcquireRelease;
2979       Flush = true;
2980       break;
2981     default:
2982       // do nothing - leave silently.
2983       break;
2984     }
2985   }
2986 
2987   if (Flush) {
2988     // Currently Flush RT call still doesn't take memory_ordering, so for when
2989     // that happens, this tries to do the resolution of which atomic ordering
2990     // to use with but issue the flush call
2991     // TODO: pass `FlushAO` after memory ordering support is added
2992     (void)FlushAO;
2993     emitFlush(Loc);
2994   }
2995 
2996   // for AO == AtomicOrdering::Monotonic and  all other case combinations
2997   // do nothing
2998   return Flush;
2999 }
3000 
3001 OpenMPIRBuilder::InsertPointTy
createAtomicRead(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOrdering AO)3002 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
3003                                   AtomicOpValue &X, AtomicOpValue &V,
3004                                   AtomicOrdering AO) {
3005   if (!updateToLocation(Loc))
3006     return Loc.IP;
3007 
3008   Type *XTy = X.Var->getType();
3009   assert(XTy->isPointerTy() && "OMP Atomic expects a pointer to target memory");
3010   Type *XElemTy = XTy->getPointerElementType();
3011   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3012           XElemTy->isPointerTy()) &&
3013          "OMP atomic read expected a scalar type");
3014 
3015   Value *XRead = nullptr;
3016 
3017   if (XElemTy->isIntegerTy()) {
3018     LoadInst *XLD =
3019         Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
3020     XLD->setAtomic(AO);
3021     XRead = cast<Value>(XLD);
3022   } else {
3023     // We need to bitcast and perform atomic op as integer
3024     unsigned Addrspace = cast<PointerType>(XTy)->getAddressSpace();
3025     IntegerType *IntCastTy =
3026         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
3027     Value *XBCast = Builder.CreateBitCast(
3028         X.Var, IntCastTy->getPointerTo(Addrspace), "atomic.src.int.cast");
3029     LoadInst *XLoad =
3030         Builder.CreateLoad(IntCastTy, XBCast, X.IsVolatile, "omp.atomic.load");
3031     XLoad->setAtomic(AO);
3032     if (XElemTy->isFloatingPointTy()) {
3033       XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
3034     } else {
3035       XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
3036     }
3037   }
3038   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
3039   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
3040   return Builder.saveIP();
3041 }
3042 
3043 OpenMPIRBuilder::InsertPointTy
createAtomicWrite(const LocationDescription & Loc,AtomicOpValue & X,Value * Expr,AtomicOrdering AO)3044 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
3045                                    AtomicOpValue &X, Value *Expr,
3046                                    AtomicOrdering AO) {
3047   if (!updateToLocation(Loc))
3048     return Loc.IP;
3049 
3050   Type *XTy = X.Var->getType();
3051   assert(XTy->isPointerTy() && "OMP Atomic expects a pointer to target memory");
3052   Type *XElemTy = XTy->getPointerElementType();
3053   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3054           XElemTy->isPointerTy()) &&
3055          "OMP atomic write expected a scalar type");
3056 
3057   if (XElemTy->isIntegerTy()) {
3058     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
3059     XSt->setAtomic(AO);
3060   } else {
3061     // We need to bitcast and perform atomic op as integers
3062     unsigned Addrspace = cast<PointerType>(XTy)->getAddressSpace();
3063     IntegerType *IntCastTy =
3064         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
3065     Value *XBCast = Builder.CreateBitCast(
3066         X.Var, IntCastTy->getPointerTo(Addrspace), "atomic.dst.int.cast");
3067     Value *ExprCast =
3068         Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
3069     StoreInst *XSt = Builder.CreateStore(ExprCast, XBCast, X.IsVolatile);
3070     XSt->setAtomic(AO);
3071   }
3072 
3073   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
3074   return Builder.saveIP();
3075 }
3076 
createAtomicUpdate(const LocationDescription & Loc,Instruction * AllocIP,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool IsXLHSInRHSPart)3077 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
3078     const LocationDescription &Loc, Instruction *AllocIP, AtomicOpValue &X,
3079     Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
3080     AtomicUpdateCallbackTy &UpdateOp, bool IsXLHSInRHSPart) {
3081   if (!updateToLocation(Loc))
3082     return Loc.IP;
3083 
3084   LLVM_DEBUG({
3085     Type *XTy = X.Var->getType();
3086     assert(XTy->isPointerTy() &&
3087            "OMP Atomic expects a pointer to target memory");
3088     Type *XElemTy = XTy->getPointerElementType();
3089     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3090             XElemTy->isPointerTy()) &&
3091            "OMP atomic update expected a scalar type");
3092     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
3093            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
3094            "OpenMP atomic does not support LT or GT operations");
3095   });
3096 
3097   emitAtomicUpdate(AllocIP, X.Var, Expr, AO, RMWOp, UpdateOp, X.IsVolatile,
3098                    IsXLHSInRHSPart);
3099   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
3100   return Builder.saveIP();
3101 }
3102 
emitRMWOpAsInstruction(Value * Src1,Value * Src2,AtomicRMWInst::BinOp RMWOp)3103 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
3104                                                AtomicRMWInst::BinOp RMWOp) {
3105   switch (RMWOp) {
3106   case AtomicRMWInst::Add:
3107     return Builder.CreateAdd(Src1, Src2);
3108   case AtomicRMWInst::Sub:
3109     return Builder.CreateSub(Src1, Src2);
3110   case AtomicRMWInst::And:
3111     return Builder.CreateAnd(Src1, Src2);
3112   case AtomicRMWInst::Nand:
3113     return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
3114   case AtomicRMWInst::Or:
3115     return Builder.CreateOr(Src1, Src2);
3116   case AtomicRMWInst::Xor:
3117     return Builder.CreateXor(Src1, Src2);
3118   case AtomicRMWInst::Xchg:
3119   case AtomicRMWInst::FAdd:
3120   case AtomicRMWInst::FSub:
3121   case AtomicRMWInst::BAD_BINOP:
3122   case AtomicRMWInst::Max:
3123   case AtomicRMWInst::Min:
3124   case AtomicRMWInst::UMax:
3125   case AtomicRMWInst::UMin:
3126     llvm_unreachable("Unsupported atomic update operation");
3127   }
3128   llvm_unreachable("Unsupported atomic update operation");
3129 }
3130 
3131 std::pair<Value *, Value *>
emitAtomicUpdate(Instruction * AllocIP,Value * X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool VolatileX,bool IsXLHSInRHSPart)3132 OpenMPIRBuilder::emitAtomicUpdate(Instruction *AllocIP, Value *X, Value *Expr,
3133                                   AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
3134                                   AtomicUpdateCallbackTy &UpdateOp,
3135                                   bool VolatileX, bool IsXLHSInRHSPart) {
3136   Type *XElemTy = X->getType()->getPointerElementType();
3137 
3138   bool DoCmpExch =
3139       ((RMWOp == AtomicRMWInst::BAD_BINOP) || (RMWOp == AtomicRMWInst::FAdd)) ||
3140       (RMWOp == AtomicRMWInst::FSub) ||
3141       (RMWOp == AtomicRMWInst::Sub && !IsXLHSInRHSPart);
3142 
3143   std::pair<Value *, Value *> Res;
3144   if (XElemTy->isIntegerTy() && !DoCmpExch) {
3145     Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
3146     // not needed except in case of postfix captures. Generate anyway for
3147     // consistency with the else part. Will be removed with any DCE pass.
3148     Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
3149   } else {
3150     unsigned Addrspace = cast<PointerType>(X->getType())->getAddressSpace();
3151     IntegerType *IntCastTy =
3152         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
3153     Value *XBCast =
3154         Builder.CreateBitCast(X, IntCastTy->getPointerTo(Addrspace));
3155     LoadInst *OldVal =
3156         Builder.CreateLoad(IntCastTy, XBCast, X->getName() + ".atomic.load");
3157     OldVal->setAtomic(AO);
3158     // CurBB
3159     // |     /---\
3160 		// ContBB    |
3161     // |     \---/
3162     // ExitBB
3163     BasicBlock *CurBB = Builder.GetInsertBlock();
3164     Instruction *CurBBTI = CurBB->getTerminator();
3165     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
3166     BasicBlock *ExitBB =
3167         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
3168     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
3169                                                 X->getName() + ".atomic.cont");
3170     ContBB->getTerminator()->eraseFromParent();
3171     Builder.SetInsertPoint(ContBB);
3172     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
3173     PHI->addIncoming(OldVal, CurBB);
3174     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
3175     NewAtomicAddr->setName(X->getName() + "x.new.val");
3176     NewAtomicAddr->moveBefore(AllocIP);
3177     IntegerType *NewAtomicCastTy =
3178         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
3179     bool IsIntTy = XElemTy->isIntegerTy();
3180     Value *NewAtomicIntAddr =
3181         (IsIntTy)
3182             ? NewAtomicAddr
3183             : Builder.CreateBitCast(NewAtomicAddr,
3184                                     NewAtomicCastTy->getPointerTo(Addrspace));
3185     Value *OldExprVal = PHI;
3186     if (!IsIntTy) {
3187       if (XElemTy->isFloatingPointTy()) {
3188         OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
3189                                            X->getName() + ".atomic.fltCast");
3190       } else {
3191         OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
3192                                             X->getName() + ".atomic.ptrCast");
3193       }
3194     }
3195 
3196     Value *Upd = UpdateOp(OldExprVal, Builder);
3197     Builder.CreateStore(Upd, NewAtomicAddr);
3198     LoadInst *DesiredVal = Builder.CreateLoad(XElemTy, NewAtomicIntAddr);
3199     Value *XAddr =
3200         (IsIntTy)
3201             ? X
3202             : Builder.CreateBitCast(X, IntCastTy->getPointerTo(Addrspace));
3203     AtomicOrdering Failure =
3204         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
3205     AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
3206         XAddr, OldExprVal, DesiredVal, llvm::MaybeAlign(), AO, Failure);
3207     Result->setVolatile(VolatileX);
3208     Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
3209     Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
3210     PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
3211     Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
3212 
3213     Res.first = OldExprVal;
3214     Res.second = Upd;
3215 
3216     // set Insertion point in exit block
3217     if (UnreachableInst *ExitTI =
3218             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
3219       CurBBTI->eraseFromParent();
3220       Builder.SetInsertPoint(ExitBB);
3221     } else {
3222       Builder.SetInsertPoint(ExitTI);
3223     }
3224   }
3225 
3226   return Res;
3227 }
3228 
createAtomicCapture(const LocationDescription & Loc,Instruction * AllocIP,AtomicOpValue & X,AtomicOpValue & V,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool UpdateExpr,bool IsPostfixUpdate,bool IsXLHSInRHSPart)3229 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
3230     const LocationDescription &Loc, Instruction *AllocIP, AtomicOpValue &X,
3231     AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
3232     AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
3233     bool UpdateExpr, bool IsPostfixUpdate, bool IsXLHSInRHSPart) {
3234   if (!updateToLocation(Loc))
3235     return Loc.IP;
3236 
3237   LLVM_DEBUG({
3238     Type *XTy = X.Var->getType();
3239     assert(XTy->isPointerTy() &&
3240            "OMP Atomic expects a pointer to target memory");
3241     Type *XElemTy = XTy->getPointerElementType();
3242     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3243             XElemTy->isPointerTy()) &&
3244            "OMP atomic capture expected a scalar type");
3245     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
3246            "OpenMP atomic does not support LT or GT operations");
3247   });
3248 
3249   // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
3250   // 'x' is simply atomically rewritten with 'expr'.
3251   AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
3252   std::pair<Value *, Value *> Result =
3253       emitAtomicUpdate(AllocIP, X.Var, Expr, AO, AtomicOp, UpdateOp,
3254                        X.IsVolatile, IsXLHSInRHSPart);
3255 
3256   Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
3257   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
3258 
3259   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
3260   return Builder.saveIP();
3261 }
3262 
3263 GlobalVariable *
createOffloadMapnames(SmallVectorImpl<llvm::Constant * > & Names,std::string VarName)3264 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
3265                                        std::string VarName) {
3266   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
3267       llvm::ArrayType::get(
3268           llvm::Type::getInt8Ty(M.getContext())->getPointerTo(), Names.size()),
3269       Names);
3270   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
3271       M, MapNamesArrayInit->getType(),
3272       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
3273       VarName);
3274   return MapNamesArrayGlobal;
3275 }
3276 
3277 // Create all simple and struct types exposed by the runtime and remember
3278 // the llvm::PointerTypes of them for easy access later.
initializeTypes(Module & M)3279 void OpenMPIRBuilder::initializeTypes(Module &M) {
3280   LLVMContext &Ctx = M.getContext();
3281   StructType *T;
3282 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
3283 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize)                             \
3284   VarName##Ty = ArrayType::get(ElemTy, ArraySize);                             \
3285   VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
3286 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...)                  \
3287   VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg);            \
3288   VarName##Ptr = PointerType::getUnqual(VarName);
3289 #define OMP_STRUCT_TYPE(VarName, StructName, ...)                              \
3290   T = StructType::getTypeByName(Ctx, StructName);                              \
3291   if (!T)                                                                      \
3292     T = StructType::create(Ctx, {__VA_ARGS__}, StructName);                    \
3293   VarName = T;                                                                 \
3294   VarName##Ptr = PointerType::getUnqual(T);
3295 #include "llvm/Frontend/OpenMP/OMPKinds.def"
3296 }
3297 
collectBlocks(SmallPtrSetImpl<BasicBlock * > & BlockSet,SmallVectorImpl<BasicBlock * > & BlockVector)3298 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
3299     SmallPtrSetImpl<BasicBlock *> &BlockSet,
3300     SmallVectorImpl<BasicBlock *> &BlockVector) {
3301   SmallVector<BasicBlock *, 32> Worklist;
3302   BlockSet.insert(EntryBB);
3303   BlockSet.insert(ExitBB);
3304 
3305   Worklist.push_back(EntryBB);
3306   while (!Worklist.empty()) {
3307     BasicBlock *BB = Worklist.pop_back_val();
3308     BlockVector.push_back(BB);
3309     for (BasicBlock *SuccBB : successors(BB))
3310       if (BlockSet.insert(SuccBB).second)
3311         Worklist.push_back(SuccBB);
3312   }
3313 }
3314 
collectControlBlocks(SmallVectorImpl<BasicBlock * > & BBs)3315 void CanonicalLoopInfo::collectControlBlocks(
3316     SmallVectorImpl<BasicBlock *> &BBs) {
3317   // We only count those BBs as control block for which we do not need to
3318   // reverse the CFG, i.e. not the loop body which can contain arbitrary control
3319   // flow. For consistency, this also means we do not add the Body block, which
3320   // is just the entry to the body code.
3321   BBs.reserve(BBs.size() + 6);
3322   BBs.append({Preheader, Header, Cond, Latch, Exit, After});
3323 }
3324 
assertOK() const3325 void CanonicalLoopInfo::assertOK() const {
3326 #ifndef NDEBUG
3327   // No constraints if this object currently does not describe a loop.
3328   if (!isValid())
3329     return;
3330 
3331   // Verify standard control-flow we use for OpenMP loops.
3332   assert(Preheader);
3333   assert(isa<BranchInst>(Preheader->getTerminator()) &&
3334          "Preheader must terminate with unconditional branch");
3335   assert(Preheader->getSingleSuccessor() == Header &&
3336          "Preheader must jump to header");
3337 
3338   assert(Header);
3339   assert(isa<BranchInst>(Header->getTerminator()) &&
3340          "Header must terminate with unconditional branch");
3341   assert(Header->getSingleSuccessor() == Cond &&
3342          "Header must jump to exiting block");
3343 
3344   assert(Cond);
3345   assert(Cond->getSinglePredecessor() == Header &&
3346          "Exiting block only reachable from header");
3347 
3348   assert(isa<BranchInst>(Cond->getTerminator()) &&
3349          "Exiting block must terminate with conditional branch");
3350   assert(size(successors(Cond)) == 2 &&
3351          "Exiting block must have two successors");
3352   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
3353          "Exiting block's first successor jump to the body");
3354   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
3355          "Exiting block's second successor must exit the loop");
3356 
3357   assert(Body);
3358   assert(Body->getSinglePredecessor() == Cond &&
3359          "Body only reachable from exiting block");
3360   assert(!isa<PHINode>(Body->front()));
3361 
3362   assert(Latch);
3363   assert(isa<BranchInst>(Latch->getTerminator()) &&
3364          "Latch must terminate with unconditional branch");
3365   assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
3366   // TODO: To support simple redirecting of the end of the body code that has
3367   // multiple; introduce another auxiliary basic block like preheader and after.
3368   assert(Latch->getSinglePredecessor() != nullptr);
3369   assert(!isa<PHINode>(Latch->front()));
3370 
3371   assert(Exit);
3372   assert(isa<BranchInst>(Exit->getTerminator()) &&
3373          "Exit block must terminate with unconditional branch");
3374   assert(Exit->getSingleSuccessor() == After &&
3375          "Exit block must jump to after block");
3376 
3377   assert(After);
3378   assert(After->getSinglePredecessor() == Exit &&
3379          "After block only reachable from exit block");
3380   assert(After->empty() || !isa<PHINode>(After->front()));
3381 
3382   Instruction *IndVar = getIndVar();
3383   assert(IndVar && "Canonical induction variable not found?");
3384   assert(isa<IntegerType>(IndVar->getType()) &&
3385          "Induction variable must be an integer");
3386   assert(cast<PHINode>(IndVar)->getParent() == Header &&
3387          "Induction variable must be a PHI in the loop header");
3388   assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
3389   assert(
3390       cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
3391   assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
3392 
3393   auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
3394   assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
3395   assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
3396   assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
3397   assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
3398              ->isOne());
3399 
3400   Value *TripCount = getTripCount();
3401   assert(TripCount && "Loop trip count not found?");
3402   assert(IndVar->getType() == TripCount->getType() &&
3403          "Trip count and induction variable must have the same type");
3404 
3405   auto *CmpI = cast<CmpInst>(&Cond->front());
3406   assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
3407          "Exit condition must be a signed less-than comparison");
3408   assert(CmpI->getOperand(0) == IndVar &&
3409          "Exit condition must compare the induction variable");
3410   assert(CmpI->getOperand(1) == TripCount &&
3411          "Exit condition must compare with the trip count");
3412 #endif
3413 }
3414 
invalidate()3415 void CanonicalLoopInfo::invalidate() {
3416   Preheader = nullptr;
3417   Header = nullptr;
3418   Cond = nullptr;
3419   Body = nullptr;
3420   Latch = nullptr;
3421   Exit = nullptr;
3422   After = nullptr;
3423 }
3424