1 //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../PassDetail.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/OpenACC/OpenACC.h"
14 #include "mlir/IR/Builders.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // DataDescriptor implementation
20 //===----------------------------------------------------------------------===//
21 
getStructName()22 constexpr StringRef getStructName() { return "openacc_data"; }
23 
24 /// Construct a helper for the given descriptor value.
DataDescriptor(Value descriptor)25 DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) {
26   assert(value != nullptr && "value cannot be null");
27 }
28 
29 /// Builds IR creating an `undef` value of the data descriptor.
undef(OpBuilder & builder,Location loc,Type basePtrTy,Type ptrTy)30 DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc,
31                                      Type basePtrTy, Type ptrTy) {
32   Type descriptorType = LLVM::LLVMStructType::getNewIdentified(
33       builder.getContext(), getStructName(),
34       {basePtrTy, ptrTy, builder.getI64Type()});
35   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
36   return DataDescriptor(descriptor);
37 }
38 
39 /// Check whether the type is a valid data descriptor.
isValid(Value descriptor)40 bool DataDescriptor::isValid(Value descriptor) {
41   if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) {
42     if (type.isIdentified() && type.getName().startswith(getStructName()) &&
43         type.getBody().size() == 3 &&
44         (type.getBody()[kPtrBasePosInDataDescriptor]
45              .isa<LLVM::LLVMPointerType>() ||
46          type.getBody()[kPtrBasePosInDataDescriptor]
47              .isa<LLVM::LLVMStructType>()) &&
48         type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() &&
49         type.getBody()[kSizePosInDataDescriptor].isInteger(64))
50       return true;
51   }
52   return false;
53 }
54 
55 /// Builds IR inserting the base pointer value into the descriptor.
setBasePointer(OpBuilder & builder,Location loc,Value basePtr)56 void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc,
57                                     Value basePtr) {
58   setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr);
59 }
60 
61 /// Builds IR inserting the pointer value into the descriptor.
setPointer(OpBuilder & builder,Location loc,Value ptr)62 void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) {
63   setPtr(builder, loc, kPtrPosInDataDescriptor, ptr);
64 }
65 
66 /// Builds IR inserting the size value into the descriptor.
setSize(OpBuilder & builder,Location loc,Value size)67 void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) {
68   setPtr(builder, loc, kSizePosInDataDescriptor, size);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Conversion patterns
73 //===----------------------------------------------------------------------===//
74 
75 namespace {
76 
77 template <typename Op>
78 class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
79   using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
80 
81   LogicalResult
matchAndRewrite(Op op,ArrayRef<Value> operands,ConversionPatternRewriter & builder) const82   matchAndRewrite(Op op, ArrayRef<Value> operands,
83                   ConversionPatternRewriter &builder) const override {
84     Location loc = op.getLoc();
85     TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
86 
87     unsigned numDataOperand = op.getNumDataOperands();
88 
89     // Keep the non data operands without modification.
90     auto nonDataOperands =
91         operands.take_front(operands.size() - numDataOperand);
92     SmallVector<Value> convertedOperands;
93     convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
94 
95     // Go over the data operand and legalize them for translation.
96     for (unsigned idx = 0; idx < numDataOperand; ++idx) {
97       Value originalDataOperand = op.getDataOperand(idx);
98 
99       // Traverse operands that were converted to MemRefDescriptors.
100       if (auto memRefType =
101               originalDataOperand.getType().dyn_cast<MemRefType>()) {
102         Type structType = converter->convertType(memRefType);
103         Value memRefDescriptor = builder
104                                      .create<UnrealizedConversionCastOp>(
105                                          loc, structType, originalDataOperand)
106                                      .getResult(0);
107 
108         // Calculate the size of the memref and get the pointer to the allocated
109         // buffer.
110         SmallVector<Value> sizes;
111         SmallVector<Value> strides;
112         Value sizeBytes;
113         ConvertToLLVMPattern::getMemRefDescriptorSizes(
114             loc, memRefType, {}, builder, sizes, strides, sizeBytes);
115         MemRefDescriptor descriptor(memRefDescriptor);
116         Value dataPtr = descriptor.alignedPtr(builder, loc);
117         auto ptrType = descriptor.getElementPtrType();
118 
119         auto descr = DataDescriptor::undef(builder, loc, structType, ptrType);
120         descr.setBasePointer(builder, loc, memRefDescriptor);
121         descr.setPointer(builder, loc, dataPtr);
122         descr.setSize(builder, loc, sizeBytes);
123         convertedOperands.push_back(descr);
124       } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) {
125         convertedOperands.push_back(originalDataOperand);
126       } else {
127         // Type not supported.
128         return builder.notifyMatchFailure(op, "unsupported type");
129       }
130     }
131 
132     builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
133                                    op.getOperation()->getAttrs());
134 
135     return success();
136   }
137 };
138 } // namespace
139 
populateOpenACCToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)140 void mlir::populateOpenACCToLLVMConversionPatterns(
141     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
142   patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
143   patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
144   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
145   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
146   patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
147 }
148 
149 namespace {
150 struct ConvertOpenACCToLLVMPass
151     : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> {
152   void runOnOperation() override;
153 };
154 } // namespace
155 
runOnOperation()156 void ConvertOpenACCToLLVMPass::runOnOperation() {
157   auto op = getOperation();
158   auto *context = op.getContext();
159 
160   // Convert to OpenACC operations with LLVM IR dialect
161   RewritePatternSet patterns(context);
162   LLVMTypeConverter converter(context);
163   populateOpenACCToLLVMConversionPatterns(converter, patterns);
164 
165   ConversionTarget target(*context);
166   target.addLegalDialect<LLVM::LLVMDialect>();
167   target.addLegalOp<UnrealizedConversionCastOp>();
168 
169   auto allDataOperandsAreConverted = [](ValueRange operands) {
170     for (Value operand : operands) {
171       if (!DataDescriptor::isValid(operand) &&
172           !operand.getType().isa<LLVM::LLVMPointerType>())
173         return false;
174     }
175     return true;
176   };
177 
178   target.addDynamicallyLegalOp<acc::DataOp>(
179       [allDataOperandsAreConverted](acc::DataOp op) {
180         return allDataOperandsAreConverted(op.copyOperands()) &&
181                allDataOperandsAreConverted(op.copyinOperands()) &&
182                allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
183                allDataOperandsAreConverted(op.copyoutOperands()) &&
184                allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
185                allDataOperandsAreConverted(op.createOperands()) &&
186                allDataOperandsAreConverted(op.createZeroOperands()) &&
187                allDataOperandsAreConverted(op.noCreateOperands()) &&
188                allDataOperandsAreConverted(op.presentOperands()) &&
189                allDataOperandsAreConverted(op.deviceptrOperands()) &&
190                allDataOperandsAreConverted(op.attachOperands());
191       });
192 
193   target.addDynamicallyLegalOp<acc::EnterDataOp>(
194       [allDataOperandsAreConverted](acc::EnterDataOp op) {
195         return allDataOperandsAreConverted(op.copyinOperands()) &&
196                allDataOperandsAreConverted(op.createOperands()) &&
197                allDataOperandsAreConverted(op.createZeroOperands()) &&
198                allDataOperandsAreConverted(op.attachOperands());
199       });
200 
201   target.addDynamicallyLegalOp<acc::ExitDataOp>(
202       [allDataOperandsAreConverted](acc::ExitDataOp op) {
203         return allDataOperandsAreConverted(op.copyoutOperands()) &&
204                allDataOperandsAreConverted(op.deleteOperands()) &&
205                allDataOperandsAreConverted(op.detachOperands());
206       });
207 
208   target.addDynamicallyLegalOp<acc::ParallelOp>(
209       [allDataOperandsAreConverted](acc::ParallelOp op) {
210         return allDataOperandsAreConverted(op.reductionOperands()) &&
211                allDataOperandsAreConverted(op.copyOperands()) &&
212                allDataOperandsAreConverted(op.copyinOperands()) &&
213                allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
214                allDataOperandsAreConverted(op.copyoutOperands()) &&
215                allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
216                allDataOperandsAreConverted(op.createOperands()) &&
217                allDataOperandsAreConverted(op.createZeroOperands()) &&
218                allDataOperandsAreConverted(op.noCreateOperands()) &&
219                allDataOperandsAreConverted(op.presentOperands()) &&
220                allDataOperandsAreConverted(op.devicePtrOperands()) &&
221                allDataOperandsAreConverted(op.attachOperands()) &&
222                allDataOperandsAreConverted(op.gangPrivateOperands()) &&
223                allDataOperandsAreConverted(op.gangFirstPrivateOperands());
224       });
225 
226   target.addDynamicallyLegalOp<acc::UpdateOp>(
227       [allDataOperandsAreConverted](acc::UpdateOp op) {
228         return allDataOperandsAreConverted(op.hostOperands()) &&
229                allDataOperandsAreConverted(op.deviceOperands());
230       });
231 
232   if (failed(applyPartialConversion(op, target, std::move(patterns))))
233     signalPassFailure();
234 }
235 
236 std::unique_ptr<OperationPass<ModuleOp>>
createConvertOpenACCToLLVMPass()237 mlir::createConvertOpenACCToLLVMPass() {
238   return std::make_unique<ConvertOpenACCToLLVMPass>();
239 }
240