1 //
2 // GeometryConv2DBackPropFilter.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/05/07.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "ConvertUtils.hpp"
10 #include "GeometryConvUtils.hpp"
11 #define MNN_OPEN_TIME_TRACE
12 #include <MNN/AutoTime.hpp>
13 namespace MNN {
14 class GeometryConv2DBackPropFilter : public GeometryComputer {
15 public:
computeForDepthWise(const Convolution2DCommon * common,Tensor * input,Tensor * outputDiff,Tensor * kernelDiff,Context & context,CommandBuffer & res) const16 bool computeForDepthWise(const Convolution2DCommon* common, Tensor* input, Tensor* outputDiff, Tensor* kernelDiff,
17 Context& context, CommandBuffer& res) const {
18 auto kw = common->kernelX();
19 auto kh = common->kernelY();
20 auto sw = common->strideX();
21 auto sh = common->strideY();
22 auto dw = common->dilateX();
23 auto dh = common->dilateY();
24 auto batch = outputDiff->batch();
25 auto ow = outputDiff->width();
26 auto oh = outputDiff->height();
27 auto ic = input->channel();
28 auto iw = input->width();
29 auto ih = input->height();
30 auto pads = ConvolutionCommon::convolutionPad(input, outputDiff, common);
31 if (TensorUtils::getDescribe(input)->dimensionFormat != MNN_DATA_FORMAT_NCHW) {
32 std::shared_ptr<Tensor> newT(new Tensor(input, Tensor::CAFFE, false));
33 ConvertUtils::compute(input, newT.get(), res);
34 input = newT.get();
35 res.extras.emplace_back(newT);
36 }
37 if (TensorUtils::getDescribe(outputDiff)->dimensionFormat != MNN_DATA_FORMAT_NCHW) {
38 std::shared_ptr<Tensor> newT(new Tensor(outputDiff, Tensor::CAFFE, false));
39 ConvertUtils::compute(outputDiff, newT.get(), res);
40 outputDiff = newT.get();
41 res.extras.emplace_back(newT);
42 }
43 auto outputDes = TensorUtils::getDescribe(kernelDiff);
44 outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
45 outputDes->regions.clear();
46 for (int ky = 0; ky < kh; ++ky) {
47 auto startSy = ky * dh - pads.second;
48 int startDy = 0;
49 if (startSy < 0) {
50 startDy = ((-startSy) + sh - 1) / sh;
51 startSy = startSy + startDy * sh;
52 }
53 auto endDy = oh - 1;
54 auto endSy = endDy * sh + ky * dh - pads.second;
55 if (endSy >= ih) {
56 endDy = endDy - (endSy - ih + sh) / sh;
57 endSy = endDy * sh + ky * dh - pads.second;
58 }
59 if (startDy > endDy) {
60 continue;
61 }
62 MNN_ASSERT(endDy >= 0);
63 MNN_ASSERT(startDy < ih);
64 auto dstOffsetKy = startDy * ow;
65 auto srcOffsetKy = startSy * iw;
66 for (int kx = 0; kx < kw; ++kx) {
67 auto startSx = kx * dw - pads.first;
68 int startDx = 0;
69 if (startSx < 0) {
70 startDx = ((-startSx) + sw - 1) / sw;
71 startSx = startSx + startDx * sw;
72 }
73 auto endDx = ow - 1;
74 auto endSx = endDx * sw + kx * dw - pads.first;
75 if (endSx >= iw) {
76 endDx = endDx - (endSx - iw + sw) / sw;
77 endSx = endDx * sw + kx * dw - pads.first;
78 }
79 if (startDy > endDy) {
80 continue;
81 }
82 auto dstOffsetKx = dstOffsetKy + startDx;
83 auto srcOffsetKx = srcOffsetKy + startSx;
84 // Sampler
85 std::shared_ptr<Tensor> inputTensor(new Tensor(outputDiff, Tensor::CAFFE, false));
86 auto des = TensorUtils::getDescribe(inputTensor.get());
87 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
88 des->regions.resize(1);
89 {
90 Tensor::InsideDescribe::Region& region = des->regions[0];
91 region.origin = input;
92 region.size[0] = batch * ic;
93 region.size[1] = endDy - startDy + 1;
94 region.size[2] = endDx - startDx + 1;
95 region.src.offset = srcOffsetKx;
96 region.dst.offset = dstOffsetKx;
97 region.src.stride[0] = iw * ih;
98 region.dst.stride[0] = ow * oh;
99 region.src.stride[1] = sh * iw;
100 region.dst.stride[1] = ow;
101 region.src.stride[2] = sw;
102 region.dst.stride[2] = 1;
103 res.extras.emplace_back(inputTensor);
104 }
105
106 auto currentTensor = inputTensor.get();
107 // Multi
108 {
109 std::shared_ptr<Tensor> newTensor(new Tensor(outputDiff, Tensor::CAFFE, false));
110 auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_MUL, currentTensor, outputDiff,
111 newTensor.get());
112 res.command.emplace_back(std::move(cmd));
113 res.extras.emplace_back(newTensor);
114 currentTensor = newTensor.get();
115 }
116 // Reduce - 0
117 {
118 std::shared_ptr<Tensor> reduceInputTensor(
119 Tensor::createDevice<float>({batch * ic, ow * oh, 1}, Tensor::CAFFE));
120 {
121 auto inputDes = TensorUtils::getDescribe(reduceInputTensor.get());
122 inputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
123 inputDes->regions = {TensorUtils::makeFullSlice(currentTensor)};
124 }
125 std::shared_ptr<Tensor> reduceOutputTensor(
126 Tensor::createDevice<float>({batch * ic, 1, 1}, Tensor::CAFFE));
127 auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, reduceInputTensor.get(),
128 reduceOutputTensor.get());
129 currentTensor = reduceOutputTensor.get();
130 res.command.emplace_back(std::move(cmd));
131 res.extras.emplace_back(reduceInputTensor);
132 res.extras.emplace_back(reduceOutputTensor);
133 }
134 // Reduce - 1
135 {
136 std::shared_ptr<Tensor> reduceInputTensor(
137 Tensor::createDevice<float>({1, batch, ic}, Tensor::CAFFE));
138 {
139 auto inputDes = TensorUtils::getDescribe(reduceInputTensor.get());
140 inputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
141 inputDes->regions = {TensorUtils::makeFullSlice(currentTensor)};
142 }
143 std::shared_ptr<Tensor> reduceOutputTensor(Tensor::createDevice<float>({1, 1, ic}, Tensor::CAFFE));
144 currentTensor = reduceOutputTensor.get();
145 auto cmd = GeometryComputerUtils::makeReduce(ReductionType_SUM, reduceInputTensor.get(),
146 reduceOutputTensor.get());
147 res.command.emplace_back(std::move(cmd));
148 res.extras.emplace_back(reduceInputTensor);
149 res.extras.emplace_back(reduceOutputTensor);
150 }
151 // Set to output
152 Tensor::InsideDescribe::Region region;
153 region.origin = currentTensor;
154 region.size[0] = 1;
155 region.size[1] = 1;
156 region.size[2] = ic;
157 region.dst.offset = ky * kw + kx;
158 region.dst.stride[0] = 0;
159 region.dst.stride[1] = 0;
160 region.dst.stride[2] = kh * kw;
161 region.src.offset = 0;
162 region.src.stride[0] = 0;
163 region.src.stride[1] = 0;
164 region.src.stride[2] = 1;
165 outputDes->regions.emplace_back(std::move(region));
166 }
167 }
168 return true;
169 }
onCompute(const Op * op,const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,Context & context,CommandBuffer & res) const170 virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
171 Context& context, CommandBuffer& res) const override {
172 auto common = op->main_as_Convolution2D()->common();
173 auto input = inputs[0];
174 auto outputDiff = inputs[1];
175 bool depthWise = false;
176 if (inputs[0]->channel() == inputs[1]->channel() && inputs[1]->channel() == common->group()) {
177 depthWise = true;
178 return computeForDepthWise(common, input, outputDiff, outputs[0], context, res);
179 }
180 auto kw = common->kernelX();
181 auto kh = common->kernelY();
182 auto sw = common->strideX();
183 auto sh = common->strideY();
184 auto dw = common->dilateX();
185 auto dh = common->dilateY();
186 auto batch = outputDiff->batch();
187 auto ow = outputDiff->width();
188 auto oh = outputDiff->height();
189 auto oc = outputDiff->channel();
190 auto ic = input->channel();
191 auto iw = input->width();
192 auto ih = input->height();
193 auto pads = ConvolutionCommon::convolutionPad(input, outputDiff, common);
194 MNN_ASSERT(TensorUtils::getDescribe(input)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
195 MNN_ASSERT(TensorUtils::getDescribe(outputDiff)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
196 Tensor* A = nullptr;
197 Tensor* B = nullptr;
198 {
199 // B: Input Im2Col, n, ic, ih, iw -> ic*kh*kw, n*oh*ow
200 std::shared_ptr<Tensor> im2Col(new Tensor);
201 GeometryConvUtils::im2Col(im2Col.get(), input, ic, kh, kw, batch, oh, ow, ih, iw, sh, sw, dh, dw, pads);
202 B = im2Col.get();
203 res.extras.emplace_back(im2Col);
204 }
205 {
206 // A: Output n, oc, oh, ow -> oc, n*oh*ow
207 std::shared_ptr<Tensor> outputTranspose(new Tensor);
208 A = outputTranspose.get();
209 outputTranspose->buffer().type = halide_type_of<float>();
210 outputTranspose->buffer().dimensions = 2;
211 outputTranspose->setLength(0, oc);
212 outputTranspose->setLength(1, batch * ow * oh);
213 auto des = TensorUtils::getDescribe(outputTranspose.get());
214 des->regions.resize(1);
215 des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
216 auto& reg = des->regions[0];
217 reg.origin = outputDiff;
218 reg.size[0] = oc;
219 reg.size[1] = batch;
220 reg.size[2] = ow * oh;
221 reg.src.offset = 0;
222 reg.src.stride[0] = oh * ow;
223 reg.src.stride[1] = oh * ow * oc;
224 reg.src.stride[2] = 1;
225 reg.dst.offset = 0;
226 reg.dst.stride[0] = oh * ow * batch;
227 reg.dst.stride[1] = oh * ow;
228 reg.dst.stride[2] = 1;
229 res.extras.emplace_back(std::move(outputTranspose));
230 }
231 {
232 // C = MatMul(B, A)
233 std::shared_ptr<Tensor> C(new Tensor);
234 C->buffer().type = halide_type_of<float>();
235 C->buffer().dimensions = 2;
236 C->setLength(0, ic * kw * kh);
237 C->setLength(1, oc);
238 auto cmd = GeometryComputerUtils::makeMatMul(B, A, C.get(), nullptr, false, true);
239 auto kernelDiffDes = TensorUtils::getDescribe(outputs[0]);
240 kernelDiffDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
241
242 // Transpose
243 auto len0 = kw * kh * ic;
244 auto len1 = oc;
245 kernelDiffDes->regions.resize(1);
246 auto& desReg = kernelDiffDes->regions[0];
247 desReg.size[0] = 1;
248 desReg.size[1] = len1;
249 desReg.size[2] = len0;
250 desReg.dst.offset = 0;
251 desReg.dst.stride[0] = 0;
252 desReg.dst.stride[1] = len0;
253 desReg.dst.stride[2] = 1;
254 desReg.src.offset = 0;
255 desReg.src.stride[0] = 0;
256 desReg.src.stride[1] = 1;
257 desReg.src.stride[2] = len1;
258 desReg.origin = C.get();
259 res.extras.emplace_back(std::move(C));
260 res.command.emplace_back(std::move(cmd));
261 }
262 return true;
263 }
264 };
265
_create()266 static void _create() {
267 std::shared_ptr<GeometryComputer> comp(new GeometryConv2DBackPropFilter);
268 GeometryComputer::registerGeometryComputer(comp, {OpType_Conv2DBackPropFilter});
269 }
270
271 REGISTER_GEOMETRY(GeometryConv2DBackPropFilter, _create);
272
273 } // namespace MNN
274