1 //
2 //  CPUMatrixBandPart.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/09/17.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "backend/cpu/CPUMatrixBandPart.hpp"
10 #include "backend/cpu/compute/ConvOpt.h"
11 #include "core/TensorUtils.hpp"
12 #include "core/Macro.h"
13 namespace MNN {
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)14 ErrorCode CPUMatrixBandPart::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
15     MNN_ASSERT(3 == inputs.size());
16     auto dimensions = inputs[0]->dimensions();
17     auto height     = inputs[0]->length(dimensions - 2);
18     auto width      = inputs[0]->length(dimensions - 1);
19     mMask.reset(Tensor::createDevice<float>({1, height*width}, Tensor::CAFFE_C4));
20     auto res                                               = backend()->onAcquireBuffer(mMask.get(), Backend::DYNAMIC);
21     if (!res) {
22         return OUT_OF_MEMORY;
23     }
24     backend()->onReleaseBuffer(mMask.get(), Backend::DYNAMIC);
25     return NO_ERROR;
26 }
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)27 ErrorCode CPUMatrixBandPart::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
28     // Generate Mask
29     auto lower   = inputs[1]->host<int32_t>()[0];
30     auto upper   = inputs[2]->host<int32_t>()[0];
31     auto maskPtr = mMask->host<float>();
32     auto dimensions = inputs[0]->dimensions();
33     auto height     = inputs[0]->length(dimensions - 2);
34     auto width      = inputs[0]->length(dimensions - 1);
35 
36     for (int y = 0; y < height; ++y) {
37         auto maskY = maskPtr + y * width;
38         for (int x = 0; x < width; ++x) {
39             bool valid = (lower < 0 || (y - x) <= lower) && (upper < 0 || (x - y) <= upper);
40             maskY[x]   = valid ? 1.0f : 0.0f;
41         }
42     }
43 
44     // Run Mul
45     auto outputPtr = outputs[0]->host<float>();
46     auto inputPtr  = inputs[0]->host<float>();
47     int outside    = 1;
48     for (int i = 0; i < inputs[0]->dimensions() - 2; ++i) {
49         outside *= inputs[0]->length(i);
50     }
51     auto inside = height * width;
52     for (int i = 0; i < outside; ++i) {
53         MNNMatrixProdCommon(outputPtr + i * inside, inputPtr + i * inside, maskPtr, inside, 0, 0, 0, 1);
54     }
55     return NO_ERROR;
56 }
57 
58 class CPUMatrixBandPartCreator : public CPUBackend::Creator {
59 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const60     virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
61                                 const MNN::Op *op, Backend *backend) const override {
62         return new CPUMatrixBandPart(backend);
63     }
64 };
65 
66 REGISTER_CPU_OP_CREATOR(CPUMatrixBandPartCreator, OpType_MatrixBandPart);
67 } // namespace MNN
68