1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include <cstddef>
25 
26 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
27 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
29 
30 using namespace mlir;
31 using namespace mlir::omp;
32 
33 namespace {
34 /// Model for pointer-like types that already provide a `getElementType` method.
35 template <typename T>
36 struct PointerLikeModel
37     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
getElementType__anon47251b610111::PointerLikeModel38   Type getElementType(Type pointer) const {
39     return pointer.cast<T>().getElementType();
40   }
41 };
42 } // end namespace
43 
initialize()44 void OpenMPDialect::initialize() {
45   addOperations<
46 #define GET_OP_LIST
47 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
48       >();
49 
50   LLVM::LLVMPointerType::attachInterface<
51       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
52   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // ParallelOp
57 //===----------------------------------------------------------------------===//
58 
build(OpBuilder & builder,OperationState & state,ArrayRef<NamedAttribute> attributes)59 void ParallelOp::build(OpBuilder &builder, OperationState &state,
60                        ArrayRef<NamedAttribute> attributes) {
61   ParallelOp::build(
62       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
63       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
64       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
65       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
66       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
67   state.addAttributes(attributes);
68 }
69 
70 /// Parse a list of operands with types.
71 ///
72 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
73 /// ssa-id-and-type-list ::= ssa-id-and-type |
74 ///                          ssa-id-and-type `,` ssa-id-and-type-list
75 /// ssa-id-and-type ::= ssa-id `:` type
76 static ParseResult
parseOperandAndTypeList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)77 parseOperandAndTypeList(OpAsmParser &parser,
78                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
79                         SmallVectorImpl<Type> &types) {
80   if (parser.parseLParen())
81     return failure();
82 
83   do {
84     OpAsmParser::OperandType operand;
85     Type type;
86     if (parser.parseOperand(operand) || parser.parseColonType(type))
87       return failure();
88     operands.push_back(operand);
89     types.push_back(type);
90   } while (succeeded(parser.parseOptionalComma()));
91 
92   if (parser.parseRParen())
93     return failure();
94 
95   return success();
96 }
97 
98 /// Parse an allocate clause with allocators and a list of operands with types.
99 ///
100 /// operand-and-type-list ::= `(` allocate-operand-list `)`
101 /// allocate-operand-list :: = allocate-operand |
102 ///                            allocator-operand `,` allocate-operand-list
103 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
104 /// ssa-id-and-type ::= ssa-id `:` type
parseAllocateAndAllocator(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocate,SmallVectorImpl<Type> & typesAllocate,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocator,SmallVectorImpl<Type> & typesAllocator)105 static ParseResult parseAllocateAndAllocator(
106     OpAsmParser &parser,
107     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
108     SmallVectorImpl<Type> &typesAllocate,
109     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
110     SmallVectorImpl<Type> &typesAllocator) {
111   if (parser.parseLParen())
112     return failure();
113 
114   do {
115     OpAsmParser::OperandType operand;
116     Type type;
117 
118     if (parser.parseOperand(operand) || parser.parseColonType(type))
119       return failure();
120     operandsAllocator.push_back(operand);
121     typesAllocator.push_back(type);
122     if (parser.parseArrow())
123       return failure();
124     if (parser.parseOperand(operand) || parser.parseColonType(type))
125       return failure();
126 
127     operandsAllocate.push_back(operand);
128     typesAllocate.push_back(type);
129   } while (succeeded(parser.parseOptionalComma()));
130 
131   if (parser.parseRParen())
132     return failure();
133 
134   return success();
135 }
136 
verifyParallelOp(ParallelOp op)137 static LogicalResult verifyParallelOp(ParallelOp op) {
138   if (op.allocate_vars().size() != op.allocators_vars().size())
139     return op.emitError(
140         "expected equal sizes for allocate and allocator variables");
141   return success();
142 }
143 
printParallelOp(OpAsmPrinter & p,ParallelOp op)144 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
145   p << "omp.parallel";
146 
147   if (auto ifCond = op.if_expr_var())
148     p << " if(" << ifCond << " : " << ifCond.getType() << ")";
149 
150   if (auto threads = op.num_threads_var())
151     p << " num_threads(" << threads << " : " << threads.getType() << ")";
152 
153   // Print private, firstprivate, shared and copyin parameters
154   auto printDataVars = [&p](StringRef name, OperandRange vars) {
155     if (vars.size()) {
156       p << " " << name << "(";
157       for (unsigned i = 0; i < vars.size(); ++i) {
158         std::string separator = i == vars.size() - 1 ? ")" : ", ";
159         p << vars[i] << " : " << vars[i].getType() << separator;
160       }
161     }
162   };
163 
164   // Print allocator and allocate parameters
165   auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
166                                         OperandRange varsAllocator) {
167     if (varsAllocate.empty())
168       return;
169 
170     p << " allocate(";
171     for (unsigned i = 0; i < varsAllocate.size(); ++i) {
172       std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
173       p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
174       p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
175     }
176   };
177 
178   printDataVars("private", op.private_vars());
179   printDataVars("firstprivate", op.firstprivate_vars());
180   printDataVars("shared", op.shared_vars());
181   printDataVars("copyin", op.copyin_vars());
182   printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
183 
184   if (auto def = op.default_val())
185     p << " default(" << def->drop_front(3) << ")";
186 
187   if (auto bind = op.proc_bind_val())
188     p << " proc_bind(" << bind << ")";
189 
190   p.printRegion(op.getRegion());
191 }
192 
193 /// Emit an error if the same clause is present more than once on an operation.
allowedOnce(OpAsmParser & parser,StringRef clause,StringRef operation)194 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause,
195                                StringRef operation) {
196   return parser.emitError(parser.getNameLoc())
197          << " at most one " << clause << " clause can appear on the "
198          << operation << " operation";
199 }
200 
201 /// Parses a parallel operation.
202 ///
203 /// operation ::= `omp.parallel` clause-list
204 /// clause-list ::= clause | clause clause-list
205 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
206 ///            default | procBind
207 /// if ::= `if` `(` ssa-id `)`
208 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
209 /// private ::= `private` operand-and-type-list
210 /// firstprivate ::= `firstprivate` operand-and-type-list
211 /// shared ::= `shared` operand-and-type-list
212 /// copyin ::= `copyin` operand-and-type-list
213 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
214 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
215 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
216 ///
217 /// Note that each clause can only appear once in the clase-list.
parseParallelOp(OpAsmParser & parser,OperationState & result)218 static ParseResult parseParallelOp(OpAsmParser &parser,
219                                    OperationState &result) {
220   std::pair<OpAsmParser::OperandType, Type> ifCond;
221   std::pair<OpAsmParser::OperandType, Type> numThreads;
222   SmallVector<OpAsmParser::OperandType, 4> privates;
223   SmallVector<Type, 4> privateTypes;
224   SmallVector<OpAsmParser::OperandType, 4> firstprivates;
225   SmallVector<Type, 4> firstprivateTypes;
226   SmallVector<OpAsmParser::OperandType, 4> shareds;
227   SmallVector<Type, 4> sharedTypes;
228   SmallVector<OpAsmParser::OperandType, 4> copyins;
229   SmallVector<Type, 4> copyinTypes;
230   SmallVector<OpAsmParser::OperandType, 4> allocates;
231   SmallVector<Type, 4> allocateTypes;
232   SmallVector<OpAsmParser::OperandType, 4> allocators;
233   SmallVector<Type, 4> allocatorTypes;
234   std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
235   StringRef keyword;
236   bool defaultVal = false;
237   bool procBind = false;
238 
239   const int ifClausePos = 0;
240   const int numThreadsClausePos = 1;
241   const int privateClausePos = 2;
242   const int firstprivateClausePos = 3;
243   const int sharedClausePos = 4;
244   const int copyinClausePos = 5;
245   const int allocateClausePos = 6;
246   const int allocatorPos = 7;
247   const StringRef opName = result.name.getStringRef();
248 
249   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
250     if (keyword == "if") {
251       // Fail if there was already another if condition.
252       if (segments[ifClausePos])
253         return allowedOnce(parser, "if", opName);
254       if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
255           parser.parseColonType(ifCond.second) || parser.parseRParen())
256         return failure();
257       segments[ifClausePos] = 1;
258     } else if (keyword == "num_threads") {
259       // Fail if there was already another num_threads clause.
260       if (segments[numThreadsClausePos])
261         return allowedOnce(parser, "num_threads", opName);
262       if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
263           parser.parseColonType(numThreads.second) || parser.parseRParen())
264         return failure();
265       segments[numThreadsClausePos] = 1;
266     } else if (keyword == "private") {
267       // Fail if there was already another private clause.
268       if (segments[privateClausePos])
269         return allowedOnce(parser, "private", opName);
270       if (parseOperandAndTypeList(parser, privates, privateTypes))
271         return failure();
272       segments[privateClausePos] = privates.size();
273     } else if (keyword == "firstprivate") {
274       // Fail if there was already another firstprivate clause.
275       if (segments[firstprivateClausePos])
276         return allowedOnce(parser, "firstprivate", opName);
277       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
278         return failure();
279       segments[firstprivateClausePos] = firstprivates.size();
280     } else if (keyword == "shared") {
281       // Fail if there was already another shared clause.
282       if (segments[sharedClausePos])
283         return allowedOnce(parser, "shared", opName);
284       if (parseOperandAndTypeList(parser, shareds, sharedTypes))
285         return failure();
286       segments[sharedClausePos] = shareds.size();
287     } else if (keyword == "copyin") {
288       // Fail if there was already another copyin clause.
289       if (segments[copyinClausePos])
290         return allowedOnce(parser, "copyin", opName);
291       if (parseOperandAndTypeList(parser, copyins, copyinTypes))
292         return failure();
293       segments[copyinClausePos] = copyins.size();
294     } else if (keyword == "allocate") {
295       // Fail if there was already another allocate clause.
296       if (segments[allocateClausePos])
297         return allowedOnce(parser, "allocate", opName);
298       if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
299                                     allocators, allocatorTypes))
300         return failure();
301       segments[allocateClausePos] = allocates.size();
302       segments[allocatorPos] = allocators.size();
303     } else if (keyword == "default") {
304       // Fail if there was already another default clause.
305       if (defaultVal)
306         return allowedOnce(parser, "default", opName);
307       defaultVal = true;
308       StringRef defval;
309       if (parser.parseLParen() || parser.parseKeyword(&defval) ||
310           parser.parseRParen())
311         return failure();
312       // The def prefix is required for the attribute as "private" is a keyword
313       // in C++.
314       auto attr = parser.getBuilder().getStringAttr("def" + defval);
315       result.addAttribute("default_val", attr);
316     } else if (keyword == "proc_bind") {
317       // Fail if there was already another proc_bind clause.
318       if (procBind)
319         return allowedOnce(parser, "proc_bind", opName);
320       procBind = true;
321       StringRef bind;
322       if (parser.parseLParen() || parser.parseKeyword(&bind) ||
323           parser.parseRParen())
324         return failure();
325       auto attr = parser.getBuilder().getStringAttr(bind);
326       result.addAttribute("proc_bind_val", attr);
327     } else {
328       return parser.emitError(parser.getNameLoc())
329              << keyword << " is not a valid clause for the " << opName
330              << " operation";
331     }
332   }
333 
334   // Add if parameter.
335   if (segments[ifClausePos] &&
336       parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
337     return failure();
338 
339   // Add num_threads parameter.
340   if (segments[numThreadsClausePos] &&
341       parser.resolveOperand(numThreads.first, numThreads.second,
342                             result.operands))
343     return failure();
344 
345   // Add private parameters.
346   if (segments[privateClausePos] &&
347       parser.resolveOperands(privates, privateTypes, privates[0].location,
348                              result.operands))
349     return failure();
350 
351   // Add firstprivate parameters.
352   if (segments[firstprivateClausePos] &&
353       parser.resolveOperands(firstprivates, firstprivateTypes,
354                              firstprivates[0].location, result.operands))
355     return failure();
356 
357   // Add shared parameters.
358   if (segments[sharedClausePos] &&
359       parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
360                              result.operands))
361     return failure();
362 
363   // Add copyin parameters.
364   if (segments[copyinClausePos] &&
365       parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
366                              result.operands))
367     return failure();
368 
369   // Add allocate parameters.
370   if (segments[allocateClausePos] &&
371       parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
372                              result.operands))
373     return failure();
374 
375   // Add allocator parameters.
376   if (segments[allocatorPos] &&
377       parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
378                              result.operands))
379     return failure();
380 
381   result.addAttribute("operand_segment_sizes",
382                       parser.getBuilder().getI32VectorAttr(segments));
383 
384   Region *body = result.addRegion();
385   SmallVector<OpAsmParser::OperandType, 4> regionArgs;
386   SmallVector<Type, 4> regionArgTypes;
387   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
388     return failure();
389   return success();
390 }
391 
392 /// linear ::= `linear` `(` linear-list `)`
393 /// linear-list := linear-val | linear-val linear-list
394 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
395 static ParseResult
parseLinearClause(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & vars,SmallVectorImpl<Type> & types,SmallVectorImpl<OpAsmParser::OperandType> & stepVars)396 parseLinearClause(OpAsmParser &parser,
397                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
398                   SmallVectorImpl<Type> &types,
399                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
400   if (parser.parseLParen())
401     return failure();
402 
403   do {
404     OpAsmParser::OperandType var;
405     Type type;
406     OpAsmParser::OperandType stepVar;
407     if (parser.parseOperand(var) || parser.parseEqual() ||
408         parser.parseOperand(stepVar) || parser.parseColonType(type))
409       return failure();
410 
411     vars.push_back(var);
412     types.push_back(type);
413     stepVars.push_back(stepVar);
414   } while (succeeded(parser.parseOptionalComma()));
415 
416   if (parser.parseRParen())
417     return failure();
418 
419   return success();
420 }
421 
422 /// schedule ::= `schedule` `(` sched-list `)`
423 /// sched-list ::= sched-val | sched-val sched-list
424 /// sched-val ::= sched-with-chunk | sched-wo-chunk
425 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
426 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
427 /// sched-wo-chunk ::=  `auto` | `runtime`
428 static ParseResult
parseScheduleClause(OpAsmParser & parser,SmallString<8> & schedule,Optional<OpAsmParser::OperandType> & chunkSize)429 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
430                     Optional<OpAsmParser::OperandType> &chunkSize) {
431   if (parser.parseLParen())
432     return failure();
433 
434   StringRef keyword;
435   if (parser.parseKeyword(&keyword))
436     return failure();
437 
438   schedule = keyword;
439   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
440     if (succeeded(parser.parseOptionalEqual())) {
441       chunkSize = OpAsmParser::OperandType{};
442       if (parser.parseOperand(*chunkSize))
443         return failure();
444     } else {
445       chunkSize = llvm::NoneType::None;
446     }
447   } else if (keyword == "auto" || keyword == "runtime") {
448     chunkSize = llvm::NoneType::None;
449   } else {
450     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
451   }
452 
453   if (parser.parseRParen())
454     return failure();
455 
456   return success();
457 }
458 
459 /// reduction-init ::= `reduction` `(` reduction-entry-list `)`
460 /// reduction-entry-list ::= reduction-entry
461 ///                        | reduction-entry-list `,` reduction-entry
462 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
463 static ParseResult
parseReductionVarList(OpAsmParser & parser,SmallVectorImpl<SymbolRefAttr> & symbols,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)464 parseReductionVarList(OpAsmParser &parser,
465                       SmallVectorImpl<SymbolRefAttr> &symbols,
466                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
467                       SmallVectorImpl<Type> &types) {
468   if (failed(parser.parseLParen()))
469     return failure();
470 
471   do {
472     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
473         parser.parseOperand(operands.emplace_back()) ||
474         parser.parseColonType(types.emplace_back()))
475       return failure();
476   } while (succeeded(parser.parseOptionalComma()));
477   return parser.parseRParen();
478 }
479 
480 /// Parses an OpenMP Workshare Loop operation
481 ///
482 /// operation ::= `omp.wsloop` loop-control clause-list
483 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
484 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
485 /// steps := `step` `(`ssa-id-list`)`
486 /// clause-list ::= clause | empty | clause-list
487 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
488 //             collapse | nowait | ordered | order | inclusive
489 /// private ::= `private` `(` ssa-id-and-type-list `)`
490 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)`
491 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)`
492 /// linear ::= `linear` `(` linear-list `)`
493 /// schedule ::= `schedule` `(` sched-list `)`
494 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
495 /// nowait ::= `nowait`
496 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
497 /// order ::= `order` `(` `concurrent` `)`
498 /// inclusive ::= `inclusive`
499 ///
parseWsLoopOp(OpAsmParser & parser,OperationState & result)500 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
501   Type loopVarType;
502   int numIVs;
503 
504   // Parse an opening `(` followed by induction variables followed by `)`
505   SmallVector<OpAsmParser::OperandType> ivs;
506   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
507                                      OpAsmParser::Delimiter::Paren))
508     return failure();
509 
510   numIVs = static_cast<int>(ivs.size());
511 
512   if (parser.parseColonType(loopVarType))
513     return failure();
514 
515   // Parse loop bounds.
516   SmallVector<OpAsmParser::OperandType> lower;
517   if (parser.parseEqual() ||
518       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
519       parser.resolveOperands(lower, loopVarType, result.operands))
520     return failure();
521 
522   SmallVector<OpAsmParser::OperandType> upper;
523   if (parser.parseKeyword("to") ||
524       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
525       parser.resolveOperands(upper, loopVarType, result.operands))
526     return failure();
527 
528   // Parse step values.
529   SmallVector<OpAsmParser::OperandType> steps;
530   if (parser.parseKeyword("step") ||
531       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
532       parser.resolveOperands(steps, loopVarType, result.operands))
533     return failure();
534 
535   SmallVector<OpAsmParser::OperandType> privates;
536   SmallVector<Type> privateTypes;
537   SmallVector<OpAsmParser::OperandType> firstprivates;
538   SmallVector<Type> firstprivateTypes;
539   SmallVector<OpAsmParser::OperandType> lastprivates;
540   SmallVector<Type> lastprivateTypes;
541   SmallVector<OpAsmParser::OperandType> linears;
542   SmallVector<Type> linearTypes;
543   SmallVector<OpAsmParser::OperandType> linearSteps;
544   SmallVector<SymbolRefAttr> reductionSymbols;
545   SmallVector<OpAsmParser::OperandType> reductionVars;
546   SmallVector<Type> reductionVarTypes;
547   SmallString<8> schedule;
548   Optional<OpAsmParser::OperandType> scheduleChunkSize;
549 
550   const StringRef opName = result.name.getStringRef();
551   StringRef keyword;
552 
553   enum SegmentPos {
554     lbPos = 0,
555     ubPos,
556     stepPos,
557     privateClausePos,
558     firstprivateClausePos,
559     lastprivateClausePos,
560     linearClausePos,
561     linearStepPos,
562     reductionVarPos,
563     scheduleClausePos,
564   };
565   std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0};
566 
567   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
568     if (keyword == "private") {
569       if (segments[privateClausePos])
570         return allowedOnce(parser, "private", opName);
571       if (parseOperandAndTypeList(parser, privates, privateTypes))
572         return failure();
573       segments[privateClausePos] = privates.size();
574     } else if (keyword == "firstprivate") {
575       // fail if there was already another firstprivate clause
576       if (segments[firstprivateClausePos])
577         return allowedOnce(parser, "firstprivate", opName);
578       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
579         return failure();
580       segments[firstprivateClausePos] = firstprivates.size();
581     } else if (keyword == "lastprivate") {
582       // fail if there was already another shared clause
583       if (segments[lastprivateClausePos])
584         return allowedOnce(parser, "lastprivate", opName);
585       if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
586         return failure();
587       segments[lastprivateClausePos] = lastprivates.size();
588     } else if (keyword == "linear") {
589       // fail if there was already another linear clause
590       if (segments[linearClausePos])
591         return allowedOnce(parser, "linear", opName);
592       if (parseLinearClause(parser, linears, linearTypes, linearSteps))
593         return failure();
594       segments[linearClausePos] = linears.size();
595       segments[linearStepPos] = linearSteps.size();
596     } else if (keyword == "schedule") {
597       if (!schedule.empty())
598         return allowedOnce(parser, "schedule", opName);
599       if (parseScheduleClause(parser, schedule, scheduleChunkSize))
600         return failure();
601       if (scheduleChunkSize) {
602         segments[scheduleClausePos] = 1;
603       }
604     } else if (keyword == "collapse") {
605       auto type = parser.getBuilder().getI64Type();
606       mlir::IntegerAttr attr;
607       if (parser.parseLParen() || parser.parseAttribute(attr, type) ||
608           parser.parseRParen())
609         return failure();
610       result.addAttribute("collapse_val", attr);
611     } else if (keyword == "nowait") {
612       auto attr = UnitAttr::get(parser.getBuilder().getContext());
613       result.addAttribute("nowait", attr);
614     } else if (keyword == "ordered") {
615       mlir::IntegerAttr attr;
616       if (succeeded(parser.parseOptionalLParen())) {
617         auto type = parser.getBuilder().getI64Type();
618         if (parser.parseAttribute(attr, type))
619           return failure();
620         if (parser.parseRParen())
621           return failure();
622       } else {
623         // Use 0 to represent no ordered parameter was specified
624         attr = parser.getBuilder().getI64IntegerAttr(0);
625       }
626       result.addAttribute("ordered_val", attr);
627     } else if (keyword == "order") {
628       StringRef order;
629       if (parser.parseLParen() || parser.parseKeyword(&order) ||
630           parser.parseRParen())
631         return failure();
632       auto attr = parser.getBuilder().getStringAttr(order);
633       result.addAttribute("order", attr);
634     } else if (keyword == "inclusive") {
635       auto attr = UnitAttr::get(parser.getBuilder().getContext());
636       result.addAttribute("inclusive", attr);
637     } else if (keyword == "reduction") {
638       if (segments[reductionVarPos])
639         return allowedOnce(parser, "reduction", opName);
640       if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars,
641                                        reductionVarTypes)))
642         return failure();
643       segments[reductionVarPos] = reductionVars.size();
644     }
645   }
646 
647   if (segments[privateClausePos]) {
648     parser.resolveOperands(privates, privateTypes, privates[0].location,
649                            result.operands);
650   }
651 
652   if (segments[firstprivateClausePos]) {
653     parser.resolveOperands(firstprivates, firstprivateTypes,
654                            firstprivates[0].location, result.operands);
655   }
656 
657   if (segments[lastprivateClausePos]) {
658     parser.resolveOperands(lastprivates, lastprivateTypes,
659                            lastprivates[0].location, result.operands);
660   }
661 
662   if (segments[linearClausePos]) {
663     parser.resolveOperands(linears, linearTypes, linears[0].location,
664                            result.operands);
665     auto linearStepType = parser.getBuilder().getI32Type();
666     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
667     parser.resolveOperands(linearSteps, linearStepTypes,
668                            linearSteps[0].location, result.operands);
669   }
670 
671   if (segments[reductionVarPos]) {
672     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
673                                       parser.getNameLoc(), result.operands))) {
674       return failure();
675     }
676     SmallVector<Attribute> reductions(reductionSymbols.begin(),
677                                       reductionSymbols.end());
678     result.addAttribute("reductions",
679                         parser.getBuilder().getArrayAttr(reductions));
680   }
681 
682   if (!schedule.empty()) {
683     schedule[0] = llvm::toUpper(schedule[0]);
684     auto attr = parser.getBuilder().getStringAttr(schedule);
685     result.addAttribute("schedule_val", attr);
686     if (scheduleChunkSize) {
687       auto chunkSizeType = parser.getBuilder().getI32Type();
688       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
689     }
690   }
691 
692   result.addAttribute("operand_segment_sizes",
693                       parser.getBuilder().getI32VectorAttr(segments));
694 
695   // Now parse the body.
696   Region *body = result.addRegion();
697   SmallVector<Type> ivTypes(numIVs, loopVarType);
698   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
699   if (parser.parseRegion(*body, blockArgs, ivTypes))
700     return failure();
701   return success();
702 }
703 
printWsLoopOp(OpAsmPrinter & p,WsLoopOp op)704 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
705   auto args = op.getRegion().front().getArguments();
706   p << op.getOperationName() << " (" << args << ") : " << args[0].getType()
707     << " = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step ("
708     << op.step() << ")";
709 
710   // Print private, firstprivate, shared and copyin parameters
711   auto printDataVars = [&p](StringRef name, OperandRange vars) {
712     if (vars.empty())
713       return;
714 
715     p << " " << name << "(";
716     llvm::interleaveComma(
717         vars, p, [&](const Value &v) { p << v << " : " << v.getType(); });
718     p << ")";
719   };
720   printDataVars("private", op.private_vars());
721   printDataVars("firstprivate", op.firstprivate_vars());
722   printDataVars("lastprivate", op.lastprivate_vars());
723 
724   auto linearVars = op.linear_vars();
725   auto linearVarsSize = linearVars.size();
726   if (linearVarsSize) {
727     p << " "
728       << "linear"
729       << "(";
730     for (unsigned i = 0; i < linearVarsSize; ++i) {
731       std::string separator = i == linearVarsSize - 1 ? ")" : ", ";
732       p << linearVars[i];
733       if (op.linear_step_vars().size() > i)
734         p << " = " << op.linear_step_vars()[i];
735       p << " : " << linearVars[i].getType() << separator;
736     }
737   }
738 
739   if (auto sched = op.schedule_val()) {
740     auto schedLower = sched->lower();
741     p << " schedule(" << schedLower;
742     if (auto chunk = op.schedule_chunk_var()) {
743       p << " = " << chunk;
744     }
745     p << ")";
746   }
747 
748   if (auto collapse = op.collapse_val())
749     p << " collapse(" << collapse << ")";
750 
751   if (op.nowait())
752     p << " nowait";
753 
754   if (auto ordered = op.ordered_val()) {
755     p << " ordered(" << ordered << ")";
756   }
757 
758   if (!op.reduction_vars().empty()) {
759     p << " reduction(";
760     for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) {
761       if (i != 0)
762         p << ", ";
763       p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : "
764         << op.reduction_vars()[i].getType();
765     }
766     p << ")";
767   }
768 
769   if (op.inclusive()) {
770     p << " inclusive";
771   }
772 
773   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // ReductionOp
778 //===----------------------------------------------------------------------===//
779 
parseAtomicReductionRegion(OpAsmParser & parser,Region & region)780 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
781                                               Region &region) {
782   if (parser.parseOptionalKeyword("atomic"))
783     return success();
784   return parser.parseRegion(region);
785 }
786 
printAtomicReductionRegion(OpAsmPrinter & printer,ReductionDeclareOp op,Region & region)787 static void printAtomicReductionRegion(OpAsmPrinter &printer,
788                                        ReductionDeclareOp op, Region &region) {
789   if (region.empty())
790     return;
791   printer << "atomic ";
792   printer.printRegion(region);
793 }
794 
verifyReductionDeclareOp(ReductionDeclareOp op)795 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
796   if (op.initializerRegion().empty())
797     return op.emitOpError() << "expects non-empty initializer region";
798   Block &initializerEntryBlock = op.initializerRegion().front();
799   if (initializerEntryBlock.getNumArguments() != 1 ||
800       initializerEntryBlock.getArgument(0).getType() != op.type()) {
801     return op.emitOpError() << "expects initializer region with one argument "
802                                "of the reduction type";
803   }
804 
805   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
806     if (yieldOp.results().size() != 1 ||
807         yieldOp.results().getTypes()[0] != op.type())
808       return op.emitOpError() << "expects initializer region to yield a value "
809                                  "of the reduction type";
810   }
811 
812   if (op.reductionRegion().empty())
813     return op.emitOpError() << "expects non-empty reduction region";
814   Block &reductionEntryBlock = op.reductionRegion().front();
815   if (reductionEntryBlock.getNumArguments() != 2 ||
816       reductionEntryBlock.getArgumentTypes()[0] !=
817           reductionEntryBlock.getArgumentTypes()[1] ||
818       reductionEntryBlock.getArgumentTypes()[0] != op.type())
819     return op.emitOpError() << "expects reduction region with two arguments of "
820                                "the reduction type";
821   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
822     if (yieldOp.results().size() != 1 ||
823         yieldOp.results().getTypes()[0] != op.type())
824       return op.emitOpError() << "expects reduction region to yield a value "
825                                  "of the reduction type";
826   }
827 
828   if (op.atomicReductionRegion().empty())
829     return success();
830 
831   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
832   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
833       atomicReductionEntryBlock.getArgumentTypes()[0] !=
834           atomicReductionEntryBlock.getArgumentTypes()[1])
835     return op.emitOpError() << "expects atomic reduction region with two "
836                                "arguments of the same type";
837   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
838                      .dyn_cast<PointerLikeType>();
839   if (!ptrType || ptrType.getElementType() != op.type())
840     return op.emitOpError() << "expects atomic reduction region arguments to "
841                                "be accumulators containing the reduction type";
842   return success();
843 }
844 
verifyReductionOp(ReductionOp op)845 static LogicalResult verifyReductionOp(ReductionOp op) {
846   // TODO: generalize this to an op interface when there is more than one op
847   // that supports reductions.
848   auto container = op->getParentOfType<WsLoopOp>();
849   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
850     if (container.reduction_vars()[i] == op.accumulator())
851       return success();
852 
853   return op.emitOpError() << "the accumulator is not used by the parent";
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // WsLoopOp
858 //===----------------------------------------------------------------------===//
859 
build(OpBuilder & builder,OperationState & state,ValueRange lowerBound,ValueRange upperBound,ValueRange step,ArrayRef<NamedAttribute> attributes)860 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
861                      ValueRange lowerBound, ValueRange upperBound,
862                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
863   build(builder, state, TypeRange(), lowerBound, upperBound, step,
864         /*private_vars=*/ValueRange(),
865         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
866         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
867         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
868         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
869         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
870         /*inclusive=*/nullptr, /*buildBody=*/false);
871   state.addAttributes(attributes);
872 }
873 
build(OpBuilder &,OperationState & state,TypeRange resultTypes,ValueRange operands,ArrayRef<NamedAttribute> attributes)874 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
875                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
876   state.addOperands(operands);
877   state.addAttributes(attributes);
878   (void)state.addRegion();
879   assert(resultTypes.empty() && "mismatched number of return types");
880   state.addTypes(resultTypes);
881 }
882 
build(OpBuilder & builder,OperationState & result,TypeRange typeRange,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange privateVars,ValueRange firstprivateVars,ValueRange lastprivateVars,ValueRange linearVars,ValueRange linearStepVars,ValueRange reductionVars,StringAttr scheduleVal,Value scheduleChunkVar,IntegerAttr collapseVal,UnitAttr nowait,IntegerAttr orderedVal,StringAttr orderVal,UnitAttr inclusive,bool buildBody)883 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
884                      TypeRange typeRange, ValueRange lowerBounds,
885                      ValueRange upperBounds, ValueRange steps,
886                      ValueRange privateVars, ValueRange firstprivateVars,
887                      ValueRange lastprivateVars, ValueRange linearVars,
888                      ValueRange linearStepVars, ValueRange reductionVars,
889                      StringAttr scheduleVal, Value scheduleChunkVar,
890                      IntegerAttr collapseVal, UnitAttr nowait,
891                      IntegerAttr orderedVal, StringAttr orderVal,
892                      UnitAttr inclusive, bool buildBody) {
893   result.addOperands(lowerBounds);
894   result.addOperands(upperBounds);
895   result.addOperands(steps);
896   result.addOperands(privateVars);
897   result.addOperands(firstprivateVars);
898   result.addOperands(linearVars);
899   result.addOperands(linearStepVars);
900   if (scheduleChunkVar)
901     result.addOperands(scheduleChunkVar);
902 
903   if (scheduleVal)
904     result.addAttribute("schedule_val", scheduleVal);
905   if (collapseVal)
906     result.addAttribute("collapse_val", collapseVal);
907   if (nowait)
908     result.addAttribute("nowait", nowait);
909   if (orderedVal)
910     result.addAttribute("ordered_val", orderedVal);
911   if (orderVal)
912     result.addAttribute("order", orderVal);
913   if (inclusive)
914     result.addAttribute("inclusive", inclusive);
915   result.addAttribute(
916       WsLoopOp::getOperandSegmentSizeAttr(),
917       builder.getI32VectorAttr(
918           {static_cast<int32_t>(lowerBounds.size()),
919            static_cast<int32_t>(upperBounds.size()),
920            static_cast<int32_t>(steps.size()),
921            static_cast<int32_t>(privateVars.size()),
922            static_cast<int32_t>(firstprivateVars.size()),
923            static_cast<int32_t>(lastprivateVars.size()),
924            static_cast<int32_t>(linearVars.size()),
925            static_cast<int32_t>(linearStepVars.size()),
926            static_cast<int32_t>(reductionVars.size()),
927            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
928 
929   Region *bodyRegion = result.addRegion();
930   if (buildBody) {
931     OpBuilder::InsertionGuard guard(builder);
932     unsigned numIVs = steps.size();
933     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
934     builder.createBlock(bodyRegion, {}, argTypes);
935   }
936 }
937 
verifyWsLoopOp(WsLoopOp op)938 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
939   if (op.getNumReductionVars() != 0) {
940     if (!op.reductions() ||
941         op.reductions()->size() != op.getNumReductionVars()) {
942       return op.emitOpError() << "expected as many reduction symbol references "
943                                  "as reduction variables";
944     }
945   } else {
946     if (op.reductions())
947       return op.emitOpError() << "unexpected reduction symbol references";
948     return success();
949   }
950 
951   DenseSet<Value> accumulators;
952   for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) {
953     Value accum = std::get<0>(args);
954     if (!accumulators.insert(accum).second) {
955       return op.emitOpError() << "accumulator variable used more than once";
956     }
957     Type varType = accum.getType().cast<PointerLikeType>();
958     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
959     auto decl =
960         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
961     if (!decl) {
962       return op.emitOpError() << "expected symbol reference " << symbolRef
963                               << " to point to a reduction declaration";
964     }
965 
966     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) {
967       return op.emitOpError()
968              << "expected accumulator (" << varType
969              << ") to be the same type as reduction declaration ("
970              << decl.getAccumulatorType() << ")";
971     }
972   }
973 
974   return success();
975 }
976 
977 #define GET_OP_CLASSES
978 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
979