1// 2// MetalConvolutionWinograd.mm 3// MNN 4// 5// Created by MNN on 2019/01/31. 6// Copyright © 2018, Alibaba Group Holding Limited 7// 8 9#import "backend/metal/MetalConvolutionWinograd.hpp" 10#import "core/Macro.h" 11#import "core/Macro.h" 12#import "backend/metal/MetalBackend.hpp" 13#import "backend/metal/MetalConvolution.hpp" 14#import "math/WingoradGenerater.hpp" 15 16#if MNN_METAL_ENABLED 17 18#define UNIT 2 19 20namespace MNN { 21bool MetalConvolutionWinograd::isValid(const Convolution2D *conv, const Tensor *input) { 22 if (conv->quanParameter() != nullptr || conv->common()->group() != 1) { 23 return false; 24 } 25 auto common = conv->common(); 26 if (input->batch() != 1 27 || !((common->kernelX() == common->kernelY()) && ((common->kernelX() == 3) || (common->kernelX() == 5))) 28 || common->dilateX() != 1 29 || common->dilateY() != 1 30 || common->strideX() != 1 31 || common->strideY() != 1) { 32 return false; 33 } 34 auto iw = input->width(), ih = input->height(); 35 auto ic = ROUND_UP(input->channel(), 4), oc = ROUND_UP(common->outputCount(), 4); 36 return ic * oc * ih / iw >= 2048; 37} 38 39MetalConvolutionWinograd::MetalConvolutionWinograd(Backend *backend, const Tensor *input, const MNN::Op *op) 40 : MetalConvolutionCommon(backend, op) { 41 auto conv = op->main_as_Convolution2D(); 42 mSrcUnit = UNIT + conv->common()->kernelY() - 1; 43 mDstUnit = UNIT; 44 loadWeight(conv); 45} 46 47ErrorCode MetalConvolutionWinograd::onResize(const std::vector<Tensor *> &inputs, 48 const std::vector<Tensor *> &outputs) { 49 auto backend = static_cast<MetalBackend *>(this->backend()); 50 auto context = (__bridge MNNMetalContext *)backend->context(); 51 auto input = inputs[0]; 52 auto output = outputs[0]; 53 54 auto ow = output->width(); 55 auto oh = output->height(); 56 auto uw = UP_DIV(ow, mDstUnit); 57 auto uh = UP_DIV(oh, mDstUnit); 58 auto us = UP_DIV(uw * uh, 4); 59 auto iz = UP_DIV(input->channel(), 4); 60 auto oz = UP_DIV(output->channel(), 4); 61 int padX = mPadX; 62 int padY = mPadY; 63 64 if (mPadMode == PadMode_SAME) { 65 int kernelWidthSize = (mKernelX - 1) * mDilateX + 1; 66 int kernelHeightSize = (mKernelY - 1) * mDilateY + 1; 67 int pw = (output->width() - 1) * mStrideX + kernelWidthSize - input->width(); 68 int ph = (output->height() - 1) * mStrideY + kernelHeightSize - input->height(); 69 padX = pw / 2; 70 padY = ph / 2; 71 } 72 73 // create const buffer 74 struct TransformBuffer { 75 int inputSize[4]; 76 int outputSize[4]; 77 int padX; 78 int padY; 79 int unitWidth; 80 int unitHeight; 81 int unit; 82 int activation; 83 int remain[2]; 84 }; 85 TransformBuffer transform; 86 transform.inputSize[0] = input->width(); 87 transform.inputSize[1] = input->height(); 88 transform.inputSize[2] = iz; 89 transform.inputSize[3] = input->batch(); 90 transform.outputSize[0] = output->width(); 91 transform.outputSize[1] = output->height(); 92 transform.outputSize[2] = oz; 93 transform.outputSize[3] = output->batch(); 94 transform.padX = padX; 95 transform.padY = padY; 96 transform.unitWidth = uw; 97 transform.unitHeight = uh; 98 transform.unit = mDstUnit; 99 transform.activation = mActivationType; 100 mConstBuffer.reset(sizeof(transform)); 101 ::memcpy(mConstBuffer.buffer().contents, &transform, sizeof(transform)); 102 103 // create matmul buffer 104 int shapes[] = {us, oz, iz, mSrcUnit * mSrcUnit}; 105 mShapeBuffer = [context newDeviceBuffer:sizeof(shapes) bytes:shapes access:CPUWriteOnly]; 106 107 // save threads size 108 mInputTransformThreads.width = uw; 109 mInputTransformThreads.height = uh; 110 mInputTransformThreads.depth = iz; 111 mMatMulThreads.width = us; 112 mMatMulThreads.height = oz; 113 mMatMulThreads.depth = mSrcUnit * mSrcUnit; 114 mOutputTransformThreads.width = uw; 115 mOutputTransformThreads.height = uh; 116 mOutputTransformThreads.depth = oz; 117 118 // accquire space 119 int is = mSrcUnit * mSrcUnit * us * iz * 16 * sizeof(metal_float) / sizeof(uint8_t); 120 int os = mSrcUnit * mSrcUnit * us * oz * 16 * sizeof(metal_float) / sizeof(uint8_t); 121 mTempSrc.reset(Tensor::createDevice<uint8_t>(std::vector<int>{is})); 122 mTempDst.reset(Tensor::createDevice<uint8_t>(std::vector<int>{os})); 123 backend->onAcquireBuffer(mTempSrc.get(), Backend::DYNAMIC); 124 backend->onAcquireBuffer(mTempDst.get(), Backend::DYNAMIC); 125 backend->onReleaseBuffer(mTempSrc.get(), Backend::DYNAMIC); 126 backend->onReleaseBuffer(mTempDst.get(), Backend::DYNAMIC); 127 128 return NO_ERROR; 129} 130 131ErrorCode MetalConvolutionWinograd::onFloat(const Tensor *input, const Tensor *output) { 132 auto backend = static_cast<MetalBackend *>(this->backend()); 133 auto context = (__bridge MNNMetalContext *)backend->context(); 134 135 if(backend->isCommandEncoderSet()) { 136 return NO_ERROR; 137 } 138 139 auto func = [=](){ 140 auto encoder = backend->encoder(); 141 { // transform 142 auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_source2_3_1" : @"winograd_transform_source2_5_1" encoder:encoder]; 143 [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)input->deviceId() offset:0 atIndex:0]; 144 [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempSrc->deviceId() offset:0 atIndex:1]; 145 [encoder setBuffer:mConstBuffer.buffer() offset:0 atIndex:2]; 146 [context dispatchEncoder:encoder threads:mInputTransformThreads bandwidth:bandwidth]; 147 } 148 { // gemm 149 auto bandwidth = [context load:@"matmul4x4" encoder:encoder]; 150 [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempSrc->deviceId() offset:0 atIndex:0]; 151 [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempDst->deviceId() offset:0 atIndex:1]; 152 [encoder setBuffer:mWeight offset:0 atIndex:2]; 153 [encoder setBuffer:mShapeBuffer offset:0 atIndex:3]; 154 [context dispatchEncoder:encoder threads:mMatMulThreads bandwidth:bandwidth]; 155 } 156 { // transform 157 auto bandwidth = [context load:mKernelX == 3 ? @"winograd_transform_dest2_3_1" : @"winograd_transform_dest2_5_1" encoder:encoder]; 158 [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempDst->deviceId() offset:0 atIndex:0]; 159 [encoder setBuffer:mBias offset:0 atIndex:1]; 160 [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)output->deviceId() offset:0 atIndex:2]; 161 [encoder setBuffer:mConstBuffer.buffer() offset:0 atIndex:3]; 162 [context dispatchEncoder:encoder threads:mOutputTransformThreads bandwidth:bandwidth]; 163 } 164 MNN_PRINT_ENCODER(context, encoder); 165 166 auto context = (__bridge MNNMetalContext *)backend->context(); 167 if(context.isCommitEachShader) { 168 backend->flushEncoder(); 169 [context commit_net]; 170 } 171 }; 172 func(); 173 backend->addOpEncoder(func); 174 175 return NO_ERROR; 176} 177id<MTLBuffer> MetalConvolutionWinograd::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) { 178 auto backend = static_cast<MetalBackend *>(this->backend()); 179 auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); 180 181 std::shared_ptr<Tensor> srcWeight(Tensor::create<float>(std::vector<int>{oc, ic, kh, kh}, (void *)src, Tensor::CAFFE)); 182 Math::WinogradGenerater generater(mDstUnit, kh, 1.0f); 183 std::shared_ptr<Tensor> dstWeight = generater.allocTransformWeight(srcWeight.get(), 4, 4); 184 generater.transformWeight(dstWeight.get(), srcWeight.get()); 185 186#if MNN_METAL_FULL_PRECISION 187 auto bytes = dstWeight->host<metal_float>(); 188#else 189 std::shared_ptr<Tensor> dstWeightHalf(Tensor::create<int16_t>(dstWeight->shape())); 190 auto f32 = dstWeight->host<float>(); 191 auto f16 = dstWeightHalf->host<metal_float>(); 192 for (int i = 0; i < dstWeight->elementSize(); ++i) { 193 f16[i] = f32[i]; 194 } 195 auto bytes = dstWeightHalf->host<metal_float>(); 196#endif 197 return [context newDeviceBuffer:4 * UP_DIV(ic, 4) * UP_DIV(oc, 4) * mSrcUnit * mSrcUnit * 4 * sizeof(metal_float) 198 bytes:bytes 199 access:CPUWriteOnly]; 200} 201 202} // namespace MNN 203#endif /* MNN_METAL_ENABLED */ 204