1 //===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Analysis/DataLayoutAnalysis.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
21 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/Math/IR/Math.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/BlockAndValueMapping.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Support/MathExtras.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "mlir/Transforms/Utils.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/IR/DerivedTypes.h"
40 #include "llvm/IR/IRBuilder.h"
41 #include "llvm/IR/Type.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include <functional>
45 
46 using namespace mlir;
47 
48 #define PASS_NAME "convert-std-to-llvm"
49 
50 /// Only retain those attributes that are not constructed by
51 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
52 /// attributes.
filterFuncAttributes(ArrayRef<NamedAttribute> attrs,bool filterArgAttrs,SmallVectorImpl<NamedAttribute> & result)53 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
54                                  bool filterArgAttrs,
55                                  SmallVectorImpl<NamedAttribute> &result) {
56   for (const auto &attr : attrs) {
57     if (attr.first == SymbolTable::getSymbolAttrName() ||
58         attr.first == function_like_impl::getTypeAttrName() ||
59         attr.first == "std.varargs" ||
60         (filterArgAttrs &&
61          attr.first == function_like_impl::getArgDictAttrName()))
62       continue;
63     result.push_back(attr);
64   }
65 }
66 
67 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
68 /// arguments instead of unpacked arguments. This function can be called from C
69 /// by passing a pointer to a C struct corresponding to a memref descriptor.
70 /// Similarly, returned memrefs are passed via pointers to a C struct that is
71 /// passed as additional argument.
72 /// Internally, the auxiliary function unpacks the descriptor into individual
73 /// components and forwards them to `newFuncOp` and forwards the results to
74 /// the extra arguments.
wrapForExternalCallers(OpBuilder & rewriter,Location loc,LLVMTypeConverter & typeConverter,FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)75 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
76                                    LLVMTypeConverter &typeConverter,
77                                    FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
78   auto type = funcOp.getType();
79   SmallVector<NamedAttribute, 4> attributes;
80   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
81                        attributes);
82   Type wrapperFuncType;
83   bool resultIsNowArg;
84   std::tie(wrapperFuncType, resultIsNowArg) =
85       typeConverter.convertFunctionTypeCWrapper(type);
86   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
87       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
88       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
89 
90   OpBuilder::InsertionGuard guard(rewriter);
91   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
92 
93   SmallVector<Value, 8> args;
94   size_t argOffset = resultIsNowArg ? 1 : 0;
95   for (auto &en : llvm::enumerate(type.getInputs())) {
96     Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
97     if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
98       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
99       MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
100       continue;
101     }
102     if (en.value().isa<UnrankedMemRefType>()) {
103       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
104       UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
105       continue;
106     }
107 
108     args.push_back(arg);
109   }
110 
111   auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
112 
113   if (resultIsNowArg) {
114     rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
115                                    wrapperFuncOp.getArgument(0));
116     rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
117   } else {
118     rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
119   }
120 }
121 
122 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
123 /// arguments instead of unpacked arguments. Creates a body for the (external)
124 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
125 /// individual arguments into this descriptor and passes a pointer to it into
126 /// the auxiliary function. If the result of the function cannot be directly
127 /// returned, we write it to a special first argument that provides a pointer
128 /// to a corresponding struct. This auxiliary external function is now
129 /// compatible with functions defined in C using pointers to C structs
130 /// corresponding to a memref descriptor.
wrapExternalFunction(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)131 static void wrapExternalFunction(OpBuilder &builder, Location loc,
132                                  LLVMTypeConverter &typeConverter,
133                                  FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
134   OpBuilder::InsertionGuard guard(builder);
135 
136   Type wrapperType;
137   bool resultIsNowArg;
138   std::tie(wrapperType, resultIsNowArg) =
139       typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
140   // This conversion can only fail if it could not convert one of the argument
141   // types. But since it has been applied to a non-wrapper function before, it
142   // should have failed earlier and not reach this point at all.
143   assert(wrapperType && "unexpected type conversion failure");
144 
145   SmallVector<NamedAttribute, 4> attributes;
146   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
147                        attributes);
148 
149   // Create the auxiliary function.
150   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
151       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
152       wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
153 
154   builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
155 
156   // Get a ValueRange containing arguments.
157   FunctionType type = funcOp.getType();
158   SmallVector<Value, 8> args;
159   args.reserve(type.getNumInputs());
160   ValueRange wrapperArgsRange(newFuncOp.getArguments());
161 
162   if (resultIsNowArg) {
163     // Allocate the struct on the stack and pass the pointer.
164     Type resultType =
165         wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
166     Value one = builder.create<LLVM::ConstantOp>(
167         loc, typeConverter.convertType(builder.getIndexType()),
168         builder.getIntegerAttr(builder.getIndexType(), 1));
169     Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
170     args.push_back(result);
171   }
172 
173   // Iterate over the inputs of the original function and pack values into
174   // memref descriptors if the original type is a memref.
175   for (auto &en : llvm::enumerate(type.getInputs())) {
176     Value arg;
177     int numToDrop = 1;
178     auto memRefType = en.value().dyn_cast<MemRefType>();
179     auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
180     if (memRefType || unrankedMemRefType) {
181       numToDrop = memRefType
182                       ? MemRefDescriptor::getNumUnpackedValues(memRefType)
183                       : UnrankedMemRefDescriptor::getNumUnpackedValues();
184       Value packed =
185           memRefType
186               ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
187                                        wrapperArgsRange.take_front(numToDrop))
188               : UnrankedMemRefDescriptor::pack(
189                     builder, loc, typeConverter, unrankedMemRefType,
190                     wrapperArgsRange.take_front(numToDrop));
191 
192       auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
193       Value one = builder.create<LLVM::ConstantOp>(
194           loc, typeConverter.convertType(builder.getIndexType()),
195           builder.getIntegerAttr(builder.getIndexType(), 1));
196       Value allocated =
197           builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
198       builder.create<LLVM::StoreOp>(loc, packed, allocated);
199       arg = allocated;
200     } else {
201       arg = wrapperArgsRange[0];
202     }
203 
204     args.push_back(arg);
205     wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
206   }
207   assert(wrapperArgsRange.empty() && "did not map some of the arguments");
208 
209   auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
210 
211   if (resultIsNowArg) {
212     Value result = builder.create<LLVM::LoadOp>(loc, args.front());
213     builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
214   } else {
215     builder.create<LLVM::ReturnOp>(loc, call.getResults());
216   }
217 }
218 
219 namespace {
220 
221 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
222 protected:
223   using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
224 
225   // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
226   // to this legalization pattern.
227   LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp__anonc33f93640111::FuncOpConversionBase228   convertFuncOpToLLVMFuncOp(FuncOp funcOp,
229                             ConversionPatternRewriter &rewriter) const {
230     // Convert the original function arguments. They are converted using the
231     // LLVMTypeConverter provided to this legalization pattern.
232     auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
233     TypeConverter::SignatureConversion result(funcOp.getNumArguments());
234     auto llvmType = getTypeConverter()->convertFunctionSignature(
235         funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
236     if (!llvmType)
237       return nullptr;
238 
239     // Propagate argument attributes to all converted arguments obtained after
240     // converting a given original argument.
241     SmallVector<NamedAttribute, 4> attributes;
242     filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
243                          attributes);
244     if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
245       SmallVector<Attribute, 4> newArgAttrs(
246           llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
247       for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
248         auto mapping = result.getInputMapping(i);
249         assert(mapping.hasValue() &&
250                "unexpected deletion of function argument");
251         for (size_t j = 0; j < mapping->size; ++j)
252           newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
253       }
254       attributes.push_back(
255           rewriter.getNamedAttr(function_like_impl::getArgDictAttrName(),
256                                 rewriter.getArrayAttr(newArgAttrs)));
257     }
258     for (auto pair : llvm::enumerate(attributes)) {
259       if (pair.value().first == "llvm.linkage") {
260         attributes.erase(attributes.begin() + pair.index());
261         break;
262       }
263     }
264 
265     // Create an LLVM function, use external linkage by default until MLIR
266     // functions have linkage.
267     LLVM::Linkage linkage = LLVM::Linkage::External;
268     if (funcOp->hasAttr("llvm.linkage")) {
269       auto attr =
270           funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
271       if (!attr) {
272         funcOp->emitError()
273             << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
274         return nullptr;
275       }
276       linkage = attr.getLinkage();
277     }
278     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
279         funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
280         /*dsoLocal*/ false, attributes);
281     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
282                                 newFuncOp.end());
283     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
284                                            &result)))
285       return nullptr;
286 
287     return newFuncOp;
288   }
289 };
290 
291 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
292 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
293 /// information.
294 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
295 struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion__anonc33f93640111::FuncOpConversion296   FuncOpConversion(LLVMTypeConverter &converter)
297       : FuncOpConversionBase(converter) {}
298 
299   LogicalResult
matchAndRewrite__anonc33f93640111::FuncOpConversion300   matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
301                   ConversionPatternRewriter &rewriter) const override {
302     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
303     if (!newFuncOp)
304       return failure();
305 
306     if (getTypeConverter()->getOptions().emitCWrappers ||
307         funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
308       if (newFuncOp.isExternal())
309         wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
310                              funcOp, newFuncOp);
311       else
312         wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
313                                funcOp, newFuncOp);
314     }
315 
316     rewriter.eraseOp(funcOp);
317     return success();
318   }
319 };
320 
321 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
322 /// to the MemRef element type. This will impact the calling convention and ABI.
323 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
324   using FuncOpConversionBase::FuncOpConversionBase;
325 
326   LogicalResult
matchAndRewrite__anonc33f93640111::BarePtrFuncOpConversion327   matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override {
329 
330     // TODO: bare ptr conversion could be handled by argument materialization
331     // and most of the code below would go away. But to do this, we would need a
332     // way to distinguish between FuncOp and other regions in the
333     // addArgumentMaterialization hook.
334 
335     // Store the type of memref-typed arguments before the conversion so that we
336     // can promote them to MemRef descriptor at the beginning of the function.
337     SmallVector<Type, 8> oldArgTypes =
338         llvm::to_vector<8>(funcOp.getType().getInputs());
339 
340     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
341     if (!newFuncOp)
342       return failure();
343     if (newFuncOp.getBody().empty()) {
344       rewriter.eraseOp(funcOp);
345       return success();
346     }
347 
348     // Promote bare pointers from memref arguments to memref descriptors at the
349     // beginning of the function so that all the memrefs in the function have a
350     // uniform representation.
351     Block *entryBlock = &newFuncOp.getBody().front();
352     auto blockArgs = entryBlock->getArguments();
353     assert(blockArgs.size() == oldArgTypes.size() &&
354            "The number of arguments and types doesn't match");
355 
356     OpBuilder::InsertionGuard guard(rewriter);
357     rewriter.setInsertionPointToStart(entryBlock);
358     for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
359       BlockArgument arg = std::get<0>(it);
360       Type argTy = std::get<1>(it);
361 
362       // Unranked memrefs are not supported in the bare pointer calling
363       // convention. We should have bailed out before in the presence of
364       // unranked memrefs.
365       assert(!argTy.isa<UnrankedMemRefType>() &&
366              "Unranked memref is not supported");
367       auto memrefTy = argTy.dyn_cast<MemRefType>();
368       if (!memrefTy)
369         continue;
370 
371       // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
372       // or unranked memref descriptor and replace placeholder with the last
373       // instruction of the memref descriptor.
374       // TODO: The placeholder is needed to avoid replacing barePtr uses in the
375       // MemRef descriptor instructions. We may want to have a utility in the
376       // rewriter to properly handle this use case.
377       Location loc = funcOp.getLoc();
378       auto placeholder = rewriter.create<LLVM::UndefOp>(
379           loc, getTypeConverter()->convertType(memrefTy));
380       rewriter.replaceUsesOfBlockArgument(arg, placeholder);
381 
382       Value desc = MemRefDescriptor::fromStaticShape(
383           rewriter, loc, *getTypeConverter(), memrefTy, arg);
384       rewriter.replaceOp(placeholder, {desc});
385     }
386 
387     rewriter.eraseOp(funcOp);
388     return success();
389   }
390 };
391 
392 // Straightforward lowerings.
393 using AbsFOpLowering = VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp>;
394 using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
395 using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>;
396 using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
397 using BitcastOpLowering =
398     VectorConvertToLLVMPattern<BitcastOp, LLVM::BitcastOp>;
399 using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
400 using CopySignOpLowering =
401     VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
402 using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
403 using FPExtOpLowering = VectorConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp>;
404 using FPToSIOpLowering = VectorConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp>;
405 using FPToUIOpLowering = VectorConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp>;
406 using FPTruncOpLowering =
407     VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
408 using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
409 using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
410 using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
411 using MulIOpLowering = VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp>;
412 using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
413 using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
414 using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
415 using SIToFPOpLowering = VectorConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp>;
416 using SelectOpLowering = VectorConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
417 using SignExtendIOpLowering =
418     VectorConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp>;
419 using ShiftLeftOpLowering =
420     VectorConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp>;
421 using SignedDivIOpLowering =
422     VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp>;
423 using SignedRemIOpLowering =
424     VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp>;
425 using SignedShiftRightOpLowering =
426     VectorConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
427 using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
428 using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
429 using TruncateIOpLowering =
430     VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
431 using UIToFPOpLowering = VectorConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp>;
432 using UnsignedDivIOpLowering =
433     VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
434 using UnsignedRemIOpLowering =
435     VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp>;
436 using UnsignedShiftRightOpLowering =
437     VectorConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
438 using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
439 using ZeroExtendIOpLowering =
440     VectorConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp>;
441 
442 /// Lower `std.assert`. The default lowering calls the `abort` function if the
443 /// assertion is violated and has no effect otherwise. The failure message is
444 /// ignored by the default lowering but should be propagated by any custom
445 /// lowering.
446 struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
447   using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
448 
449   LogicalResult
matchAndRewrite__anonc33f93640111::AssertOpLowering450   matchAndRewrite(AssertOp op, OpAdaptor adaptor,
451                   ConversionPatternRewriter &rewriter) const override {
452     auto loc = op.getLoc();
453 
454     // Insert the `abort` declaration if necessary.
455     auto module = op->getParentOfType<ModuleOp>();
456     auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
457     if (!abortFunc) {
458       OpBuilder::InsertionGuard guard(rewriter);
459       rewriter.setInsertionPointToStart(module.getBody());
460       auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
461       abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
462                                                     "abort", abortFuncTy);
463     }
464 
465     // Split block at `assert` operation.
466     Block *opBlock = rewriter.getInsertionBlock();
467     auto opPosition = rewriter.getInsertionPoint();
468     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
469 
470     // Generate IR to call `abort`.
471     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
472     rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
473     rewriter.create<LLVM::UnreachableOp>(loc);
474 
475     // Generate assertion test.
476     rewriter.setInsertionPointToEnd(opBlock);
477     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
478         op, adaptor.arg(), continuationBlock, failureBlock);
479 
480     return success();
481   }
482 };
483 
484 struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
485   using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
486 
487   LogicalResult
matchAndRewrite__anonc33f93640111::ConstantOpLowering488   matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
489                   ConversionPatternRewriter &rewriter) const override {
490     // If constant refers to a function, convert it to "addressof".
491     if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
492       auto type = typeConverter->convertType(op.getResult().getType());
493       if (!type || !LLVM::isCompatibleType(type))
494         return rewriter.notifyMatchFailure(op, "failed to convert result type");
495 
496       auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
497                                                       symbolRef.getValue());
498       for (const NamedAttribute &attr : op->getAttrs()) {
499         if (attr.first.strref() == "value")
500           continue;
501         newOp->setAttr(attr.first, attr.second);
502       }
503       rewriter.replaceOp(op, newOp->getResults());
504       return success();
505     }
506 
507     // Calling into other scopes (non-flat reference) is not supported in LLVM.
508     if (op.getValue().isa<SymbolRefAttr>())
509       return rewriter.notifyMatchFailure(
510           op, "referring to a symbol outside of the current module");
511 
512     return LLVM::detail::oneToOneRewrite(
513         op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
514         *getTypeConverter(), rewriter);
515   }
516 };
517 
518 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
519 // passes the pointer to the MemRef across function boundaries.
520 template <typename CallOpType>
521 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
522   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
523   using Super = CallOpInterfaceLowering<CallOpType>;
524   using Base = ConvertOpToLLVMPattern<CallOpType>;
525 
526   LogicalResult
matchAndRewrite__anonc33f93640111::CallOpInterfaceLowering527   matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
528                   ConversionPatternRewriter &rewriter) const override {
529     // Pack the result types into a struct.
530     Type packedResult = nullptr;
531     unsigned numResults = callOp.getNumResults();
532     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
533 
534     if (numResults != 0) {
535       if (!(packedResult =
536                 this->getTypeConverter()->packFunctionResults(resultTypes)))
537         return failure();
538     }
539 
540     auto promoted = this->getTypeConverter()->promoteOperands(
541         callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
542         adaptor.getOperands(), rewriter);
543     auto newOp = rewriter.create<LLVM::CallOp>(
544         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
545         promoted, callOp->getAttrs());
546 
547     SmallVector<Value, 4> results;
548     if (numResults < 2) {
549       // If < 2 results, packing did not do anything and we can just return.
550       results.append(newOp.result_begin(), newOp.result_end());
551     } else {
552       // Otherwise, it had been converted to an operation producing a structure.
553       // Extract individual results from the structure and return them as list.
554       results.reserve(numResults);
555       for (unsigned i = 0; i < numResults; ++i) {
556         auto type =
557             this->typeConverter->convertType(callOp.getResult(i).getType());
558         results.push_back(rewriter.create<LLVM::ExtractValueOp>(
559             callOp.getLoc(), type, newOp->getResult(0),
560             rewriter.getI64ArrayAttr(i)));
561       }
562     }
563 
564     if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
565       // For the bare-ptr calling convention, promote memref results to
566       // descriptors.
567       assert(results.size() == resultTypes.size() &&
568              "The number of arguments and types doesn't match");
569       this->getTypeConverter()->promoteBarePtrsToDescriptors(
570           rewriter, callOp.getLoc(), resultTypes, results);
571     } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
572                                                     resultTypes, results,
573                                                     /*toDynamic=*/false))) {
574       return failure();
575     }
576 
577     rewriter.replaceOp(callOp, results);
578     return success();
579   }
580 };
581 
582 struct CallOpLowering : public CallOpInterfaceLowering<CallOp> {
583   using Super::Super;
584 };
585 
586 struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
587   using Super::Super;
588 };
589 
590 struct UnrealizedConversionCastOpLowering
591     : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
592   using ConvertOpToLLVMPattern<
593       UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
594 
595   LogicalResult
matchAndRewrite__anonc33f93640111::UnrealizedConversionCastOpLowering596   matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
597                   ConversionPatternRewriter &rewriter) const override {
598     SmallVector<Type> convertedTypes;
599     if (succeeded(typeConverter->convertTypes(op.outputs().getTypes(),
600                                               convertedTypes)) &&
601         convertedTypes == adaptor.inputs().getTypes()) {
602       rewriter.replaceOp(op, adaptor.inputs());
603       return success();
604     }
605 
606     convertedTypes.clear();
607     if (succeeded(typeConverter->convertTypes(adaptor.inputs().getTypes(),
608                                               convertedTypes)) &&
609         convertedTypes == op.outputs().getType()) {
610       rewriter.replaceOp(op, adaptor.inputs());
611       return success();
612     }
613     return failure();
614   }
615 };
616 
617 struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
618   using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
619 
620   LogicalResult
matchAndRewrite__anonc33f93640111::RankOpLowering621   matchAndRewrite(RankOp op, OpAdaptor adaptor,
622                   ConversionPatternRewriter &rewriter) const override {
623     Location loc = op.getLoc();
624     Type operandType = op.memrefOrTensor().getType();
625     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
626       UnrankedMemRefDescriptor desc(adaptor.memrefOrTensor());
627       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
628       return success();
629     }
630     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
631       rewriter.replaceOp(
632           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
633       return success();
634     }
635     return failure();
636   }
637 };
638 
639 // Common base for load and store operations on MemRefs.  Restricts the match
640 // to supported MemRef types. Provides functionality to emit code accessing a
641 // specific element of the underlying data buffer.
642 template <typename Derived>
643 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
644   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
645   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
646   using Base = LoadStoreOpLowering<Derived>;
647 
match__anonc33f93640111::LoadStoreOpLowering648   LogicalResult match(Derived op) const override {
649     MemRefType type = op.getMemRefType();
650     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
651   }
652 };
653 
654 // The lowering of index_cast becomes an integer conversion since index becomes
655 // an integer.  If the bit width of the source and target integer types is the
656 // same, just erase the cast.  If the target type is wider, sign-extend the
657 // value, otherwise truncate it.
658 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
659   using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
660 
661   LogicalResult
matchAndRewrite__anonc33f93640111::IndexCastOpLowering662   matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor,
663                   ConversionPatternRewriter &rewriter) const override {
664     auto targetType =
665         typeConverter->convertType(indexCastOp.getResult().getType());
666     auto targetElementType =
667         typeConverter
668             ->convertType(getElementTypeOrSelf(indexCastOp.getResult()))
669             .cast<IntegerType>();
670     auto sourceElementType =
671         getElementTypeOrSelf(adaptor.in()).cast<IntegerType>();
672     unsigned targetBits = targetElementType.getWidth();
673     unsigned sourceBits = sourceElementType.getWidth();
674 
675     if (targetBits == sourceBits)
676       rewriter.replaceOp(indexCastOp, adaptor.in());
677     else if (targetBits < sourceBits)
678       rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
679                                                  adaptor.in());
680     else
681       rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
682                                                 adaptor.in());
683     return success();
684   }
685 };
686 
687 // Convert std.cmp predicate into the LLVM dialect CmpPredicate.  The two
688 // enums share the numerical values so just cast.
689 template <typename LLVMPredType, typename StdPredType>
convertCmpPredicate(StdPredType pred)690 static LLVMPredType convertCmpPredicate(StdPredType pred) {
691   return static_cast<LLVMPredType>(pred);
692 }
693 
694 struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
695   using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
696 
697   LogicalResult
matchAndRewrite__anonc33f93640111::CmpIOpLowering698   matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor,
699                   ConversionPatternRewriter &rewriter) const override {
700     auto operandType = adaptor.lhs().getType();
701     auto resultType = cmpiOp.getResult().getType();
702 
703     // Handle the scalar and 1D vector cases.
704     if (!operandType.isa<LLVM::LLVMArrayType>()) {
705       rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
706           cmpiOp, typeConverter->convertType(resultType),
707           convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
708           adaptor.lhs(), adaptor.rhs());
709       return success();
710     }
711 
712     auto vectorType = resultType.dyn_cast<VectorType>();
713     if (!vectorType)
714       return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type");
715 
716     return LLVM::detail::handleMultidimensionalVectors(
717         cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(),
718         [&](Type llvm1DVectorTy, ValueRange operands) {
719           CmpIOpAdaptor adaptor(operands);
720           return rewriter.create<LLVM::ICmpOp>(
721               cmpiOp.getLoc(), llvm1DVectorTy,
722               convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
723               adaptor.lhs(), adaptor.rhs());
724         },
725         rewriter);
726 
727     return success();
728   }
729 };
730 
731 struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
732   using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
733 
734   LogicalResult
matchAndRewrite__anonc33f93640111::CmpFOpLowering735   matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor,
736                   ConversionPatternRewriter &rewriter) const override {
737     auto operandType = adaptor.lhs().getType();
738     auto resultType = cmpfOp.getResult().getType();
739 
740     // Handle the scalar and 1D vector cases.
741     if (!operandType.isa<LLVM::LLVMArrayType>()) {
742       rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
743           cmpfOp, typeConverter->convertType(resultType),
744           convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
745           adaptor.lhs(), adaptor.rhs());
746       return success();
747     }
748 
749     auto vectorType = resultType.dyn_cast<VectorType>();
750     if (!vectorType)
751       return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type");
752 
753     return LLVM::detail::handleMultidimensionalVectors(
754         cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(),
755         [&](Type llvm1DVectorTy, ValueRange operands) {
756           CmpFOpAdaptor adaptor(operands);
757           return rewriter.create<LLVM::FCmpOp>(
758               cmpfOp.getLoc(), llvm1DVectorTy,
759               convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
760               adaptor.lhs(), adaptor.rhs());
761         },
762         rewriter);
763   }
764 };
765 
766 // Base class for LLVM IR lowering terminator operations with successors.
767 template <typename SourceOp, typename TargetOp>
768 struct OneToOneLLVMTerminatorLowering
769     : public ConvertOpToLLVMPattern<SourceOp> {
770   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
771   using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
772 
773   LogicalResult
matchAndRewrite__anonc33f93640111::OneToOneLLVMTerminatorLowering774   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
775                   ConversionPatternRewriter &rewriter) const override {
776     rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
777                                           op->getSuccessors(), op->getAttrs());
778     return success();
779   }
780 };
781 
782 // Special lowering pattern for `ReturnOps`.  Unlike all other operations,
783 // `ReturnOp` interacts with the function signature and must have as many
784 // operands as the function has return values.  Because in LLVM IR, functions
785 // can only return 0 or 1 value, we pack multiple values into a structure type.
786 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
787 // necessary before returning it
788 struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
789   using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
790 
791   LogicalResult
matchAndRewrite__anonc33f93640111::ReturnOpLowering792   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
793                   ConversionPatternRewriter &rewriter) const override {
794     Location loc = op.getLoc();
795     unsigned numArguments = op.getNumOperands();
796     SmallVector<Value, 4> updatedOperands;
797 
798     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
799       // For the bare-ptr calling convention, extract the aligned pointer to
800       // be returned from the memref descriptor.
801       for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
802         Type oldTy = std::get<0>(it).getType();
803         Value newOperand = std::get<1>(it);
804         if (oldTy.isa<MemRefType>()) {
805           MemRefDescriptor memrefDesc(newOperand);
806           newOperand = memrefDesc.alignedPtr(rewriter, loc);
807         } else if (oldTy.isa<UnrankedMemRefType>()) {
808           // Unranked memref is not supported in the bare pointer calling
809           // convention.
810           return failure();
811         }
812         updatedOperands.push_back(newOperand);
813       }
814     } else {
815       updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
816       (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
817                                     updatedOperands,
818                                     /*toDynamic=*/true);
819     }
820 
821     // If ReturnOp has 0 or 1 operand, create it and return immediately.
822     if (numArguments == 0) {
823       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
824                                                   op->getAttrs());
825       return success();
826     }
827     if (numArguments == 1) {
828       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
829           op, TypeRange(), updatedOperands, op->getAttrs());
830       return success();
831     }
832 
833     // Otherwise, we need to pack the arguments into an LLVM struct type before
834     // returning.
835     auto packedType = getTypeConverter()->packFunctionResults(
836         llvm::to_vector<4>(op.getOperandTypes()));
837 
838     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
839     for (unsigned i = 0; i < numArguments; ++i) {
840       packed = rewriter.create<LLVM::InsertValueOp>(
841           loc, packedType, packed, updatedOperands[i],
842           rewriter.getI64ArrayAttr(i));
843     }
844     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
845                                                 op->getAttrs());
846     return success();
847   }
848 };
849 
850 // FIXME: this should be tablegen'ed as well.
851 struct BranchOpLowering
852     : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
853   using Super::Super;
854 };
855 struct CondBranchOpLowering
856     : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
857   using Super::Super;
858 };
859 struct SwitchOpLowering
860     : public OneToOneLLVMTerminatorLowering<SwitchOp, LLVM::SwitchOp> {
861   using Super::Super;
862 };
863 
864 // The Splat operation is lowered to an insertelement + a shufflevector
865 // operation. Splat to only 1-d vector result types are lowered.
866 struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
867   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
868 
869   LogicalResult
matchAndRewrite__anonc33f93640111::SplatOpLowering870   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
871                   ConversionPatternRewriter &rewriter) const override {
872     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
873     if (!resultType || resultType.getRank() != 1)
874       return failure();
875 
876     // First insert it into an undef vector so we can shuffle it.
877     auto vectorType = typeConverter->convertType(splatOp.getType());
878     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
879     auto zero = rewriter.create<LLVM::ConstantOp>(
880         splatOp.getLoc(),
881         typeConverter->convertType(rewriter.getIntegerType(32)),
882         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
883 
884     auto v = rewriter.create<LLVM::InsertElementOp>(
885         splatOp.getLoc(), vectorType, undef, adaptor.input(), zero);
886 
887     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
888     SmallVector<int32_t, 4> zeroValues(width, 0);
889 
890     // Shuffle the value across the desired number of elements.
891     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
892     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
893                                                        zeroAttrs);
894     return success();
895   }
896 };
897 
898 // The Splat operation is lowered to an insertelement + a shufflevector
899 // operation. Splat to only 2+-d vector result types are lowered by the
900 // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
901 struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
902   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
903 
904   LogicalResult
matchAndRewrite__anonc33f93640111::SplatNdOpLowering905   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
906                   ConversionPatternRewriter &rewriter) const override {
907     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
908     if (!resultType || resultType.getRank() == 1)
909       return failure();
910 
911     // First insert it into an undef vector so we can shuffle it.
912     auto loc = splatOp.getLoc();
913     auto vectorTypeInfo =
914         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
915     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
916     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
917     if (!llvmNDVectorTy || !llvm1DVectorTy)
918       return failure();
919 
920     // Construct returned value.
921     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
922 
923     // Construct a 1-D vector with the splatted value that we insert in all the
924     // places within the returned descriptor.
925     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
926     auto zero = rewriter.create<LLVM::ConstantOp>(
927         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
928         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
929     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
930                                                      adaptor.input(), zero);
931 
932     // Shuffle the value across the desired number of elements.
933     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
934     SmallVector<int32_t, 4> zeroValues(width, 0);
935     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
936     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
937 
938     // Iterate of linear index, convert to coords space and insert splatted 1-D
939     // vector in each position.
940     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
941       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
942                                                   position);
943     });
944     rewriter.replaceOp(splatOp, desc);
945     return success();
946   }
947 };
948 
949 } // namespace
950 
951 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
952 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
matchSimpleAtomicOp(AtomicRMWOp atomicOp)953 static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
954   switch (atomicOp.kind()) {
955   case AtomicRMWKind::addf:
956     return LLVM::AtomicBinOp::fadd;
957   case AtomicRMWKind::addi:
958     return LLVM::AtomicBinOp::add;
959   case AtomicRMWKind::assign:
960     return LLVM::AtomicBinOp::xchg;
961   case AtomicRMWKind::maxs:
962     return LLVM::AtomicBinOp::max;
963   case AtomicRMWKind::maxu:
964     return LLVM::AtomicBinOp::umax;
965   case AtomicRMWKind::mins:
966     return LLVM::AtomicBinOp::min;
967   case AtomicRMWKind::minu:
968     return LLVM::AtomicBinOp::umin;
969   default:
970     return llvm::None;
971   }
972   llvm_unreachable("Invalid AtomicRMWKind");
973 }
974 
975 namespace {
976 
977 struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
978   using Base::Base;
979 
980   LogicalResult
matchAndRewrite__anonc33f93640511::AtomicRMWOpLowering981   matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor,
982                   ConversionPatternRewriter &rewriter) const override {
983     if (failed(match(atomicOp)))
984       return failure();
985     auto maybeKind = matchSimpleAtomicOp(atomicOp);
986     if (!maybeKind)
987       return failure();
988     auto resultType = adaptor.value().getType();
989     auto memRefType = atomicOp.getMemRefType();
990     auto dataPtr =
991         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
992                              adaptor.indices(), rewriter);
993     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
994         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
995         LLVM::AtomicOrdering::acq_rel);
996     return success();
997   }
998 };
999 
1000 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
1001 /// retried until it succeeds in atomically storing a new value into memory.
1002 ///
1003 ///      +---------------------------------+
1004 ///      |   <code before the AtomicRMWOp> |
1005 ///      |   <compute initial %loaded>     |
1006 ///      |   br loop(%loaded)              |
1007 ///      +---------------------------------+
1008 ///             |
1009 ///  -------|   |
1010 ///  |      v   v
1011 ///  |   +--------------------------------+
1012 ///  |   | loop(%loaded):                 |
1013 ///  |   |   <body contents>              |
1014 ///  |   |   %pair = cmpxchg              |
1015 ///  |   |   %ok = %pair[0]               |
1016 ///  |   |   %new = %pair[1]              |
1017 ///  |   |   cond_br %ok, end, loop(%new) |
1018 ///  |   +--------------------------------+
1019 ///  |          |        |
1020 ///  |-----------        |
1021 ///                      v
1022 ///      +--------------------------------+
1023 ///      | end:                           |
1024 ///      |   <code after the AtomicRMWOp> |
1025 ///      +--------------------------------+
1026 ///
1027 struct GenericAtomicRMWOpLowering
1028     : public LoadStoreOpLowering<GenericAtomicRMWOp> {
1029   using Base::Base;
1030 
1031   LogicalResult
matchAndRewrite__anonc33f93640511::GenericAtomicRMWOpLowering1032   matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
1033                   ConversionPatternRewriter &rewriter) const override {
1034 
1035     auto loc = atomicOp.getLoc();
1036     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
1037 
1038     // Split the block into initial, loop, and ending parts.
1039     auto *initBlock = rewriter.getInsertionBlock();
1040     auto *loopBlock =
1041         rewriter.createBlock(initBlock->getParent(),
1042                              std::next(Region::iterator(initBlock)), valueType);
1043     auto *endBlock = rewriter.createBlock(
1044         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
1045 
1046     // Operations range to be moved to `endBlock`.
1047     auto opsToMoveStart = atomicOp->getIterator();
1048     auto opsToMoveEnd = initBlock->back().getIterator();
1049 
1050     // Compute the loaded value and branch to the loop block.
1051     rewriter.setInsertionPointToEnd(initBlock);
1052     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
1053     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
1054                                         adaptor.indices(), rewriter);
1055     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
1056     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
1057 
1058     // Prepare the body of the loop block.
1059     rewriter.setInsertionPointToStart(loopBlock);
1060 
1061     // Clone the GenericAtomicRMWOp region and extract the result.
1062     auto loopArgument = loopBlock->getArgument(0);
1063     BlockAndValueMapping mapping;
1064     mapping.map(atomicOp.getCurrentValue(), loopArgument);
1065     Block &entryBlock = atomicOp.body().front();
1066     for (auto &nestedOp : entryBlock.without_terminator()) {
1067       Operation *clone = rewriter.clone(nestedOp, mapping);
1068       mapping.map(nestedOp.getResults(), clone->getResults());
1069     }
1070     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
1071 
1072     // Prepare the epilog of the loop block.
1073     // Append the cmpxchg op to the end of the loop block.
1074     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
1075     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
1076     auto boolType = IntegerType::get(rewriter.getContext(), 1);
1077     auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
1078                                                      {valueType, boolType});
1079     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
1080         loc, pairType, dataPtr, loopArgument, result, successOrdering,
1081         failureOrdering);
1082     // Extract the %new_loaded and %ok values from the pair.
1083     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
1084         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
1085     Value ok = rewriter.create<LLVM::ExtractValueOp>(
1086         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
1087 
1088     // Conditionally branch to the end or back to the loop depending on %ok.
1089     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
1090                                     loopBlock, newLoaded);
1091 
1092     rewriter.setInsertionPointToEnd(endBlock);
1093     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
1094                  std::next(opsToMoveEnd), rewriter);
1095 
1096     // The 'result' of the atomic_rmw op is the newly loaded value.
1097     rewriter.replaceOp(atomicOp, {newLoaded});
1098 
1099     return success();
1100   }
1101 
1102 private:
1103   // Clones a segment of ops [start, end) and erases the original.
moveOpsRange__anonc33f93640511::GenericAtomicRMWOpLowering1104   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
1105                     Block::iterator start, Block::iterator end,
1106                     ConversionPatternRewriter &rewriter) const {
1107     BlockAndValueMapping mapping;
1108     mapping.map(oldResult, newResult);
1109     SmallVector<Operation *, 2> opsToErase;
1110     for (auto it = start; it != end; ++it) {
1111       rewriter.clone(*it, mapping);
1112       opsToErase.push_back(&*it);
1113     }
1114     for (auto *it : opsToErase)
1115       rewriter.eraseOp(it);
1116   }
1117 };
1118 
1119 } // namespace
1120 
populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter & converter,RewritePatternSet & patterns)1121 void mlir::populateStdToLLVMFuncOpConversionPattern(
1122     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1123   if (converter.getOptions().useBarePtrCallConv)
1124     patterns.add<BarePtrFuncOpConversion>(converter);
1125   else
1126     patterns.add<FuncOpConversion>(converter);
1127 }
1128 
populateStdToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)1129 void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
1130                                                RewritePatternSet &patterns) {
1131   populateStdToLLVMFuncOpConversionPattern(converter, patterns);
1132   // clang-format off
1133   patterns.add<
1134       AbsFOpLowering,
1135       AddFOpLowering,
1136       AddIOpLowering,
1137       AndOpLowering,
1138       AssertOpLowering,
1139       AtomicRMWOpLowering,
1140       BitcastOpLowering,
1141       BranchOpLowering,
1142       CallIndirectOpLowering,
1143       CallOpLowering,
1144       CeilFOpLowering,
1145       CmpFOpLowering,
1146       CmpIOpLowering,
1147       CondBranchOpLowering,
1148       CopySignOpLowering,
1149       ConstantOpLowering,
1150       DivFOpLowering,
1151       FloorFOpLowering,
1152       FmaFOpLowering,
1153       GenericAtomicRMWOpLowering,
1154       FPExtOpLowering,
1155       FPToSIOpLowering,
1156       FPToUIOpLowering,
1157       FPTruncOpLowering,
1158       IndexCastOpLowering,
1159       MulFOpLowering,
1160       MulIOpLowering,
1161       NegFOpLowering,
1162       OrOpLowering,
1163       RemFOpLowering,
1164       RankOpLowering,
1165       ReturnOpLowering,
1166       SIToFPOpLowering,
1167       SelectOpLowering,
1168       ShiftLeftOpLowering,
1169       SignExtendIOpLowering,
1170       SignedDivIOpLowering,
1171       SignedRemIOpLowering,
1172       SignedShiftRightOpLowering,
1173       SplatOpLowering,
1174       SplatNdOpLowering,
1175       SubFOpLowering,
1176       SubIOpLowering,
1177       SwitchOpLowering,
1178       TruncateIOpLowering,
1179       UIToFPOpLowering,
1180       UnsignedDivIOpLowering,
1181       UnsignedRemIOpLowering,
1182       UnsignedShiftRightOpLowering,
1183       XOrOpLowering,
1184       ZeroExtendIOpLowering>(converter);
1185   // clang-format on
1186 }
1187 
1188 namespace {
1189 /// A pass converting MLIR operations into the LLVM IR dialect.
1190 struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
1191   LLVMLoweringPass() = default;
LLVMLoweringPass__anonc33f93640611::LLVMLoweringPass1192   LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers,
1193                    unsigned indexBitwidth, bool useAlignedAlloc,
1194                    const llvm::DataLayout &dataLayout) {
1195     this->useBarePtrCallConv = useBarePtrCallConv;
1196     this->emitCWrappers = emitCWrappers;
1197     this->indexBitwidth = indexBitwidth;
1198     this->dataLayout = dataLayout.getStringRepresentation();
1199   }
1200 
1201   /// Run the dialect converter on the module.
runOnOperation__anonc33f93640611::LLVMLoweringPass1202   void runOnOperation() override {
1203     if (useBarePtrCallConv && emitCWrappers) {
1204       getOperation().emitError()
1205           << "incompatible conversion options: bare-pointer calling convention "
1206              "and C wrapper emission";
1207       signalPassFailure();
1208       return;
1209     }
1210     if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
1211             this->dataLayout, [this](const Twine &message) {
1212               getOperation().emitError() << message.str();
1213             }))) {
1214       signalPassFailure();
1215       return;
1216     }
1217 
1218     ModuleOp m = getOperation();
1219     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1220 
1221     LowerToLLVMOptions options(&getContext(),
1222                                dataLayoutAnalysis.getAtOrAbove(m));
1223     options.useBarePtrCallConv = useBarePtrCallConv;
1224     options.emitCWrappers = emitCWrappers;
1225     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1226       options.overrideIndexBitwidth(indexBitwidth);
1227     options.dataLayout = llvm::DataLayout(this->dataLayout);
1228 
1229     LLVMTypeConverter typeConverter(&getContext(), options,
1230                                     &dataLayoutAnalysis);
1231 
1232     RewritePatternSet patterns(&getContext());
1233     populateStdToLLVMConversionPatterns(typeConverter, patterns);
1234 
1235     LLVMConversionTarget target(getContext());
1236     if (failed(applyPartialConversion(m, target, std::move(patterns))))
1237       signalPassFailure();
1238 
1239     m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
1240                StringAttr::get(m.getContext(), this->dataLayout));
1241   }
1242 };
1243 } // end namespace
1244 
createLowerToLLVMPass()1245 std::unique_ptr<OperationPass<ModuleOp>> mlir::createLowerToLLVMPass() {
1246   return std::make_unique<LLVMLoweringPass>();
1247 }
1248 
1249 std::unique_ptr<OperationPass<ModuleOp>>
createLowerToLLVMPass(const LowerToLLVMOptions & options)1250 mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
1251   auto allocLowering = options.allocLowering;
1252   // There is no way to provide additional patterns for pass, so
1253   // AllocLowering::None will always fail.
1254   assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
1255          "LLVMLoweringPass doesn't support AllocLowering::None");
1256   bool useAlignedAlloc =
1257       (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
1258   return std::make_unique<LLVMLoweringPass>(
1259       options.useBarePtrCallConv, options.emitCWrappers,
1260       options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
1261 }
1262