1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include <cstddef>
25 
26 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
27 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
29 
30 using namespace mlir;
31 using namespace mlir::omp;
32 
33 namespace {
34 /// Model for pointer-like types that already provide a `getElementType` method.
35 template <typename T>
36 struct PointerLikeModel
37     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
getElementType__anon959c6e340111::PointerLikeModel38   Type getElementType(Type pointer) const {
39     return pointer.cast<T>().getElementType();
40   }
41 };
42 } // end namespace
43 
initialize()44 void OpenMPDialect::initialize() {
45   addOperations<
46 #define GET_OP_LIST
47 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
48       >();
49 
50   LLVM::LLVMPointerType::attachInterface<
51       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
52   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // ParallelOp
57 //===----------------------------------------------------------------------===//
58 
build(OpBuilder & builder,OperationState & state,ArrayRef<NamedAttribute> attributes)59 void ParallelOp::build(OpBuilder &builder, OperationState &state,
60                        ArrayRef<NamedAttribute> attributes) {
61   ParallelOp::build(
62       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
63       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
64       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
65       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
66       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
67   state.addAttributes(attributes);
68 }
69 
70 /// Parse a list of operands with types.
71 ///
72 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
73 /// ssa-id-and-type-list ::= ssa-id-and-type |
74 ///                          ssa-id-and-type `,` ssa-id-and-type-list
75 /// ssa-id-and-type ::= ssa-id `:` type
76 static ParseResult
parseOperandAndTypeList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)77 parseOperandAndTypeList(OpAsmParser &parser,
78                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
79                         SmallVectorImpl<Type> &types) {
80   return parser.parseCommaSeparatedList(
81       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
82         OpAsmParser::OperandType operand;
83         Type type;
84         if (parser.parseOperand(operand) || parser.parseColonType(type))
85           return failure();
86         operands.push_back(operand);
87         types.push_back(type);
88         return success();
89       });
90 }
91 
92 /// Parse an allocate clause with allocators and a list of operands with types.
93 ///
94 /// operand-and-type-list ::= `(` allocate-operand-list `)`
95 /// allocate-operand-list :: = allocate-operand |
96 ///                            allocator-operand `,` allocate-operand-list
97 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
98 /// ssa-id-and-type ::= ssa-id `:` type
parseAllocateAndAllocator(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocate,SmallVectorImpl<Type> & typesAllocate,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocator,SmallVectorImpl<Type> & typesAllocator)99 static ParseResult parseAllocateAndAllocator(
100     OpAsmParser &parser,
101     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
102     SmallVectorImpl<Type> &typesAllocate,
103     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
104     SmallVectorImpl<Type> &typesAllocator) {
105 
106   return parser.parseCommaSeparatedList(
107       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
108         OpAsmParser::OperandType operand;
109         Type type;
110         if (parser.parseOperand(operand) || parser.parseColonType(type))
111           return failure();
112         operandsAllocator.push_back(operand);
113         typesAllocator.push_back(type);
114         if (parser.parseArrow())
115           return failure();
116         if (parser.parseOperand(operand) || parser.parseColonType(type))
117           return failure();
118 
119         operandsAllocate.push_back(operand);
120         typesAllocate.push_back(type);
121         return success();
122       });
123 }
124 
verifyParallelOp(ParallelOp op)125 static LogicalResult verifyParallelOp(ParallelOp op) {
126   if (op.allocate_vars().size() != op.allocators_vars().size())
127     return op.emitError(
128         "expected equal sizes for allocate and allocator variables");
129   return success();
130 }
131 
printParallelOp(OpAsmPrinter & p,ParallelOp op)132 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
133   if (auto ifCond = op.if_expr_var())
134     p << " if(" << ifCond << " : " << ifCond.getType() << ")";
135 
136   if (auto threads = op.num_threads_var())
137     p << " num_threads(" << threads << " : " << threads.getType() << ")";
138 
139   // Print private, firstprivate, shared and copyin parameters
140   auto printDataVars = [&p](StringRef name, OperandRange vars) {
141     if (vars.size()) {
142       p << " " << name << "(";
143       for (unsigned i = 0; i < vars.size(); ++i) {
144         std::string separator = i == vars.size() - 1 ? ")" : ", ";
145         p << vars[i] << " : " << vars[i].getType() << separator;
146       }
147     }
148   };
149 
150   // Print allocator and allocate parameters
151   auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
152                                         OperandRange varsAllocator) {
153     if (varsAllocate.empty())
154       return;
155 
156     p << " allocate(";
157     for (unsigned i = 0; i < varsAllocate.size(); ++i) {
158       std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
159       p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
160       p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
161     }
162   };
163 
164   printDataVars("private", op.private_vars());
165   printDataVars("firstprivate", op.firstprivate_vars());
166   printDataVars("shared", op.shared_vars());
167   printDataVars("copyin", op.copyin_vars());
168   printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
169 
170   if (auto def = op.default_val())
171     p << " default(" << def->drop_front(3) << ")";
172 
173   if (auto bind = op.proc_bind_val())
174     p << " proc_bind(" << bind << ")";
175 
176   p.printRegion(op.getRegion());
177 }
178 
179 /// Emit an error if the same clause is present more than once on an operation.
allowedOnce(OpAsmParser & parser,StringRef clause,StringRef operation)180 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause,
181                                StringRef operation) {
182   return parser.emitError(parser.getNameLoc())
183          << " at most one " << clause << " clause can appear on the "
184          << operation << " operation";
185 }
186 
187 /// Parses a parallel operation.
188 ///
189 /// operation ::= `omp.parallel` clause-list
190 /// clause-list ::= clause | clause clause-list
191 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
192 ///            default | procBind
193 /// if ::= `if` `(` ssa-id `)`
194 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
195 /// private ::= `private` operand-and-type-list
196 /// firstprivate ::= `firstprivate` operand-and-type-list
197 /// shared ::= `shared` operand-and-type-list
198 /// copyin ::= `copyin` operand-and-type-list
199 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
200 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
201 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
202 ///
203 /// Note that each clause can only appear once in the clase-list.
parseParallelOp(OpAsmParser & parser,OperationState & result)204 static ParseResult parseParallelOp(OpAsmParser &parser,
205                                    OperationState &result) {
206   std::pair<OpAsmParser::OperandType, Type> ifCond;
207   std::pair<OpAsmParser::OperandType, Type> numThreads;
208   SmallVector<OpAsmParser::OperandType, 4> privates;
209   SmallVector<Type, 4> privateTypes;
210   SmallVector<OpAsmParser::OperandType, 4> firstprivates;
211   SmallVector<Type, 4> firstprivateTypes;
212   SmallVector<OpAsmParser::OperandType, 4> shareds;
213   SmallVector<Type, 4> sharedTypes;
214   SmallVector<OpAsmParser::OperandType, 4> copyins;
215   SmallVector<Type, 4> copyinTypes;
216   SmallVector<OpAsmParser::OperandType, 4> allocates;
217   SmallVector<Type, 4> allocateTypes;
218   SmallVector<OpAsmParser::OperandType, 4> allocators;
219   SmallVector<Type, 4> allocatorTypes;
220   std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
221   StringRef keyword;
222   bool defaultVal = false;
223   bool procBind = false;
224 
225   const int ifClausePos = 0;
226   const int numThreadsClausePos = 1;
227   const int privateClausePos = 2;
228   const int firstprivateClausePos = 3;
229   const int sharedClausePos = 4;
230   const int copyinClausePos = 5;
231   const int allocateClausePos = 6;
232   const int allocatorPos = 7;
233   const StringRef opName = result.name.getStringRef();
234 
235   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
236     if (keyword == "if") {
237       // Fail if there was already another if condition.
238       if (segments[ifClausePos])
239         return allowedOnce(parser, "if", opName);
240       if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
241           parser.parseColonType(ifCond.second) || parser.parseRParen())
242         return failure();
243       segments[ifClausePos] = 1;
244     } else if (keyword == "num_threads") {
245       // Fail if there was already another num_threads clause.
246       if (segments[numThreadsClausePos])
247         return allowedOnce(parser, "num_threads", opName);
248       if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
249           parser.parseColonType(numThreads.second) || parser.parseRParen())
250         return failure();
251       segments[numThreadsClausePos] = 1;
252     } else if (keyword == "private") {
253       // Fail if there was already another private clause.
254       if (segments[privateClausePos])
255         return allowedOnce(parser, "private", opName);
256       if (parseOperandAndTypeList(parser, privates, privateTypes))
257         return failure();
258       segments[privateClausePos] = privates.size();
259     } else if (keyword == "firstprivate") {
260       // Fail if there was already another firstprivate clause.
261       if (segments[firstprivateClausePos])
262         return allowedOnce(parser, "firstprivate", opName);
263       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
264         return failure();
265       segments[firstprivateClausePos] = firstprivates.size();
266     } else if (keyword == "shared") {
267       // Fail if there was already another shared clause.
268       if (segments[sharedClausePos])
269         return allowedOnce(parser, "shared", opName);
270       if (parseOperandAndTypeList(parser, shareds, sharedTypes))
271         return failure();
272       segments[sharedClausePos] = shareds.size();
273     } else if (keyword == "copyin") {
274       // Fail if there was already another copyin clause.
275       if (segments[copyinClausePos])
276         return allowedOnce(parser, "copyin", opName);
277       if (parseOperandAndTypeList(parser, copyins, copyinTypes))
278         return failure();
279       segments[copyinClausePos] = copyins.size();
280     } else if (keyword == "allocate") {
281       // Fail if there was already another allocate clause.
282       if (segments[allocateClausePos])
283         return allowedOnce(parser, "allocate", opName);
284       if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
285                                     allocators, allocatorTypes))
286         return failure();
287       segments[allocateClausePos] = allocates.size();
288       segments[allocatorPos] = allocators.size();
289     } else if (keyword == "default") {
290       // Fail if there was already another default clause.
291       if (defaultVal)
292         return allowedOnce(parser, "default", opName);
293       defaultVal = true;
294       StringRef defval;
295       if (parser.parseLParen() || parser.parseKeyword(&defval) ||
296           parser.parseRParen())
297         return failure();
298       // The def prefix is required for the attribute as "private" is a keyword
299       // in C++.
300       auto attr = parser.getBuilder().getStringAttr("def" + defval);
301       result.addAttribute("default_val", attr);
302     } else if (keyword == "proc_bind") {
303       // Fail if there was already another proc_bind clause.
304       if (procBind)
305         return allowedOnce(parser, "proc_bind", opName);
306       procBind = true;
307       StringRef bind;
308       if (parser.parseLParen() || parser.parseKeyword(&bind) ||
309           parser.parseRParen())
310         return failure();
311       auto attr = parser.getBuilder().getStringAttr(bind);
312       result.addAttribute("proc_bind_val", attr);
313     } else {
314       return parser.emitError(parser.getNameLoc())
315              << keyword << " is not a valid clause for the " << opName
316              << " operation";
317     }
318   }
319 
320   // Add if parameter.
321   if (segments[ifClausePos] &&
322       parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
323     return failure();
324 
325   // Add num_threads parameter.
326   if (segments[numThreadsClausePos] &&
327       parser.resolveOperand(numThreads.first, numThreads.second,
328                             result.operands))
329     return failure();
330 
331   // Add private parameters.
332   if (segments[privateClausePos] &&
333       parser.resolveOperands(privates, privateTypes, privates[0].location,
334                              result.operands))
335     return failure();
336 
337   // Add firstprivate parameters.
338   if (segments[firstprivateClausePos] &&
339       parser.resolveOperands(firstprivates, firstprivateTypes,
340                              firstprivates[0].location, result.operands))
341     return failure();
342 
343   // Add shared parameters.
344   if (segments[sharedClausePos] &&
345       parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
346                              result.operands))
347     return failure();
348 
349   // Add copyin parameters.
350   if (segments[copyinClausePos] &&
351       parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
352                              result.operands))
353     return failure();
354 
355   // Add allocate parameters.
356   if (segments[allocateClausePos] &&
357       parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
358                              result.operands))
359     return failure();
360 
361   // Add allocator parameters.
362   if (segments[allocatorPos] &&
363       parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
364                              result.operands))
365     return failure();
366 
367   result.addAttribute("operand_segment_sizes",
368                       parser.getBuilder().getI32VectorAttr(segments));
369 
370   Region *body = result.addRegion();
371   SmallVector<OpAsmParser::OperandType, 4> regionArgs;
372   SmallVector<Type, 4> regionArgTypes;
373   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
374     return failure();
375   return success();
376 }
377 
378 /// linear ::= `linear` `(` linear-list `)`
379 /// linear-list := linear-val | linear-val linear-list
380 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
381 static ParseResult
parseLinearClause(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & vars,SmallVectorImpl<Type> & types,SmallVectorImpl<OpAsmParser::OperandType> & stepVars)382 parseLinearClause(OpAsmParser &parser,
383                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
384                   SmallVectorImpl<Type> &types,
385                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
386   if (parser.parseLParen())
387     return failure();
388 
389   do {
390     OpAsmParser::OperandType var;
391     Type type;
392     OpAsmParser::OperandType stepVar;
393     if (parser.parseOperand(var) || parser.parseEqual() ||
394         parser.parseOperand(stepVar) || parser.parseColonType(type))
395       return failure();
396 
397     vars.push_back(var);
398     types.push_back(type);
399     stepVars.push_back(stepVar);
400   } while (succeeded(parser.parseOptionalComma()));
401 
402   if (parser.parseRParen())
403     return failure();
404 
405   return success();
406 }
407 
408 /// schedule ::= `schedule` `(` sched-list `)`
409 /// sched-list ::= sched-val | sched-val sched-list
410 /// sched-val ::= sched-with-chunk | sched-wo-chunk
411 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
412 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
413 /// sched-wo-chunk ::=  `auto` | `runtime`
414 static ParseResult
parseScheduleClause(OpAsmParser & parser,SmallString<8> & schedule,Optional<OpAsmParser::OperandType> & chunkSize)415 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
416                     Optional<OpAsmParser::OperandType> &chunkSize) {
417   if (parser.parseLParen())
418     return failure();
419 
420   StringRef keyword;
421   if (parser.parseKeyword(&keyword))
422     return failure();
423 
424   schedule = keyword;
425   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
426     if (succeeded(parser.parseOptionalEqual())) {
427       chunkSize = OpAsmParser::OperandType{};
428       if (parser.parseOperand(*chunkSize))
429         return failure();
430     } else {
431       chunkSize = llvm::NoneType::None;
432     }
433   } else if (keyword == "auto" || keyword == "runtime") {
434     chunkSize = llvm::NoneType::None;
435   } else {
436     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
437   }
438 
439   if (parser.parseRParen())
440     return failure();
441 
442   return success();
443 }
444 
445 /// reduction-init ::= `reduction` `(` reduction-entry-list `)`
446 /// reduction-entry-list ::= reduction-entry
447 ///                        | reduction-entry-list `,` reduction-entry
448 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
449 static ParseResult
parseReductionVarList(OpAsmParser & parser,SmallVectorImpl<SymbolRefAttr> & symbols,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)450 parseReductionVarList(OpAsmParser &parser,
451                       SmallVectorImpl<SymbolRefAttr> &symbols,
452                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
453                       SmallVectorImpl<Type> &types) {
454   if (failed(parser.parseLParen()))
455     return failure();
456 
457   do {
458     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
459         parser.parseOperand(operands.emplace_back()) ||
460         parser.parseColonType(types.emplace_back()))
461       return failure();
462   } while (succeeded(parser.parseOptionalComma()));
463   return parser.parseRParen();
464 }
465 
466 /// Parses an OpenMP Workshare Loop operation
467 ///
468 /// operation ::= `omp.wsloop` loop-control clause-list
469 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
470 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
471 /// steps := `step` `(`ssa-id-list`)`
472 /// clause-list ::= clause | empty | clause-list
473 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
474 //             collapse | nowait | ordered | order | inclusive
475 /// private ::= `private` `(` ssa-id-and-type-list `)`
476 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)`
477 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)`
478 /// linear ::= `linear` `(` linear-list `)`
479 /// schedule ::= `schedule` `(` sched-list `)`
480 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
481 /// nowait ::= `nowait`
482 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
483 /// order ::= `order` `(` `concurrent` `)`
484 /// inclusive ::= `inclusive`
485 ///
parseWsLoopOp(OpAsmParser & parser,OperationState & result)486 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
487   Type loopVarType;
488   int numIVs;
489 
490   // Parse an opening `(` followed by induction variables followed by `)`
491   SmallVector<OpAsmParser::OperandType> ivs;
492   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
493                                      OpAsmParser::Delimiter::Paren))
494     return failure();
495 
496   numIVs = static_cast<int>(ivs.size());
497 
498   if (parser.parseColonType(loopVarType))
499     return failure();
500 
501   // Parse loop bounds.
502   SmallVector<OpAsmParser::OperandType> lower;
503   if (parser.parseEqual() ||
504       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
505       parser.resolveOperands(lower, loopVarType, result.operands))
506     return failure();
507 
508   SmallVector<OpAsmParser::OperandType> upper;
509   if (parser.parseKeyword("to") ||
510       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
511       parser.resolveOperands(upper, loopVarType, result.operands))
512     return failure();
513 
514   // Parse step values.
515   SmallVector<OpAsmParser::OperandType> steps;
516   if (parser.parseKeyword("step") ||
517       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
518       parser.resolveOperands(steps, loopVarType, result.operands))
519     return failure();
520 
521   SmallVector<OpAsmParser::OperandType> privates;
522   SmallVector<Type> privateTypes;
523   SmallVector<OpAsmParser::OperandType> firstprivates;
524   SmallVector<Type> firstprivateTypes;
525   SmallVector<OpAsmParser::OperandType> lastprivates;
526   SmallVector<Type> lastprivateTypes;
527   SmallVector<OpAsmParser::OperandType> linears;
528   SmallVector<Type> linearTypes;
529   SmallVector<OpAsmParser::OperandType> linearSteps;
530   SmallVector<SymbolRefAttr> reductionSymbols;
531   SmallVector<OpAsmParser::OperandType> reductionVars;
532   SmallVector<Type> reductionVarTypes;
533   SmallString<8> schedule;
534   Optional<OpAsmParser::OperandType> scheduleChunkSize;
535 
536   const StringRef opName = result.name.getStringRef();
537   StringRef keyword;
538 
539   enum SegmentPos {
540     lbPos = 0,
541     ubPos,
542     stepPos,
543     privateClausePos,
544     firstprivateClausePos,
545     lastprivateClausePos,
546     linearClausePos,
547     linearStepPos,
548     reductionVarPos,
549     scheduleClausePos,
550   };
551   std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0};
552 
553   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
554     if (keyword == "private") {
555       if (segments[privateClausePos])
556         return allowedOnce(parser, "private", opName);
557       if (parseOperandAndTypeList(parser, privates, privateTypes))
558         return failure();
559       segments[privateClausePos] = privates.size();
560     } else if (keyword == "firstprivate") {
561       // fail if there was already another firstprivate clause
562       if (segments[firstprivateClausePos])
563         return allowedOnce(parser, "firstprivate", opName);
564       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
565         return failure();
566       segments[firstprivateClausePos] = firstprivates.size();
567     } else if (keyword == "lastprivate") {
568       // fail if there was already another shared clause
569       if (segments[lastprivateClausePos])
570         return allowedOnce(parser, "lastprivate", opName);
571       if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
572         return failure();
573       segments[lastprivateClausePos] = lastprivates.size();
574     } else if (keyword == "linear") {
575       // fail if there was already another linear clause
576       if (segments[linearClausePos])
577         return allowedOnce(parser, "linear", opName);
578       if (parseLinearClause(parser, linears, linearTypes, linearSteps))
579         return failure();
580       segments[linearClausePos] = linears.size();
581       segments[linearStepPos] = linearSteps.size();
582     } else if (keyword == "schedule") {
583       if (!schedule.empty())
584         return allowedOnce(parser, "schedule", opName);
585       if (parseScheduleClause(parser, schedule, scheduleChunkSize))
586         return failure();
587       if (scheduleChunkSize) {
588         segments[scheduleClausePos] = 1;
589       }
590     } else if (keyword == "collapse") {
591       auto type = parser.getBuilder().getI64Type();
592       mlir::IntegerAttr attr;
593       if (parser.parseLParen() || parser.parseAttribute(attr, type) ||
594           parser.parseRParen())
595         return failure();
596       result.addAttribute("collapse_val", attr);
597     } else if (keyword == "nowait") {
598       auto attr = UnitAttr::get(parser.getContext());
599       result.addAttribute("nowait", attr);
600     } else if (keyword == "ordered") {
601       mlir::IntegerAttr attr;
602       if (succeeded(parser.parseOptionalLParen())) {
603         auto type = parser.getBuilder().getI64Type();
604         if (parser.parseAttribute(attr, type))
605           return failure();
606         if (parser.parseRParen())
607           return failure();
608       } else {
609         // Use 0 to represent no ordered parameter was specified
610         attr = parser.getBuilder().getI64IntegerAttr(0);
611       }
612       result.addAttribute("ordered_val", attr);
613     } else if (keyword == "order") {
614       StringRef order;
615       if (parser.parseLParen() || parser.parseKeyword(&order) ||
616           parser.parseRParen())
617         return failure();
618       auto attr = parser.getBuilder().getStringAttr(order);
619       result.addAttribute("order", attr);
620     } else if (keyword == "inclusive") {
621       auto attr = UnitAttr::get(parser.getContext());
622       result.addAttribute("inclusive", attr);
623     } else if (keyword == "reduction") {
624       if (segments[reductionVarPos])
625         return allowedOnce(parser, "reduction", opName);
626       if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars,
627                                        reductionVarTypes)))
628         return failure();
629       segments[reductionVarPos] = reductionVars.size();
630     }
631   }
632 
633   if (segments[privateClausePos]) {
634     parser.resolveOperands(privates, privateTypes, privates[0].location,
635                            result.operands);
636   }
637 
638   if (segments[firstprivateClausePos]) {
639     parser.resolveOperands(firstprivates, firstprivateTypes,
640                            firstprivates[0].location, result.operands);
641   }
642 
643   if (segments[lastprivateClausePos]) {
644     parser.resolveOperands(lastprivates, lastprivateTypes,
645                            lastprivates[0].location, result.operands);
646   }
647 
648   if (segments[linearClausePos]) {
649     parser.resolveOperands(linears, linearTypes, linears[0].location,
650                            result.operands);
651     auto linearStepType = parser.getBuilder().getI32Type();
652     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
653     parser.resolveOperands(linearSteps, linearStepTypes,
654                            linearSteps[0].location, result.operands);
655   }
656 
657   if (segments[reductionVarPos]) {
658     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
659                                       parser.getNameLoc(), result.operands))) {
660       return failure();
661     }
662     SmallVector<Attribute> reductions(reductionSymbols.begin(),
663                                       reductionSymbols.end());
664     result.addAttribute("reductions",
665                         parser.getBuilder().getArrayAttr(reductions));
666   }
667 
668   if (!schedule.empty()) {
669     schedule[0] = llvm::toUpper(schedule[0]);
670     auto attr = parser.getBuilder().getStringAttr(schedule);
671     result.addAttribute("schedule_val", attr);
672     if (scheduleChunkSize) {
673       auto chunkSizeType = parser.getBuilder().getI32Type();
674       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
675     }
676   }
677 
678   result.addAttribute("operand_segment_sizes",
679                       parser.getBuilder().getI32VectorAttr(segments));
680 
681   // Now parse the body.
682   Region *body = result.addRegion();
683   SmallVector<Type> ivTypes(numIVs, loopVarType);
684   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
685   if (parser.parseRegion(*body, blockArgs, ivTypes))
686     return failure();
687   return success();
688 }
689 
printWsLoopOp(OpAsmPrinter & p,WsLoopOp op)690 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
691   auto args = op.getRegion().front().getArguments();
692   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
693     << ") to (" << op.upperBound() << ") step (" << op.step() << ")";
694 
695   // Print private, firstprivate, shared and copyin parameters
696   auto printDataVars = [&p](StringRef name, OperandRange vars) {
697     if (vars.empty())
698       return;
699 
700     p << " " << name << "(";
701     llvm::interleaveComma(
702         vars, p, [&](const Value &v) { p << v << " : " << v.getType(); });
703     p << ")";
704   };
705   printDataVars("private", op.private_vars());
706   printDataVars("firstprivate", op.firstprivate_vars());
707   printDataVars("lastprivate", op.lastprivate_vars());
708 
709   auto linearVars = op.linear_vars();
710   auto linearVarsSize = linearVars.size();
711   if (linearVarsSize) {
712     p << " "
713       << "linear"
714       << "(";
715     for (unsigned i = 0; i < linearVarsSize; ++i) {
716       std::string separator = i == linearVarsSize - 1 ? ")" : ", ";
717       p << linearVars[i];
718       if (op.linear_step_vars().size() > i)
719         p << " = " << op.linear_step_vars()[i];
720       p << " : " << linearVars[i].getType() << separator;
721     }
722   }
723 
724   if (auto sched = op.schedule_val()) {
725     auto schedLower = sched->lower();
726     p << " schedule(" << schedLower;
727     if (auto chunk = op.schedule_chunk_var()) {
728       p << " = " << chunk;
729     }
730     p << ")";
731   }
732 
733   if (auto collapse = op.collapse_val())
734     p << " collapse(" << collapse << ")";
735 
736   if (op.nowait())
737     p << " nowait";
738 
739   if (auto ordered = op.ordered_val()) {
740     p << " ordered(" << ordered << ")";
741   }
742 
743   if (!op.reduction_vars().empty()) {
744     p << " reduction(";
745     for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) {
746       if (i != 0)
747         p << ", ";
748       p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : "
749         << op.reduction_vars()[i].getType();
750     }
751     p << ")";
752   }
753 
754   if (op.inclusive()) {
755     p << " inclusive";
756   }
757 
758   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // ReductionOp
763 //===----------------------------------------------------------------------===//
764 
parseAtomicReductionRegion(OpAsmParser & parser,Region & region)765 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
766                                               Region &region) {
767   if (parser.parseOptionalKeyword("atomic"))
768     return success();
769   return parser.parseRegion(region);
770 }
771 
printAtomicReductionRegion(OpAsmPrinter & printer,ReductionDeclareOp op,Region & region)772 static void printAtomicReductionRegion(OpAsmPrinter &printer,
773                                        ReductionDeclareOp op, Region &region) {
774   if (region.empty())
775     return;
776   printer << "atomic ";
777   printer.printRegion(region);
778 }
779 
verifyReductionDeclareOp(ReductionDeclareOp op)780 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
781   if (op.initializerRegion().empty())
782     return op.emitOpError() << "expects non-empty initializer region";
783   Block &initializerEntryBlock = op.initializerRegion().front();
784   if (initializerEntryBlock.getNumArguments() != 1 ||
785       initializerEntryBlock.getArgument(0).getType() != op.type()) {
786     return op.emitOpError() << "expects initializer region with one argument "
787                                "of the reduction type";
788   }
789 
790   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
791     if (yieldOp.results().size() != 1 ||
792         yieldOp.results().getTypes()[0] != op.type())
793       return op.emitOpError() << "expects initializer region to yield a value "
794                                  "of the reduction type";
795   }
796 
797   if (op.reductionRegion().empty())
798     return op.emitOpError() << "expects non-empty reduction region";
799   Block &reductionEntryBlock = op.reductionRegion().front();
800   if (reductionEntryBlock.getNumArguments() != 2 ||
801       reductionEntryBlock.getArgumentTypes()[0] !=
802           reductionEntryBlock.getArgumentTypes()[1] ||
803       reductionEntryBlock.getArgumentTypes()[0] != op.type())
804     return op.emitOpError() << "expects reduction region with two arguments of "
805                                "the reduction type";
806   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
807     if (yieldOp.results().size() != 1 ||
808         yieldOp.results().getTypes()[0] != op.type())
809       return op.emitOpError() << "expects reduction region to yield a value "
810                                  "of the reduction type";
811   }
812 
813   if (op.atomicReductionRegion().empty())
814     return success();
815 
816   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
817   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
818       atomicReductionEntryBlock.getArgumentTypes()[0] !=
819           atomicReductionEntryBlock.getArgumentTypes()[1])
820     return op.emitOpError() << "expects atomic reduction region with two "
821                                "arguments of the same type";
822   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
823                      .dyn_cast<PointerLikeType>();
824   if (!ptrType || ptrType.getElementType() != op.type())
825     return op.emitOpError() << "expects atomic reduction region arguments to "
826                                "be accumulators containing the reduction type";
827   return success();
828 }
829 
verifyReductionOp(ReductionOp op)830 static LogicalResult verifyReductionOp(ReductionOp op) {
831   // TODO: generalize this to an op interface when there is more than one op
832   // that supports reductions.
833   auto container = op->getParentOfType<WsLoopOp>();
834   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
835     if (container.reduction_vars()[i] == op.accumulator())
836       return success();
837 
838   return op.emitOpError() << "the accumulator is not used by the parent";
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // WsLoopOp
843 //===----------------------------------------------------------------------===//
844 
build(OpBuilder & builder,OperationState & state,ValueRange lowerBound,ValueRange upperBound,ValueRange step,ArrayRef<NamedAttribute> attributes)845 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
846                      ValueRange lowerBound, ValueRange upperBound,
847                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
848   build(builder, state, TypeRange(), lowerBound, upperBound, step,
849         /*private_vars=*/ValueRange(),
850         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
851         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
852         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
853         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
854         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
855         /*inclusive=*/nullptr, /*buildBody=*/false);
856   state.addAttributes(attributes);
857 }
858 
build(OpBuilder &,OperationState & state,TypeRange resultTypes,ValueRange operands,ArrayRef<NamedAttribute> attributes)859 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
860                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
861   state.addOperands(operands);
862   state.addAttributes(attributes);
863   (void)state.addRegion();
864   assert(resultTypes.empty() && "mismatched number of return types");
865   state.addTypes(resultTypes);
866 }
867 
build(OpBuilder & builder,OperationState & result,TypeRange typeRange,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange privateVars,ValueRange firstprivateVars,ValueRange lastprivateVars,ValueRange linearVars,ValueRange linearStepVars,ValueRange reductionVars,StringAttr scheduleVal,Value scheduleChunkVar,IntegerAttr collapseVal,UnitAttr nowait,IntegerAttr orderedVal,StringAttr orderVal,UnitAttr inclusive,bool buildBody)868 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
869                      TypeRange typeRange, ValueRange lowerBounds,
870                      ValueRange upperBounds, ValueRange steps,
871                      ValueRange privateVars, ValueRange firstprivateVars,
872                      ValueRange lastprivateVars, ValueRange linearVars,
873                      ValueRange linearStepVars, ValueRange reductionVars,
874                      StringAttr scheduleVal, Value scheduleChunkVar,
875                      IntegerAttr collapseVal, UnitAttr nowait,
876                      IntegerAttr orderedVal, StringAttr orderVal,
877                      UnitAttr inclusive, bool buildBody) {
878   result.addOperands(lowerBounds);
879   result.addOperands(upperBounds);
880   result.addOperands(steps);
881   result.addOperands(privateVars);
882   result.addOperands(firstprivateVars);
883   result.addOperands(linearVars);
884   result.addOperands(linearStepVars);
885   if (scheduleChunkVar)
886     result.addOperands(scheduleChunkVar);
887 
888   if (scheduleVal)
889     result.addAttribute("schedule_val", scheduleVal);
890   if (collapseVal)
891     result.addAttribute("collapse_val", collapseVal);
892   if (nowait)
893     result.addAttribute("nowait", nowait);
894   if (orderedVal)
895     result.addAttribute("ordered_val", orderedVal);
896   if (orderVal)
897     result.addAttribute("order", orderVal);
898   if (inclusive)
899     result.addAttribute("inclusive", inclusive);
900   result.addAttribute(
901       WsLoopOp::getOperandSegmentSizeAttr(),
902       builder.getI32VectorAttr(
903           {static_cast<int32_t>(lowerBounds.size()),
904            static_cast<int32_t>(upperBounds.size()),
905            static_cast<int32_t>(steps.size()),
906            static_cast<int32_t>(privateVars.size()),
907            static_cast<int32_t>(firstprivateVars.size()),
908            static_cast<int32_t>(lastprivateVars.size()),
909            static_cast<int32_t>(linearVars.size()),
910            static_cast<int32_t>(linearStepVars.size()),
911            static_cast<int32_t>(reductionVars.size()),
912            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
913 
914   Region *bodyRegion = result.addRegion();
915   if (buildBody) {
916     OpBuilder::InsertionGuard guard(builder);
917     unsigned numIVs = steps.size();
918     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
919     builder.createBlock(bodyRegion, {}, argTypes);
920   }
921 }
922 
verifyWsLoopOp(WsLoopOp op)923 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
924   if (op.getNumReductionVars() != 0) {
925     if (!op.reductions() ||
926         op.reductions()->size() != op.getNumReductionVars()) {
927       return op.emitOpError() << "expected as many reduction symbol references "
928                                  "as reduction variables";
929     }
930   } else {
931     if (op.reductions())
932       return op.emitOpError() << "unexpected reduction symbol references";
933     return success();
934   }
935 
936   DenseSet<Value> accumulators;
937   for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) {
938     Value accum = std::get<0>(args);
939     if (!accumulators.insert(accum).second) {
940       return op.emitOpError() << "accumulator variable used more than once";
941     }
942     Type varType = accum.getType().cast<PointerLikeType>();
943     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
944     auto decl =
945         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
946     if (!decl) {
947       return op.emitOpError() << "expected symbol reference " << symbolRef
948                               << " to point to a reduction declaration";
949     }
950 
951     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) {
952       return op.emitOpError()
953              << "expected accumulator (" << varType
954              << ") to be the same type as reduction declaration ("
955              << decl.getAccumulatorType() << ")";
956     }
957   }
958 
959   return success();
960 }
961 
962 //===----------------------------------------------------------------------===//
963 // Parser, printer and verifier for Synchronization Hint (2.17.12)
964 //===----------------------------------------------------------------------===//
965 
966 /// Parses a Synchronization Hint clause. The value of hint is an integer
967 /// which is a combination of different hints from `omp_sync_hint_t`.
968 ///
969 /// hint-clause = `hint` `(` hint-value `)`
parseSynchronizationHint(OpAsmParser & parser,IntegerAttr & hintAttr)970 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
971                                             IntegerAttr &hintAttr) {
972   if (failed(parser.parseOptionalKeyword("hint"))) {
973     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
974     return success();
975   }
976 
977   if (failed(parser.parseLParen()))
978     return failure();
979   StringRef hintKeyword;
980   int64_t hint = 0;
981   do {
982     if (failed(parser.parseKeyword(&hintKeyword)))
983       return failure();
984     if (hintKeyword == "uncontended")
985       hint |= 1;
986     else if (hintKeyword == "contended")
987       hint |= 2;
988     else if (hintKeyword == "nonspeculative")
989       hint |= 4;
990     else if (hintKeyword == "speculative")
991       hint |= 8;
992     else
993       return parser.emitError(parser.getCurrentLocation())
994              << hintKeyword << " is not a valid hint";
995   } while (succeeded(parser.parseOptionalComma()));
996   if (failed(parser.parseRParen()))
997     return failure();
998   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
999   return success();
1000 }
1001 
1002 /// Prints a Synchronization Hint clause
printSynchronizationHint(OpAsmPrinter & p,Operation * op,IntegerAttr hintAttr)1003 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
1004                                      IntegerAttr hintAttr) {
1005   int64_t hint = hintAttr.getInt();
1006 
1007   if (hint == 0)
1008     return;
1009 
1010   // Helper function to get n-th bit from the right end of `value`
1011   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1012 
1013   bool uncontended = bitn(hint, 0);
1014   bool contended = bitn(hint, 1);
1015   bool nonspeculative = bitn(hint, 2);
1016   bool speculative = bitn(hint, 3);
1017 
1018   SmallVector<StringRef> hints;
1019   if (uncontended)
1020     hints.push_back("uncontended");
1021   if (contended)
1022     hints.push_back("contended");
1023   if (nonspeculative)
1024     hints.push_back("nonspeculative");
1025   if (speculative)
1026     hints.push_back("speculative");
1027 
1028   p << "hint(";
1029   llvm::interleaveComma(hints, p);
1030   p << ")";
1031 }
1032 
1033 /// Verifies a synchronization hint clause
verifySynchronizationHint(Operation * op,int32_t hint)1034 static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) {
1035 
1036   // Helper function to get n-th bit from the right end of `value`
1037   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1038 
1039   bool uncontended = bitn(hint, 0);
1040   bool contended = bitn(hint, 1);
1041   bool nonspeculative = bitn(hint, 2);
1042   bool speculative = bitn(hint, 3);
1043 
1044   if (uncontended && contended)
1045     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1046                                 "omp_sync_hint_contended cannot be combined";
1047   if (nonspeculative && speculative)
1048     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1049                                 "omp_sync_hint_speculative cannot be combined.";
1050   return success();
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // Verifier for critical construct (2.17.1)
1055 //===----------------------------------------------------------------------===//
1056 
verifyCriticalOp(CriticalOp op)1057 static LogicalResult verifyCriticalOp(CriticalOp op) {
1058 
1059   if (failed(verifySynchronizationHint(op, op.hint()))) {
1060     return failure();
1061   }
1062   if (!op.name().hasValue() && (op.hint() != 0))
1063     return op.emitOpError() << "must specify a name unless the effect is as if "
1064                                "no hint is specified";
1065 
1066   if (op.nameAttr()) {
1067     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1068     auto decl =
1069         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1070     if (!decl) {
1071       return op.emitOpError() << "expected symbol reference " << symbolRef
1072                               << " to point to a critical declaration";
1073     }
1074   }
1075 
1076   return success();
1077 }
1078 
1079 #define GET_OP_CLASSES
1080 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1081