1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/Passes.h"
15 
16 using namespace mlir;
17 
18 // Updates the func op and entry block.
19 //
20 // Any args appended to the entry block are added to `appendedEntryArgs`.
updateFuncOp(FuncOp func,SmallVectorImpl<BlockArgument> & appendedEntryArgs)21 static void updateFuncOp(FuncOp func,
22                          SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
23   auto functionType = func.getType();
24 
25   // Collect information about the results will become appended arguments.
26   SmallVector<Type, 6> erasedResultTypes;
27   SmallVector<unsigned, 6> erasedResultIndices;
28   for (auto resultType : llvm::enumerate(functionType.getResults())) {
29     if (resultType.value().isa<BaseMemRefType>()) {
30       erasedResultIndices.push_back(resultType.index());
31       erasedResultTypes.push_back(resultType.value());
32     }
33   }
34 
35   // Add the new arguments to the function type.
36   auto newArgTypes = llvm::to_vector<6>(
37       llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
38   auto newFunctionType = FunctionType::get(
39       newArgTypes, functionType.getResults(), func.getContext());
40   func.setType(newFunctionType);
41 
42   // Transfer the result attributes to arg attributes.
43   for (int i = 0, e = erasedResultTypes.size(); i < e; i++)
44     func.setArgAttrs(functionType.getNumInputs() + i,
45                      func.getResultAttrs(erasedResultIndices[i]));
46 
47   // Erase the results.
48   func.eraseResults(erasedResultIndices);
49 
50   // Add the new arguments to the entry block if the function is not external.
51   if (func.isExternal())
52     return;
53   auto newArgs = func.front().addArguments(erasedResultTypes);
54   appendedEntryArgs.append(newArgs.begin(), newArgs.end());
55 }
56 
57 // Updates all ReturnOps in the scope of the given FuncOp by either keeping them
58 // as return values or copying the associated buffer contents into the given
59 // out-params.
updateReturnOps(FuncOp func,ArrayRef<BlockArgument> appendedEntryArgs)60 static void updateReturnOps(FuncOp func,
61                             ArrayRef<BlockArgument> appendedEntryArgs) {
62   func.walk([&](ReturnOp op) {
63     SmallVector<Value, 6> copyIntoOutParams;
64     SmallVector<Value, 6> keepAsReturnOperands;
65     for (Value operand : op.getOperands()) {
66       if (operand.getType().isa<BaseMemRefType>())
67         copyIntoOutParams.push_back(operand);
68       else
69         keepAsReturnOperands.push_back(operand);
70     }
71     OpBuilder builder(op);
72     for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
73       builder.create<linalg::CopyOp>(op.getLoc(), std::get<0>(t),
74                                      std::get<1>(t));
75     builder.create<ReturnOp>(op.getLoc(), keepAsReturnOperands);
76     op.erase();
77   });
78 }
79 
80 // Updates all CallOps in the scope of the given ModuleOp by allocating
81 // temporary buffers for newly introduced out params.
updateCalls(ModuleOp module)82 static LogicalResult updateCalls(ModuleOp module) {
83   bool didFail = false;
84   module.walk([&](CallOp op) {
85     SmallVector<Value, 6> replaceWithNewCallResults;
86     SmallVector<Value, 6> replaceWithOutParams;
87     for (OpResult result : op.getResults()) {
88       if (result.getType().isa<BaseMemRefType>())
89         replaceWithOutParams.push_back(result);
90       else
91         replaceWithNewCallResults.push_back(result);
92     }
93     SmallVector<Value, 6> outParams;
94     OpBuilder builder(op);
95     for (Value memref : replaceWithOutParams) {
96       if (!memref.getType().cast<BaseMemRefType>().hasStaticShape()) {
97         op.emitError()
98             << "cannot create out param for dynamically shaped result";
99         didFail = true;
100         return;
101       }
102       Value outParam = builder.create<AllocOp>(
103           op.getLoc(), memref.getType().cast<MemRefType>());
104       memref.replaceAllUsesWith(outParam);
105       outParams.push_back(outParam);
106     }
107 
108     auto newOperands = llvm::to_vector<6>(op.getOperands());
109     newOperands.append(outParams.begin(), outParams.end());
110     auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
111         replaceWithNewCallResults, [](Value v) { return v.getType(); }));
112     auto newCall = builder.create<CallOp>(op.getLoc(), op.calleeAttr(),
113                                           newResultTypes, newOperands);
114     for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
115       std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
116     op.erase();
117   });
118 
119   return failure(didFail);
120 }
121 
122 namespace {
123 struct BufferResultsToOutParamsPass
124     : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> {
runOnOperation__anon49a959eb0411::BufferResultsToOutParamsPass125   void runOnOperation() override {
126     ModuleOp module = getOperation();
127 
128     for (auto func : module.getOps<FuncOp>()) {
129       SmallVector<BlockArgument, 6> appendedEntryArgs;
130       updateFuncOp(func, appendedEntryArgs);
131       if (func.isExternal())
132         continue;
133       updateReturnOps(func, appendedEntryArgs);
134     }
135     if (failed(updateCalls(module)))
136       return signalPassFailure();
137   }
138 };
139 } // end anonymous namespace
140 
createBufferResultsToOutParamsPass()141 std::unique_ptr<Pass> mlir::createBufferResultsToOutParamsPass() {
142   return std::make_unique<BufferResultsToOutParamsPass>();
143 }
144