1 //
2 // CPUMatrixBandPart.cpp
3 // MNN
4 //
5 // Created by MNN on 2019/09/17.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "backend/cpu/CPUMatrixBandPart.hpp"
10 #include "backend/cpu/compute/ConvOpt.h"
11 #include "core/TensorUtils.hpp"
12 #include "core/Macro.h"
13 namespace MNN {
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)14 ErrorCode CPUMatrixBandPart::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
15 MNN_ASSERT(3 == inputs.size());
16 auto dimensions = inputs[0]->dimensions();
17 auto height = inputs[0]->length(dimensions - 2);
18 auto width = inputs[0]->length(dimensions - 1);
19 mMask.reset(Tensor::createDevice<float>({1, height*width}, Tensor::CAFFE_C4));
20 auto res = backend()->onAcquireBuffer(mMask.get(), Backend::DYNAMIC);
21 if (!res) {
22 return OUT_OF_MEMORY;
23 }
24 backend()->onReleaseBuffer(mMask.get(), Backend::DYNAMIC);
25 return NO_ERROR;
26 }
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)27 ErrorCode CPUMatrixBandPart::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
28 // Generate Mask
29 auto lower = inputs[1]->host<int32_t>()[0];
30 auto upper = inputs[2]->host<int32_t>()[0];
31 auto maskPtr = mMask->host<float>();
32 auto dimensions = inputs[0]->dimensions();
33 auto height = inputs[0]->length(dimensions - 2);
34 auto width = inputs[0]->length(dimensions - 1);
35
36 for (int y = 0; y < height; ++y) {
37 auto maskY = maskPtr + y * width;
38 for (int x = 0; x < width; ++x) {
39 bool valid = (lower < 0 || (y - x) <= lower) && (upper < 0 || (x - y) <= upper);
40 maskY[x] = valid ? 1.0f : 0.0f;
41 }
42 }
43
44 // Run Mul
45 auto outputPtr = outputs[0]->host<float>();
46 auto inputPtr = inputs[0]->host<float>();
47 int outside = 1;
48 for (int i = 0; i < inputs[0]->dimensions() - 2; ++i) {
49 outside *= inputs[0]->length(i);
50 }
51 auto inside = height * width;
52 for (int i = 0; i < outside; ++i) {
53 MNNMatrixProdCommon(outputPtr + i * inside, inputPtr + i * inside, maskPtr, inside, 0, 0, 0, 1);
54 }
55 return NO_ERROR;
56 }
57
58 class CPUMatrixBandPartCreator : public CPUBackend::Creator {
59 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const60 virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
61 const MNN::Op *op, Backend *backend) const override {
62 return new CPUMatrixBandPart(backend);
63 }
64 };
65
66 REGISTER_CPU_OP_CREATOR(CPUMatrixBandPartCreator, OpType_MatrixBandPart);
67 } // namespace MNN
68