1// 2// MetalConvolutionCommon.mm 3// MNN 4// 5// Created by MNN on 2019/02/25. 6// Copyright © 2018, Alibaba Group Holding Limited 7// 8 9#import "backend/metal/MetalConvolutionCommon.hpp" 10#import "core/Macro.h" 11#import "backend/metal/MetalBackend.hpp" 12#import "backend/metal/MetalConvolution1x1.hpp" 13#import "backend/metal/MetalConvolutionWinograd.hpp" 14#import "core/TensorUtils.hpp" 15 16#if MNN_METAL_ENABLED 17namespace MNN { 18 19static id<MTLBuffer> biasForConv(MNNMetalContext *context, const Convolution2D *conv) { 20 auto bias = conv->bias(); 21 auto oc = conv->common()->outputCount(); 22 auto buffer = [context newDeviceBuffer:UP_DIV(oc, 4) * 4 * sizeof(metal_float) access:CPUWriteOnly]; 23 auto src = bias->data(); 24 auto dst = (metal_float *)buffer.contents; 25#pragma clang loop vectorize(enable) unroll(enable) 26 for (int i = 0; i < oc; i++) { 27 dst[i] = src[i]; 28 } 29 return buffer; 30} 31 32MetalConvolutionCommon::MetalConvolutionCommon(Backend *backend, const MNN::Op *op) : Execution(backend), mConstBuffer(static_cast<MetalBackend*>(backend)->runtime()) { 33 auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); 34 auto conv = op->main_as_Convolution2D(); 35 auto common = conv->common(); 36 mDepthwise = op->type() == OpType_ConvolutionDepthwise; 37 mGroups = common->group(); 38 mKernelX = common->kernelX(); 39 mKernelY = common->kernelY(); 40 mPadMode = common->padMode(); 41 mPadX = common->padX(); 42 mPadY = common->padY(); 43 mStrideX = common->strideX(); 44 mStrideY = common->strideY(); 45 mDilateX = common->dilateX(); 46 mDilateY = common->dilateY(); 47 mBias = biasForConv(context, conv); 48 mActivationType = common->relu() ? 1 : (common->relu6() ? 2 : 0); 49} 50 51ErrorCode MetalConvolutionCommon::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { 52 return NO_ERROR; 53} 54 55ErrorCode MetalConvolutionCommon::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { 56 return onFloat(inputs[0], outputs[0]); 57} 58 59template <typename FType, typename TType> 60static id<MTLBuffer> weightInBlock(MNNMetalContext *context, int group, int oc, int ic, int kh, int kw, 61 const FType *src) { 62 auto goc = oc / group; 63 auto gic = ic / group; 64 auto goc_4 = UP_DIV(goc, 4); 65 auto gic_4 = UP_DIV(gic, 4); 66 auto buffer = [context newDeviceBuffer:group * goc_4 * gic_4 * kw * kh * 16 * sizeof(TType) access:CPUWriteOnly]; 67 auto dst = (TType *)buffer.contents; 68 69 for (int g = 0; g < group; g++) { 70 auto g_dst = dst + g * goc_4 * gic_4 * kh * kw * 16; // g 71 for (int o = 0; o < goc; o++) { 72 auto zo = o / 4, ro = o % 4; 73 auto o_dst = g_dst + zo * gic_4 * kh * kw * 16 + ro * 4; // o/4 x 4 74 for (int i = 0; i < gic; i++) { 75 auto zi = i / 4, ri = i % 4; 76 auto i_dst = o_dst + zi * kh * kw * 16 + ri; // i/4 x 4 77 for (int h = 0; h < kh; h++) { 78 for (int w = 0; w < kw; w++) { 79 // to [g][o/4][i/4][h][w][16] 80 // from [g][o][i][h][w] 81 i_dst[(h * kw + w) * 16] = *src++; 82 } 83 } 84 } 85 } 86 } 87 return buffer; 88} 89 90void MetalConvolutionCommon::loadWeight(const MNN::Convolution2D *conv) { 91 std::shared_ptr<ConvolutionCommon::Int8Common> qnt = NULL; 92 if (conv->quanParameter()) { 93 qnt = ConvolutionCommon::load(conv->quanParameter(), true); 94 } 95 mWeight = weightForConv(conv, qnt.get(), mDepthwise); 96} 97 98 99id<MTLBuffer> MetalConvolutionCommon::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) { 100 auto backend = static_cast<MetalBackend *>(this->backend()); 101 auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); 102 return weightInBlock<float, metal_float>(context, group, oc, ic, kh, kw, src); 103} 104 105id<MTLBuffer> MetalConvolutionCommon::weightForConv(const Convolution2D *conv, ConvolutionCommon::Int8Common *qnt, 106 bool depthwise) { 107 // param 108 auto size = qnt ? MAX(qnt->weight.size(), qnt->weightFloat.size()) : conv->weight()->size(); 109 auto common = conv->common(); 110 auto kw = common->kernelX(); 111 auto kh = common->kernelY(); 112 auto group = common->group(); 113 auto oc = common->outputCount(); 114 auto ic = size / kw / kh / (oc / group); 115 116 // convert 117 if (qnt && qnt->weightFloat.size() > 0) { 118 return weightForFloat(group, oc, ic, kh, kw, qnt->weightFloat.get()); 119 } else { 120 return weightForFloat(group, oc, ic, kh, kw, conv->weight()->data()); 121 } 122} 123} // namespace MNN 124 125#endif 126