1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/EmitC/IR/EmitC.h"
10 #include "mlir/Dialect/SCF/SCF.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/IndentedOstream.h"
17 #include "mlir/Target/Cpp/CppEmitter.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #define DEBUG_TYPE "translate-to-cpp"
26 
27 using namespace mlir;
28 using namespace mlir::emitc;
29 using llvm::formatv;
30 
31 /// Convenience functions to produce interleaved output with functions returning
32 /// a LogicalResult. This is different than those in STLExtras as functions used
33 /// on each element doesn't return a string.
34 template <typename ForwardIterator, typename UnaryFunctor,
35           typename NullaryFunctor>
36 inline LogicalResult
interleaveWithError(ForwardIterator begin,ForwardIterator end,UnaryFunctor eachFn,NullaryFunctor betweenFn)37 interleaveWithError(ForwardIterator begin, ForwardIterator end,
38                     UnaryFunctor eachFn, NullaryFunctor betweenFn) {
39   if (begin == end)
40     return success();
41   if (failed(eachFn(*begin)))
42     return failure();
43   ++begin;
44   for (; begin != end; ++begin) {
45     betweenFn();
46     if (failed(eachFn(*begin)))
47       return failure();
48   }
49   return success();
50 }
51 
52 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
interleaveWithError(const Container & c,UnaryFunctor eachFn,NullaryFunctor betweenFn)53 inline LogicalResult interleaveWithError(const Container &c,
54                                          UnaryFunctor eachFn,
55                                          NullaryFunctor betweenFn) {
56   return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
57 }
58 
59 template <typename Container, typename UnaryFunctor>
interleaveCommaWithError(const Container & c,raw_ostream & os,UnaryFunctor eachFn)60 inline LogicalResult interleaveCommaWithError(const Container &c,
61                                               raw_ostream &os,
62                                               UnaryFunctor eachFn) {
63   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
64 }
65 
66 namespace {
67 /// Emitter that uses dialect specific emitters to emit C++ code.
68 struct CppEmitter {
69   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
70 
71   /// Emits attribute or returns failure.
72   LogicalResult emitAttribute(Location loc, Attribute attr);
73 
74   /// Emits operation 'op' with/without training semicolon or returns failure.
75   LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
76 
77   /// Emits type 'type' or returns failure.
78   LogicalResult emitType(Location loc, Type type);
79 
80   /// Emits array of types as a std::tuple of the emitted types.
81   /// - emits void for an empty array;
82   /// - emits the type of the only element for arrays of size one;
83   /// - emits a std::tuple otherwise;
84   LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
85 
86   /// Emits array of types as a std::tuple of the emitted types independently of
87   /// the array size.
88   LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
89 
90   /// Emits an assignment for a variable which has been declared previously.
91   LogicalResult emitVariableAssignment(OpResult result);
92 
93   /// Emits a variable declaration for a result of an operation.
94   LogicalResult emitVariableDeclaration(OpResult result,
95                                         bool trailingSemicolon);
96 
97   /// Emits the variable declaration and assignment prefix for 'op'.
98   /// - emits separate variable followed by std::tie for multi-valued operation;
99   /// - emits single type followed by variable for single result;
100   /// - emits nothing if no value produced by op;
101   /// Emits final '=' operator where a type is produced. Returns failure if
102   /// any result type could not be converted.
103   LogicalResult emitAssignPrefix(Operation &op);
104 
105   /// Emits a label for the block.
106   LogicalResult emitLabel(Block &block);
107 
108   /// Emits the operands and atttributes of the operation. All operands are
109   /// emitted first and then all attributes in alphabetical order.
110   LogicalResult emitOperandsAndAttributes(Operation &op,
111                                           ArrayRef<StringRef> exclude = {});
112 
113   /// Emits the operands of the operation. All operands are emitted in order.
114   LogicalResult emitOperands(Operation &op);
115 
116   /// Return the existing or a new name for a Value.
117   StringRef getOrCreateName(Value val);
118 
119   /// Return the existing or a new label of a Block.
120   StringRef getOrCreateName(Block &block);
121 
122   /// Whether to map an mlir integer to a unsigned integer in C++.
123   bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
124 
125   /// RAII helper function to manage entering/exiting C++ scopes.
126   struct Scope {
Scope__anon45c64d1e0211::CppEmitter::Scope127     Scope(CppEmitter &emitter)
128         : valueMapperScope(emitter.valueMapper),
129           blockMapperScope(emitter.blockMapper), emitter(emitter) {
130       emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
131       emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
132     }
~Scope__anon45c64d1e0211::CppEmitter::Scope133     ~Scope() {
134       emitter.valueInScopeCount.pop();
135       emitter.labelInScopeCount.pop();
136     }
137 
138   private:
139     llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
140     llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
141     CppEmitter &emitter;
142   };
143 
144   /// Returns wether the Value is assigned to a C++ variable in the scope.
145   bool hasValueInScope(Value val);
146 
147   // Returns whether a label is assigned to the block.
148   bool hasBlockLabel(Block &block);
149 
150   /// Returns the output stream.
ostream__anon45c64d1e0211::CppEmitter151   raw_indented_ostream &ostream() { return os; };
152 
153   /// Returns if all variables for op results and basic block arguments need to
154   /// be declared at the beginning of a function.
shouldDeclareVariablesAtTop__anon45c64d1e0211::CppEmitter155   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
156 
157 private:
158   using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
159   using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
160 
161   /// Output stream to emit to.
162   raw_indented_ostream os;
163 
164   /// Boolean to enforce that all variables for op results and block
165   /// arguments are declared at the beginning of the function. This also
166   /// includes results from ops located in nested regions.
167   bool declareVariablesAtTop;
168 
169   /// Map from value to name of C++ variable that contain the name.
170   ValueMapper valueMapper;
171 
172   /// Map from block to name of C++ label.
173   BlockMapper blockMapper;
174 
175   /// The number of values in the current scope. This is used to declare the
176   /// names of values in a scope.
177   std::stack<int64_t> valueInScopeCount;
178   std::stack<int64_t> labelInScopeCount;
179 };
180 } // namespace
181 
printConstantOp(CppEmitter & emitter,Operation * operation,Attribute value)182 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
183                                      Attribute value) {
184   OpResult result = operation->getResult(0);
185 
186   // Only emit an assignment as the variable was already declared when printing
187   // the FuncOp.
188   if (emitter.shouldDeclareVariablesAtTop()) {
189     // Skip the assignment if the emitc.constant has no value.
190     if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
191       if (oAttr.getValue().empty())
192         return success();
193     }
194 
195     if (failed(emitter.emitVariableAssignment(result)))
196       return failure();
197     return emitter.emitAttribute(operation->getLoc(), value);
198   }
199 
200   // Emit a variable declaration for an emitc.constant op without value.
201   if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
202     if (oAttr.getValue().empty())
203       // The semicolon gets printed by the emitOperation function.
204       return emitter.emitVariableDeclaration(result,
205                                              /*trailingSemicolon=*/false);
206   }
207 
208   // Emit a variable declaration.
209   if (failed(emitter.emitAssignPrefix(*operation)))
210     return failure();
211   return emitter.emitAttribute(operation->getLoc(), value);
212 }
213 
printOperation(CppEmitter & emitter,emitc::ConstantOp constantOp)214 static LogicalResult printOperation(CppEmitter &emitter,
215                                     emitc::ConstantOp constantOp) {
216   Operation *operation = constantOp.getOperation();
217   Attribute value = constantOp.value();
218 
219   return printConstantOp(emitter, operation, value);
220 }
221 
printOperation(CppEmitter & emitter,mlir::ConstantOp constantOp)222 static LogicalResult printOperation(CppEmitter &emitter,
223                                     mlir::ConstantOp constantOp) {
224   Operation *operation = constantOp.getOperation();
225   Attribute value = constantOp.value();
226 
227   return printConstantOp(emitter, operation, value);
228 }
229 
printOperation(CppEmitter & emitter,BranchOp branchOp)230 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
231   raw_ostream &os = emitter.ostream();
232   Block &successor = *branchOp.getSuccessor();
233 
234   for (auto pair :
235        llvm::zip(branchOp.getOperands(), successor.getArguments())) {
236     Value &operand = std::get<0>(pair);
237     BlockArgument &argument = std::get<1>(pair);
238     os << emitter.getOrCreateName(argument) << " = "
239        << emitter.getOrCreateName(operand) << ";\n";
240   }
241 
242   os << "goto ";
243   if (!(emitter.hasBlockLabel(successor)))
244     return branchOp.emitOpError("unable to find label for successor block");
245   os << emitter.getOrCreateName(successor);
246   return success();
247 }
248 
printOperation(CppEmitter & emitter,CondBranchOp condBranchOp)249 static LogicalResult printOperation(CppEmitter &emitter,
250                                     CondBranchOp condBranchOp) {
251   raw_indented_ostream &os = emitter.ostream();
252   Block &trueSuccessor = *condBranchOp.getTrueDest();
253   Block &falseSuccessor = *condBranchOp.getFalseDest();
254 
255   os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
256      << ") {\n";
257 
258   os.indent();
259 
260   // If condition is true.
261   for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
262                              trueSuccessor.getArguments())) {
263     Value &operand = std::get<0>(pair);
264     BlockArgument &argument = std::get<1>(pair);
265     os << emitter.getOrCreateName(argument) << " = "
266        << emitter.getOrCreateName(operand) << ";\n";
267   }
268 
269   os << "goto ";
270   if (!(emitter.hasBlockLabel(trueSuccessor))) {
271     return condBranchOp.emitOpError("unable to find label for successor block");
272   }
273   os << emitter.getOrCreateName(trueSuccessor) << ";\n";
274   os.unindent() << "} else {\n";
275   os.indent();
276   // If condition is false.
277   for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
278                              falseSuccessor.getArguments())) {
279     Value &operand = std::get<0>(pair);
280     BlockArgument &argument = std::get<1>(pair);
281     os << emitter.getOrCreateName(argument) << " = "
282        << emitter.getOrCreateName(operand) << ";\n";
283   }
284 
285   os << "goto ";
286   if (!(emitter.hasBlockLabel(falseSuccessor))) {
287     return condBranchOp.emitOpError()
288            << "unable to find label for successor block";
289   }
290   os << emitter.getOrCreateName(falseSuccessor) << ";\n";
291   os.unindent() << "}";
292   return success();
293 }
294 
printOperation(CppEmitter & emitter,mlir::CallOp callOp)295 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) {
296   if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
297     return failure();
298 
299   raw_ostream &os = emitter.ostream();
300   os << callOp.getCallee() << "(";
301   if (failed(emitter.emitOperands(*callOp.getOperation())))
302     return failure();
303   os << ")";
304   return success();
305 }
306 
printOperation(CppEmitter & emitter,emitc::CallOp callOp)307 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
308   raw_ostream &os = emitter.ostream();
309   Operation &op = *callOp.getOperation();
310 
311   if (failed(emitter.emitAssignPrefix(op)))
312     return failure();
313   os << callOp.callee();
314 
315   auto emitArgs = [&](Attribute attr) -> LogicalResult {
316     if (auto t = attr.dyn_cast<IntegerAttr>()) {
317       // Index attributes are treated specially as operand index.
318       if (t.getType().isIndex()) {
319         int64_t idx = t.getInt();
320         if ((idx < 0) || (idx >= op.getNumOperands()))
321           return op.emitOpError("invalid operand index");
322         if (!emitter.hasValueInScope(op.getOperand(idx)))
323           return op.emitOpError("operand ")
324                  << idx << "'s value not defined in scope";
325         os << emitter.getOrCreateName(op.getOperand(idx));
326         return success();
327       }
328     }
329     if (failed(emitter.emitAttribute(op.getLoc(), attr)))
330       return failure();
331 
332     return success();
333   };
334 
335   if (callOp.template_args()) {
336     os << "<";
337     if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs)))
338       return failure();
339     os << ">";
340   }
341 
342   os << "(";
343 
344   LogicalResult emittedArgs =
345       callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs)
346                     : emitter.emitOperands(op);
347   if (failed(emittedArgs))
348     return failure();
349   os << ")";
350   return success();
351 }
352 
printOperation(CppEmitter & emitter,emitc::ApplyOp applyOp)353 static LogicalResult printOperation(CppEmitter &emitter,
354                                     emitc::ApplyOp applyOp) {
355   raw_ostream &os = emitter.ostream();
356   Operation &op = *applyOp.getOperation();
357 
358   if (failed(emitter.emitAssignPrefix(op)))
359     return failure();
360   os << applyOp.applicableOperator();
361   os << emitter.getOrCreateName(applyOp.getOperand());
362 
363   return success();
364 }
365 
printOperation(CppEmitter & emitter,emitc::IncludeOp includeOp)366 static LogicalResult printOperation(CppEmitter &emitter,
367                                     emitc::IncludeOp includeOp) {
368   raw_ostream &os = emitter.ostream();
369 
370   os << "#include ";
371   if (includeOp.is_standard_include())
372     os << "<" << includeOp.include() << ">";
373   else
374     os << "\"" << includeOp.include() << "\"";
375 
376   return success();
377 }
378 
printOperation(CppEmitter & emitter,scf::ForOp forOp)379 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
380 
381   raw_indented_ostream &os = emitter.ostream();
382 
383   OperandRange operands = forOp.getIterOperands();
384   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
385   Operation::result_range results = forOp.getResults();
386 
387   if (!emitter.shouldDeclareVariablesAtTop()) {
388     for (OpResult result : results) {
389       if (failed(emitter.emitVariableDeclaration(result,
390                                                  /*trailingSemicolon=*/true)))
391         return failure();
392     }
393   }
394 
395   for (auto pair : llvm::zip(iterArgs, operands)) {
396     if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
397       return failure();
398     os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
399     os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
400     os << "\n";
401   }
402 
403   os << "for (";
404   if (failed(
405           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
406     return failure();
407   os << " ";
408   os << emitter.getOrCreateName(forOp.getInductionVar());
409   os << " = ";
410   os << emitter.getOrCreateName(forOp.lowerBound());
411   os << "; ";
412   os << emitter.getOrCreateName(forOp.getInductionVar());
413   os << " < ";
414   os << emitter.getOrCreateName(forOp.upperBound());
415   os << "; ";
416   os << emitter.getOrCreateName(forOp.getInductionVar());
417   os << " += ";
418   os << emitter.getOrCreateName(forOp.step());
419   os << ") {\n";
420   os.indent();
421 
422   Region &forRegion = forOp.region();
423   auto regionOps = forRegion.getOps();
424 
425   // We skip the trailing yield op because this updates the result variables
426   // of the for op in the generated code. Instead we update the iterArgs at
427   // the end of a loop iteration and set the result variables after the for
428   // loop.
429   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
430     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
431       return failure();
432   }
433 
434   Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
435   // Copy yield operands into iterArgs at the end of a loop iteration.
436   for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
437     BlockArgument iterArg = std::get<0>(pair);
438     Value operand = std::get<1>(pair);
439     os << emitter.getOrCreateName(iterArg) << " = "
440        << emitter.getOrCreateName(operand) << ";\n";
441   }
442 
443   os.unindent() << "}";
444 
445   // Copy iterArgs into results after the for loop.
446   for (auto pair : llvm::zip(results, iterArgs)) {
447     OpResult result = std::get<0>(pair);
448     BlockArgument iterArg = std::get<1>(pair);
449     os << "\n"
450        << emitter.getOrCreateName(result) << " = "
451        << emitter.getOrCreateName(iterArg) << ";";
452   }
453 
454   return success();
455 }
456 
printOperation(CppEmitter & emitter,scf::IfOp ifOp)457 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
458   raw_indented_ostream &os = emitter.ostream();
459 
460   if (!emitter.shouldDeclareVariablesAtTop()) {
461     for (OpResult result : ifOp.getResults()) {
462       if (failed(emitter.emitVariableDeclaration(result,
463                                                  /*trailingSemicolon=*/true)))
464         return failure();
465     }
466   }
467 
468   os << "if (";
469   if (failed(emitter.emitOperands(*ifOp.getOperation())))
470     return failure();
471   os << ") {\n";
472   os.indent();
473 
474   Region &thenRegion = ifOp.thenRegion();
475   for (Operation &op : thenRegion.getOps()) {
476     // Note: This prints a superfluous semicolon if the terminating yield op has
477     // zero results.
478     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
479       return failure();
480   }
481 
482   os.unindent() << "}";
483 
484   Region &elseRegion = ifOp.elseRegion();
485   if (!elseRegion.empty()) {
486     os << " else {\n";
487     os.indent();
488 
489     for (Operation &op : elseRegion.getOps()) {
490       // Note: This prints a superfluous semicolon if the terminating yield op
491       // has zero results.
492       if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
493         return failure();
494     }
495 
496     os.unindent() << "}";
497   }
498 
499   return success();
500 }
501 
printOperation(CppEmitter & emitter,scf::YieldOp yieldOp)502 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
503   raw_ostream &os = emitter.ostream();
504   Operation &parentOp = *yieldOp.getOperation()->getParentOp();
505 
506   if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
507     return yieldOp.emitError("number of operands does not to match the number "
508                              "of the parent op's results");
509   }
510 
511   if (failed(interleaveWithError(
512           llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
513           [&](auto pair) -> LogicalResult {
514             auto result = std::get<0>(pair);
515             auto operand = std::get<1>(pair);
516             os << emitter.getOrCreateName(result) << " = ";
517 
518             if (!emitter.hasValueInScope(operand))
519               return yieldOp.emitError("operand value not in scope");
520             os << emitter.getOrCreateName(operand);
521             return success();
522           },
523           [&]() { os << ";\n"; })))
524     return failure();
525 
526   return success();
527 }
528 
printOperation(CppEmitter & emitter,ReturnOp returnOp)529 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) {
530   raw_ostream &os = emitter.ostream();
531   os << "return";
532   switch (returnOp.getNumOperands()) {
533   case 0:
534     return success();
535   case 1:
536     os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
537     return success(emitter.hasValueInScope(returnOp.getOperand(0)));
538   default:
539     os << " std::make_tuple(";
540     if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
541       return failure();
542     os << ")";
543     return success();
544   }
545 }
546 
printOperation(CppEmitter & emitter,ModuleOp moduleOp)547 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
548   CppEmitter::Scope scope(emitter);
549 
550   for (Operation &op : moduleOp) {
551     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
552       return failure();
553   }
554   return success();
555 }
556 
printOperation(CppEmitter & emitter,FuncOp functionOp)557 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
558   // We need to declare variables at top if the function has multiple blocks.
559   if (!emitter.shouldDeclareVariablesAtTop() &&
560       functionOp.getBlocks().size() > 1) {
561     return functionOp.emitOpError(
562         "with multiple blocks needs variables declared at top");
563   }
564 
565   CppEmitter::Scope scope(emitter);
566   raw_indented_ostream &os = emitter.ostream();
567   if (failed(emitter.emitTypes(functionOp.getLoc(),
568                                functionOp.getType().getResults())))
569     return failure();
570   os << " " << functionOp.getName();
571 
572   os << "(";
573   if (failed(interleaveCommaWithError(
574           functionOp.getArguments(), os,
575           [&](BlockArgument arg) -> LogicalResult {
576             if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
577               return failure();
578             os << " " << emitter.getOrCreateName(arg);
579             return success();
580           })))
581     return failure();
582   os << ") {\n";
583   os.indent();
584   if (emitter.shouldDeclareVariablesAtTop()) {
585     // Declare all variables that hold op results including those from nested
586     // regions.
587     WalkResult result =
588         functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
589           for (OpResult result : op->getResults()) {
590             if (failed(emitter.emitVariableDeclaration(
591                     result, /*trailingSemicolon=*/true))) {
592               return WalkResult(
593                   op->emitError("unable to declare result variable for op"));
594             }
595           }
596           return WalkResult::advance();
597         });
598     if (result.wasInterrupted())
599       return failure();
600   }
601 
602   Region::BlockListType &blocks = functionOp.getBlocks();
603   // Create label names for basic blocks.
604   for (Block &block : blocks) {
605     emitter.getOrCreateName(block);
606   }
607 
608   // Declare variables for basic block arguments.
609   for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
610     Block &block = *it;
611     for (BlockArgument &arg : block.getArguments()) {
612       if (emitter.hasValueInScope(arg))
613         return functionOp.emitOpError(" block argument #")
614                << arg.getArgNumber() << " is out of scope";
615       if (failed(
616               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
617         return failure();
618       }
619       os << " " << emitter.getOrCreateName(arg) << ";\n";
620     }
621   }
622 
623   for (Block &block : blocks) {
624     // Only print a label if there is more than one block.
625     if (blocks.size() > 1) {
626       if (failed(emitter.emitLabel(block)))
627         return failure();
628     }
629     for (Operation &op : block.getOperations()) {
630       // When generating code for an scf.if or std.cond_br op no semicolon needs
631       // to be printed after the closing brace.
632       // When generating code for an scf.for op, printing a trailing semicolon
633       // is handled within the printOperation function.
634       bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
635 
636       if (failed(emitter.emitOperation(
637               op, /*trailingSemicolon=*/trailingSemicolon)))
638         return failure();
639     }
640   }
641   os.unindent() << "}\n";
642   return success();
643 }
644 
CppEmitter(raw_ostream & os,bool declareVariablesAtTop)645 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
646     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
647   valueInScopeCount.push(0);
648   labelInScopeCount.push(0);
649 }
650 
651 /// Return the existing or a new name for a Value.
getOrCreateName(Value val)652 StringRef CppEmitter::getOrCreateName(Value val) {
653   if (!valueMapper.count(val))
654     valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
655   return *valueMapper.begin(val);
656 }
657 
658 /// Return the existing or a new label for a Block.
getOrCreateName(Block & block)659 StringRef CppEmitter::getOrCreateName(Block &block) {
660   if (!blockMapper.count(&block))
661     blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
662   return *blockMapper.begin(&block);
663 }
664 
shouldMapToUnsigned(IntegerType::SignednessSemantics val)665 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
666   switch (val) {
667   case IntegerType::Signless:
668     return false;
669   case IntegerType::Signed:
670     return false;
671   case IntegerType::Unsigned:
672     return true;
673   }
674   llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
675 }
676 
hasValueInScope(Value val)677 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
678 
hasBlockLabel(Block & block)679 bool CppEmitter::hasBlockLabel(Block &block) {
680   return blockMapper.count(&block);
681 }
682 
emitAttribute(Location loc,Attribute attr)683 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
684   auto printInt = [&](APInt val, bool isUnsigned) {
685     if (val.getBitWidth() == 1) {
686       if (val.getBoolValue())
687         os << "true";
688       else
689         os << "false";
690     } else {
691       SmallString<128> strValue;
692       val.toString(strValue, 10, !isUnsigned, false);
693       os << strValue;
694     }
695   };
696 
697   auto printFloat = [&](APFloat val) {
698     if (val.isFinite()) {
699       SmallString<128> strValue;
700       // Use default values of toString except don't truncate zeros.
701       val.toString(strValue, 0, 0, false);
702       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
703       case llvm::APFloatBase::S_IEEEsingle:
704         os << "(float)";
705         break;
706       case llvm::APFloatBase::S_IEEEdouble:
707         os << "(double)";
708         break;
709       default:
710         break;
711       };
712       os << strValue;
713     } else if (val.isNaN()) {
714       os << "NAN";
715     } else if (val.isInfinity()) {
716       if (val.isNegative())
717         os << "-";
718       os << "INFINITY";
719     }
720   };
721 
722   // Print floating point attributes.
723   if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
724     printFloat(fAttr.getValue());
725     return success();
726   }
727   if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
728     os << '{';
729     interleaveComma(dense, os, [&](APFloat val) { printFloat(val); });
730     os << '}';
731     return success();
732   }
733 
734   // Print integer attributes.
735   if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
736     if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
737       printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
738       return success();
739     }
740     if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
741       printInt(iAttr.getValue(), false);
742       return success();
743     }
744   }
745   if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
746     if (auto iType = dense.getType()
747                          .cast<TensorType>()
748                          .getElementType()
749                          .dyn_cast<IntegerType>()) {
750       os << '{';
751       interleaveComma(dense, os, [&](APInt val) {
752         printInt(val, shouldMapToUnsigned(iType.getSignedness()));
753       });
754       os << '}';
755       return success();
756     }
757     if (auto iType = dense.getType()
758                          .cast<TensorType>()
759                          .getElementType()
760                          .dyn_cast<IndexType>()) {
761       os << '{';
762       interleaveComma(dense, os, [&](APInt val) { printInt(val, false); });
763       os << '}';
764       return success();
765     }
766   }
767 
768   // Print opaque attributes.
769   if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
770     os << oAttr.getValue();
771     return success();
772   }
773 
774   // Print symbolic reference attributes.
775   if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
776     if (sAttr.getNestedReferences().size() > 1)
777       return emitError(loc, "attribute has more than 1 nested reference");
778     os << sAttr.getRootReference().getValue();
779     return success();
780   }
781 
782   // Print type attributes.
783   if (auto type = attr.dyn_cast<TypeAttr>())
784     return emitType(loc, type.getValue());
785 
786   return emitError(loc, "cannot emit attribute of type ") << attr.getType();
787 }
788 
emitOperands(Operation & op)789 LogicalResult CppEmitter::emitOperands(Operation &op) {
790   auto emitOperandName = [&](Value result) -> LogicalResult {
791     if (!hasValueInScope(result))
792       return op.emitOpError() << "operand value not in scope";
793     os << getOrCreateName(result);
794     return success();
795   };
796   return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
797 }
798 
799 LogicalResult
emitOperandsAndAttributes(Operation & op,ArrayRef<StringRef> exclude)800 CppEmitter::emitOperandsAndAttributes(Operation &op,
801                                       ArrayRef<StringRef> exclude) {
802   if (failed(emitOperands(op)))
803     return failure();
804   // Insert comma in between operands and non-filtered attributes if needed.
805   if (op.getNumOperands() > 0) {
806     for (NamedAttribute attr : op.getAttrs()) {
807       if (!llvm::is_contained(exclude, attr.first.strref())) {
808         os << ", ";
809         break;
810       }
811     }
812   }
813   // Emit attributes.
814   auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
815     if (llvm::is_contained(exclude, attr.first.strref()))
816       return success();
817     os << "/* " << attr.first << " */";
818     if (failed(emitAttribute(op.getLoc(), attr.second)))
819       return failure();
820     return success();
821   };
822   return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
823 }
824 
emitVariableAssignment(OpResult result)825 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
826   if (!hasValueInScope(result)) {
827     return result.getDefiningOp()->emitOpError(
828         "result variable for the operation has not been declared");
829   }
830   os << getOrCreateName(result) << " = ";
831   return success();
832 }
833 
emitVariableDeclaration(OpResult result,bool trailingSemicolon)834 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
835                                                   bool trailingSemicolon) {
836   if (hasValueInScope(result)) {
837     return result.getDefiningOp()->emitError(
838         "result variable for the operation already declared");
839   }
840   if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
841     return failure();
842   os << " " << getOrCreateName(result);
843   if (trailingSemicolon)
844     os << ";\n";
845   return success();
846 }
847 
emitAssignPrefix(Operation & op)848 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
849   switch (op.getNumResults()) {
850   case 0:
851     break;
852   case 1: {
853     OpResult result = op.getResult(0);
854     if (shouldDeclareVariablesAtTop()) {
855       if (failed(emitVariableAssignment(result)))
856         return failure();
857     } else {
858       if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
859         return failure();
860       os << " = ";
861     }
862     break;
863   }
864   default:
865     if (!shouldDeclareVariablesAtTop()) {
866       for (OpResult result : op.getResults()) {
867         if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
868           return failure();
869       }
870     }
871     os << "std::tie(";
872     interleaveComma(op.getResults(), os,
873                     [&](Value result) { os << getOrCreateName(result); });
874     os << ") = ";
875   }
876   return success();
877 }
878 
emitLabel(Block & block)879 LogicalResult CppEmitter::emitLabel(Block &block) {
880   if (!hasBlockLabel(block))
881     return block.getParentOp()->emitError("label for block not found");
882   // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
883   // label instead of using `getOStream`.
884   os.getOStream() << getOrCreateName(block) << ":\n";
885   return success();
886 }
887 
emitOperation(Operation & op,bool trailingSemicolon)888 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
889   LogicalResult status =
890       llvm::TypeSwitch<Operation *, LogicalResult>(&op)
891           // EmitC ops.
892           .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp,
893                 emitc::IncludeOp>(
894               [&](auto op) { return printOperation(*this, op); })
895           // SCF ops.
896           .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
897               [&](auto op) { return printOperation(*this, op); })
898           // Standard ops.
899           .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
900                 ModuleOp, ReturnOp>(
901               [&](auto op) { return printOperation(*this, op); })
902           .Default([&](Operation *) {
903             return op.emitOpError("unable to find printer for op");
904           });
905 
906   if (failed(status))
907     return failure();
908   os << (trailingSemicolon ? ";\n" : "\n");
909   return success();
910 }
911 
emitType(Location loc,Type type)912 LogicalResult CppEmitter::emitType(Location loc, Type type) {
913   if (auto iType = type.dyn_cast<IntegerType>()) {
914     switch (iType.getWidth()) {
915     case 1:
916       return (os << "bool"), success();
917     case 8:
918     case 16:
919     case 32:
920     case 64:
921       if (shouldMapToUnsigned(iType.getSignedness()))
922         return (os << "uint" << iType.getWidth() << "_t"), success();
923       else
924         return (os << "int" << iType.getWidth() << "_t"), success();
925     default:
926       return emitError(loc, "cannot emit integer type ") << type;
927     }
928   }
929   if (auto fType = type.dyn_cast<FloatType>()) {
930     switch (fType.getWidth()) {
931     case 32:
932       return (os << "float"), success();
933     case 64:
934       return (os << "double"), success();
935     default:
936       return emitError(loc, "cannot emit float type ") << type;
937     }
938   }
939   if (auto iType = type.dyn_cast<IndexType>())
940     return (os << "size_t"), success();
941   if (auto tType = type.dyn_cast<TensorType>()) {
942     if (!tType.hasRank())
943       return emitError(loc, "cannot emit unranked tensor type");
944     if (!tType.hasStaticShape())
945       return emitError(loc, "cannot emit tensor type with non static shape");
946     os << "Tensor<";
947     if (failed(emitType(loc, tType.getElementType())))
948       return failure();
949     auto shape = tType.getShape();
950     for (auto dimSize : shape) {
951       os << ", ";
952       os << dimSize;
953     }
954     os << ">";
955     return success();
956   }
957   if (auto tType = type.dyn_cast<TupleType>())
958     return emitTupleType(loc, tType.getTypes());
959   if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
960     os << oType.getValue();
961     return success();
962   }
963   return emitError(loc, "cannot emit type ") << type;
964 }
965 
emitTypes(Location loc,ArrayRef<Type> types)966 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
967   switch (types.size()) {
968   case 0:
969     os << "void";
970     return success();
971   case 1:
972     return emitType(loc, types.front());
973   default:
974     return emitTupleType(loc, types);
975   }
976 }
977 
emitTupleType(Location loc,ArrayRef<Type> types)978 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
979   os << "std::tuple<";
980   if (failed(interleaveCommaWithError(
981           types, os, [&](Type type) { return emitType(loc, type); })))
982     return failure();
983   os << ">";
984   return success();
985 }
986 
translateToCpp(Operation * op,raw_ostream & os,bool declareVariablesAtTop)987 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
988                                     bool declareVariablesAtTop) {
989   CppEmitter emitter(os, declareVariablesAtTop);
990   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
991 }
992