1//
2//  MetalConvolutionWinograd.mm
3//  MNN
4//
5//  Created by MNN on 2019/01/31.
6//  Copyright © 2018, Alibaba Group Holding Limited
7//
8
9#import "backend/metal/MetalConvolutionWinograd.hpp"
10#import "core/Macro.h"
11#import "core/Macro.h"
12#import "backend/metal/MetalBackend.hpp"
13#import "backend/metal/MetalConvolution.hpp"
14#import "math/WingoradGenerater.hpp"
15
16#if MNN_METAL_ENABLED
17
18#define UNIT 2
19
20namespace MNN {
21bool MetalConvolutionWinograd::isValid(const Convolution2D *conv, const Tensor *input) {
22    if (conv->quanParameter() != nullptr || conv->common()->group() != 1) {
23        return false;
24    }
25    auto common = conv->common();
26    if (input->batch() != 1
27        || !((common->kernelX() == common->kernelY()) && ((common->kernelX() == 3) || (common->kernelX() == 5)))
28        || common->dilateX() != 1
29        || common->dilateY() != 1
30        || common->strideX() != 1
31        || common->strideY() != 1) {
32        return false;
33    }
34    auto iw = input->width(), ih = input->height();
35    auto ic = ROUND_UP(input->channel(), 4), oc = ROUND_UP(common->outputCount(), 4);
36    return ic * oc * ih / iw >= 2048;
37}
38
39MetalConvolutionWinograd::MetalConvolutionWinograd(Backend *backend, const Tensor *input, const MNN::Op *op)
40    : MetalConvolutionCommon(backend, op) {
41    auto conv = op->main_as_Convolution2D();
42    mSrcUnit  = UNIT + conv->common()->kernelY() - 1;
43    mDstUnit  = UNIT;
44    loadWeight(conv);
45}
46
47ErrorCode MetalConvolutionWinograd::onResize(const std::vector<Tensor *> &inputs,
48                                             const std::vector<Tensor *> &outputs) {
49    auto backend = static_cast<MetalBackend *>(this->backend());
50    auto context = (__bridge MNNMetalContext *)backend->context();
51    auto input   = inputs[0];
52    auto output  = outputs[0];
53
54    auto ow  = output->width();
55    auto oh  = output->height();
56    auto uw  = UP_DIV(ow, mDstUnit);
57    auto uh  = UP_DIV(oh, mDstUnit);
58    auto us  = UP_DIV(uw * uh, 4);
59    auto iz  = UP_DIV(input->channel(), 4);
60    auto oz  = UP_DIV(output->channel(), 4);
61    int padX = mPadX;
62    int padY = mPadY;
63
64    if (mPadMode == PadMode_SAME) {
65        int kernelWidthSize = (mKernelX - 1) * mDilateX + 1;
66        int kernelHeightSize = (mKernelY - 1) * mDilateY + 1;
67        int pw = (output->width() - 1) * mStrideX + kernelWidthSize - input->width();
68        int ph = (output->height() - 1) * mStrideY + kernelHeightSize - input->height();
69        padX   = pw / 2;
70        padY   = ph / 2;
71    }
72
73    // create const buffer
74    struct TransformBuffer {
75        int inputSize[4];
76        int outputSize[4];
77        int padX;
78        int padY;
79        int unitWidth;
80        int unitHeight;
81        int unit;
82        int activation;
83        int remain[2];
84    };
85    TransformBuffer transform;
86    transform.inputSize[0]  = input->width();
87    transform.inputSize[1]  = input->height();
88    transform.inputSize[2]  = iz;
89    transform.inputSize[3]  = input->batch();
90    transform.outputSize[0] = output->width();
91    transform.outputSize[1] = output->height();
92    transform.outputSize[2] = oz;
93    transform.outputSize[3] = output->batch();
94    transform.padX          = padX;
95    transform.padY          = padY;
96    transform.unitWidth     = uw;
97    transform.unitHeight    = uh;
98    transform.unit          = mDstUnit;
99    transform.activation    = mActivationType;
100    mConstBuffer.reset(sizeof(transform));
101    ::memcpy(mConstBuffer.buffer().contents, &transform, sizeof(transform));
102
103    // create matmul buffer
104    int shapes[] = {us, oz, iz, mSrcUnit * mSrcUnit};
105    mShapeBuffer = [context newDeviceBuffer:sizeof(shapes) bytes:shapes access:CPUWriteOnly];
106
107    // save threads size
108    mInputTransformThreads.width   = uw;
109    mInputTransformThreads.height  = uh;
110    mInputTransformThreads.depth   = iz;
111    mMatMulThreads.width           = us;
112    mMatMulThreads.height          = oz;
113    mMatMulThreads.depth           = mSrcUnit * mSrcUnit;
114    mOutputTransformThreads.width  = uw;
115    mOutputTransformThreads.height = uh;
116    mOutputTransformThreads.depth  = oz;
117
118    // accquire space
119    int is = mSrcUnit * mSrcUnit * us * iz * 16 * sizeof(metal_float) / sizeof(uint8_t);
120    int os = mSrcUnit * mSrcUnit * us * oz * 16 * sizeof(metal_float) / sizeof(uint8_t);
121    mTempSrc.reset(Tensor::createDevice<uint8_t>(std::vector<int>{is}));
122    mTempDst.reset(Tensor::createDevice<uint8_t>(std::vector<int>{os}));
123    backend->onAcquireBuffer(mTempSrc.get(), Backend::DYNAMIC);
124    backend->onAcquireBuffer(mTempDst.get(), Backend::DYNAMIC);
125    backend->onReleaseBuffer(mTempSrc.get(), Backend::DYNAMIC);
126    backend->onReleaseBuffer(mTempDst.get(), Backend::DYNAMIC);
127
128    return NO_ERROR;
129}
130
131ErrorCode MetalConvolutionWinograd::onFloat(const Tensor *input, const Tensor *output) {
132    auto backend = static_cast<MetalBackend *>(this->backend());
133    auto context = (__bridge MNNMetalContext *)backend->context();
134
135    if(backend->isCommandEncoderSet()) {
136        return NO_ERROR;
137    }
138
139    auto func = [=](){
140        auto encoder = backend->encoder();
141        { // transform
142            auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_source2_3_1" : @"winograd_transform_source2_5_1" encoder:encoder];
143            [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)input->deviceId() offset:0 atIndex:0];
144            [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempSrc->deviceId() offset:0 atIndex:1];
145            [encoder setBuffer:mConstBuffer.buffer() offset:0 atIndex:2];
146            [context dispatchEncoder:encoder threads:mInputTransformThreads bandwidth:bandwidth];
147        }
148        { // gemm
149            auto bandwidth = [context load:@"matmul4x4" encoder:encoder];
150            [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempSrc->deviceId() offset:0 atIndex:0];
151            [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempDst->deviceId() offset:0 atIndex:1];
152            [encoder setBuffer:mWeight offset:0 atIndex:2];
153            [encoder setBuffer:mShapeBuffer offset:0 atIndex:3];
154            [context dispatchEncoder:encoder threads:mMatMulThreads bandwidth:bandwidth];
155        }
156        { // transform
157            auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_dest2_3_1" : @"winograd_transform_dest2_5_1" encoder:encoder];
158            [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempDst->deviceId() offset:0 atIndex:0];
159            [encoder setBuffer:mBias offset:0 atIndex:1];
160            [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)output->deviceId() offset:0 atIndex:2];
161            [encoder setBuffer:mConstBuffer.buffer() offset:0 atIndex:3];
162            [context dispatchEncoder:encoder threads:mOutputTransformThreads bandwidth:bandwidth];
163        }
164        MNN_PRINT_ENCODER(context, encoder);
165
166        auto context = (__bridge MNNMetalContext *)backend->context();
167        if(context.isCommitEachShader) {
168            backend->flushEncoder();
169            [context commit_net];
170        }
171    };
172    func();
173    backend->addOpEncoder(func);
174
175    return NO_ERROR;
176}
177id<MTLBuffer> MetalConvolutionWinograd::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) {
178    auto backend = static_cast<MetalBackend *>(this->backend());
179    auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
180
181    std::shared_ptr<Tensor> srcWeight(Tensor::create<float>(std::vector<int>{oc, ic, kh, kh}, (void *)src, Tensor::CAFFE));
182    Math::WinogradGenerater generater(mDstUnit, kh, 1.0f);
183    std::shared_ptr<Tensor> dstWeight = generater.allocTransformWeight(srcWeight.get(), 4, 4);
184    generater.transformWeight(dstWeight.get(), srcWeight.get());
185
186#if MNN_METAL_FULL_PRECISION
187    auto bytes = dstWeight->host<metal_float>();
188#else
189    std::shared_ptr<Tensor> dstWeightHalf(Tensor::create<int16_t>(dstWeight->shape()));
190    auto f32 = dstWeight->host<float>();
191    auto f16 = dstWeightHalf->host<metal_float>();
192    for (int i = 0; i < dstWeight->elementSize(); ++i) {
193        f16[i] = f32[i];
194    }
195    auto bytes = dstWeightHalf->host<metal_float>();
196#endif
197    return [context newDeviceBuffer:4 * UP_DIV(ic, 4) * UP_DIV(oc, 4) * mSrcUnit * mSrcUnit * 4 * sizeof(metal_float)
198                              bytes:bytes
199                             access:CPUWriteOnly];
200}
201
202} // namespace MNN
203#endif /* MNN_METAL_ENABLED */
204