1//
2//  MetalConvolutionCommon.mm
3//  MNN
4//
5//  Created by MNN on 2019/02/25.
6//  Copyright © 2018, Alibaba Group Holding Limited
7//
8
9#import "backend/metal/MetalConvolutionCommon.hpp"
10#import "core/Macro.h"
11#import "backend/metal/MetalBackend.hpp"
12#import "backend/metal/MetalConvolution1x1.hpp"
13#import "backend/metal/MetalConvolutionWinograd.hpp"
14#import "core/TensorUtils.hpp"
15
16#if MNN_METAL_ENABLED
17namespace MNN {
18
19static id<MTLBuffer> biasForConv(MNNMetalContext *context, const Convolution2D *conv) {
20    auto bias   = conv->bias();
21    auto oc     = conv->common()->outputCount();
22    auto buffer = [context newDeviceBuffer:UP_DIV(oc, 4) * 4 * sizeof(metal_float) access:CPUWriteOnly];
23    auto src    = bias->data();
24    auto dst    = (metal_float *)buffer.contents;
25#pragma clang loop vectorize(enable) unroll(enable)
26    for (int i = 0; i < oc; i++) {
27        dst[i] = src[i];
28    }
29    return buffer;
30}
31
32MetalConvolutionCommon::MetalConvolutionCommon(Backend *backend, const MNN::Op *op) : Execution(backend), mConstBuffer(static_cast<MetalBackend*>(backend)->runtime()) {
33    auto context    = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
34    auto conv       = op->main_as_Convolution2D();
35    auto common     = conv->common();
36    mDepthwise      = op->type() == OpType_ConvolutionDepthwise;
37    mGroups         = common->group();
38    mKernelX        = common->kernelX();
39    mKernelY        = common->kernelY();
40    mPadMode        = common->padMode();
41    mPadX           = common->padX();
42    mPadY           = common->padY();
43    mStrideX        = common->strideX();
44    mStrideY        = common->strideY();
45    mDilateX        = common->dilateX();
46    mDilateY        = common->dilateY();
47    mBias           = biasForConv(context, conv);
48    mActivationType = common->relu() ? 1 : (common->relu6() ? 2 : 0);
49}
50
51ErrorCode MetalConvolutionCommon::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
52    return NO_ERROR;
53}
54
55ErrorCode MetalConvolutionCommon::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
56    return onFloat(inputs[0], outputs[0]);
57}
58
59template <typename FType, typename TType>
60static id<MTLBuffer> weightInBlock(MNNMetalContext *context, int group, int oc, int ic, int kh, int kw,
61                                   const FType *src) {
62    auto goc    = oc / group;
63    auto gic    = ic / group;
64    auto goc_4  = UP_DIV(goc, 4);
65    auto gic_4  = UP_DIV(gic, 4);
66    auto buffer = [context newDeviceBuffer:group * goc_4 * gic_4 * kw * kh * 16 * sizeof(TType) access:CPUWriteOnly];
67    auto dst    = (TType *)buffer.contents;
68
69    for (int g = 0; g < group; g++) {
70        auto g_dst = dst + g * goc_4 * gic_4 * kh * kw * 16; // g
71        for (int o = 0; o < goc; o++) {
72            auto zo = o / 4, ro = o % 4;
73            auto o_dst = g_dst + zo * gic_4 * kh * kw * 16 + ro * 4; // o/4 x 4
74            for (int i = 0; i < gic; i++) {
75                auto zi = i / 4, ri = i % 4;
76                auto i_dst = o_dst + zi * kh * kw * 16 + ri; // i/4 x 4
77                for (int h = 0; h < kh; h++) {
78                    for (int w = 0; w < kw; w++) {
79                        // to   [g][o/4][i/4][h][w][16]
80                        // from [g][o][i][h][w]
81                        i_dst[(h * kw + w) * 16] = *src++;
82                    }
83                }
84            }
85        }
86    }
87    return buffer;
88}
89
90void MetalConvolutionCommon::loadWeight(const MNN::Convolution2D *conv) {
91    std::shared_ptr<ConvolutionCommon::Int8Common> qnt = NULL;
92    if (conv->quanParameter()) {
93        qnt          = ConvolutionCommon::load(conv->quanParameter(), true);
94    }
95    mWeight = weightForConv(conv, qnt.get(), mDepthwise);
96}
97
98
99id<MTLBuffer> MetalConvolutionCommon::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) {
100    auto backend = static_cast<MetalBackend *>(this->backend());
101    auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
102    return weightInBlock<float, metal_float>(context, group, oc, ic, kh, kw, src);
103}
104
105id<MTLBuffer> MetalConvolutionCommon::weightForConv(const Convolution2D *conv, ConvolutionCommon::Int8Common *qnt,
106                                                    bool depthwise) {
107    // param
108    auto size   = qnt ? MAX(qnt->weight.size(), qnt->weightFloat.size()) : conv->weight()->size();
109    auto common = conv->common();
110    auto kw     = common->kernelX();
111    auto kh     = common->kernelY();
112    auto group  = common->group();
113    auto oc     = common->outputCount();
114    auto ic     = size / kw / kh / (oc / group);
115
116    // convert
117    if (qnt && qnt->weightFloat.size() > 0) {
118        return weightForFloat(group, oc, ic, kh, kw, qnt->weightFloat.get());
119    } else {
120        return weightForFloat(group, oc, ic, kh, kw, conv->weight()->data());
121    }
122}
123} // namespace MNN
124
125#endif
126