1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include <cstddef>
25
26 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
27 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
29
30 using namespace mlir;
31 using namespace mlir::omp;
32
33 namespace {
34 /// Model for pointer-like types that already provide a `getElementType` method.
35 template <typename T>
36 struct PointerLikeModel
37 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
getElementType__anon47251b610111::PointerLikeModel38 Type getElementType(Type pointer) const {
39 return pointer.cast<T>().getElementType();
40 }
41 };
42 } // end namespace
43
initialize()44 void OpenMPDialect::initialize() {
45 addOperations<
46 #define GET_OP_LIST
47 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
48 >();
49
50 LLVM::LLVMPointerType::attachInterface<
51 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
52 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
53 }
54
55 //===----------------------------------------------------------------------===//
56 // ParallelOp
57 //===----------------------------------------------------------------------===//
58
build(OpBuilder & builder,OperationState & state,ArrayRef<NamedAttribute> attributes)59 void ParallelOp::build(OpBuilder &builder, OperationState &state,
60 ArrayRef<NamedAttribute> attributes) {
61 ParallelOp::build(
62 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
63 /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
64 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
65 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
66 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
67 state.addAttributes(attributes);
68 }
69
70 /// Parse a list of operands with types.
71 ///
72 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
73 /// ssa-id-and-type-list ::= ssa-id-and-type |
74 /// ssa-id-and-type `,` ssa-id-and-type-list
75 /// ssa-id-and-type ::= ssa-id `:` type
76 static ParseResult
parseOperandAndTypeList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)77 parseOperandAndTypeList(OpAsmParser &parser,
78 SmallVectorImpl<OpAsmParser::OperandType> &operands,
79 SmallVectorImpl<Type> &types) {
80 if (parser.parseLParen())
81 return failure();
82
83 do {
84 OpAsmParser::OperandType operand;
85 Type type;
86 if (parser.parseOperand(operand) || parser.parseColonType(type))
87 return failure();
88 operands.push_back(operand);
89 types.push_back(type);
90 } while (succeeded(parser.parseOptionalComma()));
91
92 if (parser.parseRParen())
93 return failure();
94
95 return success();
96 }
97
98 /// Parse an allocate clause with allocators and a list of operands with types.
99 ///
100 /// operand-and-type-list ::= `(` allocate-operand-list `)`
101 /// allocate-operand-list :: = allocate-operand |
102 /// allocator-operand `,` allocate-operand-list
103 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
104 /// ssa-id-and-type ::= ssa-id `:` type
parseAllocateAndAllocator(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocate,SmallVectorImpl<Type> & typesAllocate,SmallVectorImpl<OpAsmParser::OperandType> & operandsAllocator,SmallVectorImpl<Type> & typesAllocator)105 static ParseResult parseAllocateAndAllocator(
106 OpAsmParser &parser,
107 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
108 SmallVectorImpl<Type> &typesAllocate,
109 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
110 SmallVectorImpl<Type> &typesAllocator) {
111 if (parser.parseLParen())
112 return failure();
113
114 do {
115 OpAsmParser::OperandType operand;
116 Type type;
117
118 if (parser.parseOperand(operand) || parser.parseColonType(type))
119 return failure();
120 operandsAllocator.push_back(operand);
121 typesAllocator.push_back(type);
122 if (parser.parseArrow())
123 return failure();
124 if (parser.parseOperand(operand) || parser.parseColonType(type))
125 return failure();
126
127 operandsAllocate.push_back(operand);
128 typesAllocate.push_back(type);
129 } while (succeeded(parser.parseOptionalComma()));
130
131 if (parser.parseRParen())
132 return failure();
133
134 return success();
135 }
136
verifyParallelOp(ParallelOp op)137 static LogicalResult verifyParallelOp(ParallelOp op) {
138 if (op.allocate_vars().size() != op.allocators_vars().size())
139 return op.emitError(
140 "expected equal sizes for allocate and allocator variables");
141 return success();
142 }
143
printParallelOp(OpAsmPrinter & p,ParallelOp op)144 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
145 p << "omp.parallel";
146
147 if (auto ifCond = op.if_expr_var())
148 p << " if(" << ifCond << " : " << ifCond.getType() << ")";
149
150 if (auto threads = op.num_threads_var())
151 p << " num_threads(" << threads << " : " << threads.getType() << ")";
152
153 // Print private, firstprivate, shared and copyin parameters
154 auto printDataVars = [&p](StringRef name, OperandRange vars) {
155 if (vars.size()) {
156 p << " " << name << "(";
157 for (unsigned i = 0; i < vars.size(); ++i) {
158 std::string separator = i == vars.size() - 1 ? ")" : ", ";
159 p << vars[i] << " : " << vars[i].getType() << separator;
160 }
161 }
162 };
163
164 // Print allocator and allocate parameters
165 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
166 OperandRange varsAllocator) {
167 if (varsAllocate.empty())
168 return;
169
170 p << " allocate(";
171 for (unsigned i = 0; i < varsAllocate.size(); ++i) {
172 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
173 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
174 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
175 }
176 };
177
178 printDataVars("private", op.private_vars());
179 printDataVars("firstprivate", op.firstprivate_vars());
180 printDataVars("shared", op.shared_vars());
181 printDataVars("copyin", op.copyin_vars());
182 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
183
184 if (auto def = op.default_val())
185 p << " default(" << def->drop_front(3) << ")";
186
187 if (auto bind = op.proc_bind_val())
188 p << " proc_bind(" << bind << ")";
189
190 p.printRegion(op.getRegion());
191 }
192
193 /// Emit an error if the same clause is present more than once on an operation.
allowedOnce(OpAsmParser & parser,StringRef clause,StringRef operation)194 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause,
195 StringRef operation) {
196 return parser.emitError(parser.getNameLoc())
197 << " at most one " << clause << " clause can appear on the "
198 << operation << " operation";
199 }
200
201 /// Parses a parallel operation.
202 ///
203 /// operation ::= `omp.parallel` clause-list
204 /// clause-list ::= clause | clause clause-list
205 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
206 /// default | procBind
207 /// if ::= `if` `(` ssa-id `)`
208 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
209 /// private ::= `private` operand-and-type-list
210 /// firstprivate ::= `firstprivate` operand-and-type-list
211 /// shared ::= `shared` operand-and-type-list
212 /// copyin ::= `copyin` operand-and-type-list
213 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
214 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
215 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
216 ///
217 /// Note that each clause can only appear once in the clase-list.
parseParallelOp(OpAsmParser & parser,OperationState & result)218 static ParseResult parseParallelOp(OpAsmParser &parser,
219 OperationState &result) {
220 std::pair<OpAsmParser::OperandType, Type> ifCond;
221 std::pair<OpAsmParser::OperandType, Type> numThreads;
222 SmallVector<OpAsmParser::OperandType, 4> privates;
223 SmallVector<Type, 4> privateTypes;
224 SmallVector<OpAsmParser::OperandType, 4> firstprivates;
225 SmallVector<Type, 4> firstprivateTypes;
226 SmallVector<OpAsmParser::OperandType, 4> shareds;
227 SmallVector<Type, 4> sharedTypes;
228 SmallVector<OpAsmParser::OperandType, 4> copyins;
229 SmallVector<Type, 4> copyinTypes;
230 SmallVector<OpAsmParser::OperandType, 4> allocates;
231 SmallVector<Type, 4> allocateTypes;
232 SmallVector<OpAsmParser::OperandType, 4> allocators;
233 SmallVector<Type, 4> allocatorTypes;
234 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
235 StringRef keyword;
236 bool defaultVal = false;
237 bool procBind = false;
238
239 const int ifClausePos = 0;
240 const int numThreadsClausePos = 1;
241 const int privateClausePos = 2;
242 const int firstprivateClausePos = 3;
243 const int sharedClausePos = 4;
244 const int copyinClausePos = 5;
245 const int allocateClausePos = 6;
246 const int allocatorPos = 7;
247 const StringRef opName = result.name.getStringRef();
248
249 while (succeeded(parser.parseOptionalKeyword(&keyword))) {
250 if (keyword == "if") {
251 // Fail if there was already another if condition.
252 if (segments[ifClausePos])
253 return allowedOnce(parser, "if", opName);
254 if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
255 parser.parseColonType(ifCond.second) || parser.parseRParen())
256 return failure();
257 segments[ifClausePos] = 1;
258 } else if (keyword == "num_threads") {
259 // Fail if there was already another num_threads clause.
260 if (segments[numThreadsClausePos])
261 return allowedOnce(parser, "num_threads", opName);
262 if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
263 parser.parseColonType(numThreads.second) || parser.parseRParen())
264 return failure();
265 segments[numThreadsClausePos] = 1;
266 } else if (keyword == "private") {
267 // Fail if there was already another private clause.
268 if (segments[privateClausePos])
269 return allowedOnce(parser, "private", opName);
270 if (parseOperandAndTypeList(parser, privates, privateTypes))
271 return failure();
272 segments[privateClausePos] = privates.size();
273 } else if (keyword == "firstprivate") {
274 // Fail if there was already another firstprivate clause.
275 if (segments[firstprivateClausePos])
276 return allowedOnce(parser, "firstprivate", opName);
277 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
278 return failure();
279 segments[firstprivateClausePos] = firstprivates.size();
280 } else if (keyword == "shared") {
281 // Fail if there was already another shared clause.
282 if (segments[sharedClausePos])
283 return allowedOnce(parser, "shared", opName);
284 if (parseOperandAndTypeList(parser, shareds, sharedTypes))
285 return failure();
286 segments[sharedClausePos] = shareds.size();
287 } else if (keyword == "copyin") {
288 // Fail if there was already another copyin clause.
289 if (segments[copyinClausePos])
290 return allowedOnce(parser, "copyin", opName);
291 if (parseOperandAndTypeList(parser, copyins, copyinTypes))
292 return failure();
293 segments[copyinClausePos] = copyins.size();
294 } else if (keyword == "allocate") {
295 // Fail if there was already another allocate clause.
296 if (segments[allocateClausePos])
297 return allowedOnce(parser, "allocate", opName);
298 if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
299 allocators, allocatorTypes))
300 return failure();
301 segments[allocateClausePos] = allocates.size();
302 segments[allocatorPos] = allocators.size();
303 } else if (keyword == "default") {
304 // Fail if there was already another default clause.
305 if (defaultVal)
306 return allowedOnce(parser, "default", opName);
307 defaultVal = true;
308 StringRef defval;
309 if (parser.parseLParen() || parser.parseKeyword(&defval) ||
310 parser.parseRParen())
311 return failure();
312 // The def prefix is required for the attribute as "private" is a keyword
313 // in C++.
314 auto attr = parser.getBuilder().getStringAttr("def" + defval);
315 result.addAttribute("default_val", attr);
316 } else if (keyword == "proc_bind") {
317 // Fail if there was already another proc_bind clause.
318 if (procBind)
319 return allowedOnce(parser, "proc_bind", opName);
320 procBind = true;
321 StringRef bind;
322 if (parser.parseLParen() || parser.parseKeyword(&bind) ||
323 parser.parseRParen())
324 return failure();
325 auto attr = parser.getBuilder().getStringAttr(bind);
326 result.addAttribute("proc_bind_val", attr);
327 } else {
328 return parser.emitError(parser.getNameLoc())
329 << keyword << " is not a valid clause for the " << opName
330 << " operation";
331 }
332 }
333
334 // Add if parameter.
335 if (segments[ifClausePos] &&
336 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
337 return failure();
338
339 // Add num_threads parameter.
340 if (segments[numThreadsClausePos] &&
341 parser.resolveOperand(numThreads.first, numThreads.second,
342 result.operands))
343 return failure();
344
345 // Add private parameters.
346 if (segments[privateClausePos] &&
347 parser.resolveOperands(privates, privateTypes, privates[0].location,
348 result.operands))
349 return failure();
350
351 // Add firstprivate parameters.
352 if (segments[firstprivateClausePos] &&
353 parser.resolveOperands(firstprivates, firstprivateTypes,
354 firstprivates[0].location, result.operands))
355 return failure();
356
357 // Add shared parameters.
358 if (segments[sharedClausePos] &&
359 parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
360 result.operands))
361 return failure();
362
363 // Add copyin parameters.
364 if (segments[copyinClausePos] &&
365 parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
366 result.operands))
367 return failure();
368
369 // Add allocate parameters.
370 if (segments[allocateClausePos] &&
371 parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
372 result.operands))
373 return failure();
374
375 // Add allocator parameters.
376 if (segments[allocatorPos] &&
377 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
378 result.operands))
379 return failure();
380
381 result.addAttribute("operand_segment_sizes",
382 parser.getBuilder().getI32VectorAttr(segments));
383
384 Region *body = result.addRegion();
385 SmallVector<OpAsmParser::OperandType, 4> regionArgs;
386 SmallVector<Type, 4> regionArgTypes;
387 if (parser.parseRegion(*body, regionArgs, regionArgTypes))
388 return failure();
389 return success();
390 }
391
392 /// linear ::= `linear` `(` linear-list `)`
393 /// linear-list := linear-val | linear-val linear-list
394 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
395 static ParseResult
parseLinearClause(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::OperandType> & vars,SmallVectorImpl<Type> & types,SmallVectorImpl<OpAsmParser::OperandType> & stepVars)396 parseLinearClause(OpAsmParser &parser,
397 SmallVectorImpl<OpAsmParser::OperandType> &vars,
398 SmallVectorImpl<Type> &types,
399 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
400 if (parser.parseLParen())
401 return failure();
402
403 do {
404 OpAsmParser::OperandType var;
405 Type type;
406 OpAsmParser::OperandType stepVar;
407 if (parser.parseOperand(var) || parser.parseEqual() ||
408 parser.parseOperand(stepVar) || parser.parseColonType(type))
409 return failure();
410
411 vars.push_back(var);
412 types.push_back(type);
413 stepVars.push_back(stepVar);
414 } while (succeeded(parser.parseOptionalComma()));
415
416 if (parser.parseRParen())
417 return failure();
418
419 return success();
420 }
421
422 /// schedule ::= `schedule` `(` sched-list `)`
423 /// sched-list ::= sched-val | sched-val sched-list
424 /// sched-val ::= sched-with-chunk | sched-wo-chunk
425 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
426 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
427 /// sched-wo-chunk ::= `auto` | `runtime`
428 static ParseResult
parseScheduleClause(OpAsmParser & parser,SmallString<8> & schedule,Optional<OpAsmParser::OperandType> & chunkSize)429 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
430 Optional<OpAsmParser::OperandType> &chunkSize) {
431 if (parser.parseLParen())
432 return failure();
433
434 StringRef keyword;
435 if (parser.parseKeyword(&keyword))
436 return failure();
437
438 schedule = keyword;
439 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
440 if (succeeded(parser.parseOptionalEqual())) {
441 chunkSize = OpAsmParser::OperandType{};
442 if (parser.parseOperand(*chunkSize))
443 return failure();
444 } else {
445 chunkSize = llvm::NoneType::None;
446 }
447 } else if (keyword == "auto" || keyword == "runtime") {
448 chunkSize = llvm::NoneType::None;
449 } else {
450 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
451 }
452
453 if (parser.parseRParen())
454 return failure();
455
456 return success();
457 }
458
459 /// reduction-init ::= `reduction` `(` reduction-entry-list `)`
460 /// reduction-entry-list ::= reduction-entry
461 /// | reduction-entry-list `,` reduction-entry
462 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
463 static ParseResult
parseReductionVarList(OpAsmParser & parser,SmallVectorImpl<SymbolRefAttr> & symbols,SmallVectorImpl<OpAsmParser::OperandType> & operands,SmallVectorImpl<Type> & types)464 parseReductionVarList(OpAsmParser &parser,
465 SmallVectorImpl<SymbolRefAttr> &symbols,
466 SmallVectorImpl<OpAsmParser::OperandType> &operands,
467 SmallVectorImpl<Type> &types) {
468 if (failed(parser.parseLParen()))
469 return failure();
470
471 do {
472 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
473 parser.parseOperand(operands.emplace_back()) ||
474 parser.parseColonType(types.emplace_back()))
475 return failure();
476 } while (succeeded(parser.parseOptionalComma()));
477 return parser.parseRParen();
478 }
479
480 /// Parses an OpenMP Workshare Loop operation
481 ///
482 /// operation ::= `omp.wsloop` loop-control clause-list
483 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
484 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
485 /// steps := `step` `(`ssa-id-list`)`
486 /// clause-list ::= clause | empty | clause-list
487 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
488 // collapse | nowait | ordered | order | inclusive
489 /// private ::= `private` `(` ssa-id-and-type-list `)`
490 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)`
491 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)`
492 /// linear ::= `linear` `(` linear-list `)`
493 /// schedule ::= `schedule` `(` sched-list `)`
494 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
495 /// nowait ::= `nowait`
496 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
497 /// order ::= `order` `(` `concurrent` `)`
498 /// inclusive ::= `inclusive`
499 ///
parseWsLoopOp(OpAsmParser & parser,OperationState & result)500 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
501 Type loopVarType;
502 int numIVs;
503
504 // Parse an opening `(` followed by induction variables followed by `)`
505 SmallVector<OpAsmParser::OperandType> ivs;
506 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
507 OpAsmParser::Delimiter::Paren))
508 return failure();
509
510 numIVs = static_cast<int>(ivs.size());
511
512 if (parser.parseColonType(loopVarType))
513 return failure();
514
515 // Parse loop bounds.
516 SmallVector<OpAsmParser::OperandType> lower;
517 if (parser.parseEqual() ||
518 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
519 parser.resolveOperands(lower, loopVarType, result.operands))
520 return failure();
521
522 SmallVector<OpAsmParser::OperandType> upper;
523 if (parser.parseKeyword("to") ||
524 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
525 parser.resolveOperands(upper, loopVarType, result.operands))
526 return failure();
527
528 // Parse step values.
529 SmallVector<OpAsmParser::OperandType> steps;
530 if (parser.parseKeyword("step") ||
531 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
532 parser.resolveOperands(steps, loopVarType, result.operands))
533 return failure();
534
535 SmallVector<OpAsmParser::OperandType> privates;
536 SmallVector<Type> privateTypes;
537 SmallVector<OpAsmParser::OperandType> firstprivates;
538 SmallVector<Type> firstprivateTypes;
539 SmallVector<OpAsmParser::OperandType> lastprivates;
540 SmallVector<Type> lastprivateTypes;
541 SmallVector<OpAsmParser::OperandType> linears;
542 SmallVector<Type> linearTypes;
543 SmallVector<OpAsmParser::OperandType> linearSteps;
544 SmallVector<SymbolRefAttr> reductionSymbols;
545 SmallVector<OpAsmParser::OperandType> reductionVars;
546 SmallVector<Type> reductionVarTypes;
547 SmallString<8> schedule;
548 Optional<OpAsmParser::OperandType> scheduleChunkSize;
549
550 const StringRef opName = result.name.getStringRef();
551 StringRef keyword;
552
553 enum SegmentPos {
554 lbPos = 0,
555 ubPos,
556 stepPos,
557 privateClausePos,
558 firstprivateClausePos,
559 lastprivateClausePos,
560 linearClausePos,
561 linearStepPos,
562 reductionVarPos,
563 scheduleClausePos,
564 };
565 std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0};
566
567 while (succeeded(parser.parseOptionalKeyword(&keyword))) {
568 if (keyword == "private") {
569 if (segments[privateClausePos])
570 return allowedOnce(parser, "private", opName);
571 if (parseOperandAndTypeList(parser, privates, privateTypes))
572 return failure();
573 segments[privateClausePos] = privates.size();
574 } else if (keyword == "firstprivate") {
575 // fail if there was already another firstprivate clause
576 if (segments[firstprivateClausePos])
577 return allowedOnce(parser, "firstprivate", opName);
578 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
579 return failure();
580 segments[firstprivateClausePos] = firstprivates.size();
581 } else if (keyword == "lastprivate") {
582 // fail if there was already another shared clause
583 if (segments[lastprivateClausePos])
584 return allowedOnce(parser, "lastprivate", opName);
585 if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
586 return failure();
587 segments[lastprivateClausePos] = lastprivates.size();
588 } else if (keyword == "linear") {
589 // fail if there was already another linear clause
590 if (segments[linearClausePos])
591 return allowedOnce(parser, "linear", opName);
592 if (parseLinearClause(parser, linears, linearTypes, linearSteps))
593 return failure();
594 segments[linearClausePos] = linears.size();
595 segments[linearStepPos] = linearSteps.size();
596 } else if (keyword == "schedule") {
597 if (!schedule.empty())
598 return allowedOnce(parser, "schedule", opName);
599 if (parseScheduleClause(parser, schedule, scheduleChunkSize))
600 return failure();
601 if (scheduleChunkSize) {
602 segments[scheduleClausePos] = 1;
603 }
604 } else if (keyword == "collapse") {
605 auto type = parser.getBuilder().getI64Type();
606 mlir::IntegerAttr attr;
607 if (parser.parseLParen() || parser.parseAttribute(attr, type) ||
608 parser.parseRParen())
609 return failure();
610 result.addAttribute("collapse_val", attr);
611 } else if (keyword == "nowait") {
612 auto attr = UnitAttr::get(parser.getBuilder().getContext());
613 result.addAttribute("nowait", attr);
614 } else if (keyword == "ordered") {
615 mlir::IntegerAttr attr;
616 if (succeeded(parser.parseOptionalLParen())) {
617 auto type = parser.getBuilder().getI64Type();
618 if (parser.parseAttribute(attr, type))
619 return failure();
620 if (parser.parseRParen())
621 return failure();
622 } else {
623 // Use 0 to represent no ordered parameter was specified
624 attr = parser.getBuilder().getI64IntegerAttr(0);
625 }
626 result.addAttribute("ordered_val", attr);
627 } else if (keyword == "order") {
628 StringRef order;
629 if (parser.parseLParen() || parser.parseKeyword(&order) ||
630 parser.parseRParen())
631 return failure();
632 auto attr = parser.getBuilder().getStringAttr(order);
633 result.addAttribute("order", attr);
634 } else if (keyword == "inclusive") {
635 auto attr = UnitAttr::get(parser.getBuilder().getContext());
636 result.addAttribute("inclusive", attr);
637 } else if (keyword == "reduction") {
638 if (segments[reductionVarPos])
639 return allowedOnce(parser, "reduction", opName);
640 if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars,
641 reductionVarTypes)))
642 return failure();
643 segments[reductionVarPos] = reductionVars.size();
644 }
645 }
646
647 if (segments[privateClausePos]) {
648 parser.resolveOperands(privates, privateTypes, privates[0].location,
649 result.operands);
650 }
651
652 if (segments[firstprivateClausePos]) {
653 parser.resolveOperands(firstprivates, firstprivateTypes,
654 firstprivates[0].location, result.operands);
655 }
656
657 if (segments[lastprivateClausePos]) {
658 parser.resolveOperands(lastprivates, lastprivateTypes,
659 lastprivates[0].location, result.operands);
660 }
661
662 if (segments[linearClausePos]) {
663 parser.resolveOperands(linears, linearTypes, linears[0].location,
664 result.operands);
665 auto linearStepType = parser.getBuilder().getI32Type();
666 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
667 parser.resolveOperands(linearSteps, linearStepTypes,
668 linearSteps[0].location, result.operands);
669 }
670
671 if (segments[reductionVarPos]) {
672 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
673 parser.getNameLoc(), result.operands))) {
674 return failure();
675 }
676 SmallVector<Attribute> reductions(reductionSymbols.begin(),
677 reductionSymbols.end());
678 result.addAttribute("reductions",
679 parser.getBuilder().getArrayAttr(reductions));
680 }
681
682 if (!schedule.empty()) {
683 schedule[0] = llvm::toUpper(schedule[0]);
684 auto attr = parser.getBuilder().getStringAttr(schedule);
685 result.addAttribute("schedule_val", attr);
686 if (scheduleChunkSize) {
687 auto chunkSizeType = parser.getBuilder().getI32Type();
688 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
689 }
690 }
691
692 result.addAttribute("operand_segment_sizes",
693 parser.getBuilder().getI32VectorAttr(segments));
694
695 // Now parse the body.
696 Region *body = result.addRegion();
697 SmallVector<Type> ivTypes(numIVs, loopVarType);
698 SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
699 if (parser.parseRegion(*body, blockArgs, ivTypes))
700 return failure();
701 return success();
702 }
703
printWsLoopOp(OpAsmPrinter & p,WsLoopOp op)704 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
705 auto args = op.getRegion().front().getArguments();
706 p << op.getOperationName() << " (" << args << ") : " << args[0].getType()
707 << " = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step ("
708 << op.step() << ")";
709
710 // Print private, firstprivate, shared and copyin parameters
711 auto printDataVars = [&p](StringRef name, OperandRange vars) {
712 if (vars.empty())
713 return;
714
715 p << " " << name << "(";
716 llvm::interleaveComma(
717 vars, p, [&](const Value &v) { p << v << " : " << v.getType(); });
718 p << ")";
719 };
720 printDataVars("private", op.private_vars());
721 printDataVars("firstprivate", op.firstprivate_vars());
722 printDataVars("lastprivate", op.lastprivate_vars());
723
724 auto linearVars = op.linear_vars();
725 auto linearVarsSize = linearVars.size();
726 if (linearVarsSize) {
727 p << " "
728 << "linear"
729 << "(";
730 for (unsigned i = 0; i < linearVarsSize; ++i) {
731 std::string separator = i == linearVarsSize - 1 ? ")" : ", ";
732 p << linearVars[i];
733 if (op.linear_step_vars().size() > i)
734 p << " = " << op.linear_step_vars()[i];
735 p << " : " << linearVars[i].getType() << separator;
736 }
737 }
738
739 if (auto sched = op.schedule_val()) {
740 auto schedLower = sched->lower();
741 p << " schedule(" << schedLower;
742 if (auto chunk = op.schedule_chunk_var()) {
743 p << " = " << chunk;
744 }
745 p << ")";
746 }
747
748 if (auto collapse = op.collapse_val())
749 p << " collapse(" << collapse << ")";
750
751 if (op.nowait())
752 p << " nowait";
753
754 if (auto ordered = op.ordered_val()) {
755 p << " ordered(" << ordered << ")";
756 }
757
758 if (!op.reduction_vars().empty()) {
759 p << " reduction(";
760 for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) {
761 if (i != 0)
762 p << ", ";
763 p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : "
764 << op.reduction_vars()[i].getType();
765 }
766 p << ")";
767 }
768
769 if (op.inclusive()) {
770 p << " inclusive";
771 }
772
773 p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
774 }
775
776 //===----------------------------------------------------------------------===//
777 // ReductionOp
778 //===----------------------------------------------------------------------===//
779
parseAtomicReductionRegion(OpAsmParser & parser,Region & region)780 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
781 Region ®ion) {
782 if (parser.parseOptionalKeyword("atomic"))
783 return success();
784 return parser.parseRegion(region);
785 }
786
printAtomicReductionRegion(OpAsmPrinter & printer,ReductionDeclareOp op,Region & region)787 static void printAtomicReductionRegion(OpAsmPrinter &printer,
788 ReductionDeclareOp op, Region ®ion) {
789 if (region.empty())
790 return;
791 printer << "atomic ";
792 printer.printRegion(region);
793 }
794
verifyReductionDeclareOp(ReductionDeclareOp op)795 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
796 if (op.initializerRegion().empty())
797 return op.emitOpError() << "expects non-empty initializer region";
798 Block &initializerEntryBlock = op.initializerRegion().front();
799 if (initializerEntryBlock.getNumArguments() != 1 ||
800 initializerEntryBlock.getArgument(0).getType() != op.type()) {
801 return op.emitOpError() << "expects initializer region with one argument "
802 "of the reduction type";
803 }
804
805 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
806 if (yieldOp.results().size() != 1 ||
807 yieldOp.results().getTypes()[0] != op.type())
808 return op.emitOpError() << "expects initializer region to yield a value "
809 "of the reduction type";
810 }
811
812 if (op.reductionRegion().empty())
813 return op.emitOpError() << "expects non-empty reduction region";
814 Block &reductionEntryBlock = op.reductionRegion().front();
815 if (reductionEntryBlock.getNumArguments() != 2 ||
816 reductionEntryBlock.getArgumentTypes()[0] !=
817 reductionEntryBlock.getArgumentTypes()[1] ||
818 reductionEntryBlock.getArgumentTypes()[0] != op.type())
819 return op.emitOpError() << "expects reduction region with two arguments of "
820 "the reduction type";
821 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
822 if (yieldOp.results().size() != 1 ||
823 yieldOp.results().getTypes()[0] != op.type())
824 return op.emitOpError() << "expects reduction region to yield a value "
825 "of the reduction type";
826 }
827
828 if (op.atomicReductionRegion().empty())
829 return success();
830
831 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
832 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
833 atomicReductionEntryBlock.getArgumentTypes()[0] !=
834 atomicReductionEntryBlock.getArgumentTypes()[1])
835 return op.emitOpError() << "expects atomic reduction region with two "
836 "arguments of the same type";
837 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
838 .dyn_cast<PointerLikeType>();
839 if (!ptrType || ptrType.getElementType() != op.type())
840 return op.emitOpError() << "expects atomic reduction region arguments to "
841 "be accumulators containing the reduction type";
842 return success();
843 }
844
verifyReductionOp(ReductionOp op)845 static LogicalResult verifyReductionOp(ReductionOp op) {
846 // TODO: generalize this to an op interface when there is more than one op
847 // that supports reductions.
848 auto container = op->getParentOfType<WsLoopOp>();
849 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
850 if (container.reduction_vars()[i] == op.accumulator())
851 return success();
852
853 return op.emitOpError() << "the accumulator is not used by the parent";
854 }
855
856 //===----------------------------------------------------------------------===//
857 // WsLoopOp
858 //===----------------------------------------------------------------------===//
859
build(OpBuilder & builder,OperationState & state,ValueRange lowerBound,ValueRange upperBound,ValueRange step,ArrayRef<NamedAttribute> attributes)860 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
861 ValueRange lowerBound, ValueRange upperBound,
862 ValueRange step, ArrayRef<NamedAttribute> attributes) {
863 build(builder, state, TypeRange(), lowerBound, upperBound, step,
864 /*private_vars=*/ValueRange(),
865 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
866 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
867 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
868 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
869 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
870 /*inclusive=*/nullptr, /*buildBody=*/false);
871 state.addAttributes(attributes);
872 }
873
build(OpBuilder &,OperationState & state,TypeRange resultTypes,ValueRange operands,ArrayRef<NamedAttribute> attributes)874 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
875 ValueRange operands, ArrayRef<NamedAttribute> attributes) {
876 state.addOperands(operands);
877 state.addAttributes(attributes);
878 (void)state.addRegion();
879 assert(resultTypes.empty() && "mismatched number of return types");
880 state.addTypes(resultTypes);
881 }
882
build(OpBuilder & builder,OperationState & result,TypeRange typeRange,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange privateVars,ValueRange firstprivateVars,ValueRange lastprivateVars,ValueRange linearVars,ValueRange linearStepVars,ValueRange reductionVars,StringAttr scheduleVal,Value scheduleChunkVar,IntegerAttr collapseVal,UnitAttr nowait,IntegerAttr orderedVal,StringAttr orderVal,UnitAttr inclusive,bool buildBody)883 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
884 TypeRange typeRange, ValueRange lowerBounds,
885 ValueRange upperBounds, ValueRange steps,
886 ValueRange privateVars, ValueRange firstprivateVars,
887 ValueRange lastprivateVars, ValueRange linearVars,
888 ValueRange linearStepVars, ValueRange reductionVars,
889 StringAttr scheduleVal, Value scheduleChunkVar,
890 IntegerAttr collapseVal, UnitAttr nowait,
891 IntegerAttr orderedVal, StringAttr orderVal,
892 UnitAttr inclusive, bool buildBody) {
893 result.addOperands(lowerBounds);
894 result.addOperands(upperBounds);
895 result.addOperands(steps);
896 result.addOperands(privateVars);
897 result.addOperands(firstprivateVars);
898 result.addOperands(linearVars);
899 result.addOperands(linearStepVars);
900 if (scheduleChunkVar)
901 result.addOperands(scheduleChunkVar);
902
903 if (scheduleVal)
904 result.addAttribute("schedule_val", scheduleVal);
905 if (collapseVal)
906 result.addAttribute("collapse_val", collapseVal);
907 if (nowait)
908 result.addAttribute("nowait", nowait);
909 if (orderedVal)
910 result.addAttribute("ordered_val", orderedVal);
911 if (orderVal)
912 result.addAttribute("order", orderVal);
913 if (inclusive)
914 result.addAttribute("inclusive", inclusive);
915 result.addAttribute(
916 WsLoopOp::getOperandSegmentSizeAttr(),
917 builder.getI32VectorAttr(
918 {static_cast<int32_t>(lowerBounds.size()),
919 static_cast<int32_t>(upperBounds.size()),
920 static_cast<int32_t>(steps.size()),
921 static_cast<int32_t>(privateVars.size()),
922 static_cast<int32_t>(firstprivateVars.size()),
923 static_cast<int32_t>(lastprivateVars.size()),
924 static_cast<int32_t>(linearVars.size()),
925 static_cast<int32_t>(linearStepVars.size()),
926 static_cast<int32_t>(reductionVars.size()),
927 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
928
929 Region *bodyRegion = result.addRegion();
930 if (buildBody) {
931 OpBuilder::InsertionGuard guard(builder);
932 unsigned numIVs = steps.size();
933 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
934 builder.createBlock(bodyRegion, {}, argTypes);
935 }
936 }
937
verifyWsLoopOp(WsLoopOp op)938 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
939 if (op.getNumReductionVars() != 0) {
940 if (!op.reductions() ||
941 op.reductions()->size() != op.getNumReductionVars()) {
942 return op.emitOpError() << "expected as many reduction symbol references "
943 "as reduction variables";
944 }
945 } else {
946 if (op.reductions())
947 return op.emitOpError() << "unexpected reduction symbol references";
948 return success();
949 }
950
951 DenseSet<Value> accumulators;
952 for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) {
953 Value accum = std::get<0>(args);
954 if (!accumulators.insert(accum).second) {
955 return op.emitOpError() << "accumulator variable used more than once";
956 }
957 Type varType = accum.getType().cast<PointerLikeType>();
958 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
959 auto decl =
960 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
961 if (!decl) {
962 return op.emitOpError() << "expected symbol reference " << symbolRef
963 << " to point to a reduction declaration";
964 }
965
966 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) {
967 return op.emitOpError()
968 << "expected accumulator (" << varType
969 << ") to be the same type as reduction declaration ("
970 << decl.getAccumulatorType() << ")";
971 }
972 }
973
974 return success();
975 }
976
977 #define GET_OP_CLASSES
978 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
979