1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/EmitC/IR/EmitC.h"
10 #include "mlir/Dialect/SCF/SCF.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/IndentedOstream.h"
17 #include "mlir/Target/Cpp/CppEmitter.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24
25 #define DEBUG_TYPE "translate-to-cpp"
26
27 using namespace mlir;
28 using namespace mlir::emitc;
29 using llvm::formatv;
30
31 /// Convenience functions to produce interleaved output with functions returning
32 /// a LogicalResult. This is different than those in STLExtras as functions used
33 /// on each element doesn't return a string.
34 template <typename ForwardIterator, typename UnaryFunctor,
35 typename NullaryFunctor>
36 inline LogicalResult
interleaveWithError(ForwardIterator begin,ForwardIterator end,UnaryFunctor eachFn,NullaryFunctor betweenFn)37 interleaveWithError(ForwardIterator begin, ForwardIterator end,
38 UnaryFunctor eachFn, NullaryFunctor betweenFn) {
39 if (begin == end)
40 return success();
41 if (failed(eachFn(*begin)))
42 return failure();
43 ++begin;
44 for (; begin != end; ++begin) {
45 betweenFn();
46 if (failed(eachFn(*begin)))
47 return failure();
48 }
49 return success();
50 }
51
52 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
interleaveWithError(const Container & c,UnaryFunctor eachFn,NullaryFunctor betweenFn)53 inline LogicalResult interleaveWithError(const Container &c,
54 UnaryFunctor eachFn,
55 NullaryFunctor betweenFn) {
56 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
57 }
58
59 template <typename Container, typename UnaryFunctor>
interleaveCommaWithError(const Container & c,raw_ostream & os,UnaryFunctor eachFn)60 inline LogicalResult interleaveCommaWithError(const Container &c,
61 raw_ostream &os,
62 UnaryFunctor eachFn) {
63 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
64 }
65
66 namespace {
67 /// Emitter that uses dialect specific emitters to emit C++ code.
68 struct CppEmitter {
69 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
70
71 /// Emits attribute or returns failure.
72 LogicalResult emitAttribute(Location loc, Attribute attr);
73
74 /// Emits operation 'op' with/without training semicolon or returns failure.
75 LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
76
77 /// Emits type 'type' or returns failure.
78 LogicalResult emitType(Location loc, Type type);
79
80 /// Emits array of types as a std::tuple of the emitted types.
81 /// - emits void for an empty array;
82 /// - emits the type of the only element for arrays of size one;
83 /// - emits a std::tuple otherwise;
84 LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
85
86 /// Emits array of types as a std::tuple of the emitted types independently of
87 /// the array size.
88 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
89
90 /// Emits an assignment for a variable which has been declared previously.
91 LogicalResult emitVariableAssignment(OpResult result);
92
93 /// Emits a variable declaration for a result of an operation.
94 LogicalResult emitVariableDeclaration(OpResult result,
95 bool trailingSemicolon);
96
97 /// Emits the variable declaration and assignment prefix for 'op'.
98 /// - emits separate variable followed by std::tie for multi-valued operation;
99 /// - emits single type followed by variable for single result;
100 /// - emits nothing if no value produced by op;
101 /// Emits final '=' operator where a type is produced. Returns failure if
102 /// any result type could not be converted.
103 LogicalResult emitAssignPrefix(Operation &op);
104
105 /// Emits a label for the block.
106 LogicalResult emitLabel(Block &block);
107
108 /// Emits the operands and atttributes of the operation. All operands are
109 /// emitted first and then all attributes in alphabetical order.
110 LogicalResult emitOperandsAndAttributes(Operation &op,
111 ArrayRef<StringRef> exclude = {});
112
113 /// Emits the operands of the operation. All operands are emitted in order.
114 LogicalResult emitOperands(Operation &op);
115
116 /// Return the existing or a new name for a Value.
117 StringRef getOrCreateName(Value val);
118
119 /// Return the existing or a new label of a Block.
120 StringRef getOrCreateName(Block &block);
121
122 /// Whether to map an mlir integer to a unsigned integer in C++.
123 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
124
125 /// RAII helper function to manage entering/exiting C++ scopes.
126 struct Scope {
Scope__anon45c64d1e0211::CppEmitter::Scope127 Scope(CppEmitter &emitter)
128 : valueMapperScope(emitter.valueMapper),
129 blockMapperScope(emitter.blockMapper), emitter(emitter) {
130 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
131 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
132 }
~Scope__anon45c64d1e0211::CppEmitter::Scope133 ~Scope() {
134 emitter.valueInScopeCount.pop();
135 emitter.labelInScopeCount.pop();
136 }
137
138 private:
139 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
140 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
141 CppEmitter &emitter;
142 };
143
144 /// Returns wether the Value is assigned to a C++ variable in the scope.
145 bool hasValueInScope(Value val);
146
147 // Returns whether a label is assigned to the block.
148 bool hasBlockLabel(Block &block);
149
150 /// Returns the output stream.
ostream__anon45c64d1e0211::CppEmitter151 raw_indented_ostream &ostream() { return os; };
152
153 /// Returns if all variables for op results and basic block arguments need to
154 /// be declared at the beginning of a function.
shouldDeclareVariablesAtTop__anon45c64d1e0211::CppEmitter155 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
156
157 private:
158 using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
159 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
160
161 /// Output stream to emit to.
162 raw_indented_ostream os;
163
164 /// Boolean to enforce that all variables for op results and block
165 /// arguments are declared at the beginning of the function. This also
166 /// includes results from ops located in nested regions.
167 bool declareVariablesAtTop;
168
169 /// Map from value to name of C++ variable that contain the name.
170 ValueMapper valueMapper;
171
172 /// Map from block to name of C++ label.
173 BlockMapper blockMapper;
174
175 /// The number of values in the current scope. This is used to declare the
176 /// names of values in a scope.
177 std::stack<int64_t> valueInScopeCount;
178 std::stack<int64_t> labelInScopeCount;
179 };
180 } // namespace
181
printConstantOp(CppEmitter & emitter,Operation * operation,Attribute value)182 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
183 Attribute value) {
184 OpResult result = operation->getResult(0);
185
186 // Only emit an assignment as the variable was already declared when printing
187 // the FuncOp.
188 if (emitter.shouldDeclareVariablesAtTop()) {
189 // Skip the assignment if the emitc.constant has no value.
190 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
191 if (oAttr.getValue().empty())
192 return success();
193 }
194
195 if (failed(emitter.emitVariableAssignment(result)))
196 return failure();
197 return emitter.emitAttribute(operation->getLoc(), value);
198 }
199
200 // Emit a variable declaration for an emitc.constant op without value.
201 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
202 if (oAttr.getValue().empty())
203 // The semicolon gets printed by the emitOperation function.
204 return emitter.emitVariableDeclaration(result,
205 /*trailingSemicolon=*/false);
206 }
207
208 // Emit a variable declaration.
209 if (failed(emitter.emitAssignPrefix(*operation)))
210 return failure();
211 return emitter.emitAttribute(operation->getLoc(), value);
212 }
213
printOperation(CppEmitter & emitter,emitc::ConstantOp constantOp)214 static LogicalResult printOperation(CppEmitter &emitter,
215 emitc::ConstantOp constantOp) {
216 Operation *operation = constantOp.getOperation();
217 Attribute value = constantOp.value();
218
219 return printConstantOp(emitter, operation, value);
220 }
221
printOperation(CppEmitter & emitter,mlir::ConstantOp constantOp)222 static LogicalResult printOperation(CppEmitter &emitter,
223 mlir::ConstantOp constantOp) {
224 Operation *operation = constantOp.getOperation();
225 Attribute value = constantOp.value();
226
227 return printConstantOp(emitter, operation, value);
228 }
229
printOperation(CppEmitter & emitter,BranchOp branchOp)230 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
231 raw_ostream &os = emitter.ostream();
232 Block &successor = *branchOp.getSuccessor();
233
234 for (auto pair :
235 llvm::zip(branchOp.getOperands(), successor.getArguments())) {
236 Value &operand = std::get<0>(pair);
237 BlockArgument &argument = std::get<1>(pair);
238 os << emitter.getOrCreateName(argument) << " = "
239 << emitter.getOrCreateName(operand) << ";\n";
240 }
241
242 os << "goto ";
243 if (!(emitter.hasBlockLabel(successor)))
244 return branchOp.emitOpError("unable to find label for successor block");
245 os << emitter.getOrCreateName(successor);
246 return success();
247 }
248
printOperation(CppEmitter & emitter,CondBranchOp condBranchOp)249 static LogicalResult printOperation(CppEmitter &emitter,
250 CondBranchOp condBranchOp) {
251 raw_indented_ostream &os = emitter.ostream();
252 Block &trueSuccessor = *condBranchOp.getTrueDest();
253 Block &falseSuccessor = *condBranchOp.getFalseDest();
254
255 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
256 << ") {\n";
257
258 os.indent();
259
260 // If condition is true.
261 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
262 trueSuccessor.getArguments())) {
263 Value &operand = std::get<0>(pair);
264 BlockArgument &argument = std::get<1>(pair);
265 os << emitter.getOrCreateName(argument) << " = "
266 << emitter.getOrCreateName(operand) << ";\n";
267 }
268
269 os << "goto ";
270 if (!(emitter.hasBlockLabel(trueSuccessor))) {
271 return condBranchOp.emitOpError("unable to find label for successor block");
272 }
273 os << emitter.getOrCreateName(trueSuccessor) << ";\n";
274 os.unindent() << "} else {\n";
275 os.indent();
276 // If condition is false.
277 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
278 falseSuccessor.getArguments())) {
279 Value &operand = std::get<0>(pair);
280 BlockArgument &argument = std::get<1>(pair);
281 os << emitter.getOrCreateName(argument) << " = "
282 << emitter.getOrCreateName(operand) << ";\n";
283 }
284
285 os << "goto ";
286 if (!(emitter.hasBlockLabel(falseSuccessor))) {
287 return condBranchOp.emitOpError()
288 << "unable to find label for successor block";
289 }
290 os << emitter.getOrCreateName(falseSuccessor) << ";\n";
291 os.unindent() << "}";
292 return success();
293 }
294
printOperation(CppEmitter & emitter,mlir::CallOp callOp)295 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) {
296 if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
297 return failure();
298
299 raw_ostream &os = emitter.ostream();
300 os << callOp.getCallee() << "(";
301 if (failed(emitter.emitOperands(*callOp.getOperation())))
302 return failure();
303 os << ")";
304 return success();
305 }
306
printOperation(CppEmitter & emitter,emitc::CallOp callOp)307 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
308 raw_ostream &os = emitter.ostream();
309 Operation &op = *callOp.getOperation();
310
311 if (failed(emitter.emitAssignPrefix(op)))
312 return failure();
313 os << callOp.callee();
314
315 auto emitArgs = [&](Attribute attr) -> LogicalResult {
316 if (auto t = attr.dyn_cast<IntegerAttr>()) {
317 // Index attributes are treated specially as operand index.
318 if (t.getType().isIndex()) {
319 int64_t idx = t.getInt();
320 if ((idx < 0) || (idx >= op.getNumOperands()))
321 return op.emitOpError("invalid operand index");
322 if (!emitter.hasValueInScope(op.getOperand(idx)))
323 return op.emitOpError("operand ")
324 << idx << "'s value not defined in scope";
325 os << emitter.getOrCreateName(op.getOperand(idx));
326 return success();
327 }
328 }
329 if (failed(emitter.emitAttribute(op.getLoc(), attr)))
330 return failure();
331
332 return success();
333 };
334
335 if (callOp.template_args()) {
336 os << "<";
337 if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs)))
338 return failure();
339 os << ">";
340 }
341
342 os << "(";
343
344 LogicalResult emittedArgs =
345 callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs)
346 : emitter.emitOperands(op);
347 if (failed(emittedArgs))
348 return failure();
349 os << ")";
350 return success();
351 }
352
printOperation(CppEmitter & emitter,emitc::ApplyOp applyOp)353 static LogicalResult printOperation(CppEmitter &emitter,
354 emitc::ApplyOp applyOp) {
355 raw_ostream &os = emitter.ostream();
356 Operation &op = *applyOp.getOperation();
357
358 if (failed(emitter.emitAssignPrefix(op)))
359 return failure();
360 os << applyOp.applicableOperator();
361 os << emitter.getOrCreateName(applyOp.getOperand());
362
363 return success();
364 }
365
printOperation(CppEmitter & emitter,emitc::IncludeOp includeOp)366 static LogicalResult printOperation(CppEmitter &emitter,
367 emitc::IncludeOp includeOp) {
368 raw_ostream &os = emitter.ostream();
369
370 os << "#include ";
371 if (includeOp.is_standard_include())
372 os << "<" << includeOp.include() << ">";
373 else
374 os << "\"" << includeOp.include() << "\"";
375
376 return success();
377 }
378
printOperation(CppEmitter & emitter,scf::ForOp forOp)379 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
380
381 raw_indented_ostream &os = emitter.ostream();
382
383 OperandRange operands = forOp.getIterOperands();
384 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
385 Operation::result_range results = forOp.getResults();
386
387 if (!emitter.shouldDeclareVariablesAtTop()) {
388 for (OpResult result : results) {
389 if (failed(emitter.emitVariableDeclaration(result,
390 /*trailingSemicolon=*/true)))
391 return failure();
392 }
393 }
394
395 for (auto pair : llvm::zip(iterArgs, operands)) {
396 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
397 return failure();
398 os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
399 os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
400 os << "\n";
401 }
402
403 os << "for (";
404 if (failed(
405 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
406 return failure();
407 os << " ";
408 os << emitter.getOrCreateName(forOp.getInductionVar());
409 os << " = ";
410 os << emitter.getOrCreateName(forOp.lowerBound());
411 os << "; ";
412 os << emitter.getOrCreateName(forOp.getInductionVar());
413 os << " < ";
414 os << emitter.getOrCreateName(forOp.upperBound());
415 os << "; ";
416 os << emitter.getOrCreateName(forOp.getInductionVar());
417 os << " += ";
418 os << emitter.getOrCreateName(forOp.step());
419 os << ") {\n";
420 os.indent();
421
422 Region &forRegion = forOp.region();
423 auto regionOps = forRegion.getOps();
424
425 // We skip the trailing yield op because this updates the result variables
426 // of the for op in the generated code. Instead we update the iterArgs at
427 // the end of a loop iteration and set the result variables after the for
428 // loop.
429 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
430 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
431 return failure();
432 }
433
434 Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
435 // Copy yield operands into iterArgs at the end of a loop iteration.
436 for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
437 BlockArgument iterArg = std::get<0>(pair);
438 Value operand = std::get<1>(pair);
439 os << emitter.getOrCreateName(iterArg) << " = "
440 << emitter.getOrCreateName(operand) << ";\n";
441 }
442
443 os.unindent() << "}";
444
445 // Copy iterArgs into results after the for loop.
446 for (auto pair : llvm::zip(results, iterArgs)) {
447 OpResult result = std::get<0>(pair);
448 BlockArgument iterArg = std::get<1>(pair);
449 os << "\n"
450 << emitter.getOrCreateName(result) << " = "
451 << emitter.getOrCreateName(iterArg) << ";";
452 }
453
454 return success();
455 }
456
printOperation(CppEmitter & emitter,scf::IfOp ifOp)457 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
458 raw_indented_ostream &os = emitter.ostream();
459
460 if (!emitter.shouldDeclareVariablesAtTop()) {
461 for (OpResult result : ifOp.getResults()) {
462 if (failed(emitter.emitVariableDeclaration(result,
463 /*trailingSemicolon=*/true)))
464 return failure();
465 }
466 }
467
468 os << "if (";
469 if (failed(emitter.emitOperands(*ifOp.getOperation())))
470 return failure();
471 os << ") {\n";
472 os.indent();
473
474 Region &thenRegion = ifOp.thenRegion();
475 for (Operation &op : thenRegion.getOps()) {
476 // Note: This prints a superfluous semicolon if the terminating yield op has
477 // zero results.
478 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
479 return failure();
480 }
481
482 os.unindent() << "}";
483
484 Region &elseRegion = ifOp.elseRegion();
485 if (!elseRegion.empty()) {
486 os << " else {\n";
487 os.indent();
488
489 for (Operation &op : elseRegion.getOps()) {
490 // Note: This prints a superfluous semicolon if the terminating yield op
491 // has zero results.
492 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
493 return failure();
494 }
495
496 os.unindent() << "}";
497 }
498
499 return success();
500 }
501
printOperation(CppEmitter & emitter,scf::YieldOp yieldOp)502 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
503 raw_ostream &os = emitter.ostream();
504 Operation &parentOp = *yieldOp.getOperation()->getParentOp();
505
506 if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
507 return yieldOp.emitError("number of operands does not to match the number "
508 "of the parent op's results");
509 }
510
511 if (failed(interleaveWithError(
512 llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
513 [&](auto pair) -> LogicalResult {
514 auto result = std::get<0>(pair);
515 auto operand = std::get<1>(pair);
516 os << emitter.getOrCreateName(result) << " = ";
517
518 if (!emitter.hasValueInScope(operand))
519 return yieldOp.emitError("operand value not in scope");
520 os << emitter.getOrCreateName(operand);
521 return success();
522 },
523 [&]() { os << ";\n"; })))
524 return failure();
525
526 return success();
527 }
528
printOperation(CppEmitter & emitter,ReturnOp returnOp)529 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) {
530 raw_ostream &os = emitter.ostream();
531 os << "return";
532 switch (returnOp.getNumOperands()) {
533 case 0:
534 return success();
535 case 1:
536 os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
537 return success(emitter.hasValueInScope(returnOp.getOperand(0)));
538 default:
539 os << " std::make_tuple(";
540 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
541 return failure();
542 os << ")";
543 return success();
544 }
545 }
546
printOperation(CppEmitter & emitter,ModuleOp moduleOp)547 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
548 CppEmitter::Scope scope(emitter);
549
550 for (Operation &op : moduleOp) {
551 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
552 return failure();
553 }
554 return success();
555 }
556
printOperation(CppEmitter & emitter,FuncOp functionOp)557 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
558 // We need to declare variables at top if the function has multiple blocks.
559 if (!emitter.shouldDeclareVariablesAtTop() &&
560 functionOp.getBlocks().size() > 1) {
561 return functionOp.emitOpError(
562 "with multiple blocks needs variables declared at top");
563 }
564
565 CppEmitter::Scope scope(emitter);
566 raw_indented_ostream &os = emitter.ostream();
567 if (failed(emitter.emitTypes(functionOp.getLoc(),
568 functionOp.getType().getResults())))
569 return failure();
570 os << " " << functionOp.getName();
571
572 os << "(";
573 if (failed(interleaveCommaWithError(
574 functionOp.getArguments(), os,
575 [&](BlockArgument arg) -> LogicalResult {
576 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
577 return failure();
578 os << " " << emitter.getOrCreateName(arg);
579 return success();
580 })))
581 return failure();
582 os << ") {\n";
583 os.indent();
584 if (emitter.shouldDeclareVariablesAtTop()) {
585 // Declare all variables that hold op results including those from nested
586 // regions.
587 WalkResult result =
588 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
589 for (OpResult result : op->getResults()) {
590 if (failed(emitter.emitVariableDeclaration(
591 result, /*trailingSemicolon=*/true))) {
592 return WalkResult(
593 op->emitError("unable to declare result variable for op"));
594 }
595 }
596 return WalkResult::advance();
597 });
598 if (result.wasInterrupted())
599 return failure();
600 }
601
602 Region::BlockListType &blocks = functionOp.getBlocks();
603 // Create label names for basic blocks.
604 for (Block &block : blocks) {
605 emitter.getOrCreateName(block);
606 }
607
608 // Declare variables for basic block arguments.
609 for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
610 Block &block = *it;
611 for (BlockArgument &arg : block.getArguments()) {
612 if (emitter.hasValueInScope(arg))
613 return functionOp.emitOpError(" block argument #")
614 << arg.getArgNumber() << " is out of scope";
615 if (failed(
616 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
617 return failure();
618 }
619 os << " " << emitter.getOrCreateName(arg) << ";\n";
620 }
621 }
622
623 for (Block &block : blocks) {
624 // Only print a label if there is more than one block.
625 if (blocks.size() > 1) {
626 if (failed(emitter.emitLabel(block)))
627 return failure();
628 }
629 for (Operation &op : block.getOperations()) {
630 // When generating code for an scf.if or std.cond_br op no semicolon needs
631 // to be printed after the closing brace.
632 // When generating code for an scf.for op, printing a trailing semicolon
633 // is handled within the printOperation function.
634 bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
635
636 if (failed(emitter.emitOperation(
637 op, /*trailingSemicolon=*/trailingSemicolon)))
638 return failure();
639 }
640 }
641 os.unindent() << "}\n";
642 return success();
643 }
644
CppEmitter(raw_ostream & os,bool declareVariablesAtTop)645 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
646 : os(os), declareVariablesAtTop(declareVariablesAtTop) {
647 valueInScopeCount.push(0);
648 labelInScopeCount.push(0);
649 }
650
651 /// Return the existing or a new name for a Value.
getOrCreateName(Value val)652 StringRef CppEmitter::getOrCreateName(Value val) {
653 if (!valueMapper.count(val))
654 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
655 return *valueMapper.begin(val);
656 }
657
658 /// Return the existing or a new label for a Block.
getOrCreateName(Block & block)659 StringRef CppEmitter::getOrCreateName(Block &block) {
660 if (!blockMapper.count(&block))
661 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
662 return *blockMapper.begin(&block);
663 }
664
shouldMapToUnsigned(IntegerType::SignednessSemantics val)665 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
666 switch (val) {
667 case IntegerType::Signless:
668 return false;
669 case IntegerType::Signed:
670 return false;
671 case IntegerType::Unsigned:
672 return true;
673 }
674 llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
675 }
676
hasValueInScope(Value val)677 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
678
hasBlockLabel(Block & block)679 bool CppEmitter::hasBlockLabel(Block &block) {
680 return blockMapper.count(&block);
681 }
682
emitAttribute(Location loc,Attribute attr)683 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
684 auto printInt = [&](APInt val, bool isUnsigned) {
685 if (val.getBitWidth() == 1) {
686 if (val.getBoolValue())
687 os << "true";
688 else
689 os << "false";
690 } else {
691 SmallString<128> strValue;
692 val.toString(strValue, 10, !isUnsigned, false);
693 os << strValue;
694 }
695 };
696
697 auto printFloat = [&](APFloat val) {
698 if (val.isFinite()) {
699 SmallString<128> strValue;
700 // Use default values of toString except don't truncate zeros.
701 val.toString(strValue, 0, 0, false);
702 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
703 case llvm::APFloatBase::S_IEEEsingle:
704 os << "(float)";
705 break;
706 case llvm::APFloatBase::S_IEEEdouble:
707 os << "(double)";
708 break;
709 default:
710 break;
711 };
712 os << strValue;
713 } else if (val.isNaN()) {
714 os << "NAN";
715 } else if (val.isInfinity()) {
716 if (val.isNegative())
717 os << "-";
718 os << "INFINITY";
719 }
720 };
721
722 // Print floating point attributes.
723 if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
724 printFloat(fAttr.getValue());
725 return success();
726 }
727 if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
728 os << '{';
729 interleaveComma(dense, os, [&](APFloat val) { printFloat(val); });
730 os << '}';
731 return success();
732 }
733
734 // Print integer attributes.
735 if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
736 if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
737 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
738 return success();
739 }
740 if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
741 printInt(iAttr.getValue(), false);
742 return success();
743 }
744 }
745 if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
746 if (auto iType = dense.getType()
747 .cast<TensorType>()
748 .getElementType()
749 .dyn_cast<IntegerType>()) {
750 os << '{';
751 interleaveComma(dense, os, [&](APInt val) {
752 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
753 });
754 os << '}';
755 return success();
756 }
757 if (auto iType = dense.getType()
758 .cast<TensorType>()
759 .getElementType()
760 .dyn_cast<IndexType>()) {
761 os << '{';
762 interleaveComma(dense, os, [&](APInt val) { printInt(val, false); });
763 os << '}';
764 return success();
765 }
766 }
767
768 // Print opaque attributes.
769 if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
770 os << oAttr.getValue();
771 return success();
772 }
773
774 // Print symbolic reference attributes.
775 if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
776 if (sAttr.getNestedReferences().size() > 1)
777 return emitError(loc, "attribute has more than 1 nested reference");
778 os << sAttr.getRootReference().getValue();
779 return success();
780 }
781
782 // Print type attributes.
783 if (auto type = attr.dyn_cast<TypeAttr>())
784 return emitType(loc, type.getValue());
785
786 return emitError(loc, "cannot emit attribute of type ") << attr.getType();
787 }
788
emitOperands(Operation & op)789 LogicalResult CppEmitter::emitOperands(Operation &op) {
790 auto emitOperandName = [&](Value result) -> LogicalResult {
791 if (!hasValueInScope(result))
792 return op.emitOpError() << "operand value not in scope";
793 os << getOrCreateName(result);
794 return success();
795 };
796 return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
797 }
798
799 LogicalResult
emitOperandsAndAttributes(Operation & op,ArrayRef<StringRef> exclude)800 CppEmitter::emitOperandsAndAttributes(Operation &op,
801 ArrayRef<StringRef> exclude) {
802 if (failed(emitOperands(op)))
803 return failure();
804 // Insert comma in between operands and non-filtered attributes if needed.
805 if (op.getNumOperands() > 0) {
806 for (NamedAttribute attr : op.getAttrs()) {
807 if (!llvm::is_contained(exclude, attr.first.strref())) {
808 os << ", ";
809 break;
810 }
811 }
812 }
813 // Emit attributes.
814 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
815 if (llvm::is_contained(exclude, attr.first.strref()))
816 return success();
817 os << "/* " << attr.first << " */";
818 if (failed(emitAttribute(op.getLoc(), attr.second)))
819 return failure();
820 return success();
821 };
822 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
823 }
824
emitVariableAssignment(OpResult result)825 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
826 if (!hasValueInScope(result)) {
827 return result.getDefiningOp()->emitOpError(
828 "result variable for the operation has not been declared");
829 }
830 os << getOrCreateName(result) << " = ";
831 return success();
832 }
833
emitVariableDeclaration(OpResult result,bool trailingSemicolon)834 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
835 bool trailingSemicolon) {
836 if (hasValueInScope(result)) {
837 return result.getDefiningOp()->emitError(
838 "result variable for the operation already declared");
839 }
840 if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
841 return failure();
842 os << " " << getOrCreateName(result);
843 if (trailingSemicolon)
844 os << ";\n";
845 return success();
846 }
847
emitAssignPrefix(Operation & op)848 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
849 switch (op.getNumResults()) {
850 case 0:
851 break;
852 case 1: {
853 OpResult result = op.getResult(0);
854 if (shouldDeclareVariablesAtTop()) {
855 if (failed(emitVariableAssignment(result)))
856 return failure();
857 } else {
858 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
859 return failure();
860 os << " = ";
861 }
862 break;
863 }
864 default:
865 if (!shouldDeclareVariablesAtTop()) {
866 for (OpResult result : op.getResults()) {
867 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
868 return failure();
869 }
870 }
871 os << "std::tie(";
872 interleaveComma(op.getResults(), os,
873 [&](Value result) { os << getOrCreateName(result); });
874 os << ") = ";
875 }
876 return success();
877 }
878
emitLabel(Block & block)879 LogicalResult CppEmitter::emitLabel(Block &block) {
880 if (!hasBlockLabel(block))
881 return block.getParentOp()->emitError("label for block not found");
882 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
883 // label instead of using `getOStream`.
884 os.getOStream() << getOrCreateName(block) << ":\n";
885 return success();
886 }
887
emitOperation(Operation & op,bool trailingSemicolon)888 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
889 LogicalResult status =
890 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
891 // EmitC ops.
892 .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp,
893 emitc::IncludeOp>(
894 [&](auto op) { return printOperation(*this, op); })
895 // SCF ops.
896 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
897 [&](auto op) { return printOperation(*this, op); })
898 // Standard ops.
899 .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
900 ModuleOp, ReturnOp>(
901 [&](auto op) { return printOperation(*this, op); })
902 .Default([&](Operation *) {
903 return op.emitOpError("unable to find printer for op");
904 });
905
906 if (failed(status))
907 return failure();
908 os << (trailingSemicolon ? ";\n" : "\n");
909 return success();
910 }
911
emitType(Location loc,Type type)912 LogicalResult CppEmitter::emitType(Location loc, Type type) {
913 if (auto iType = type.dyn_cast<IntegerType>()) {
914 switch (iType.getWidth()) {
915 case 1:
916 return (os << "bool"), success();
917 case 8:
918 case 16:
919 case 32:
920 case 64:
921 if (shouldMapToUnsigned(iType.getSignedness()))
922 return (os << "uint" << iType.getWidth() << "_t"), success();
923 else
924 return (os << "int" << iType.getWidth() << "_t"), success();
925 default:
926 return emitError(loc, "cannot emit integer type ") << type;
927 }
928 }
929 if (auto fType = type.dyn_cast<FloatType>()) {
930 switch (fType.getWidth()) {
931 case 32:
932 return (os << "float"), success();
933 case 64:
934 return (os << "double"), success();
935 default:
936 return emitError(loc, "cannot emit float type ") << type;
937 }
938 }
939 if (auto iType = type.dyn_cast<IndexType>())
940 return (os << "size_t"), success();
941 if (auto tType = type.dyn_cast<TensorType>()) {
942 if (!tType.hasRank())
943 return emitError(loc, "cannot emit unranked tensor type");
944 if (!tType.hasStaticShape())
945 return emitError(loc, "cannot emit tensor type with non static shape");
946 os << "Tensor<";
947 if (failed(emitType(loc, tType.getElementType())))
948 return failure();
949 auto shape = tType.getShape();
950 for (auto dimSize : shape) {
951 os << ", ";
952 os << dimSize;
953 }
954 os << ">";
955 return success();
956 }
957 if (auto tType = type.dyn_cast<TupleType>())
958 return emitTupleType(loc, tType.getTypes());
959 if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
960 os << oType.getValue();
961 return success();
962 }
963 return emitError(loc, "cannot emit type ") << type;
964 }
965
emitTypes(Location loc,ArrayRef<Type> types)966 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
967 switch (types.size()) {
968 case 0:
969 os << "void";
970 return success();
971 case 1:
972 return emitType(loc, types.front());
973 default:
974 return emitTupleType(loc, types);
975 }
976 }
977
emitTupleType(Location loc,ArrayRef<Type> types)978 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
979 os << "std::tuple<";
980 if (failed(interleaveCommaWithError(
981 types, os, [&](Type type) { return emitType(loc, type); })))
982 return failure();
983 os << ">";
984 return success();
985 }
986
translateToCpp(Operation * op,raw_ostream & os,bool declareVariablesAtTop)987 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
988 bool declareVariablesAtTop) {
989 CppEmitter emitter(os, declareVariablesAtTop);
990 return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
991 }
992