1 //
2 //  CUDALoop.cpp
3 //  MNN
4 //
5 //  Created by MNN on b'2021/04/20'.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 #include <map>
9 #include "MatMulExecution.hpp"
10 #include "backend/cuda/core/CUDABackend.hpp"
11 #include "Raster.cuh"
12 namespace MNN {
13 namespace CUDA {
14 class CUDALoop : public Execution {
15 public:
16     struct Unit {
17         std::vector<Tensor*> inputs;
18         std::vector<Tensor*> outputs;
19         std::shared_ptr<Execution> exe;
20     };
CUDALoop(Backend * bn,const LoopParam * loop)21     CUDALoop(Backend* bn, const LoopParam* loop) : Execution(bn) {
22         // The LoopParam is created by geometry, won't be released
23         mLoop = loop;
24         mStack.resize(loop->tensorNumber());
25         mExecutions.resize(loop->commands()->size());
26         mStackPtr.resize(loop->tensorNumber());
27     }
~CUDALoop()28     virtual ~ CUDALoop() {
29         // Do nothing
30     }
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)31     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
32         mMidTensors.clear();
33         mIndiceCopy.clear();
34         int inputIndexSize = mLoop->inputIndexes()->size();
35         MNN_ASSERT(inputIndexSize == inputs.size());
36         for (int i=0; i<inputIndexSize; ++i) {
37             mStack[mLoop->inputIndexes()->data()[i]] = inputs[i];
38         }
39         int outputIndexSize = mLoop->outputIndexes()->size();
40         MNN_ASSERT(outputIndexSize == outputs.size());
41         for (int i=0; i<outputIndexSize; ++i) {
42             mStack[mLoop->outputIndexes()->data()[i]] = outputs[i];
43         }
44         if (1 == mLoop->commands()->size()) {
45             auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
46             auto op = cmd->op();
47             if (OpType_UnaryOp == op->type() && nullptr == op->main()) {
48                 return NO_ERROR;
49             }
50         }
51         for (int i=0; i<mLoop->commands()->size(); ++i) {
52             auto cmd = mLoop->commands()->GetAs<RegionCommand>(i);
53             auto op = cmd->op();
54             auto& unit = mExecutions[i];
55             // Find indice and copy to cpu
56             int size = cmd->iterIndexes()->size();
57             for (int v=0; v<size; ++v) {
58                 auto tensorIndex = cmd->indexes()->data()[v];
59                 auto tensor = mStack[tensorIndex];
60                 auto iterIndex = cmd->iterIndexes()->data()[v];
61                 if (iterIndex >= 0 && mStack[iterIndex]->host<void>() == nullptr) {
62                     std::shared_ptr<Tensor> tensorHost(new Tensor(mStack[iterIndex], mStack[iterIndex]->getDimensionType()));
63                     mIndiceCopy.insert(std::make_pair(mStack[iterIndex], tensorHost.get()));
64                     mStack[iterIndex] = tensorHost.get();
65                     mMidTensors.emplace_back(std::move(tensorHost));
66                 }
67             }
68             // Prepare for MatMul
69             if (OpType_MatMul == op->type()) {
70                 bool transposeC = true;
71                 int e = cmd->size()->data()[0];
72                 int l = cmd->size()->data()[1];
73                 int h = cmd->size()->data()[2];
74                 std::shared_ptr<Tensor> A, B, C, Bias;
75                 C.reset(Tensor::createDevice<float>({e, h}));
76                 if (op->main_as_MatMul()->transposeA()) {
77                     A.reset(Tensor::createDevice<float>({l, e}));
78                 } else {
79                     A.reset(Tensor::createDevice<float>({e, l}));
80                 }
81                 if (op->main_as_MatMul()->transposeB()) {
82                     B.reset(Tensor::createDevice<float>({h, l}));
83                 } else {
84                     B.reset(Tensor::createDevice<float>({l, h}));
85                 }
86                 auto view = cmd->view()->GetAs<View>(0);
87                 if (view->stride()->data()[0] == 1) {
88                     transposeC = false;
89                 }
90                 if (cmd->indexes()->size() > 3) {
91                     Bias.reset(Tensor::createDevice<float>({h}));
92                     unit.inputs = {A.get(), B.get(), Bias.get()};
93                 } else {
94                     unit.inputs = {A.get(), B.get()};
95                 }
96                 unit.outputs = {C.get()};
97                 unit.exe.reset(new MatMulExecution(op->main_as_MatMul()->transposeA(),  op->main_as_MatMul()->transposeB(), backend()));
98                 if (nullptr == unit.exe) {
99                     return OUT_OF_MEMORY;
100                 }
101                 auto code = unit.exe->onResize(unit.inputs, unit.outputs);
102                 if (NO_ERROR != code) {
103                     return code;
104                 }
105                 mMidTensors.emplace_back(A);
106                 mMidTensors.emplace_back(B);
107                 mMidTensors.emplace_back(C);
108                 mMidTensors.emplace_back(Bias);
109                 continue;
110             }
111         }
112         return NO_ERROR;
113     }
114 
onExecute(const std::vector<Tensor * > & originInputs,const std::vector<Tensor * > & originOutputs)115     virtual ErrorCode onExecute(const std::vector<Tensor *> &originInputs, const std::vector<Tensor *> &originOutputs) override {
116         auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
117         if (1 == mLoop->commands()->size()) {
118             auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
119             auto op = cmd->op();
120             if (OpType_UnaryOp == op->type() && nullptr == op->main()) {
121                 Tensor::InsideDescribe::Region reg;
122                 auto srcView = cmd->view()->GetAs<View>(1);
123                 auto dstView = cmd->view()->GetAs<View>(0);
124                 ::memcpy(reg.size, cmd->size()->data(), 3 * sizeof(int32_t));
125                 ::memcpy(reg.src.stride, srcView->stride()->data(), 3 * sizeof(int32_t));
126                 ::memcpy(reg.dst.stride, dstView->stride()->data(), 3 * sizeof(int32_t));
127                 auto input = mStack[cmd->indexes()->data()[1]];
128                 auto inputSize = input->elementSize();
129                 auto output = mStack[cmd->indexes()->data()[0]];
130                 auto bytes = input->getType().bytes();
131                 auto step0 = cmd->steps()->data()[0];
132                 auto step1 = cmd->steps()->data()[1];
133                 auto loopNumber = mLoop->loopNumber();
134                 auto index0 = cmd->iterIndexes()->data()[0];
135                 const int32_t* dstIndice = nullptr;
136                 if (index0 >= 0) {
137                     dstIndice = (int32_t*)originInputs[index0]->deviceId();
138                 }
139                 auto index1 = cmd->iterIndexes()->data()[1];
140                 const int32_t* srcIndice = nullptr;
141                 if (index1 >= 0) {
142                     srcIndice = (int32_t*)originInputs[index1]->deviceId();
143                 }
144 
145                 BlitWithIndice(
146                     (uint8_t*)(output->deviceId()) + dstView->offset() * bytes,
147                     (uint8_t*)(input->deviceId()) + srcView->offset() * bytes,
148                     dstIndice, srcIndice, index0, index1,
149                     loopNumber, step0, step1, input->elementSize(),
150                     reg, bytes, runtime);
151 
152                 return NO_ERROR;
153             }
154         }
155         // Copy Index
156         for (auto& iter : mIndiceCopy) {
157             backend()->onCopyBuffer(iter.first, iter.second);
158         }
159         auto bytes = sizeof(float);//TODO: Support Half
160         for (int iter=0; iter < mLoop->loopNumber(); ++iter) {
161             for (int index=0; index<mLoop->commands()->size(); ++index) {
162                 auto cmd = mLoop->commands()->GetAs<RegionCommand>(index);
163                 auto op = cmd->op();
164                 int size = cmd->iterIndexes()->size();
165                 for (int v=0; v<size; ++v) {
166                     auto tensorIndex = cmd->indexes()->data()[v];
167                     auto tensor = mStack[tensorIndex];
168                     auto iterIndex = cmd->iterIndexes()->data()[v];
169                     auto offset = iter;
170                     if (iterIndex >= 0) {
171                         offset = mStack[iterIndex]->host<int32_t>()[iter];
172                     }
173                     auto view = cmd->view()->GetAs<View>(v);
174                     offset = offset * cmd->steps()->data()[v] + view->offset();
175                     mStackPtr[tensorIndex] = tensor->deviceId() + offset * bytes;
176                 }
177                 if (OpType_UnaryOp == op->type()) {
178                     auto src = (float*)mStackPtr[cmd->indexes()->data()[1]];
179                     auto dst = (float*)mStackPtr[cmd->indexes()->data()[0]];
180                     int unaryType = op->main_as_UnaryOp()->opType();
181                     auto srcStride = cmd->view()->GetAs<View>(1)->stride()->data();
182                     auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data();
183                     UnaryBlit((uint8_t*)dst, (const uint8_t*)src, cmd->size()->data(), srcStride, dstStride, bytes, runtime, unaryType);
184                     continue;
185                 }
186                 if (OpType_MatMul == op->type()) {
187                     auto& unit = mExecutions[index];
188                     if (3 == size) {
189                         unit.inputs[0]->buffer().device = mStackPtr[cmd->indexes()->data()[1]];
190                         unit.inputs[1]->buffer().device = mStackPtr[cmd->indexes()->data()[2]];
191                         unit.outputs[0]->buffer().device = mStackPtr[cmd->indexes()->data()[0]];
192                     } else {
193                         MNN_ASSERT(4 == size);
194                         unit.inputs[0]->buffer().device = mStackPtr[cmd->indexes()->data()[1]];
195                         unit.inputs[1]->buffer().device = mStackPtr[cmd->indexes()->data()[2]];
196                         unit.inputs[2]->buffer().device = mStackPtr[cmd->indexes()->data()[3]];
197                         unit.outputs[0]->buffer().device = mStackPtr[cmd->indexes()->data()[0]];
198                     }
199                     unit.exe->onExecute(unit.inputs, unit.outputs);
200                     continue;
201                 }
202                 if (OpType_BinaryOp == op->type()) {
203                     auto src0 = mStackPtr[cmd->indexes()->data()[1]];
204                     auto src1 = mStackPtr[cmd->indexes()->data()[2]];
205                     auto dst = mStackPtr[cmd->indexes()->data()[0]];
206                     auto opType = op->main_as_BinaryOp()->opType();
207                     auto srcStride0 = cmd->view()->GetAs<View>(1)->stride()->data();
208                     auto srcStride1 = cmd->view()->GetAs<View>(1)->stride()->data();
209                     auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data();
210 
211                     BinaryBlit((uint8_t*)dst, (const uint8_t*)src0, (const uint8_t*)src1,
212                         cmd->size()->data(), srcStride0, srcStride1, dstStride, halide_type_of<float>(), runtime, opType);
213 
214                 }
215             }
216         }
217         return NO_ERROR;
218     }
219 private:
220     const LoopParam* mLoop;
221     std::vector<Tensor*> mStack;
222     std::vector<std::shared_ptr<Tensor>> mMidTensors;
223     std::vector<Unit> mExecutions;
224     std::vector<uint64_t> mStackPtr;
225     std::map<Tensor*, Tensor*> mIndiceCopy;
226 };
227 
228 class LoopCreator : public CUDABackend::Creator {
229 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const230     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
231                                 const MNN::Op* op, Backend* backend) const override {
232         if (op->main_type() != OpParameter_LoopParam) {
233             return nullptr;
234         }
235         return new CUDALoop(backend, op->main_as_LoopParam());
236     }
237 };
238 
239 static CUDACreatorRegister<LoopCreator> __init(OpType_While);
240 
241 };
242 };
243