1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenACC dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/OpenACC/OpenACC.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
22 
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 
29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Flag values are extracted from openmp/libomptarget/include/omptarget.h and
36 /// mapped to corresponding OpenACC flags.
37 static constexpr uint64_t kCreateFlag = 0x000;
38 static constexpr uint64_t kDeviceCopyinFlag = 0x001;
39 static constexpr uint64_t kHostCopyoutFlag = 0x002;
40 static constexpr uint64_t kCopyFlag = kDeviceCopyinFlag | kHostCopyoutFlag;
41 static constexpr uint64_t kPresentFlag = 0x1000;
42 static constexpr uint64_t kDeleteFlag = 0x008;
43 // Runtime extension to implement the OpenACC second reference counter.
44 static constexpr uint64_t kHoldFlag = 0x2000;
45 
46 /// Default value for the device id
47 static constexpr int64_t kDefaultDevice = -1;
48 
49 /// Create a constant string location from the MLIR Location information.
createSourceLocStrFromLocation(Location loc,OpenACCIRBuilder & builder,StringRef name)50 static llvm::Constant *createSourceLocStrFromLocation(Location loc,
51                                                       OpenACCIRBuilder &builder,
52                                                       StringRef name) {
53   if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
54     StringRef fileName = fileLoc.getFilename();
55     unsigned lineNo = fileLoc.getLine();
56     unsigned colNo = fileLoc.getColumn();
57     return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo);
58   } else {
59     std::string locStr;
60     llvm::raw_string_ostream locOS(locStr);
61     locOS << loc;
62     return builder.getOrCreateSrcLocStr(locOS.str());
63   }
64 }
65 
66 /// Create the location struct from the operation location information.
createSourceLocationInfo(OpenACCIRBuilder & builder,Operation * op)67 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder,
68                                              Operation *op) {
69   auto loc = op->getLoc();
70   auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
71   StringRef funcName = funcOp ? funcOp.getName() : "unknown";
72   llvm::Constant *locStr =
73       createSourceLocStrFromLocation(loc, builder, funcName);
74   return builder.getOrCreateIdent(locStr);
75 }
76 
77 /// Create a constant string representing the mapping information extracted from
78 /// the MLIR location information.
createMappingInformation(Location loc,OpenACCIRBuilder & builder)79 static llvm::Constant *createMappingInformation(Location loc,
80                                                 OpenACCIRBuilder &builder) {
81   if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
82     StringRef name = nameLoc.getName();
83     return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name);
84   } else {
85     return createSourceLocStrFromLocation(loc, builder, "unknown");
86   }
87 }
88 
89 /// Return the runtime function used to lower the given operation.
getAssociatedFunction(OpenACCIRBuilder & builder,Operation * op)90 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder,
91                                              Operation *op) {
92   return llvm::TypeSwitch<Operation *, llvm::Function *>(op)
93       .Case([&](acc::EnterDataOp) {
94         return builder.getOrCreateRuntimeFunctionPtr(
95             llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
96       })
97       .Case([&](acc::ExitDataOp) {
98         return builder.getOrCreateRuntimeFunctionPtr(
99             llvm::omp::OMPRTL___tgt_target_data_end_mapper);
100       })
101       .Case([&](acc::UpdateOp) {
102         return builder.getOrCreateRuntimeFunctionPtr(
103             llvm::omp::OMPRTL___tgt_target_data_update_mapper);
104       });
105   llvm_unreachable("Unknown OpenACC operation");
106 }
107 
108 /// Computes the size of type in bytes.
getSizeInBytes(llvm::IRBuilderBase & builder,llvm::Value * basePtr)109 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder,
110                                    llvm::Value *basePtr) {
111   llvm::LLVMContext &ctx = builder.getContext();
112   llvm::Value *null =
113       llvm::Constant::getNullValue(basePtr->getType()->getPointerTo());
114   llvm::Value *sizeGep =
115       builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1));
116   llvm::Value *sizePtrToInt =
117       builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx));
118   return sizePtrToInt;
119 }
120 
121 /// Extract pointer, size and mapping information from operands
122 /// to populate the future functions arguments.
123 static LogicalResult
processOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,Operation * op,ValueRange operands,unsigned totalNbOperand,uint64_t operandFlag,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,unsigned & index,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)124 processOperands(llvm::IRBuilderBase &builder,
125                 LLVM::ModuleTranslation &moduleTranslation, Operation *op,
126                 ValueRange operands, unsigned totalNbOperand,
127                 uint64_t operandFlag, SmallVector<uint64_t> &flags,
128                 SmallVectorImpl<llvm::Constant *> &names, unsigned &index,
129                 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
130   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
131   llvm::LLVMContext &ctx = builder.getContext();
132   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
133   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
134   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
135   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
136 
137   for (Value data : operands) {
138     llvm::Value *dataValue = moduleTranslation.lookupValue(data);
139 
140     llvm::Value *dataPtrBase;
141     llvm::Value *dataPtr;
142     llvm::Value *dataSize;
143 
144     // Handle operands that were converted to DataDescriptor.
145     if (DataDescriptor::isValid(data)) {
146       dataPtrBase =
147           builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor);
148       dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor);
149       dataSize =
150           builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor);
151     } else if (data.getType().isa<LLVM::LLVMPointerType>()) {
152       dataPtrBase = dataValue;
153       dataPtr = dataValue;
154       dataSize = getSizeInBytes(builder, dataValue);
155     } else {
156       return op->emitOpError()
157              << "Data operand must be legalized before translation."
158              << "Unsupported type: " << data.getType();
159     }
160 
161     // Store base pointer extracted from operand into the i-th position of
162     // argBase.
163     llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
164         arrI8PtrTy, mapperAllocas.ArgsBase,
165         {builder.getInt32(0), builder.getInt32(index)});
166     llvm::Value *ptrBaseCast = builder.CreateBitCast(
167         ptrBaseGEP, dataPtrBase->getType()->getPointerTo());
168     builder.CreateStore(dataPtrBase, ptrBaseCast);
169 
170     // Store pointer extracted from operand into the i-th position of args.
171     llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
172         arrI8PtrTy, mapperAllocas.Args,
173         {builder.getInt32(0), builder.getInt32(index)});
174     llvm::Value *ptrCast =
175         builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo());
176     builder.CreateStore(dataPtr, ptrCast);
177 
178     // Store size extracted from operand into the i-th position of argSizes.
179     llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
180         arrI64Ty, mapperAllocas.ArgSizes,
181         {builder.getInt32(0), builder.getInt32(index)});
182     builder.CreateStore(dataSize, sizeGEP);
183 
184     flags.push_back(operandFlag);
185     llvm::Constant *mapName =
186         createMappingInformation(data.getLoc(), *accBuilder);
187     names.push_back(mapName);
188     ++index;
189   }
190   return success();
191 }
192 
193 /// Process data operands from acc::EnterDataOp
194 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::EnterDataOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)195 processDataOperands(llvm::IRBuilderBase &builder,
196                     LLVM::ModuleTranslation &moduleTranslation,
197                     acc::EnterDataOp op, SmallVector<uint64_t> &flags,
198                     SmallVectorImpl<llvm::Constant *> &names,
199                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
200   // TODO add `create_zero` and `attach` operands
201 
202   unsigned index = 0;
203 
204   // Create operands are handled as `alloc` call.
205   if (failed(processOperands(builder, moduleTranslation, op,
206                              op.createOperands(), op.getNumDataOperands(),
207                              kCreateFlag, flags, names, index, mapperAllocas)))
208     return failure();
209 
210   // Copyin operands are handled as `to` call.
211   if (failed(processOperands(builder, moduleTranslation, op,
212                              op.copyinOperands(), op.getNumDataOperands(),
213                              kDeviceCopyinFlag, flags, names, index,
214                              mapperAllocas)))
215     return failure();
216 
217   return success();
218 }
219 
220 /// Process data operands from acc::ExitDataOp
221 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::ExitDataOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)222 processDataOperands(llvm::IRBuilderBase &builder,
223                     LLVM::ModuleTranslation &moduleTranslation,
224                     acc::ExitDataOp op, SmallVector<uint64_t> &flags,
225                     SmallVectorImpl<llvm::Constant *> &names,
226                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
227   // TODO add `detach` operands
228 
229   unsigned index = 0;
230 
231   // Delete operands are handled as `delete` call.
232   if (failed(processOperands(builder, moduleTranslation, op,
233                              op.deleteOperands(), op.getNumDataOperands(),
234                              kDeleteFlag, flags, names, index, mapperAllocas)))
235     return failure();
236 
237   // Copyout operands are handled as `from` call.
238   if (failed(processOperands(builder, moduleTranslation, op,
239                              op.copyoutOperands(), op.getNumDataOperands(),
240                              kHostCopyoutFlag, flags, names, index,
241                              mapperAllocas)))
242     return failure();
243 
244   return success();
245 }
246 
247 /// Process data operands from acc::UpdateOp
248 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::UpdateOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)249 processDataOperands(llvm::IRBuilderBase &builder,
250                     LLVM::ModuleTranslation &moduleTranslation,
251                     acc::UpdateOp op, SmallVector<uint64_t> &flags,
252                     SmallVectorImpl<llvm::Constant *> &names,
253                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
254   unsigned index = 0;
255 
256   // Host operands are handled as `from` call.
257   if (failed(processOperands(builder, moduleTranslation, op, op.hostOperands(),
258                              op.getNumDataOperands(), kHostCopyoutFlag, flags,
259                              names, index, mapperAllocas)))
260     return failure();
261 
262   // Device operands are handled as `to` call.
263   if (failed(processOperands(builder, moduleTranslation, op,
264                              op.deviceOperands(), op.getNumDataOperands(),
265                              kDeviceCopyinFlag, flags, names, index,
266                              mapperAllocas)))
267     return failure();
268 
269   return success();
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // Conversion functions
274 //===----------------------------------------------------------------------===//
275 
276 /// Converts an OpenACC data operation into LLVM IR.
convertDataOp(acc::DataOp & op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation)277 static LogicalResult convertDataOp(acc::DataOp &op,
278                                    llvm::IRBuilderBase &builder,
279                                    LLVM::ModuleTranslation &moduleTranslation) {
280   llvm::LLVMContext &ctx = builder.getContext();
281   auto enclosingFuncOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>();
282   llvm::Function *enclosingFunction =
283       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
284 
285   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
286 
287   llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
288 
289   llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
290       llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
291 
292   llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
293       llvm::omp::OMPRTL___tgt_target_data_end_mapper);
294 
295   // Number of arguments in the data operation.
296   unsigned totalNbOperand = op.getNumDataOperands();
297 
298   struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
299   OpenACCIRBuilder::InsertPointTy allocaIP(
300       &enclosingFunction->getEntryBlock(),
301       enclosingFunction->getEntryBlock().getFirstInsertionPt());
302   accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
303                                   mapperAllocas);
304 
305   SmallVector<uint64_t> flags;
306   SmallVector<llvm::Constant *> names;
307   unsigned index = 0;
308 
309   // TODO handle no_create, deviceptr and attach operands.
310 
311   if (failed(processOperands(builder, moduleTranslation, op, op.copyOperands(),
312                              totalNbOperand, kCopyFlag | kHoldFlag, flags,
313                              names, index, mapperAllocas)))
314     return failure();
315 
316   if (failed(processOperands(
317           builder, moduleTranslation, op, op.copyinOperands(), totalNbOperand,
318           kDeviceCopyinFlag | kHoldFlag, flags, names, index, mapperAllocas)))
319     return failure();
320 
321   // TODO copyin readonly currenlty handled as copyin. Update when extension
322   // available.
323   if (failed(processOperands(builder, moduleTranslation, op,
324                              op.copyinReadonlyOperands(), totalNbOperand,
325                              kDeviceCopyinFlag | kHoldFlag, flags, names, index,
326                              mapperAllocas)))
327     return failure();
328 
329   if (failed(processOperands(
330           builder, moduleTranslation, op, op.copyoutOperands(), totalNbOperand,
331           kHostCopyoutFlag | kHoldFlag, flags, names, index, mapperAllocas)))
332     return failure();
333 
334   // TODO copyout zero currenlty handled as copyout. Update when extension
335   // available.
336   if (failed(processOperands(builder, moduleTranslation, op,
337                              op.copyoutZeroOperands(), totalNbOperand,
338                              kHostCopyoutFlag | kHoldFlag, flags, names, index,
339                              mapperAllocas)))
340     return failure();
341 
342   if (failed(processOperands(builder, moduleTranslation, op,
343                              op.createOperands(), totalNbOperand,
344                              kCreateFlag | kHoldFlag, flags, names, index,
345                              mapperAllocas)))
346     return failure();
347 
348   // TODO create zero currenlty handled as create. Update when extension
349   // available.
350   if (failed(processOperands(builder, moduleTranslation, op,
351                              op.createZeroOperands(), totalNbOperand,
352                              kCreateFlag | kHoldFlag, flags, names, index,
353                              mapperAllocas)))
354     return failure();
355 
356   if (failed(processOperands(builder, moduleTranslation, op,
357                              op.presentOperands(), totalNbOperand,
358                              kPresentFlag | kHoldFlag, flags, names, index,
359                              mapperAllocas)))
360     return failure();
361 
362   llvm::GlobalVariable *maptypes =
363       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
364   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
365       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
366       maptypes, /*Idx0=*/0, /*Idx1=*/0);
367 
368   llvm::GlobalVariable *mapnames =
369       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
370   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
371       llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
372       mapnames, /*Idx0=*/0, /*Idx1=*/0);
373 
374   // Create call to start the data region.
375   accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo,
376                              maptypesArg, mapnamesArg, mapperAllocas,
377                              kDefaultDevice, totalNbOperand);
378 
379   // Convert the region.
380   llvm::BasicBlock *entryBlock = nullptr;
381 
382   for (Block &bb : op.region()) {
383     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
384         ctx, "acc.data", builder.GetInsertBlock()->getParent());
385     if (entryBlock == nullptr)
386       entryBlock = llvmBB;
387     moduleTranslation.mapBlock(&bb, llvmBB);
388   }
389 
390   auto afterDataRegion = builder.saveIP();
391 
392   llvm::BranchInst *sourceTerminator = builder.CreateBr(entryBlock);
393 
394   builder.restoreIP(afterDataRegion);
395   llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
396       ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
397 
398   SetVector<Block *> blocks =
399       LLVM::detail::getTopologicallySortedBlocks(op.region());
400   for (Block *bb : blocks) {
401     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
402     if (bb->isEntryBlock()) {
403       assert(sourceTerminator->getNumSuccessors() == 1 &&
404              "provided entry block has multiple successors");
405       sourceTerminator->setSuccessor(0, llvmBB);
406     }
407 
408     if (failed(
409             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
410       return failure();
411     }
412 
413     if (isa<acc::TerminatorOp, acc::YieldOp>(bb->getTerminator()))
414       builder.CreateBr(endDataBlock);
415   }
416 
417   // Create call to end the data region.
418   builder.SetInsertPoint(endDataBlock);
419   accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo,
420                              maptypesArg, mapnamesArg, mapperAllocas,
421                              kDefaultDevice, totalNbOperand);
422 
423   return success();
424 }
425 
426 /// Converts an OpenACC standalone data operation into LLVM IR.
427 template <typename OpTy>
428 static LogicalResult
convertStandaloneDataOp(OpTy & op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation)429 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
430                         LLVM::ModuleTranslation &moduleTranslation) {
431   auto enclosingFuncOp =
432       op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>();
433   llvm::Function *enclosingFunction =
434       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
435 
436   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
437 
438   auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
439   auto *mapperFunc = getAssociatedFunction(*accBuilder, op);
440 
441   // Number of arguments in the enter_data operation.
442   unsigned totalNbOperand = op.getNumDataOperands();
443 
444   llvm::LLVMContext &ctx = builder.getContext();
445 
446   struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
447   OpenACCIRBuilder::InsertPointTy allocaIP(
448       &enclosingFunction->getEntryBlock(),
449       enclosingFunction->getEntryBlock().getFirstInsertionPt());
450   accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
451                                   mapperAllocas);
452 
453   SmallVector<uint64_t> flags;
454   SmallVector<llvm::Constant *> names;
455 
456   if (failed(processDataOperands(builder, moduleTranslation, op, flags, names,
457                                  mapperAllocas)))
458     return failure();
459 
460   llvm::GlobalVariable *maptypes =
461       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
462   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
463       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
464       maptypes, /*Idx0=*/0, /*Idx1=*/0);
465 
466   llvm::GlobalVariable *mapnames =
467       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
468   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
469       llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
470       mapnames, /*Idx0=*/0, /*Idx1=*/0);
471 
472   accBuilder->emitMapperCall(builder.saveIP(), mapperFunc, srcLocInfo,
473                              maptypesArg, mapnamesArg, mapperAllocas,
474                              kDefaultDevice, totalNbOperand);
475 
476   return success();
477 }
478 
479 namespace {
480 
481 /// Implementation of the dialect interface that converts operations belonging
482 /// to the OpenACC dialect to LLVM IR.
483 class OpenACCDialectLLVMIRTranslationInterface
484     : public LLVMTranslationDialectInterface {
485 public:
486   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
487 
488   /// Translates the given operation to LLVM IR using the provided IR builder
489   /// and saving the state in `moduleTranslation`.
490   LogicalResult
491   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
492                    LLVM::ModuleTranslation &moduleTranslation) const final;
493 };
494 
495 } // end namespace
496 
497 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
498 /// (including OpenACC runtime calls).
convertOperation(Operation * op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation) const499 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
500     Operation *op, llvm::IRBuilderBase &builder,
501     LLVM::ModuleTranslation &moduleTranslation) const {
502 
503   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
504       .Case([&](acc::DataOp dataOp) {
505         return convertDataOp(dataOp, builder, moduleTranslation);
506       })
507       .Case([&](acc::EnterDataOp enterDataOp) {
508         return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder,
509                                                          moduleTranslation);
510       })
511       .Case([&](acc::ExitDataOp exitDataOp) {
512         return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder,
513                                                         moduleTranslation);
514       })
515       .Case([&](acc::UpdateOp updateOp) {
516         return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder,
517                                                       moduleTranslation);
518       })
519       .Case<acc::TerminatorOp, acc::YieldOp>([](auto op) {
520         // `yield` and `terminator` can be just omitted. The block structure was
521         // created in the function that handles their parent operation.
522         assert(op->getNumOperands() == 0 &&
523                "unexpected OpenACC terminator with operands");
524         return success();
525       })
526       .Default([&](Operation *op) {
527         return op->emitError("unsupported OpenACC operation: ")
528                << op->getName();
529       });
530 }
531 
registerOpenACCDialectTranslation(DialectRegistry & registry)532 void mlir::registerOpenACCDialectTranslation(DialectRegistry &registry) {
533   registry.insert<acc::OpenACCDialect>();
534   registry.addDialectInterface<acc::OpenACCDialect,
535                                OpenACCDialectLLVMIRTranslationInterface>();
536 }
537 
registerOpenACCDialectTranslation(MLIRContext & context)538 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
539   DialectRegistry registry;
540   registerOpenACCDialectTranslation(registry);
541   context.appendDialectRegistry(registry);
542 }
543