1 //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "../PassDetail.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/OpenACC/OpenACC.h"
14 #include "mlir/IR/Builders.h"
15
16 using namespace mlir;
17
18 //===----------------------------------------------------------------------===//
19 // DataDescriptor implementation
20 //===----------------------------------------------------------------------===//
21
getStructName()22 constexpr StringRef getStructName() { return "openacc_data"; }
23
24 /// Construct a helper for the given descriptor value.
DataDescriptor(Value descriptor)25 DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) {
26 assert(value != nullptr && "value cannot be null");
27 }
28
29 /// Builds IR creating an `undef` value of the data descriptor.
undef(OpBuilder & builder,Location loc,Type basePtrTy,Type ptrTy)30 DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc,
31 Type basePtrTy, Type ptrTy) {
32 Type descriptorType = LLVM::LLVMStructType::getNewIdentified(
33 builder.getContext(), getStructName(),
34 {basePtrTy, ptrTy, builder.getI64Type()});
35 Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
36 return DataDescriptor(descriptor);
37 }
38
39 /// Check whether the type is a valid data descriptor.
isValid(Value descriptor)40 bool DataDescriptor::isValid(Value descriptor) {
41 if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) {
42 if (type.isIdentified() && type.getName().startswith(getStructName()) &&
43 type.getBody().size() == 3 &&
44 (type.getBody()[kPtrBasePosInDataDescriptor]
45 .isa<LLVM::LLVMPointerType>() ||
46 type.getBody()[kPtrBasePosInDataDescriptor]
47 .isa<LLVM::LLVMStructType>()) &&
48 type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() &&
49 type.getBody()[kSizePosInDataDescriptor].isInteger(64))
50 return true;
51 }
52 return false;
53 }
54
55 /// Builds IR inserting the base pointer value into the descriptor.
setBasePointer(OpBuilder & builder,Location loc,Value basePtr)56 void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc,
57 Value basePtr) {
58 setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr);
59 }
60
61 /// Builds IR inserting the pointer value into the descriptor.
setPointer(OpBuilder & builder,Location loc,Value ptr)62 void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) {
63 setPtr(builder, loc, kPtrPosInDataDescriptor, ptr);
64 }
65
66 /// Builds IR inserting the size value into the descriptor.
setSize(OpBuilder & builder,Location loc,Value size)67 void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) {
68 setPtr(builder, loc, kSizePosInDataDescriptor, size);
69 }
70
71 //===----------------------------------------------------------------------===//
72 // Conversion patterns
73 //===----------------------------------------------------------------------===//
74
75 namespace {
76
77 template <typename Op>
78 class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
79 using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
80
81 LogicalResult
matchAndRewrite(Op op,ArrayRef<Value> operands,ConversionPatternRewriter & builder) const82 matchAndRewrite(Op op, ArrayRef<Value> operands,
83 ConversionPatternRewriter &builder) const override {
84 Location loc = op.getLoc();
85 TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
86
87 unsigned numDataOperand = op.getNumDataOperands();
88
89 // Keep the non data operands without modification.
90 auto nonDataOperands =
91 operands.take_front(operands.size() - numDataOperand);
92 SmallVector<Value> convertedOperands;
93 convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
94
95 // Go over the data operand and legalize them for translation.
96 for (unsigned idx = 0; idx < numDataOperand; ++idx) {
97 Value originalDataOperand = op.getDataOperand(idx);
98
99 // Traverse operands that were converted to MemRefDescriptors.
100 if (auto memRefType =
101 originalDataOperand.getType().dyn_cast<MemRefType>()) {
102 Type structType = converter->convertType(memRefType);
103 Value memRefDescriptor = builder
104 .create<UnrealizedConversionCastOp>(
105 loc, structType, originalDataOperand)
106 .getResult(0);
107
108 // Calculate the size of the memref and get the pointer to the allocated
109 // buffer.
110 SmallVector<Value> sizes;
111 SmallVector<Value> strides;
112 Value sizeBytes;
113 ConvertToLLVMPattern::getMemRefDescriptorSizes(
114 loc, memRefType, {}, builder, sizes, strides, sizeBytes);
115 MemRefDescriptor descriptor(memRefDescriptor);
116 Value dataPtr = descriptor.alignedPtr(builder, loc);
117 auto ptrType = descriptor.getElementPtrType();
118
119 auto descr = DataDescriptor::undef(builder, loc, structType, ptrType);
120 descr.setBasePointer(builder, loc, memRefDescriptor);
121 descr.setPointer(builder, loc, dataPtr);
122 descr.setSize(builder, loc, sizeBytes);
123 convertedOperands.push_back(descr);
124 } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) {
125 convertedOperands.push_back(originalDataOperand);
126 } else {
127 // Type not supported.
128 return builder.notifyMatchFailure(op, "unsupported type");
129 }
130 }
131
132 builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
133 op.getOperation()->getAttrs());
134
135 return success();
136 }
137 };
138 } // namespace
139
populateOpenACCToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)140 void mlir::populateOpenACCToLLVMConversionPatterns(
141 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
142 patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
143 patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
144 patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
145 patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
146 patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
147 }
148
149 namespace {
150 struct ConvertOpenACCToLLVMPass
151 : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> {
152 void runOnOperation() override;
153 };
154 } // namespace
155
runOnOperation()156 void ConvertOpenACCToLLVMPass::runOnOperation() {
157 auto op = getOperation();
158 auto *context = op.getContext();
159
160 // Convert to OpenACC operations with LLVM IR dialect
161 RewritePatternSet patterns(context);
162 LLVMTypeConverter converter(context);
163 populateOpenACCToLLVMConversionPatterns(converter, patterns);
164
165 ConversionTarget target(*context);
166 target.addLegalDialect<LLVM::LLVMDialect>();
167 target.addLegalOp<UnrealizedConversionCastOp>();
168
169 auto allDataOperandsAreConverted = [](ValueRange operands) {
170 for (Value operand : operands) {
171 if (!DataDescriptor::isValid(operand) &&
172 !operand.getType().isa<LLVM::LLVMPointerType>())
173 return false;
174 }
175 return true;
176 };
177
178 target.addDynamicallyLegalOp<acc::DataOp>(
179 [allDataOperandsAreConverted](acc::DataOp op) {
180 return allDataOperandsAreConverted(op.copyOperands()) &&
181 allDataOperandsAreConverted(op.copyinOperands()) &&
182 allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
183 allDataOperandsAreConverted(op.copyoutOperands()) &&
184 allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
185 allDataOperandsAreConverted(op.createOperands()) &&
186 allDataOperandsAreConverted(op.createZeroOperands()) &&
187 allDataOperandsAreConverted(op.noCreateOperands()) &&
188 allDataOperandsAreConverted(op.presentOperands()) &&
189 allDataOperandsAreConverted(op.deviceptrOperands()) &&
190 allDataOperandsAreConverted(op.attachOperands());
191 });
192
193 target.addDynamicallyLegalOp<acc::EnterDataOp>(
194 [allDataOperandsAreConverted](acc::EnterDataOp op) {
195 return allDataOperandsAreConverted(op.copyinOperands()) &&
196 allDataOperandsAreConverted(op.createOperands()) &&
197 allDataOperandsAreConverted(op.createZeroOperands()) &&
198 allDataOperandsAreConverted(op.attachOperands());
199 });
200
201 target.addDynamicallyLegalOp<acc::ExitDataOp>(
202 [allDataOperandsAreConverted](acc::ExitDataOp op) {
203 return allDataOperandsAreConverted(op.copyoutOperands()) &&
204 allDataOperandsAreConverted(op.deleteOperands()) &&
205 allDataOperandsAreConverted(op.detachOperands());
206 });
207
208 target.addDynamicallyLegalOp<acc::ParallelOp>(
209 [allDataOperandsAreConverted](acc::ParallelOp op) {
210 return allDataOperandsAreConverted(op.reductionOperands()) &&
211 allDataOperandsAreConverted(op.copyOperands()) &&
212 allDataOperandsAreConverted(op.copyinOperands()) &&
213 allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
214 allDataOperandsAreConverted(op.copyoutOperands()) &&
215 allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
216 allDataOperandsAreConverted(op.createOperands()) &&
217 allDataOperandsAreConverted(op.createZeroOperands()) &&
218 allDataOperandsAreConverted(op.noCreateOperands()) &&
219 allDataOperandsAreConverted(op.presentOperands()) &&
220 allDataOperandsAreConverted(op.devicePtrOperands()) &&
221 allDataOperandsAreConverted(op.attachOperands()) &&
222 allDataOperandsAreConverted(op.gangPrivateOperands()) &&
223 allDataOperandsAreConverted(op.gangFirstPrivateOperands());
224 });
225
226 target.addDynamicallyLegalOp<acc::UpdateOp>(
227 [allDataOperandsAreConverted](acc::UpdateOp op) {
228 return allDataOperandsAreConverted(op.hostOperands()) &&
229 allDataOperandsAreConverted(op.deviceOperands());
230 });
231
232 if (failed(applyPartialConversion(op, target, std::move(patterns))))
233 signalPassFailure();
234 }
235
236 std::unique_ptr<OperationPass<ModuleOp>>
createConvertOpenACCToLLVMPass()237 mlir::createConvertOpenACCToLLVMPass() {
238 return std::make_unique<ConvertOpenACCToLLVMPass>();
239 }
240