1// 2// MetalPReLU.mm 3// MNN 4// 5// Created by MNN on 2019/01/30. 6// Copyright © 2018, Alibaba Group Holding Limited 7// 8 9#import "backend/metal/MetalPReLU.hpp" 10#import "backend/metal/MNNMetalContext.h" 11#import "core/Macro.h" 12#import "backend/metal/MetalBackend.hpp" 13 14#if MNN_METAL_ENABLED 15namespace MNN { 16 17MetalPReLU::MetalPReLU(Backend *backend, const float *slope, int count) : Execution(backend) { 18 auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context(); 19 mSlope = [context newDeviceBuffer:UP_DIV(count, 4) * 4 * sizeof(float) bytes:slope access:CPUWriteOnly]; 20 mShareChannel = 1 == count; 21 if (!mShareChannel) { 22 mShape = [context newDeviceBuffer:3 * sizeof(int) access:CPUWriteOnly]; 23 } 24 mPipeline = [context pipelineWithName:mShareChannel ? @"prelu" : @"prelu_slopes"]; 25} 26 27ErrorCode MetalPReLU::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { 28 auto backend = static_cast<MetalBackend *>(this->backend()); 29 auto context = (__bridge MNNMetalContext *)backend->context(); 30 auto output = outputs[0]; 31 int w = output->width(), h = output->height(), z = UP_DIV(output->channel(), 4), b = output->batch(); 32 if (mShareChannel) { 33 mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(w * h * z * b, 1, 1)]; 34 } else { 35 ((int *)mShape.contents)[0] = w * h; 36 ((int *)mShape.contents)[1] = z; 37 ((int *)mShape.contents)[2] = b; 38 mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(w * h, z, b)]; 39 } 40 return NO_ERROR; 41} 42 43ErrorCode MetalPReLU::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { 44 auto backend = static_cast<MetalBackend *>(this->backend()); 45 46 if(backend->isCommandEncoderSet()) { 47 return NO_ERROR; 48 } 49 50 auto func = [=](){ 51 auto input = inputs[0], output = outputs[0]; 52 53 auto encoder = backend->encoder(); 54 [encoder setComputePipelineState:mPipeline]; 55 [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)input->deviceId() offset:0 atIndex:0]; 56 [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)output->deviceId() offset:0 atIndex:1]; 57 [encoder setBuffer:mSlope offset:0 atIndex:2]; 58 if (!mShareChannel) { 59 [encoder setBuffer:mShape offset:0 atIndex:3]; 60 } 61 [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; 62 63 auto context = (__bridge MNNMetalContext *)backend->context(); 64 if(context.isCommitEachShader) { 65 backend->flushEncoder(); 66 [context commit_net]; 67 } 68 }; 69 func(); 70 backend->addOpEncoder(func); 71 return NO_ERROR; 72} 73 74class MetalPReLUCreator : public MetalBackend::Creator { 75public: 76 virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend) const { 77 auto prelu = op->main_as_PRelu(); 78 return new MetalPReLU(backend, prelu->slope()->data(), prelu->slopeCount()); 79 } 80}; 81REGISTER_METAL_OP_CREATOR(MetalPReLUCreator, OpType_PReLU); 82} // namespace MNN 83#endif /* MNN_METAL_ENABLED */ 84