1 //===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "../PassDetail.h"
15 #include "mlir/Analysis/DataLayoutAnalysis.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
21 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/Math/IR/Math.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/BlockAndValueMapping.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Support/MathExtras.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "mlir/Transforms/Utils.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/IR/DerivedTypes.h"
40 #include "llvm/IR/IRBuilder.h"
41 #include "llvm/IR/Type.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include <functional>
45
46 using namespace mlir;
47
48 #define PASS_NAME "convert-std-to-llvm"
49
50 /// Only retain those attributes that are not constructed by
51 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
52 /// attributes.
filterFuncAttributes(ArrayRef<NamedAttribute> attrs,bool filterArgAttrs,SmallVectorImpl<NamedAttribute> & result)53 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
54 bool filterArgAttrs,
55 SmallVectorImpl<NamedAttribute> &result) {
56 for (const auto &attr : attrs) {
57 if (attr.first == SymbolTable::getSymbolAttrName() ||
58 attr.first == function_like_impl::getTypeAttrName() ||
59 attr.first == "std.varargs" ||
60 (filterArgAttrs &&
61 attr.first == function_like_impl::getArgDictAttrName()))
62 continue;
63 result.push_back(attr);
64 }
65 }
66
67 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
68 /// arguments instead of unpacked arguments. This function can be called from C
69 /// by passing a pointer to a C struct corresponding to a memref descriptor.
70 /// Similarly, returned memrefs are passed via pointers to a C struct that is
71 /// passed as additional argument.
72 /// Internally, the auxiliary function unpacks the descriptor into individual
73 /// components and forwards them to `newFuncOp` and forwards the results to
74 /// the extra arguments.
wrapForExternalCallers(OpBuilder & rewriter,Location loc,LLVMTypeConverter & typeConverter,FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)75 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
76 LLVMTypeConverter &typeConverter,
77 FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
78 auto type = funcOp.getType();
79 SmallVector<NamedAttribute, 4> attributes;
80 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
81 attributes);
82 Type wrapperFuncType;
83 bool resultIsNowArg;
84 std::tie(wrapperFuncType, resultIsNowArg) =
85 typeConverter.convertFunctionTypeCWrapper(type);
86 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
87 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
88 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
89
90 OpBuilder::InsertionGuard guard(rewriter);
91 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
92
93 SmallVector<Value, 8> args;
94 size_t argOffset = resultIsNowArg ? 1 : 0;
95 for (auto &en : llvm::enumerate(type.getInputs())) {
96 Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
97 if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
98 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
99 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
100 continue;
101 }
102 if (en.value().isa<UnrankedMemRefType>()) {
103 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
104 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
105 continue;
106 }
107
108 args.push_back(arg);
109 }
110
111 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
112
113 if (resultIsNowArg) {
114 rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
115 wrapperFuncOp.getArgument(0));
116 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
117 } else {
118 rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
119 }
120 }
121
122 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
123 /// arguments instead of unpacked arguments. Creates a body for the (external)
124 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
125 /// individual arguments into this descriptor and passes a pointer to it into
126 /// the auxiliary function. If the result of the function cannot be directly
127 /// returned, we write it to a special first argument that provides a pointer
128 /// to a corresponding struct. This auxiliary external function is now
129 /// compatible with functions defined in C using pointers to C structs
130 /// corresponding to a memref descriptor.
wrapExternalFunction(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)131 static void wrapExternalFunction(OpBuilder &builder, Location loc,
132 LLVMTypeConverter &typeConverter,
133 FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
134 OpBuilder::InsertionGuard guard(builder);
135
136 Type wrapperType;
137 bool resultIsNowArg;
138 std::tie(wrapperType, resultIsNowArg) =
139 typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
140 // This conversion can only fail if it could not convert one of the argument
141 // types. But since it has been applied to a non-wrapper function before, it
142 // should have failed earlier and not reach this point at all.
143 assert(wrapperType && "unexpected type conversion failure");
144
145 SmallVector<NamedAttribute, 4> attributes;
146 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
147 attributes);
148
149 // Create the auxiliary function.
150 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
151 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
152 wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
153
154 builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
155
156 // Get a ValueRange containing arguments.
157 FunctionType type = funcOp.getType();
158 SmallVector<Value, 8> args;
159 args.reserve(type.getNumInputs());
160 ValueRange wrapperArgsRange(newFuncOp.getArguments());
161
162 if (resultIsNowArg) {
163 // Allocate the struct on the stack and pass the pointer.
164 Type resultType =
165 wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
166 Value one = builder.create<LLVM::ConstantOp>(
167 loc, typeConverter.convertType(builder.getIndexType()),
168 builder.getIntegerAttr(builder.getIndexType(), 1));
169 Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
170 args.push_back(result);
171 }
172
173 // Iterate over the inputs of the original function and pack values into
174 // memref descriptors if the original type is a memref.
175 for (auto &en : llvm::enumerate(type.getInputs())) {
176 Value arg;
177 int numToDrop = 1;
178 auto memRefType = en.value().dyn_cast<MemRefType>();
179 auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
180 if (memRefType || unrankedMemRefType) {
181 numToDrop = memRefType
182 ? MemRefDescriptor::getNumUnpackedValues(memRefType)
183 : UnrankedMemRefDescriptor::getNumUnpackedValues();
184 Value packed =
185 memRefType
186 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
187 wrapperArgsRange.take_front(numToDrop))
188 : UnrankedMemRefDescriptor::pack(
189 builder, loc, typeConverter, unrankedMemRefType,
190 wrapperArgsRange.take_front(numToDrop));
191
192 auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
193 Value one = builder.create<LLVM::ConstantOp>(
194 loc, typeConverter.convertType(builder.getIndexType()),
195 builder.getIntegerAttr(builder.getIndexType(), 1));
196 Value allocated =
197 builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
198 builder.create<LLVM::StoreOp>(loc, packed, allocated);
199 arg = allocated;
200 } else {
201 arg = wrapperArgsRange[0];
202 }
203
204 args.push_back(arg);
205 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
206 }
207 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
208
209 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
210
211 if (resultIsNowArg) {
212 Value result = builder.create<LLVM::LoadOp>(loc, args.front());
213 builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
214 } else {
215 builder.create<LLVM::ReturnOp>(loc, call.getResults());
216 }
217 }
218
219 namespace {
220
221 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
222 protected:
223 using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
224
225 // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
226 // to this legalization pattern.
227 LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp__anonc33f93640111::FuncOpConversionBase228 convertFuncOpToLLVMFuncOp(FuncOp funcOp,
229 ConversionPatternRewriter &rewriter) const {
230 // Convert the original function arguments. They are converted using the
231 // LLVMTypeConverter provided to this legalization pattern.
232 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
233 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
234 auto llvmType = getTypeConverter()->convertFunctionSignature(
235 funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
236 if (!llvmType)
237 return nullptr;
238
239 // Propagate argument attributes to all converted arguments obtained after
240 // converting a given original argument.
241 SmallVector<NamedAttribute, 4> attributes;
242 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
243 attributes);
244 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
245 SmallVector<Attribute, 4> newArgAttrs(
246 llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
247 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
248 auto mapping = result.getInputMapping(i);
249 assert(mapping.hasValue() &&
250 "unexpected deletion of function argument");
251 for (size_t j = 0; j < mapping->size; ++j)
252 newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
253 }
254 attributes.push_back(
255 rewriter.getNamedAttr(function_like_impl::getArgDictAttrName(),
256 rewriter.getArrayAttr(newArgAttrs)));
257 }
258 for (auto pair : llvm::enumerate(attributes)) {
259 if (pair.value().first == "llvm.linkage") {
260 attributes.erase(attributes.begin() + pair.index());
261 break;
262 }
263 }
264
265 // Create an LLVM function, use external linkage by default until MLIR
266 // functions have linkage.
267 LLVM::Linkage linkage = LLVM::Linkage::External;
268 if (funcOp->hasAttr("llvm.linkage")) {
269 auto attr =
270 funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
271 if (!attr) {
272 funcOp->emitError()
273 << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
274 return nullptr;
275 }
276 linkage = attr.getLinkage();
277 }
278 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
279 funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
280 /*dsoLocal*/ false, attributes);
281 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
282 newFuncOp.end());
283 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
284 &result)))
285 return nullptr;
286
287 return newFuncOp;
288 }
289 };
290
291 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
292 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
293 /// information.
294 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
295 struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion__anonc33f93640111::FuncOpConversion296 FuncOpConversion(LLVMTypeConverter &converter)
297 : FuncOpConversionBase(converter) {}
298
299 LogicalResult
matchAndRewrite__anonc33f93640111::FuncOpConversion300 matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter) const override {
302 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
303 if (!newFuncOp)
304 return failure();
305
306 if (getTypeConverter()->getOptions().emitCWrappers ||
307 funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
308 if (newFuncOp.isExternal())
309 wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
310 funcOp, newFuncOp);
311 else
312 wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
313 funcOp, newFuncOp);
314 }
315
316 rewriter.eraseOp(funcOp);
317 return success();
318 }
319 };
320
321 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
322 /// to the MemRef element type. This will impact the calling convention and ABI.
323 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
324 using FuncOpConversionBase::FuncOpConversionBase;
325
326 LogicalResult
matchAndRewrite__anonc33f93640111::BarePtrFuncOpConversion327 matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter) const override {
329
330 // TODO: bare ptr conversion could be handled by argument materialization
331 // and most of the code below would go away. But to do this, we would need a
332 // way to distinguish between FuncOp and other regions in the
333 // addArgumentMaterialization hook.
334
335 // Store the type of memref-typed arguments before the conversion so that we
336 // can promote them to MemRef descriptor at the beginning of the function.
337 SmallVector<Type, 8> oldArgTypes =
338 llvm::to_vector<8>(funcOp.getType().getInputs());
339
340 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
341 if (!newFuncOp)
342 return failure();
343 if (newFuncOp.getBody().empty()) {
344 rewriter.eraseOp(funcOp);
345 return success();
346 }
347
348 // Promote bare pointers from memref arguments to memref descriptors at the
349 // beginning of the function so that all the memrefs in the function have a
350 // uniform representation.
351 Block *entryBlock = &newFuncOp.getBody().front();
352 auto blockArgs = entryBlock->getArguments();
353 assert(blockArgs.size() == oldArgTypes.size() &&
354 "The number of arguments and types doesn't match");
355
356 OpBuilder::InsertionGuard guard(rewriter);
357 rewriter.setInsertionPointToStart(entryBlock);
358 for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
359 BlockArgument arg = std::get<0>(it);
360 Type argTy = std::get<1>(it);
361
362 // Unranked memrefs are not supported in the bare pointer calling
363 // convention. We should have bailed out before in the presence of
364 // unranked memrefs.
365 assert(!argTy.isa<UnrankedMemRefType>() &&
366 "Unranked memref is not supported");
367 auto memrefTy = argTy.dyn_cast<MemRefType>();
368 if (!memrefTy)
369 continue;
370
371 // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
372 // or unranked memref descriptor and replace placeholder with the last
373 // instruction of the memref descriptor.
374 // TODO: The placeholder is needed to avoid replacing barePtr uses in the
375 // MemRef descriptor instructions. We may want to have a utility in the
376 // rewriter to properly handle this use case.
377 Location loc = funcOp.getLoc();
378 auto placeholder = rewriter.create<LLVM::UndefOp>(
379 loc, getTypeConverter()->convertType(memrefTy));
380 rewriter.replaceUsesOfBlockArgument(arg, placeholder);
381
382 Value desc = MemRefDescriptor::fromStaticShape(
383 rewriter, loc, *getTypeConverter(), memrefTy, arg);
384 rewriter.replaceOp(placeholder, {desc});
385 }
386
387 rewriter.eraseOp(funcOp);
388 return success();
389 }
390 };
391
392 // Straightforward lowerings.
393 using AbsFOpLowering = VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp>;
394 using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
395 using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>;
396 using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
397 using BitcastOpLowering =
398 VectorConvertToLLVMPattern<BitcastOp, LLVM::BitcastOp>;
399 using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
400 using CopySignOpLowering =
401 VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
402 using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
403 using FPExtOpLowering = VectorConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp>;
404 using FPToSIOpLowering = VectorConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp>;
405 using FPToUIOpLowering = VectorConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp>;
406 using FPTruncOpLowering =
407 VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
408 using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
409 using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
410 using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
411 using MulIOpLowering = VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp>;
412 using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
413 using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
414 using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
415 using SIToFPOpLowering = VectorConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp>;
416 using SelectOpLowering = VectorConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
417 using SignExtendIOpLowering =
418 VectorConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp>;
419 using ShiftLeftOpLowering =
420 VectorConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp>;
421 using SignedDivIOpLowering =
422 VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp>;
423 using SignedRemIOpLowering =
424 VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp>;
425 using SignedShiftRightOpLowering =
426 VectorConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
427 using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
428 using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
429 using TruncateIOpLowering =
430 VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
431 using UIToFPOpLowering = VectorConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp>;
432 using UnsignedDivIOpLowering =
433 VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
434 using UnsignedRemIOpLowering =
435 VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp>;
436 using UnsignedShiftRightOpLowering =
437 VectorConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
438 using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
439 using ZeroExtendIOpLowering =
440 VectorConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp>;
441
442 /// Lower `std.assert`. The default lowering calls the `abort` function if the
443 /// assertion is violated and has no effect otherwise. The failure message is
444 /// ignored by the default lowering but should be propagated by any custom
445 /// lowering.
446 struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
447 using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
448
449 LogicalResult
matchAndRewrite__anonc33f93640111::AssertOpLowering450 matchAndRewrite(AssertOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter) const override {
452 auto loc = op.getLoc();
453
454 // Insert the `abort` declaration if necessary.
455 auto module = op->getParentOfType<ModuleOp>();
456 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
457 if (!abortFunc) {
458 OpBuilder::InsertionGuard guard(rewriter);
459 rewriter.setInsertionPointToStart(module.getBody());
460 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
461 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
462 "abort", abortFuncTy);
463 }
464
465 // Split block at `assert` operation.
466 Block *opBlock = rewriter.getInsertionBlock();
467 auto opPosition = rewriter.getInsertionPoint();
468 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
469
470 // Generate IR to call `abort`.
471 Block *failureBlock = rewriter.createBlock(opBlock->getParent());
472 rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
473 rewriter.create<LLVM::UnreachableOp>(loc);
474
475 // Generate assertion test.
476 rewriter.setInsertionPointToEnd(opBlock);
477 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
478 op, adaptor.arg(), continuationBlock, failureBlock);
479
480 return success();
481 }
482 };
483
484 struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
485 using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
486
487 LogicalResult
matchAndRewrite__anonc33f93640111::ConstantOpLowering488 matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter) const override {
490 // If constant refers to a function, convert it to "addressof".
491 if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
492 auto type = typeConverter->convertType(op.getResult().getType());
493 if (!type || !LLVM::isCompatibleType(type))
494 return rewriter.notifyMatchFailure(op, "failed to convert result type");
495
496 auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
497 symbolRef.getValue());
498 for (const NamedAttribute &attr : op->getAttrs()) {
499 if (attr.first.strref() == "value")
500 continue;
501 newOp->setAttr(attr.first, attr.second);
502 }
503 rewriter.replaceOp(op, newOp->getResults());
504 return success();
505 }
506
507 // Calling into other scopes (non-flat reference) is not supported in LLVM.
508 if (op.getValue().isa<SymbolRefAttr>())
509 return rewriter.notifyMatchFailure(
510 op, "referring to a symbol outside of the current module");
511
512 return LLVM::detail::oneToOneRewrite(
513 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
514 *getTypeConverter(), rewriter);
515 }
516 };
517
518 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
519 // passes the pointer to the MemRef across function boundaries.
520 template <typename CallOpType>
521 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
522 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
523 using Super = CallOpInterfaceLowering<CallOpType>;
524 using Base = ConvertOpToLLVMPattern<CallOpType>;
525
526 LogicalResult
matchAndRewrite__anonc33f93640111::CallOpInterfaceLowering527 matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529 // Pack the result types into a struct.
530 Type packedResult = nullptr;
531 unsigned numResults = callOp.getNumResults();
532 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
533
534 if (numResults != 0) {
535 if (!(packedResult =
536 this->getTypeConverter()->packFunctionResults(resultTypes)))
537 return failure();
538 }
539
540 auto promoted = this->getTypeConverter()->promoteOperands(
541 callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
542 adaptor.getOperands(), rewriter);
543 auto newOp = rewriter.create<LLVM::CallOp>(
544 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
545 promoted, callOp->getAttrs());
546
547 SmallVector<Value, 4> results;
548 if (numResults < 2) {
549 // If < 2 results, packing did not do anything and we can just return.
550 results.append(newOp.result_begin(), newOp.result_end());
551 } else {
552 // Otherwise, it had been converted to an operation producing a structure.
553 // Extract individual results from the structure and return them as list.
554 results.reserve(numResults);
555 for (unsigned i = 0; i < numResults; ++i) {
556 auto type =
557 this->typeConverter->convertType(callOp.getResult(i).getType());
558 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
559 callOp.getLoc(), type, newOp->getResult(0),
560 rewriter.getI64ArrayAttr(i)));
561 }
562 }
563
564 if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
565 // For the bare-ptr calling convention, promote memref results to
566 // descriptors.
567 assert(results.size() == resultTypes.size() &&
568 "The number of arguments and types doesn't match");
569 this->getTypeConverter()->promoteBarePtrsToDescriptors(
570 rewriter, callOp.getLoc(), resultTypes, results);
571 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
572 resultTypes, results,
573 /*toDynamic=*/false))) {
574 return failure();
575 }
576
577 rewriter.replaceOp(callOp, results);
578 return success();
579 }
580 };
581
582 struct CallOpLowering : public CallOpInterfaceLowering<CallOp> {
583 using Super::Super;
584 };
585
586 struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
587 using Super::Super;
588 };
589
590 struct UnrealizedConversionCastOpLowering
591 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
592 using ConvertOpToLLVMPattern<
593 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
594
595 LogicalResult
matchAndRewrite__anonc33f93640111::UnrealizedConversionCastOpLowering596 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
597 ConversionPatternRewriter &rewriter) const override {
598 SmallVector<Type> convertedTypes;
599 if (succeeded(typeConverter->convertTypes(op.outputs().getTypes(),
600 convertedTypes)) &&
601 convertedTypes == adaptor.inputs().getTypes()) {
602 rewriter.replaceOp(op, adaptor.inputs());
603 return success();
604 }
605
606 convertedTypes.clear();
607 if (succeeded(typeConverter->convertTypes(adaptor.inputs().getTypes(),
608 convertedTypes)) &&
609 convertedTypes == op.outputs().getType()) {
610 rewriter.replaceOp(op, adaptor.inputs());
611 return success();
612 }
613 return failure();
614 }
615 };
616
617 struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
618 using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
619
620 LogicalResult
matchAndRewrite__anonc33f93640111::RankOpLowering621 matchAndRewrite(RankOp op, OpAdaptor adaptor,
622 ConversionPatternRewriter &rewriter) const override {
623 Location loc = op.getLoc();
624 Type operandType = op.memrefOrTensor().getType();
625 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
626 UnrankedMemRefDescriptor desc(adaptor.memrefOrTensor());
627 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
628 return success();
629 }
630 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
631 rewriter.replaceOp(
632 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
633 return success();
634 }
635 return failure();
636 }
637 };
638
639 // Common base for load and store operations on MemRefs. Restricts the match
640 // to supported MemRef types. Provides functionality to emit code accessing a
641 // specific element of the underlying data buffer.
642 template <typename Derived>
643 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
644 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
645 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
646 using Base = LoadStoreOpLowering<Derived>;
647
match__anonc33f93640111::LoadStoreOpLowering648 LogicalResult match(Derived op) const override {
649 MemRefType type = op.getMemRefType();
650 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
651 }
652 };
653
654 // The lowering of index_cast becomes an integer conversion since index becomes
655 // an integer. If the bit width of the source and target integer types is the
656 // same, just erase the cast. If the target type is wider, sign-extend the
657 // value, otherwise truncate it.
658 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
659 using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
660
661 LogicalResult
matchAndRewrite__anonc33f93640111::IndexCastOpLowering662 matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter) const override {
664 auto targetType =
665 typeConverter->convertType(indexCastOp.getResult().getType());
666 auto targetElementType =
667 typeConverter
668 ->convertType(getElementTypeOrSelf(indexCastOp.getResult()))
669 .cast<IntegerType>();
670 auto sourceElementType =
671 getElementTypeOrSelf(adaptor.in()).cast<IntegerType>();
672 unsigned targetBits = targetElementType.getWidth();
673 unsigned sourceBits = sourceElementType.getWidth();
674
675 if (targetBits == sourceBits)
676 rewriter.replaceOp(indexCastOp, adaptor.in());
677 else if (targetBits < sourceBits)
678 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
679 adaptor.in());
680 else
681 rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
682 adaptor.in());
683 return success();
684 }
685 };
686
687 // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two
688 // enums share the numerical values so just cast.
689 template <typename LLVMPredType, typename StdPredType>
convertCmpPredicate(StdPredType pred)690 static LLVMPredType convertCmpPredicate(StdPredType pred) {
691 return static_cast<LLVMPredType>(pred);
692 }
693
694 struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
695 using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
696
697 LogicalResult
matchAndRewrite__anonc33f93640111::CmpIOpLowering698 matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor,
699 ConversionPatternRewriter &rewriter) const override {
700 auto operandType = adaptor.lhs().getType();
701 auto resultType = cmpiOp.getResult().getType();
702
703 // Handle the scalar and 1D vector cases.
704 if (!operandType.isa<LLVM::LLVMArrayType>()) {
705 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
706 cmpiOp, typeConverter->convertType(resultType),
707 convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
708 adaptor.lhs(), adaptor.rhs());
709 return success();
710 }
711
712 auto vectorType = resultType.dyn_cast<VectorType>();
713 if (!vectorType)
714 return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type");
715
716 return LLVM::detail::handleMultidimensionalVectors(
717 cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(),
718 [&](Type llvm1DVectorTy, ValueRange operands) {
719 CmpIOpAdaptor adaptor(operands);
720 return rewriter.create<LLVM::ICmpOp>(
721 cmpiOp.getLoc(), llvm1DVectorTy,
722 convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
723 adaptor.lhs(), adaptor.rhs());
724 },
725 rewriter);
726
727 return success();
728 }
729 };
730
731 struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
732 using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
733
734 LogicalResult
matchAndRewrite__anonc33f93640111::CmpFOpLowering735 matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter) const override {
737 auto operandType = adaptor.lhs().getType();
738 auto resultType = cmpfOp.getResult().getType();
739
740 // Handle the scalar and 1D vector cases.
741 if (!operandType.isa<LLVM::LLVMArrayType>()) {
742 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
743 cmpfOp, typeConverter->convertType(resultType),
744 convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
745 adaptor.lhs(), adaptor.rhs());
746 return success();
747 }
748
749 auto vectorType = resultType.dyn_cast<VectorType>();
750 if (!vectorType)
751 return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type");
752
753 return LLVM::detail::handleMultidimensionalVectors(
754 cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(),
755 [&](Type llvm1DVectorTy, ValueRange operands) {
756 CmpFOpAdaptor adaptor(operands);
757 return rewriter.create<LLVM::FCmpOp>(
758 cmpfOp.getLoc(), llvm1DVectorTy,
759 convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
760 adaptor.lhs(), adaptor.rhs());
761 },
762 rewriter);
763 }
764 };
765
766 // Base class for LLVM IR lowering terminator operations with successors.
767 template <typename SourceOp, typename TargetOp>
768 struct OneToOneLLVMTerminatorLowering
769 : public ConvertOpToLLVMPattern<SourceOp> {
770 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
771 using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
772
773 LogicalResult
matchAndRewrite__anonc33f93640111::OneToOneLLVMTerminatorLowering774 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
775 ConversionPatternRewriter &rewriter) const override {
776 rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
777 op->getSuccessors(), op->getAttrs());
778 return success();
779 }
780 };
781
782 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
783 // `ReturnOp` interacts with the function signature and must have as many
784 // operands as the function has return values. Because in LLVM IR, functions
785 // can only return 0 or 1 value, we pack multiple values into a structure type.
786 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
787 // necessary before returning it
788 struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
789 using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
790
791 LogicalResult
matchAndRewrite__anonc33f93640111::ReturnOpLowering792 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
793 ConversionPatternRewriter &rewriter) const override {
794 Location loc = op.getLoc();
795 unsigned numArguments = op.getNumOperands();
796 SmallVector<Value, 4> updatedOperands;
797
798 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
799 // For the bare-ptr calling convention, extract the aligned pointer to
800 // be returned from the memref descriptor.
801 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
802 Type oldTy = std::get<0>(it).getType();
803 Value newOperand = std::get<1>(it);
804 if (oldTy.isa<MemRefType>()) {
805 MemRefDescriptor memrefDesc(newOperand);
806 newOperand = memrefDesc.alignedPtr(rewriter, loc);
807 } else if (oldTy.isa<UnrankedMemRefType>()) {
808 // Unranked memref is not supported in the bare pointer calling
809 // convention.
810 return failure();
811 }
812 updatedOperands.push_back(newOperand);
813 }
814 } else {
815 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
816 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
817 updatedOperands,
818 /*toDynamic=*/true);
819 }
820
821 // If ReturnOp has 0 or 1 operand, create it and return immediately.
822 if (numArguments == 0) {
823 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
824 op->getAttrs());
825 return success();
826 }
827 if (numArguments == 1) {
828 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
829 op, TypeRange(), updatedOperands, op->getAttrs());
830 return success();
831 }
832
833 // Otherwise, we need to pack the arguments into an LLVM struct type before
834 // returning.
835 auto packedType = getTypeConverter()->packFunctionResults(
836 llvm::to_vector<4>(op.getOperandTypes()));
837
838 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
839 for (unsigned i = 0; i < numArguments; ++i) {
840 packed = rewriter.create<LLVM::InsertValueOp>(
841 loc, packedType, packed, updatedOperands[i],
842 rewriter.getI64ArrayAttr(i));
843 }
844 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
845 op->getAttrs());
846 return success();
847 }
848 };
849
850 // FIXME: this should be tablegen'ed as well.
851 struct BranchOpLowering
852 : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
853 using Super::Super;
854 };
855 struct CondBranchOpLowering
856 : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
857 using Super::Super;
858 };
859 struct SwitchOpLowering
860 : public OneToOneLLVMTerminatorLowering<SwitchOp, LLVM::SwitchOp> {
861 using Super::Super;
862 };
863
864 // The Splat operation is lowered to an insertelement + a shufflevector
865 // operation. Splat to only 1-d vector result types are lowered.
866 struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
867 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
868
869 LogicalResult
matchAndRewrite__anonc33f93640111::SplatOpLowering870 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
871 ConversionPatternRewriter &rewriter) const override {
872 VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
873 if (!resultType || resultType.getRank() != 1)
874 return failure();
875
876 // First insert it into an undef vector so we can shuffle it.
877 auto vectorType = typeConverter->convertType(splatOp.getType());
878 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
879 auto zero = rewriter.create<LLVM::ConstantOp>(
880 splatOp.getLoc(),
881 typeConverter->convertType(rewriter.getIntegerType(32)),
882 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
883
884 auto v = rewriter.create<LLVM::InsertElementOp>(
885 splatOp.getLoc(), vectorType, undef, adaptor.input(), zero);
886
887 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
888 SmallVector<int32_t, 4> zeroValues(width, 0);
889
890 // Shuffle the value across the desired number of elements.
891 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
892 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
893 zeroAttrs);
894 return success();
895 }
896 };
897
898 // The Splat operation is lowered to an insertelement + a shufflevector
899 // operation. Splat to only 2+-d vector result types are lowered by the
900 // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
901 struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
902 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
903
904 LogicalResult
matchAndRewrite__anonc33f93640111::SplatNdOpLowering905 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
906 ConversionPatternRewriter &rewriter) const override {
907 VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
908 if (!resultType || resultType.getRank() == 1)
909 return failure();
910
911 // First insert it into an undef vector so we can shuffle it.
912 auto loc = splatOp.getLoc();
913 auto vectorTypeInfo =
914 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
915 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
916 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
917 if (!llvmNDVectorTy || !llvm1DVectorTy)
918 return failure();
919
920 // Construct returned value.
921 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
922
923 // Construct a 1-D vector with the splatted value that we insert in all the
924 // places within the returned descriptor.
925 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
926 auto zero = rewriter.create<LLVM::ConstantOp>(
927 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
928 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
929 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
930 adaptor.input(), zero);
931
932 // Shuffle the value across the desired number of elements.
933 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
934 SmallVector<int32_t, 4> zeroValues(width, 0);
935 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
936 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
937
938 // Iterate of linear index, convert to coords space and insert splatted 1-D
939 // vector in each position.
940 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
941 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
942 position);
943 });
944 rewriter.replaceOp(splatOp, desc);
945 return success();
946 }
947 };
948
949 } // namespace
950
951 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
952 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
matchSimpleAtomicOp(AtomicRMWOp atomicOp)953 static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
954 switch (atomicOp.kind()) {
955 case AtomicRMWKind::addf:
956 return LLVM::AtomicBinOp::fadd;
957 case AtomicRMWKind::addi:
958 return LLVM::AtomicBinOp::add;
959 case AtomicRMWKind::assign:
960 return LLVM::AtomicBinOp::xchg;
961 case AtomicRMWKind::maxs:
962 return LLVM::AtomicBinOp::max;
963 case AtomicRMWKind::maxu:
964 return LLVM::AtomicBinOp::umax;
965 case AtomicRMWKind::mins:
966 return LLVM::AtomicBinOp::min;
967 case AtomicRMWKind::minu:
968 return LLVM::AtomicBinOp::umin;
969 default:
970 return llvm::None;
971 }
972 llvm_unreachable("Invalid AtomicRMWKind");
973 }
974
975 namespace {
976
977 struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
978 using Base::Base;
979
980 LogicalResult
matchAndRewrite__anonc33f93640511::AtomicRMWOpLowering981 matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor,
982 ConversionPatternRewriter &rewriter) const override {
983 if (failed(match(atomicOp)))
984 return failure();
985 auto maybeKind = matchSimpleAtomicOp(atomicOp);
986 if (!maybeKind)
987 return failure();
988 auto resultType = adaptor.value().getType();
989 auto memRefType = atomicOp.getMemRefType();
990 auto dataPtr =
991 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
992 adaptor.indices(), rewriter);
993 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
994 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
995 LLVM::AtomicOrdering::acq_rel);
996 return success();
997 }
998 };
999
1000 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
1001 /// retried until it succeeds in atomically storing a new value into memory.
1002 ///
1003 /// +---------------------------------+
1004 /// | <code before the AtomicRMWOp> |
1005 /// | <compute initial %loaded> |
1006 /// | br loop(%loaded) |
1007 /// +---------------------------------+
1008 /// |
1009 /// -------| |
1010 /// | v v
1011 /// | +--------------------------------+
1012 /// | | loop(%loaded): |
1013 /// | | <body contents> |
1014 /// | | %pair = cmpxchg |
1015 /// | | %ok = %pair[0] |
1016 /// | | %new = %pair[1] |
1017 /// | | cond_br %ok, end, loop(%new) |
1018 /// | +--------------------------------+
1019 /// | | |
1020 /// |----------- |
1021 /// v
1022 /// +--------------------------------+
1023 /// | end: |
1024 /// | <code after the AtomicRMWOp> |
1025 /// +--------------------------------+
1026 ///
1027 struct GenericAtomicRMWOpLowering
1028 : public LoadStoreOpLowering<GenericAtomicRMWOp> {
1029 using Base::Base;
1030
1031 LogicalResult
matchAndRewrite__anonc33f93640511::GenericAtomicRMWOpLowering1032 matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
1033 ConversionPatternRewriter &rewriter) const override {
1034
1035 auto loc = atomicOp.getLoc();
1036 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
1037
1038 // Split the block into initial, loop, and ending parts.
1039 auto *initBlock = rewriter.getInsertionBlock();
1040 auto *loopBlock =
1041 rewriter.createBlock(initBlock->getParent(),
1042 std::next(Region::iterator(initBlock)), valueType);
1043 auto *endBlock = rewriter.createBlock(
1044 loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
1045
1046 // Operations range to be moved to `endBlock`.
1047 auto opsToMoveStart = atomicOp->getIterator();
1048 auto opsToMoveEnd = initBlock->back().getIterator();
1049
1050 // Compute the loaded value and branch to the loop block.
1051 rewriter.setInsertionPointToEnd(initBlock);
1052 auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
1053 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
1054 adaptor.indices(), rewriter);
1055 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
1056 rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
1057
1058 // Prepare the body of the loop block.
1059 rewriter.setInsertionPointToStart(loopBlock);
1060
1061 // Clone the GenericAtomicRMWOp region and extract the result.
1062 auto loopArgument = loopBlock->getArgument(0);
1063 BlockAndValueMapping mapping;
1064 mapping.map(atomicOp.getCurrentValue(), loopArgument);
1065 Block &entryBlock = atomicOp.body().front();
1066 for (auto &nestedOp : entryBlock.without_terminator()) {
1067 Operation *clone = rewriter.clone(nestedOp, mapping);
1068 mapping.map(nestedOp.getResults(), clone->getResults());
1069 }
1070 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
1071
1072 // Prepare the epilog of the loop block.
1073 // Append the cmpxchg op to the end of the loop block.
1074 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
1075 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
1076 auto boolType = IntegerType::get(rewriter.getContext(), 1);
1077 auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
1078 {valueType, boolType});
1079 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
1080 loc, pairType, dataPtr, loopArgument, result, successOrdering,
1081 failureOrdering);
1082 // Extract the %new_loaded and %ok values from the pair.
1083 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
1084 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
1085 Value ok = rewriter.create<LLVM::ExtractValueOp>(
1086 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
1087
1088 // Conditionally branch to the end or back to the loop depending on %ok.
1089 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
1090 loopBlock, newLoaded);
1091
1092 rewriter.setInsertionPointToEnd(endBlock);
1093 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
1094 std::next(opsToMoveEnd), rewriter);
1095
1096 // The 'result' of the atomic_rmw op is the newly loaded value.
1097 rewriter.replaceOp(atomicOp, {newLoaded});
1098
1099 return success();
1100 }
1101
1102 private:
1103 // Clones a segment of ops [start, end) and erases the original.
moveOpsRange__anonc33f93640511::GenericAtomicRMWOpLowering1104 void moveOpsRange(ValueRange oldResult, ValueRange newResult,
1105 Block::iterator start, Block::iterator end,
1106 ConversionPatternRewriter &rewriter) const {
1107 BlockAndValueMapping mapping;
1108 mapping.map(oldResult, newResult);
1109 SmallVector<Operation *, 2> opsToErase;
1110 for (auto it = start; it != end; ++it) {
1111 rewriter.clone(*it, mapping);
1112 opsToErase.push_back(&*it);
1113 }
1114 for (auto *it : opsToErase)
1115 rewriter.eraseOp(it);
1116 }
1117 };
1118
1119 } // namespace
1120
populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter & converter,RewritePatternSet & patterns)1121 void mlir::populateStdToLLVMFuncOpConversionPattern(
1122 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1123 if (converter.getOptions().useBarePtrCallConv)
1124 patterns.add<BarePtrFuncOpConversion>(converter);
1125 else
1126 patterns.add<FuncOpConversion>(converter);
1127 }
1128
populateStdToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)1129 void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
1130 RewritePatternSet &patterns) {
1131 populateStdToLLVMFuncOpConversionPattern(converter, patterns);
1132 // clang-format off
1133 patterns.add<
1134 AbsFOpLowering,
1135 AddFOpLowering,
1136 AddIOpLowering,
1137 AndOpLowering,
1138 AssertOpLowering,
1139 AtomicRMWOpLowering,
1140 BitcastOpLowering,
1141 BranchOpLowering,
1142 CallIndirectOpLowering,
1143 CallOpLowering,
1144 CeilFOpLowering,
1145 CmpFOpLowering,
1146 CmpIOpLowering,
1147 CondBranchOpLowering,
1148 CopySignOpLowering,
1149 ConstantOpLowering,
1150 DivFOpLowering,
1151 FloorFOpLowering,
1152 FmaFOpLowering,
1153 GenericAtomicRMWOpLowering,
1154 FPExtOpLowering,
1155 FPToSIOpLowering,
1156 FPToUIOpLowering,
1157 FPTruncOpLowering,
1158 IndexCastOpLowering,
1159 MulFOpLowering,
1160 MulIOpLowering,
1161 NegFOpLowering,
1162 OrOpLowering,
1163 RemFOpLowering,
1164 RankOpLowering,
1165 ReturnOpLowering,
1166 SIToFPOpLowering,
1167 SelectOpLowering,
1168 ShiftLeftOpLowering,
1169 SignExtendIOpLowering,
1170 SignedDivIOpLowering,
1171 SignedRemIOpLowering,
1172 SignedShiftRightOpLowering,
1173 SplatOpLowering,
1174 SplatNdOpLowering,
1175 SubFOpLowering,
1176 SubIOpLowering,
1177 SwitchOpLowering,
1178 TruncateIOpLowering,
1179 UIToFPOpLowering,
1180 UnsignedDivIOpLowering,
1181 UnsignedRemIOpLowering,
1182 UnsignedShiftRightOpLowering,
1183 XOrOpLowering,
1184 ZeroExtendIOpLowering>(converter);
1185 // clang-format on
1186 }
1187
1188 namespace {
1189 /// A pass converting MLIR operations into the LLVM IR dialect.
1190 struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
1191 LLVMLoweringPass() = default;
LLVMLoweringPass__anonc33f93640611::LLVMLoweringPass1192 LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers,
1193 unsigned indexBitwidth, bool useAlignedAlloc,
1194 const llvm::DataLayout &dataLayout) {
1195 this->useBarePtrCallConv = useBarePtrCallConv;
1196 this->emitCWrappers = emitCWrappers;
1197 this->indexBitwidth = indexBitwidth;
1198 this->dataLayout = dataLayout.getStringRepresentation();
1199 }
1200
1201 /// Run the dialect converter on the module.
runOnOperation__anonc33f93640611::LLVMLoweringPass1202 void runOnOperation() override {
1203 if (useBarePtrCallConv && emitCWrappers) {
1204 getOperation().emitError()
1205 << "incompatible conversion options: bare-pointer calling convention "
1206 "and C wrapper emission";
1207 signalPassFailure();
1208 return;
1209 }
1210 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
1211 this->dataLayout, [this](const Twine &message) {
1212 getOperation().emitError() << message.str();
1213 }))) {
1214 signalPassFailure();
1215 return;
1216 }
1217
1218 ModuleOp m = getOperation();
1219 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1220
1221 LowerToLLVMOptions options(&getContext(),
1222 dataLayoutAnalysis.getAtOrAbove(m));
1223 options.useBarePtrCallConv = useBarePtrCallConv;
1224 options.emitCWrappers = emitCWrappers;
1225 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1226 options.overrideIndexBitwidth(indexBitwidth);
1227 options.dataLayout = llvm::DataLayout(this->dataLayout);
1228
1229 LLVMTypeConverter typeConverter(&getContext(), options,
1230 &dataLayoutAnalysis);
1231
1232 RewritePatternSet patterns(&getContext());
1233 populateStdToLLVMConversionPatterns(typeConverter, patterns);
1234
1235 LLVMConversionTarget target(getContext());
1236 if (failed(applyPartialConversion(m, target, std::move(patterns))))
1237 signalPassFailure();
1238
1239 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
1240 StringAttr::get(m.getContext(), this->dataLayout));
1241 }
1242 };
1243 } // end namespace
1244
createLowerToLLVMPass()1245 std::unique_ptr<OperationPass<ModuleOp>> mlir::createLowerToLLVMPass() {
1246 return std::make_unique<LLVMLoweringPass>();
1247 }
1248
1249 std::unique_ptr<OperationPass<ModuleOp>>
createLowerToLLVMPass(const LowerToLLVMOptions & options)1250 mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
1251 auto allocLowering = options.allocLowering;
1252 // There is no way to provide additional patterns for pass, so
1253 // AllocLowering::None will always fail.
1254 assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
1255 "LLVMLoweringPass doesn't support AllocLowering::None");
1256 bool useAlignedAlloc =
1257 (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
1258 return std::make_unique<LLVMLoweringPass>(
1259 options.useBarePtrCallConv, options.emitCWrappers,
1260 options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
1261 }
1262