1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include <cstddef>
25
26 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
27 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
29
30 using namespace mlir;
31 using namespace mlir::omp;
32
33 namespace {
34 /// Model for pointer-like types that already provide a `getElementType` method.
35 template <typename T>
36 struct PointerLikeModel
37 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
getElementType__anon959c6e340111::PointerLikeModel38 Type getElementType(Type pointer) const {
39 return pointer.cast<T>().getElementType();
40 }
41 };
42 } // end namespace
43
initialize()44 void OpenMPDialect::initialize() {
45 addOperations<
46 #define GET_OP_LIST
47 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
48 >();
49
50 LLVM::LLVMPointerType::attachInterface<
51 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
52 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
53 }
54
55 //===----------------------------------------------------------------------===//
56 // ParallelOp
57 //===----------------------------------------------------------------------===//
58
build(OpBuilder & builder,OperationState & state,ArrayRef<NamedAttribute> attributes)59 void ParallelOp::build(OpBuilder &builder, OperationState &state,
60 ArrayRef<NamedAttribute> attributes) {
61 ParallelOp::build(
62 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
63 /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
64 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
65 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
66 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
67 state.addAttributes(attributes);
68 }
69
70 /// Parse a list of operands with types.
71 ///
72 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
73 /// ssa-id-and-type-list ::= ssa-id-and-type |
74 /// ssa-id-and-type `,` ssa-id-and-type-list
75 /// ssa-id-and-type ::= ssa-id `:` type
76 static ParseResult
parseOperandAndTypeList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)77 parseOperandAndTypeList(OpAsmParser &parser,
78 SmallVectorImpl<OpAsmParser::OperandType> &operands,
79 SmallVectorImpl<Type> &types) {
80 return parser.parseCommaSeparatedList(
81 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
82 OpAsmParser::OperandType operand;
83 Type type;
84 if (parser.parseOperand(operand) || parser.parseColonType(type))
85 return failure();
86 operands.push_back(operand);
87 types.push_back(type);
88 return success();
89 });
90 }
91
92 /// Parse an allocate clause with allocators and a list of operands with types.
93 ///
94 /// operand-and-type-list ::= `(` allocate-operand-list `)`
95 /// allocate-operand-list :: = allocate-operand |
96 /// allocator-operand `,` allocate-operand-list
97 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
98 /// ssa-id-and-type ::= ssa-id `:` type
parseAllocateAndAllocator(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocate,SmallVectorImpl<Type> & typesAllocate,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocator,SmallVectorImpl<Type> & typesAllocator)99 static ParseResult parseAllocateAndAllocator(
100 OpAsmParser &parser,
101 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
102 SmallVectorImpl<Type> &typesAllocate,
103 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
104 SmallVectorImpl<Type> &typesAllocator) {
105
106 return parser.parseCommaSeparatedList(
107 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
108 OpAsmParser::OperandType operand;
109 Type type;
110 if (parser.parseOperand(operand) || parser.parseColonType(type))
111 return failure();
112 operandsAllocator.push_back(operand);
113 typesAllocator.push_back(type);
114 if (parser.parseArrow())
115 return failure();
116 if (parser.parseOperand(operand) || parser.parseColonType(type))
117 return failure();
118
119 operandsAllocate.push_back(operand);
120 typesAllocate.push_back(type);
121 return success();
122 });
123 }
124
verifyParallelOp(ParallelOp op)125 static LogicalResult verifyParallelOp(ParallelOp op) {
126 if (op.allocate_vars().size() != op.allocators_vars().size())
127 return op.emitError(
128 "expected equal sizes for allocate and allocator variables");
129 return success();
130 }
131
printParallelOp(OpAsmPrinter & p,ParallelOp op)132 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
133 if (auto ifCond = op.if_expr_var())
134 p << " if(" << ifCond << " : " << ifCond.getType() << ")";
135
136 if (auto threads = op.num_threads_var())
137 p << " num_threads(" << threads << " : " << threads.getType() << ")";
138
139 // Print private, firstprivate, shared and copyin parameters
140 auto printDataVars = [&p](StringRef name, OperandRange vars) {
141 if (vars.size()) {
142 p << " " << name << "(";
143 for (unsigned i = 0; i < vars.size(); ++i) {
144 std::string separator = i == vars.size() - 1 ? ")" : ", ";
145 p << vars[i] << " : " << vars[i].getType() << separator;
146 }
147 }
148 };
149
150 // Print allocator and allocate parameters
151 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
152 OperandRange varsAllocator) {
153 if (varsAllocate.empty())
154 return;
155
156 p << " allocate(";
157 for (unsigned i = 0; i < varsAllocate.size(); ++i) {
158 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
159 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
160 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
161 }
162 };
163
164 printDataVars("private", op.private_vars());
165 printDataVars("firstprivate", op.firstprivate_vars());
166 printDataVars("shared", op.shared_vars());
167 printDataVars("copyin", op.copyin_vars());
168 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
169
170 if (auto def = op.default_val())
171 p << " default(" << def->drop_front(3) << ")";
172
173 if (auto bind = op.proc_bind_val())
174 p << " proc_bind(" << bind << ")";
175
176 p.printRegion(op.getRegion());
177 }
178
179 /// Emit an error if the same clause is present more than once on an operation.
allowedOnce(OpAsmParser & parser,StringRef clause,StringRef operation)180 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause,
181 StringRef operation) {
182 return parser.emitError(parser.getNameLoc())
183 << " at most one " << clause << " clause can appear on the "
184 << operation << " operation";
185 }
186
187 /// Parses a parallel operation.
188 ///
189 /// operation ::= `omp.parallel` clause-list
190 /// clause-list ::= clause | clause clause-list
191 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
192 /// default | procBind
193 /// if ::= `if` `(` ssa-id `)`
194 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
195 /// private ::= `private` operand-and-type-list
196 /// firstprivate ::= `firstprivate` operand-and-type-list
197 /// shared ::= `shared` operand-and-type-list
198 /// copyin ::= `copyin` operand-and-type-list
199 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
200 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
201 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
202 ///
203 /// Note that each clause can only appear once in the clase-list.
parseParallelOp(OpAsmParser & parser,OperationState & result)204 static ParseResult parseParallelOp(OpAsmParser &parser,
205 OperationState &result) {
206 std::pair<OpAsmParser::OperandType, Type> ifCond;
207 std::pair<OpAsmParser::OperandType, Type> numThreads;
208 SmallVector<OpAsmParser::OperandType, 4> privates;
209 SmallVector<Type, 4> privateTypes;
210 SmallVector<OpAsmParser::OperandType, 4> firstprivates;
211 SmallVector<Type, 4> firstprivateTypes;
212 SmallVector<OpAsmParser::OperandType, 4> shareds;
213 SmallVector<Type, 4> sharedTypes;
214 SmallVector<OpAsmParser::OperandType, 4> copyins;
215 SmallVector<Type, 4> copyinTypes;
216 SmallVector<OpAsmParser::OperandType, 4> allocates;
217 SmallVector<Type, 4> allocateTypes;
218 SmallVector<OpAsmParser::OperandType, 4> allocators;
219 SmallVector<Type, 4> allocatorTypes;
220 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
221 StringRef keyword;
222 bool defaultVal = false;
223 bool procBind = false;
224
225 const int ifClausePos = 0;
226 const int numThreadsClausePos = 1;
227 const int privateClausePos = 2;
228 const int firstprivateClausePos = 3;
229 const int sharedClausePos = 4;
230 const int copyinClausePos = 5;
231 const int allocateClausePos = 6;
232 const int allocatorPos = 7;
233 const StringRef opName = result.name.getStringRef();
234
235 while (succeeded(parser.parseOptionalKeyword(&keyword))) {
236 if (keyword == "if") {
237 // Fail if there was already another if condition.
238 if (segments[ifClausePos])
239 return allowedOnce(parser, "if", opName);
240 if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
241 parser.parseColonType(ifCond.second) || parser.parseRParen())
242 return failure();
243 segments[ifClausePos] = 1;
244 } else if (keyword == "num_threads") {
245 // Fail if there was already another num_threads clause.
246 if (segments[numThreadsClausePos])
247 return allowedOnce(parser, "num_threads", opName);
248 if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
249 parser.parseColonType(numThreads.second) || parser.parseRParen())
250 return failure();
251 segments[numThreadsClausePos] = 1;
252 } else if (keyword == "private") {
253 // Fail if there was already another private clause.
254 if (segments[privateClausePos])
255 return allowedOnce(parser, "private", opName);
256 if (parseOperandAndTypeList(parser, privates, privateTypes))
257 return failure();
258 segments[privateClausePos] = privates.size();
259 } else if (keyword == "firstprivate") {
260 // Fail if there was already another firstprivate clause.
261 if (segments[firstprivateClausePos])
262 return allowedOnce(parser, "firstprivate", opName);
263 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
264 return failure();
265 segments[firstprivateClausePos] = firstprivates.size();
266 } else if (keyword == "shared") {
267 // Fail if there was already another shared clause.
268 if (segments[sharedClausePos])
269 return allowedOnce(parser, "shared", opName);
270 if (parseOperandAndTypeList(parser, shareds, sharedTypes))
271 return failure();
272 segments[sharedClausePos] = shareds.size();
273 } else if (keyword == "copyin") {
274 // Fail if there was already another copyin clause.
275 if (segments[copyinClausePos])
276 return allowedOnce(parser, "copyin", opName);
277 if (parseOperandAndTypeList(parser, copyins, copyinTypes))
278 return failure();
279 segments[copyinClausePos] = copyins.size();
280 } else if (keyword == "allocate") {
281 // Fail if there was already another allocate clause.
282 if (segments[allocateClausePos])
283 return allowedOnce(parser, "allocate", opName);
284 if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
285 allocators, allocatorTypes))
286 return failure();
287 segments[allocateClausePos] = allocates.size();
288 segments[allocatorPos] = allocators.size();
289 } else if (keyword == "default") {
290 // Fail if there was already another default clause.
291 if (defaultVal)
292 return allowedOnce(parser, "default", opName);
293 defaultVal = true;
294 StringRef defval;
295 if (parser.parseLParen() || parser.parseKeyword(&defval) ||
296 parser.parseRParen())
297 return failure();
298 // The def prefix is required for the attribute as "private" is a keyword
299 // in C++.
300 auto attr = parser.getBuilder().getStringAttr("def" + defval);
301 result.addAttribute("default_val", attr);
302 } else if (keyword == "proc_bind") {
303 // Fail if there was already another proc_bind clause.
304 if (procBind)
305 return allowedOnce(parser, "proc_bind", opName);
306 procBind = true;
307 StringRef bind;
308 if (parser.parseLParen() || parser.parseKeyword(&bind) ||
309 parser.parseRParen())
310 return failure();
311 auto attr = parser.getBuilder().getStringAttr(bind);
312 result.addAttribute("proc_bind_val", attr);
313 } else {
314 return parser.emitError(parser.getNameLoc())
315 << keyword << " is not a valid clause for the " << opName
316 << " operation";
317 }
318 }
319
320 // Add if parameter.
321 if (segments[ifClausePos] &&
322 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
323 return failure();
324
325 // Add num_threads parameter.
326 if (segments[numThreadsClausePos] &&
327 parser.resolveOperand(numThreads.first, numThreads.second,
328 result.operands))
329 return failure();
330
331 // Add private parameters.
332 if (segments[privateClausePos] &&
333 parser.resolveOperands(privates, privateTypes, privates[0].location,
334 result.operands))
335 return failure();
336
337 // Add firstprivate parameters.
338 if (segments[firstprivateClausePos] &&
339 parser.resolveOperands(firstprivates, firstprivateTypes,
340 firstprivates[0].location, result.operands))
341 return failure();
342
343 // Add shared parameters.
344 if (segments[sharedClausePos] &&
345 parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
346 result.operands))
347 return failure();
348
349 // Add copyin parameters.
350 if (segments[copyinClausePos] &&
351 parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
352 result.operands))
353 return failure();
354
355 // Add allocate parameters.
356 if (segments[allocateClausePos] &&
357 parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
358 result.operands))
359 return failure();
360
361 // Add allocator parameters.
362 if (segments[allocatorPos] &&
363 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
364 result.operands))
365 return failure();
366
367 result.addAttribute("operand_segment_sizes",
368 parser.getBuilder().getI32VectorAttr(segments));
369
370 Region *body = result.addRegion();
371 SmallVector<OpAsmParser::OperandType, 4> regionArgs;
372 SmallVector<Type, 4> regionArgTypes;
373 if (parser.parseRegion(*body, regionArgs, regionArgTypes))
374 return failure();
375 return success();
376 }
377
378 /// linear ::= `linear` `(` linear-list `)`
379 /// linear-list := linear-val | linear-val linear-list
380 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
381 static ParseResult
parseLinearClause(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & vars,SmallVectorImpl<Type> & types,SmallVectorImpl<OpAsmParser::OperandType> & stepVars)382 parseLinearClause(OpAsmParser &parser,
383 SmallVectorImpl<OpAsmParser::OperandType> &vars,
384 SmallVectorImpl<Type> &types,
385 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
386 if (parser.parseLParen())
387 return failure();
388
389 do {
390 OpAsmParser::OperandType var;
391 Type type;
392 OpAsmParser::OperandType stepVar;
393 if (parser.parseOperand(var) || parser.parseEqual() ||
394 parser.parseOperand(stepVar) || parser.parseColonType(type))
395 return failure();
396
397 vars.push_back(var);
398 types.push_back(type);
399 stepVars.push_back(stepVar);
400 } while (succeeded(parser.parseOptionalComma()));
401
402 if (parser.parseRParen())
403 return failure();
404
405 return success();
406 }
407
408 /// schedule ::= `schedule` `(` sched-list `)`
409 /// sched-list ::= sched-val | sched-val sched-list
410 /// sched-val ::= sched-with-chunk | sched-wo-chunk
411 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
412 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
413 /// sched-wo-chunk ::= `auto` | `runtime`
414 static ParseResult
parseScheduleClause(OpAsmParser & parser,SmallString<8> & schedule,Optional<OpAsmParser::OperandType> & chunkSize)415 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
416 Optional<OpAsmParser::OperandType> &chunkSize) {
417 if (parser.parseLParen())
418 return failure();
419
420 StringRef keyword;
421 if (parser.parseKeyword(&keyword))
422 return failure();
423
424 schedule = keyword;
425 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
426 if (succeeded(parser.parseOptionalEqual())) {
427 chunkSize = OpAsmParser::OperandType{};
428 if (parser.parseOperand(*chunkSize))
429 return failure();
430 } else {
431 chunkSize = llvm::NoneType::None;
432 }
433 } else if (keyword == "auto" || keyword == "runtime") {
434 chunkSize = llvm::NoneType::None;
435 } else {
436 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
437 }
438
439 if (parser.parseRParen())
440 return failure();
441
442 return success();
443 }
444
445 /// reduction-init ::= `reduction` `(` reduction-entry-list `)`
446 /// reduction-entry-list ::= reduction-entry
447 /// | reduction-entry-list `,` reduction-entry
448 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
449 static ParseResult
parseReductionVarList(OpAsmParser & parser,SmallVectorImpl<SymbolRefAttr> & symbols,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)450 parseReductionVarList(OpAsmParser &parser,
451 SmallVectorImpl<SymbolRefAttr> &symbols,
452 SmallVectorImpl<OpAsmParser::OperandType> &operands,
453 SmallVectorImpl<Type> &types) {
454 if (failed(parser.parseLParen()))
455 return failure();
456
457 do {
458 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
459 parser.parseOperand(operands.emplace_back()) ||
460 parser.parseColonType(types.emplace_back()))
461 return failure();
462 } while (succeeded(parser.parseOptionalComma()));
463 return parser.parseRParen();
464 }
465
466 /// Parses an OpenMP Workshare Loop operation
467 ///
468 /// operation ::= `omp.wsloop` loop-control clause-list
469 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
470 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
471 /// steps := `step` `(`ssa-id-list`)`
472 /// clause-list ::= clause | empty | clause-list
473 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
474 // collapse | nowait | ordered | order | inclusive
475 /// private ::= `private` `(` ssa-id-and-type-list `)`
476 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)`
477 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)`
478 /// linear ::= `linear` `(` linear-list `)`
479 /// schedule ::= `schedule` `(` sched-list `)`
480 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
481 /// nowait ::= `nowait`
482 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
483 /// order ::= `order` `(` `concurrent` `)`
484 /// inclusive ::= `inclusive`
485 ///
parseWsLoopOp(OpAsmParser & parser,OperationState & result)486 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
487 Type loopVarType;
488 int numIVs;
489
490 // Parse an opening `(` followed by induction variables followed by `)`
491 SmallVector<OpAsmParser::OperandType> ivs;
492 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
493 OpAsmParser::Delimiter::Paren))
494 return failure();
495
496 numIVs = static_cast<int>(ivs.size());
497
498 if (parser.parseColonType(loopVarType))
499 return failure();
500
501 // Parse loop bounds.
502 SmallVector<OpAsmParser::OperandType> lower;
503 if (parser.parseEqual() ||
504 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
505 parser.resolveOperands(lower, loopVarType, result.operands))
506 return failure();
507
508 SmallVector<OpAsmParser::OperandType> upper;
509 if (parser.parseKeyword("to") ||
510 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
511 parser.resolveOperands(upper, loopVarType, result.operands))
512 return failure();
513
514 // Parse step values.
515 SmallVector<OpAsmParser::OperandType> steps;
516 if (parser.parseKeyword("step") ||
517 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
518 parser.resolveOperands(steps, loopVarType, result.operands))
519 return failure();
520
521 SmallVector<OpAsmParser::OperandType> privates;
522 SmallVector<Type> privateTypes;
523 SmallVector<OpAsmParser::OperandType> firstprivates;
524 SmallVector<Type> firstprivateTypes;
525 SmallVector<OpAsmParser::OperandType> lastprivates;
526 SmallVector<Type> lastprivateTypes;
527 SmallVector<OpAsmParser::OperandType> linears;
528 SmallVector<Type> linearTypes;
529 SmallVector<OpAsmParser::OperandType> linearSteps;
530 SmallVector<SymbolRefAttr> reductionSymbols;
531 SmallVector<OpAsmParser::OperandType> reductionVars;
532 SmallVector<Type> reductionVarTypes;
533 SmallString<8> schedule;
534 Optional<OpAsmParser::OperandType> scheduleChunkSize;
535
536 const StringRef opName = result.name.getStringRef();
537 StringRef keyword;
538
539 enum SegmentPos {
540 lbPos = 0,
541 ubPos,
542 stepPos,
543 privateClausePos,
544 firstprivateClausePos,
545 lastprivateClausePos,
546 linearClausePos,
547 linearStepPos,
548 reductionVarPos,
549 scheduleClausePos,
550 };
551 std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0};
552
553 while (succeeded(parser.parseOptionalKeyword(&keyword))) {
554 if (keyword == "private") {
555 if (segments[privateClausePos])
556 return allowedOnce(parser, "private", opName);
557 if (parseOperandAndTypeList(parser, privates, privateTypes))
558 return failure();
559 segments[privateClausePos] = privates.size();
560 } else if (keyword == "firstprivate") {
561 // fail if there was already another firstprivate clause
562 if (segments[firstprivateClausePos])
563 return allowedOnce(parser, "firstprivate", opName);
564 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
565 return failure();
566 segments[firstprivateClausePos] = firstprivates.size();
567 } else if (keyword == "lastprivate") {
568 // fail if there was already another shared clause
569 if (segments[lastprivateClausePos])
570 return allowedOnce(parser, "lastprivate", opName);
571 if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
572 return failure();
573 segments[lastprivateClausePos] = lastprivates.size();
574 } else if (keyword == "linear") {
575 // fail if there was already another linear clause
576 if (segments[linearClausePos])
577 return allowedOnce(parser, "linear", opName);
578 if (parseLinearClause(parser, linears, linearTypes, linearSteps))
579 return failure();
580 segments[linearClausePos] = linears.size();
581 segments[linearStepPos] = linearSteps.size();
582 } else if (keyword == "schedule") {
583 if (!schedule.empty())
584 return allowedOnce(parser, "schedule", opName);
585 if (parseScheduleClause(parser, schedule, scheduleChunkSize))
586 return failure();
587 if (scheduleChunkSize) {
588 segments[scheduleClausePos] = 1;
589 }
590 } else if (keyword == "collapse") {
591 auto type = parser.getBuilder().getI64Type();
592 mlir::IntegerAttr attr;
593 if (parser.parseLParen() || parser.parseAttribute(attr, type) ||
594 parser.parseRParen())
595 return failure();
596 result.addAttribute("collapse_val", attr);
597 } else if (keyword == "nowait") {
598 auto attr = UnitAttr::get(parser.getContext());
599 result.addAttribute("nowait", attr);
600 } else if (keyword == "ordered") {
601 mlir::IntegerAttr attr;
602 if (succeeded(parser.parseOptionalLParen())) {
603 auto type = parser.getBuilder().getI64Type();
604 if (parser.parseAttribute(attr, type))
605 return failure();
606 if (parser.parseRParen())
607 return failure();
608 } else {
609 // Use 0 to represent no ordered parameter was specified
610 attr = parser.getBuilder().getI64IntegerAttr(0);
611 }
612 result.addAttribute("ordered_val", attr);
613 } else if (keyword == "order") {
614 StringRef order;
615 if (parser.parseLParen() || parser.parseKeyword(&order) ||
616 parser.parseRParen())
617 return failure();
618 auto attr = parser.getBuilder().getStringAttr(order);
619 result.addAttribute("order", attr);
620 } else if (keyword == "inclusive") {
621 auto attr = UnitAttr::get(parser.getContext());
622 result.addAttribute("inclusive", attr);
623 } else if (keyword == "reduction") {
624 if (segments[reductionVarPos])
625 return allowedOnce(parser, "reduction", opName);
626 if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars,
627 reductionVarTypes)))
628 return failure();
629 segments[reductionVarPos] = reductionVars.size();
630 }
631 }
632
633 if (segments[privateClausePos]) {
634 parser.resolveOperands(privates, privateTypes, privates[0].location,
635 result.operands);
636 }
637
638 if (segments[firstprivateClausePos]) {
639 parser.resolveOperands(firstprivates, firstprivateTypes,
640 firstprivates[0].location, result.operands);
641 }
642
643 if (segments[lastprivateClausePos]) {
644 parser.resolveOperands(lastprivates, lastprivateTypes,
645 lastprivates[0].location, result.operands);
646 }
647
648 if (segments[linearClausePos]) {
649 parser.resolveOperands(linears, linearTypes, linears[0].location,
650 result.operands);
651 auto linearStepType = parser.getBuilder().getI32Type();
652 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
653 parser.resolveOperands(linearSteps, linearStepTypes,
654 linearSteps[0].location, result.operands);
655 }
656
657 if (segments[reductionVarPos]) {
658 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
659 parser.getNameLoc(), result.operands))) {
660 return failure();
661 }
662 SmallVector<Attribute> reductions(reductionSymbols.begin(),
663 reductionSymbols.end());
664 result.addAttribute("reductions",
665 parser.getBuilder().getArrayAttr(reductions));
666 }
667
668 if (!schedule.empty()) {
669 schedule[0] = llvm::toUpper(schedule[0]);
670 auto attr = parser.getBuilder().getStringAttr(schedule);
671 result.addAttribute("schedule_val", attr);
672 if (scheduleChunkSize) {
673 auto chunkSizeType = parser.getBuilder().getI32Type();
674 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
675 }
676 }
677
678 result.addAttribute("operand_segment_sizes",
679 parser.getBuilder().getI32VectorAttr(segments));
680
681 // Now parse the body.
682 Region *body = result.addRegion();
683 SmallVector<Type> ivTypes(numIVs, loopVarType);
684 SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
685 if (parser.parseRegion(*body, blockArgs, ivTypes))
686 return failure();
687 return success();
688 }
689
printWsLoopOp(OpAsmPrinter & p,WsLoopOp op)690 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
691 auto args = op.getRegion().front().getArguments();
692 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
693 << ") to (" << op.upperBound() << ") step (" << op.step() << ")";
694
695 // Print private, firstprivate, shared and copyin parameters
696 auto printDataVars = [&p](StringRef name, OperandRange vars) {
697 if (vars.empty())
698 return;
699
700 p << " " << name << "(";
701 llvm::interleaveComma(
702 vars, p, [&](const Value &v) { p << v << " : " << v.getType(); });
703 p << ")";
704 };
705 printDataVars("private", op.private_vars());
706 printDataVars("firstprivate", op.firstprivate_vars());
707 printDataVars("lastprivate", op.lastprivate_vars());
708
709 auto linearVars = op.linear_vars();
710 auto linearVarsSize = linearVars.size();
711 if (linearVarsSize) {
712 p << " "
713 << "linear"
714 << "(";
715 for (unsigned i = 0; i < linearVarsSize; ++i) {
716 std::string separator = i == linearVarsSize - 1 ? ")" : ", ";
717 p << linearVars[i];
718 if (op.linear_step_vars().size() > i)
719 p << " = " << op.linear_step_vars()[i];
720 p << " : " << linearVars[i].getType() << separator;
721 }
722 }
723
724 if (auto sched = op.schedule_val()) {
725 auto schedLower = sched->lower();
726 p << " schedule(" << schedLower;
727 if (auto chunk = op.schedule_chunk_var()) {
728 p << " = " << chunk;
729 }
730 p << ")";
731 }
732
733 if (auto collapse = op.collapse_val())
734 p << " collapse(" << collapse << ")";
735
736 if (op.nowait())
737 p << " nowait";
738
739 if (auto ordered = op.ordered_val()) {
740 p << " ordered(" << ordered << ")";
741 }
742
743 if (!op.reduction_vars().empty()) {
744 p << " reduction(";
745 for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) {
746 if (i != 0)
747 p << ", ";
748 p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : "
749 << op.reduction_vars()[i].getType();
750 }
751 p << ")";
752 }
753
754 if (op.inclusive()) {
755 p << " inclusive";
756 }
757
758 p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
759 }
760
761 //===----------------------------------------------------------------------===//
762 // ReductionOp
763 //===----------------------------------------------------------------------===//
764
parseAtomicReductionRegion(OpAsmParser & parser,Region & region)765 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
766 Region ®ion) {
767 if (parser.parseOptionalKeyword("atomic"))
768 return success();
769 return parser.parseRegion(region);
770 }
771
printAtomicReductionRegion(OpAsmPrinter & printer,ReductionDeclareOp op,Region & region)772 static void printAtomicReductionRegion(OpAsmPrinter &printer,
773 ReductionDeclareOp op, Region ®ion) {
774 if (region.empty())
775 return;
776 printer << "atomic ";
777 printer.printRegion(region);
778 }
779
verifyReductionDeclareOp(ReductionDeclareOp op)780 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
781 if (op.initializerRegion().empty())
782 return op.emitOpError() << "expects non-empty initializer region";
783 Block &initializerEntryBlock = op.initializerRegion().front();
784 if (initializerEntryBlock.getNumArguments() != 1 ||
785 initializerEntryBlock.getArgument(0).getType() != op.type()) {
786 return op.emitOpError() << "expects initializer region with one argument "
787 "of the reduction type";
788 }
789
790 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
791 if (yieldOp.results().size() != 1 ||
792 yieldOp.results().getTypes()[0] != op.type())
793 return op.emitOpError() << "expects initializer region to yield a value "
794 "of the reduction type";
795 }
796
797 if (op.reductionRegion().empty())
798 return op.emitOpError() << "expects non-empty reduction region";
799 Block &reductionEntryBlock = op.reductionRegion().front();
800 if (reductionEntryBlock.getNumArguments() != 2 ||
801 reductionEntryBlock.getArgumentTypes()[0] !=
802 reductionEntryBlock.getArgumentTypes()[1] ||
803 reductionEntryBlock.getArgumentTypes()[0] != op.type())
804 return op.emitOpError() << "expects reduction region with two arguments of "
805 "the reduction type";
806 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
807 if (yieldOp.results().size() != 1 ||
808 yieldOp.results().getTypes()[0] != op.type())
809 return op.emitOpError() << "expects reduction region to yield a value "
810 "of the reduction type";
811 }
812
813 if (op.atomicReductionRegion().empty())
814 return success();
815
816 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
817 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
818 atomicReductionEntryBlock.getArgumentTypes()[0] !=
819 atomicReductionEntryBlock.getArgumentTypes()[1])
820 return op.emitOpError() << "expects atomic reduction region with two "
821 "arguments of the same type";
822 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
823 .dyn_cast<PointerLikeType>();
824 if (!ptrType || ptrType.getElementType() != op.type())
825 return op.emitOpError() << "expects atomic reduction region arguments to "
826 "be accumulators containing the reduction type";
827 return success();
828 }
829
verifyReductionOp(ReductionOp op)830 static LogicalResult verifyReductionOp(ReductionOp op) {
831 // TODO: generalize this to an op interface when there is more than one op
832 // that supports reductions.
833 auto container = op->getParentOfType<WsLoopOp>();
834 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
835 if (container.reduction_vars()[i] == op.accumulator())
836 return success();
837
838 return op.emitOpError() << "the accumulator is not used by the parent";
839 }
840
841 //===----------------------------------------------------------------------===//
842 // WsLoopOp
843 //===----------------------------------------------------------------------===//
844
build(OpBuilder & builder,OperationState & state,ValueRange lowerBound,ValueRange upperBound,ValueRange step,ArrayRef<NamedAttribute> attributes)845 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
846 ValueRange lowerBound, ValueRange upperBound,
847 ValueRange step, ArrayRef<NamedAttribute> attributes) {
848 build(builder, state, TypeRange(), lowerBound, upperBound, step,
849 /*private_vars=*/ValueRange(),
850 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
851 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
852 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
853 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
854 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
855 /*inclusive=*/nullptr, /*buildBody=*/false);
856 state.addAttributes(attributes);
857 }
858
build(OpBuilder &,OperationState & state,TypeRange resultTypes,ValueRange operands,ArrayRef<NamedAttribute> attributes)859 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
860 ValueRange operands, ArrayRef<NamedAttribute> attributes) {
861 state.addOperands(operands);
862 state.addAttributes(attributes);
863 (void)state.addRegion();
864 assert(resultTypes.empty() && "mismatched number of return types");
865 state.addTypes(resultTypes);
866 }
867
build(OpBuilder & builder,OperationState & result,TypeRange typeRange,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange privateVars,ValueRange firstprivateVars,ValueRange lastprivateVars,ValueRange linearVars,ValueRange linearStepVars,ValueRange reductionVars,StringAttr scheduleVal,Value scheduleChunkVar,IntegerAttr collapseVal,UnitAttr nowait,IntegerAttr orderedVal,StringAttr orderVal,UnitAttr inclusive,bool buildBody)868 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
869 TypeRange typeRange, ValueRange lowerBounds,
870 ValueRange upperBounds, ValueRange steps,
871 ValueRange privateVars, ValueRange firstprivateVars,
872 ValueRange lastprivateVars, ValueRange linearVars,
873 ValueRange linearStepVars, ValueRange reductionVars,
874 StringAttr scheduleVal, Value scheduleChunkVar,
875 IntegerAttr collapseVal, UnitAttr nowait,
876 IntegerAttr orderedVal, StringAttr orderVal,
877 UnitAttr inclusive, bool buildBody) {
878 result.addOperands(lowerBounds);
879 result.addOperands(upperBounds);
880 result.addOperands(steps);
881 result.addOperands(privateVars);
882 result.addOperands(firstprivateVars);
883 result.addOperands(linearVars);
884 result.addOperands(linearStepVars);
885 if (scheduleChunkVar)
886 result.addOperands(scheduleChunkVar);
887
888 if (scheduleVal)
889 result.addAttribute("schedule_val", scheduleVal);
890 if (collapseVal)
891 result.addAttribute("collapse_val", collapseVal);
892 if (nowait)
893 result.addAttribute("nowait", nowait);
894 if (orderedVal)
895 result.addAttribute("ordered_val", orderedVal);
896 if (orderVal)
897 result.addAttribute("order", orderVal);
898 if (inclusive)
899 result.addAttribute("inclusive", inclusive);
900 result.addAttribute(
901 WsLoopOp::getOperandSegmentSizeAttr(),
902 builder.getI32VectorAttr(
903 {static_cast<int32_t>(lowerBounds.size()),
904 static_cast<int32_t>(upperBounds.size()),
905 static_cast<int32_t>(steps.size()),
906 static_cast<int32_t>(privateVars.size()),
907 static_cast<int32_t>(firstprivateVars.size()),
908 static_cast<int32_t>(lastprivateVars.size()),
909 static_cast<int32_t>(linearVars.size()),
910 static_cast<int32_t>(linearStepVars.size()),
911 static_cast<int32_t>(reductionVars.size()),
912 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
913
914 Region *bodyRegion = result.addRegion();
915 if (buildBody) {
916 OpBuilder::InsertionGuard guard(builder);
917 unsigned numIVs = steps.size();
918 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
919 builder.createBlock(bodyRegion, {}, argTypes);
920 }
921 }
922
verifyWsLoopOp(WsLoopOp op)923 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
924 if (op.getNumReductionVars() != 0) {
925 if (!op.reductions() ||
926 op.reductions()->size() != op.getNumReductionVars()) {
927 return op.emitOpError() << "expected as many reduction symbol references "
928 "as reduction variables";
929 }
930 } else {
931 if (op.reductions())
932 return op.emitOpError() << "unexpected reduction symbol references";
933 return success();
934 }
935
936 DenseSet<Value> accumulators;
937 for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) {
938 Value accum = std::get<0>(args);
939 if (!accumulators.insert(accum).second) {
940 return op.emitOpError() << "accumulator variable used more than once";
941 }
942 Type varType = accum.getType().cast<PointerLikeType>();
943 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
944 auto decl =
945 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
946 if (!decl) {
947 return op.emitOpError() << "expected symbol reference " << symbolRef
948 << " to point to a reduction declaration";
949 }
950
951 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) {
952 return op.emitOpError()
953 << "expected accumulator (" << varType
954 << ") to be the same type as reduction declaration ("
955 << decl.getAccumulatorType() << ")";
956 }
957 }
958
959 return success();
960 }
961
962 //===----------------------------------------------------------------------===//
963 // Parser, printer and verifier for Synchronization Hint (2.17.12)
964 //===----------------------------------------------------------------------===//
965
966 /// Parses a Synchronization Hint clause. The value of hint is an integer
967 /// which is a combination of different hints from `omp_sync_hint_t`.
968 ///
969 /// hint-clause = `hint` `(` hint-value `)`
parseSynchronizationHint(OpAsmParser & parser,IntegerAttr & hintAttr)970 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
971 IntegerAttr &hintAttr) {
972 if (failed(parser.parseOptionalKeyword("hint"))) {
973 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
974 return success();
975 }
976
977 if (failed(parser.parseLParen()))
978 return failure();
979 StringRef hintKeyword;
980 int64_t hint = 0;
981 do {
982 if (failed(parser.parseKeyword(&hintKeyword)))
983 return failure();
984 if (hintKeyword == "uncontended")
985 hint |= 1;
986 else if (hintKeyword == "contended")
987 hint |= 2;
988 else if (hintKeyword == "nonspeculative")
989 hint |= 4;
990 else if (hintKeyword == "speculative")
991 hint |= 8;
992 else
993 return parser.emitError(parser.getCurrentLocation())
994 << hintKeyword << " is not a valid hint";
995 } while (succeeded(parser.parseOptionalComma()));
996 if (failed(parser.parseRParen()))
997 return failure();
998 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
999 return success();
1000 }
1001
1002 /// Prints a Synchronization Hint clause
printSynchronizationHint(OpAsmPrinter & p,Operation * op,IntegerAttr hintAttr)1003 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
1004 IntegerAttr hintAttr) {
1005 int64_t hint = hintAttr.getInt();
1006
1007 if (hint == 0)
1008 return;
1009
1010 // Helper function to get n-th bit from the right end of `value`
1011 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1012
1013 bool uncontended = bitn(hint, 0);
1014 bool contended = bitn(hint, 1);
1015 bool nonspeculative = bitn(hint, 2);
1016 bool speculative = bitn(hint, 3);
1017
1018 SmallVector<StringRef> hints;
1019 if (uncontended)
1020 hints.push_back("uncontended");
1021 if (contended)
1022 hints.push_back("contended");
1023 if (nonspeculative)
1024 hints.push_back("nonspeculative");
1025 if (speculative)
1026 hints.push_back("speculative");
1027
1028 p << "hint(";
1029 llvm::interleaveComma(hints, p);
1030 p << ")";
1031 }
1032
1033 /// Verifies a synchronization hint clause
verifySynchronizationHint(Operation * op,int32_t hint)1034 static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) {
1035
1036 // Helper function to get n-th bit from the right end of `value`
1037 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1038
1039 bool uncontended = bitn(hint, 0);
1040 bool contended = bitn(hint, 1);
1041 bool nonspeculative = bitn(hint, 2);
1042 bool speculative = bitn(hint, 3);
1043
1044 if (uncontended && contended)
1045 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1046 "omp_sync_hint_contended cannot be combined";
1047 if (nonspeculative && speculative)
1048 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1049 "omp_sync_hint_speculative cannot be combined.";
1050 return success();
1051 }
1052
1053 //===----------------------------------------------------------------------===//
1054 // Verifier for critical construct (2.17.1)
1055 //===----------------------------------------------------------------------===//
1056
verifyCriticalOp(CriticalOp op)1057 static LogicalResult verifyCriticalOp(CriticalOp op) {
1058
1059 if (failed(verifySynchronizationHint(op, op.hint()))) {
1060 return failure();
1061 }
1062 if (!op.name().hasValue() && (op.hint() != 0))
1063 return op.emitOpError() << "must specify a name unless the effect is as if "
1064 "no hint is specified";
1065
1066 if (op.nameAttr()) {
1067 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1068 auto decl =
1069 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1070 if (!decl) {
1071 return op.emitOpError() << "expected symbol reference " << symbolRef
1072 << " to point to a critical declaration";
1073 }
1074 }
1075
1076 return success();
1077 }
1078
1079 #define GET_OP_CLASSES
1080 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1081