1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
9 #include "mlir/Dialect/OpenACC/OpenACC.h"
10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "mlir/IR/StandardTypes.h"
14 
15 using namespace mlir;
16 using namespace acc;
17 
18 //===----------------------------------------------------------------------===//
19 // OpenACC operations
20 //===----------------------------------------------------------------------===//
21 
initialize()22 void OpenACCDialect::initialize() {
23   addOperations<
24 #define GET_OP_LIST
25 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
26       >();
27 }
28 
29 template <typename StructureOp>
parseRegions(OpAsmParser & parser,OperationState & state,unsigned nRegions=1)30 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
31                                 unsigned nRegions = 1) {
32 
33   SmallVector<Region *, 2> regions;
34   for (unsigned i = 0; i < nRegions; ++i)
35     regions.push_back(state.addRegion());
36 
37   for (Region *region : regions) {
38     if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
39       return failure();
40   }
41 
42   return success();
43 }
44 
45 static ParseResult
parseOperandList(OpAsmParser & parser,StringRef keyword,SmallVectorImpl<OpAsmParser::OperandType> & args,SmallVectorImpl<Type> & argTypes,OperationState & result)46 parseOperandList(OpAsmParser &parser, StringRef keyword,
47                  SmallVectorImpl<OpAsmParser::OperandType> &args,
48                  SmallVectorImpl<Type> &argTypes, OperationState &result) {
49   if (failed(parser.parseOptionalKeyword(keyword)))
50     return success();
51 
52   if (failed(parser.parseLParen()))
53     return failure();
54 
55   // Exit early if the list is empty.
56   if (succeeded(parser.parseOptionalRParen()))
57     return success();
58 
59   do {
60     OpAsmParser::OperandType arg;
61     Type type;
62 
63     if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
64       return failure();
65 
66     args.push_back(arg);
67     argTypes.push_back(type);
68   } while (succeeded(parser.parseOptionalComma()));
69 
70   if (failed(parser.parseRParen()))
71     return failure();
72 
73   return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
74                                 result.operands);
75 }
76 
printOperandList(Operation::operand_range operands,StringRef listName,OpAsmPrinter & printer)77 static void printOperandList(Operation::operand_range operands,
78                              StringRef listName, OpAsmPrinter &printer) {
79 
80   if (operands.size() > 0) {
81     printer << " " << listName << "(";
82     llvm::interleaveComma(operands, printer, [&](Value op) {
83       printer << op << ": " << op.getType();
84     });
85     printer << ")";
86   }
87 }
88 
parseOptionalOperand(OpAsmParser & parser,StringRef keyword,OpAsmParser::OperandType & operand,Type type,bool & hasOptional,OperationState & result)89 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
90                                         OpAsmParser::OperandType &operand,
91                                         Type type, bool &hasOptional,
92                                         OperationState &result) {
93   hasOptional = false;
94   if (succeeded(parser.parseOptionalKeyword(keyword))) {
95     hasOptional = true;
96     if (parser.parseLParen() || parser.parseOperand(operand) ||
97         parser.resolveOperand(operand, type, result.operands) ||
98         parser.parseRParen())
99       return failure();
100   }
101   return success();
102 }
103 
parseOperandAndType(OpAsmParser & parser,OperationState & result)104 static ParseResult parseOperandAndType(OpAsmParser &parser,
105                                        OperationState &result) {
106   OpAsmParser::OperandType operand;
107   Type type;
108   if (parser.parseOperand(operand) || parser.parseColonType(type) ||
109       parser.resolveOperand(operand, type, result.operands))
110     return failure();
111   return success();
112 }
113 
114 /// Parse optional operand and its type wrapped in parenthesis prefixed with
115 /// a keyword.
116 /// Example:
117 ///   keyword `(` %vectorLength: i64 `)`
parseOptionalOperandAndType(OpAsmParser & parser,StringRef keyword,OperationState & result)118 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
119                                                        StringRef keyword,
120                                                        OperationState &result) {
121   OpAsmParser::OperandType operand;
122   if (succeeded(parser.parseOptionalKeyword(keyword))) {
123     return failure(parser.parseLParen() ||
124                    parseOperandAndType(parser, result) || parser.parseRParen());
125   }
126   return llvm::None;
127 }
128 
129 /// Parse optional operand and its type wrapped in parenthesis.
130 /// Example:
131 ///   `(` %vectorLength: i64 `)`
parseOptionalOperandAndType(OpAsmParser & parser,OperationState & result)132 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
133                                                        OperationState &result) {
134   if (succeeded(parser.parseOptionalLParen())) {
135     return failure(parseOperandAndType(parser, result) || parser.parseRParen());
136   }
137   return llvm::None;
138 }
139 
140 /// Parse optional operand with its type prefixed with prefixKeyword `=`.
141 /// Example:
142 ///   num=%gangNum: i32
parserOptionalOperandAndTypeWithPrefix(OpAsmParser & parser,OperationState & result,StringRef prefixKeyword)143 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix(
144     OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) {
145   if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) {
146     parser.parseEqual();
147     return parseOperandAndType(parser, result);
148   }
149   return llvm::None;
150 }
151 
isComputeOperation(Operation * op)152 static bool isComputeOperation(Operation *op) {
153   return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // ParallelOp
158 //===----------------------------------------------------------------------===//
159 
160 /// Parse acc.parallel operation
161 /// operation := `acc.parallel` `async` `(` index `)`?
162 ///                             `wait` `(` index-list `)`?
163 ///                             `num_gangs` `(` value `)`?
164 ///                             `num_workers` `(` value `)`?
165 ///                             `vector_length` `(` value `)`?
166 ///                             `if` `(` value `)`?
167 ///                             `self` `(` value `)`?
168 ///                             `reduction` `(` value-list `)`?
169 ///                             `copy` `(` value-list `)`?
170 ///                             `copyin` `(` value-list `)`?
171 ///                             `copyin_readonly` `(` value-list `)`?
172 ///                             `copyout` `(` value-list `)`?
173 ///                             `copyout_zero` `(` value-list `)`?
174 ///                             `create` `(` value-list `)`?
175 ///                             `create_zero` `(` value-list `)`?
176 ///                             `no_create` `(` value-list `)`?
177 ///                             `present` `(` value-list `)`?
178 ///                             `deviceptr` `(` value-list `)`?
179 ///                             `attach` `(` value-list `)`?
180 ///                             `private` `(` value-list `)`?
181 ///                             `firstprivate` `(` value-list `)`?
182 ///                             region attr-dict?
parseParallelOp(OpAsmParser & parser,OperationState & result)183 static ParseResult parseParallelOp(OpAsmParser &parser,
184                                    OperationState &result) {
185   Builder &builder = parser.getBuilder();
186   SmallVector<OpAsmParser::OperandType, 8> privateOperands,
187       firstprivateOperands, copyOperands, copyinOperands,
188       copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands,
189       createOperands, createZeroOperands, noCreateOperands, presentOperands,
190       devicePtrOperands, attachOperands, waitOperands, reductionOperands;
191   SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes,
192       copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes,
193       copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes,
194       createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
195       deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
196       firstprivateOperandTypes;
197 
198   SmallVector<Type, 8> operandTypes;
199   OpAsmParser::OperandType ifCond, selfCond;
200   bool hasIfCond = false, hasSelfCond = false;
201   OptionalParseResult async, numGangs, numWorkers, vectorLength;
202   Type i1Type = builder.getI1Type();
203 
204   // async()?
205   async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
206                                       result);
207   if (async.hasValue() && failed(*async))
208     return failure();
209 
210   // wait()?
211   if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
212                               waitOperands, waitOperandTypes, result)))
213     return failure();
214 
215   // num_gangs(value)?
216   numGangs = parseOptionalOperandAndType(
217       parser, ParallelOp::getNumGangsKeyword(), result);
218   if (numGangs.hasValue() && failed(*numGangs))
219     return failure();
220 
221   // num_workers(value)?
222   numWorkers = parseOptionalOperandAndType(
223       parser, ParallelOp::getNumWorkersKeyword(), result);
224   if (numWorkers.hasValue() && failed(*numWorkers))
225     return failure();
226 
227   // vector_length(value)?
228   vectorLength = parseOptionalOperandAndType(
229       parser, ParallelOp::getVectorLengthKeyword(), result);
230   if (vectorLength.hasValue() && failed(*vectorLength))
231     return failure();
232 
233   // if()?
234   if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
235                                   i1Type, hasIfCond, result)))
236     return failure();
237 
238   // self()?
239   if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
240                                   selfCond, i1Type, hasSelfCond, result)))
241     return failure();
242 
243   // reduction()?
244   if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
245                               reductionOperands, reductionOperandTypes,
246                               result)))
247     return failure();
248 
249   // copy()?
250   if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
251                               copyOperands, copyOperandTypes, result)))
252     return failure();
253 
254   // copyin()?
255   if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
256                               copyinOperands, copyinOperandTypes, result)))
257     return failure();
258 
259   // copyin_readonly()?
260   if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(),
261                               copyinReadonlyOperands,
262                               copyinReadonlyOperandTypes, result)))
263     return failure();
264 
265   // copyout()?
266   if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
267                               copyoutOperands, copyoutOperandTypes, result)))
268     return failure();
269 
270   // copyout_zero()?
271   if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(),
272                               copyoutZeroOperands, copyoutZeroOperandTypes,
273                               result)))
274     return failure();
275 
276   // create()?
277   if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
278                               createOperands, createOperandTypes, result)))
279     return failure();
280 
281   // create_zero()?
282   if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(),
283                               createZeroOperands, createZeroOperandTypes,
284                               result)))
285     return failure();
286 
287   // no_create()?
288   if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
289                               noCreateOperands, noCreateOperandTypes, result)))
290     return failure();
291 
292   // present()?
293   if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
294                               presentOperands, presentOperandTypes, result)))
295     return failure();
296 
297   // deviceptr()?
298   if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
299                               devicePtrOperands, deviceptrOperandTypes,
300                               result)))
301     return failure();
302 
303   // attach()?
304   if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
305                               attachOperands, attachOperandTypes, result)))
306     return failure();
307 
308   // private()?
309   if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
310                               privateOperands, privateOperandTypes, result)))
311     return failure();
312 
313   // firstprivate()?
314   if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
315                               firstprivateOperands, firstprivateOperandTypes,
316                               result)))
317     return failure();
318 
319   // Parallel op region
320   if (failed(parseRegions<ParallelOp>(parser, result)))
321     return failure();
322 
323   result.addAttribute(
324       ParallelOp::getOperandSegmentSizeAttr(),
325       builder.getI32VectorAttr(
326           {static_cast<int32_t>(async.hasValue() ? 1 : 0),
327            static_cast<int32_t>(waitOperands.size()),
328            static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
329            static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
330            static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
331            static_cast<int32_t>(hasIfCond ? 1 : 0),
332            static_cast<int32_t>(hasSelfCond ? 1 : 0),
333            static_cast<int32_t>(reductionOperands.size()),
334            static_cast<int32_t>(copyOperands.size()),
335            static_cast<int32_t>(copyinOperands.size()),
336            static_cast<int32_t>(copyinReadonlyOperands.size()),
337            static_cast<int32_t>(copyoutOperands.size()),
338            static_cast<int32_t>(copyoutZeroOperands.size()),
339            static_cast<int32_t>(createOperands.size()),
340            static_cast<int32_t>(createZeroOperands.size()),
341            static_cast<int32_t>(noCreateOperands.size()),
342            static_cast<int32_t>(presentOperands.size()),
343            static_cast<int32_t>(devicePtrOperands.size()),
344            static_cast<int32_t>(attachOperands.size()),
345            static_cast<int32_t>(privateOperands.size()),
346            static_cast<int32_t>(firstprivateOperands.size())}));
347 
348   // Additional attributes
349   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
350     return failure();
351 
352   return success();
353 }
354 
print(OpAsmPrinter & printer,ParallelOp & op)355 static void print(OpAsmPrinter &printer, ParallelOp &op) {
356   printer << ParallelOp::getOperationName();
357 
358   // async()?
359   if (Value async = op.async())
360     printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
361             << async.getType() << ")";
362 
363   // wait()?
364   printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer);
365 
366   // num_gangs()?
367   if (Value numGangs = op.numGangs())
368     printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
369             << ": " << numGangs.getType() << ")";
370 
371   // num_workers()?
372   if (Value numWorkers = op.numWorkers())
373     printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
374             << ": " << numWorkers.getType() << ")";
375 
376   // vector_length()?
377   if (Value vectorLength = op.vectorLength())
378     printer << " " << ParallelOp::getVectorLengthKeyword() << "("
379             << vectorLength << ": " << vectorLength.getType() << ")";
380 
381   // if()?
382   if (Value ifCond = op.ifCond())
383     printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
384 
385   // self()?
386   if (Value selfCond = op.selfCond())
387     printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
388 
389   // reduction()?
390   printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(),
391                    printer);
392 
393   // copy()?
394   printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer);
395 
396   // copyin()?
397   printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(),
398                    printer);
399 
400   // copyin_readonly()?
401   printOperandList(op.copyinReadonlyOperands(),
402                    ParallelOp::getCopyinReadonlyKeyword(), printer);
403 
404   // copyout()?
405   printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(),
406                    printer);
407 
408   // copyout_zero()?
409   printOperandList(op.copyoutZeroOperands(),
410                    ParallelOp::getCopyoutZeroKeyword(), printer);
411 
412   // create()?
413   printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(),
414                    printer);
415 
416   // create_zero()?
417   printOperandList(op.createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
418                    printer);
419 
420   // no_create()?
421   printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(),
422                    printer);
423 
424   // present()?
425   printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(),
426                    printer);
427 
428   // deviceptr()?
429   printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
430                    printer);
431 
432   // attach()?
433   printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(),
434                    printer);
435 
436   // private()?
437   printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
438                    printer);
439 
440   // firstprivate()?
441   printOperandList(op.gangFirstPrivateOperands(),
442                    ParallelOp::getFirstPrivateKeyword(), printer);
443 
444   printer.printRegion(op.region(),
445                       /*printEntryBlockArgs=*/false,
446                       /*printBlockTerminators=*/true);
447   printer.printOptionalAttrDictWithKeyword(
448       op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // LoopOp
453 //===----------------------------------------------------------------------===//
454 
455 /// Parse acc.loop operation
456 /// operation := `acc.loop`
457 ///              (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
458 ///              (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
459 ///              (`vector_length` `(` value `)`)?
460 ///              (`tile` `(` value-list `)`)?
461 ///              (`private` `(` value-list `)`)?
462 ///              (`reduction` `(` value-list `)`)?
463 ///              region attr-dict?
parseLoopOp(OpAsmParser & parser,OperationState & result)464 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
465   Builder &builder = parser.getBuilder();
466   unsigned executionMapping = OpenACCExecMapping::NONE;
467   SmallVector<Type, 8> operandTypes;
468   SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
469   SmallVector<OpAsmParser::OperandType, 8> tileOperands;
470   OptionalParseResult gangNum, gangStatic, worker, vector;
471 
472   // gang?
473   if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
474     executionMapping |= OpenACCExecMapping::GANG;
475 
476   // optional gang operand
477   if (succeeded(parser.parseOptionalLParen())) {
478     gangNum = parserOptionalOperandAndTypeWithPrefix(
479         parser, result, LoopOp::getGangNumKeyword());
480     if (gangNum.hasValue() && failed(*gangNum))
481       return failure();
482     parser.parseOptionalComma();
483     gangStatic = parserOptionalOperandAndTypeWithPrefix(
484         parser, result, LoopOp::getGangStaticKeyword());
485     if (gangStatic.hasValue() && failed(*gangStatic))
486       return failure();
487     parser.parseOptionalComma();
488     if (failed(parser.parseRParen()))
489       return failure();
490   }
491 
492   // worker?
493   if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
494     executionMapping |= OpenACCExecMapping::WORKER;
495 
496   // optional worker operand
497   worker = parseOptionalOperandAndType(parser, result);
498   if (worker.hasValue() && failed(*worker))
499     return failure();
500 
501   // vector?
502   if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
503     executionMapping |= OpenACCExecMapping::VECTOR;
504 
505   // optional vector operand
506   vector = parseOptionalOperandAndType(parser, result);
507   if (vector.hasValue() && failed(*vector))
508     return failure();
509 
510   // tile()?
511   if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
512                               operandTypes, result)))
513     return failure();
514 
515   // private()?
516   if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
517                               privateOperands, operandTypes, result)))
518     return failure();
519 
520   // reduction()?
521   if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
522                               reductionOperands, operandTypes, result)))
523     return failure();
524 
525   if (executionMapping != acc::OpenACCExecMapping::NONE)
526     result.addAttribute(LoopOp::getExecutionMappingAttrName(),
527                         builder.getI64IntegerAttr(executionMapping));
528 
529   // Parse optional results in case there is a reduce.
530   if (parser.parseOptionalArrowTypeList(result.types))
531     return failure();
532 
533   if (failed(parseRegions<LoopOp>(parser, result)))
534     return failure();
535 
536   result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
537                       builder.getI32VectorAttr(
538                           {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0),
539                            static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0),
540                            static_cast<int32_t>(worker.hasValue() ? 1 : 0),
541                            static_cast<int32_t>(vector.hasValue() ? 1 : 0),
542                            static_cast<int32_t>(tileOperands.size()),
543                            static_cast<int32_t>(privateOperands.size()),
544                            static_cast<int32_t>(reductionOperands.size())}));
545 
546   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
547     return failure();
548 
549   return success();
550 }
551 
print(OpAsmPrinter & printer,LoopOp & op)552 static void print(OpAsmPrinter &printer, LoopOp &op) {
553   printer << LoopOp::getOperationName();
554 
555   unsigned execMapping = op.exec_mapping();
556   if (execMapping & OpenACCExecMapping::GANG) {
557     printer << " " << LoopOp::getGangKeyword();
558     Value gangNum = op.gangNum();
559     Value gangStatic = op.gangStatic();
560 
561     // Print optional gang operands
562     if (gangNum || gangStatic) {
563       printer << "(";
564       if (gangNum) {
565         printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
566                 << gangNum.getType();
567         if (gangStatic)
568           printer << ", ";
569       }
570       if (gangStatic)
571         printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
572                 << gangStatic.getType();
573       printer << ")";
574     }
575   }
576 
577   if (execMapping & OpenACCExecMapping::WORKER) {
578     printer << " " << LoopOp::getWorkerKeyword();
579 
580     // Print optional worker operand if present
581     if (Value workerNum = op.workerNum())
582       printer << "(" << workerNum << ": " << workerNum.getType() << ")";
583   }
584 
585   if (execMapping & OpenACCExecMapping::VECTOR) {
586     printer << " " << LoopOp::getVectorKeyword();
587 
588     // Print optional vector operand if present
589     if (Value vectorLength = op.vectorLength())
590       printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
591   }
592 
593   // tile()?
594   printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer);
595 
596   // private()?
597   printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);
598 
599   // reduction()?
600   printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(),
601                    printer);
602 
603   if (op.getNumResults() > 0)
604     printer << " -> (" << op.getResultTypes() << ")";
605 
606   printer.printRegion(op.region(),
607                       /*printEntryBlockArgs=*/false,
608                       /*printBlockTerminators=*/true);
609 
610   printer.printOptionalAttrDictWithKeyword(
611       op.getAttrs(), {LoopOp::getExecutionMappingAttrName(),
612                       LoopOp::getOperandSegmentSizeAttr()});
613 }
614 
verifyLoopOp(acc::LoopOp loopOp)615 static LogicalResult verifyLoopOp(acc::LoopOp loopOp) {
616   // auto, independent and seq attribute are mutually exclusive.
617   if ((loopOp.auto_() && (loopOp.independent() || loopOp.seq())) ||
618       (loopOp.independent() && loopOp.seq())) {
619     loopOp.emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " +
620                      acc::LoopOp::getIndependentAttrName() + ", " +
621                      acc::LoopOp::getSeqAttrName() +
622                      " can be present at the same time");
623     return failure();
624   }
625 
626   // Gang, worker and vector are incompatible with seq.
627   if (loopOp.seq() && loopOp.exec_mapping() != OpenACCExecMapping::NONE) {
628     loopOp.emitError("gang, worker or vector cannot appear with the seq attr");
629     return failure();
630   }
631 
632   // Check non-empty body().
633   if (loopOp.region().empty()) {
634     loopOp.emitError("expected non-empty body.");
635     return failure();
636   }
637 
638   return success();
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // DataOp
643 //===----------------------------------------------------------------------===//
644 
verify(acc::DataOp dataOp)645 static LogicalResult verify(acc::DataOp dataOp) {
646   // 2.6.5. Data Construct restriction
647   // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
648   // attach, or default clause must appear on a data construct.
649   if (dataOp.getOperands().size() == 0 && !dataOp.defaultAttr())
650     return dataOp.emitError("at least one operand or the default attribute "
651                             "must appear on the data operation");
652   return success();
653 }
654 
655 //===----------------------------------------------------------------------===//
656 // ExitDataOp
657 //===----------------------------------------------------------------------===//
658 
verify(acc::ExitDataOp op)659 static LogicalResult verify(acc::ExitDataOp op) {
660   // 2.6.6. Data Exit Directive restriction
661   // At least one copyout, delete, or detach clause must appear on an exit data
662   // directive.
663   if (op.copyoutOperands().empty() && op.deleteOperands().empty() &&
664       op.detachOperands().empty())
665     return op.emitError(
666         "at least one operand in copyout, delete or detach must appear on the "
667         "exit data operation");
668 
669   // The async attribute represent the async clause without value. Therefore the
670   // attribute and operand cannot appear at the same time.
671   if (op.asyncOperand() && op.async())
672     return op.emitError("async attribute cannot appear with asyncOperand");
673 
674   // The wait attribute represent the wait clause without values. Therefore the
675   // attribute and operands cannot appear at the same time.
676   if (!op.waitOperands().empty() && op.wait())
677     return op.emitError("wait attribute cannot appear with waitOperands");
678 
679   if (op.waitDevnum() && op.waitOperands().empty())
680     return op.emitError("wait_devnum cannot appear without waitOperands");
681 
682   return success();
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // DataEnterOp
687 //===----------------------------------------------------------------------===//
688 
verify(acc::EnterDataOp op)689 static LogicalResult verify(acc::EnterDataOp op) {
690   // 2.6.6. Data Enter Directive restriction
691   // At least one copyin, create, or attach clause must appear on an enter data
692   // directive.
693   if (op.copyinOperands().empty() && op.createOperands().empty() &&
694       op.createZeroOperands().empty() && op.attachOperands().empty())
695     return op.emitError(
696         "at least one operand in copyin, create, "
697         "create_zero or attach must appear on the enter data operation");
698 
699   // The async attribute represent the async clause without value. Therefore the
700   // attribute and operand cannot appear at the same time.
701   if (op.asyncOperand() && op.async())
702     return op.emitError("async attribute cannot appear with asyncOperand");
703 
704   // The wait attribute represent the wait clause without values. Therefore the
705   // attribute and operands cannot appear at the same time.
706   if (!op.waitOperands().empty() && op.wait())
707     return op.emitError("wait attribute cannot appear with waitOperands");
708 
709   if (op.waitDevnum() && op.waitOperands().empty())
710     return op.emitError("wait_devnum cannot appear without waitOperands");
711 
712   return success();
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // InitOp
717 //===----------------------------------------------------------------------===//
718 
verify(acc::InitOp initOp)719 static LogicalResult verify(acc::InitOp initOp) {
720   Operation *currOp = initOp;
721   while ((currOp = currOp->getParentOp())) {
722     if (isComputeOperation(currOp))
723       return initOp.emitOpError("cannot be nested in a compute operation");
724   }
725   return success();
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // ShutdownOp
730 //===----------------------------------------------------------------------===//
731 
verify(acc::ShutdownOp op)732 static LogicalResult verify(acc::ShutdownOp op) {
733   Operation *currOp = op;
734   while ((currOp = currOp->getParentOp())) {
735     if (isComputeOperation(currOp))
736       return op.emitOpError("cannot be nested in a compute operation");
737   }
738   return success();
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // UpdateOp
743 //===----------------------------------------------------------------------===//
744 
verify(acc::UpdateOp updateOp)745 static LogicalResult verify(acc::UpdateOp updateOp) {
746   // At least one of host or device should have a value.
747   if (updateOp.hostOperands().size() == 0 &&
748       updateOp.deviceOperands().size() == 0)
749     return updateOp.emitError("at least one value must be present in"
750                               " hostOperands or deviceOperands");
751 
752   // The async attribute represent the async clause without value. Therefore the
753   // attribute and operand cannot appear at the same time.
754   if (updateOp.asyncOperand() && updateOp.async())
755     return updateOp.emitError("async attribute cannot appear with "
756                               " asyncOperand");
757 
758   // The wait attribute represent the wait clause without values. Therefore the
759   // attribute and operands cannot appear at the same time.
760   if (updateOp.waitOperands().size() > 0 && updateOp.wait())
761     return updateOp.emitError("wait attribute cannot appear with waitOperands");
762 
763   if (updateOp.waitDevnum() && updateOp.waitOperands().size() == 0)
764     return updateOp.emitError("wait_devnum cannot appear without waitOperands");
765 
766   return success();
767 }
768 
769 //===----------------------------------------------------------------------===//
770 // WaitOp
771 //===----------------------------------------------------------------------===//
772 
verify(acc::WaitOp waitOp)773 static LogicalResult verify(acc::WaitOp waitOp) {
774   // The async attribute represent the async clause without value. Therefore the
775   // attribute and operand cannot appear at the same time.
776   if (waitOp.asyncOperand() && waitOp.async())
777     return waitOp.emitError("async attribute cannot appear with asyncOperand");
778 
779   if (waitOp.waitDevnum() && waitOp.waitOperands().empty())
780     return waitOp.emitError("wait_devnum cannot appear without waitOperands");
781 
782   return success();
783 }
784 
785 #define GET_OP_CLASSES
786 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
787