1//
2//  MetalPReLU.mm
3//  MNN
4//
5//  Created by MNN on 2019/01/30.
6//  Copyright © 2018, Alibaba Group Holding Limited
7//
8
9#import "backend/metal/MetalPReLU.hpp"
10#import "backend/metal/MNNMetalContext.h"
11#import "core/Macro.h"
12#import "backend/metal/MetalBackend.hpp"
13
14#if MNN_METAL_ENABLED
15namespace MNN {
16
17MetalPReLU::MetalPReLU(Backend *backend, const float *slope, int count) : Execution(backend) {
18    auto context  = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
19    mSlope        = [context newDeviceBuffer:UP_DIV(count, 4) * 4 * sizeof(float) bytes:slope access:CPUWriteOnly];
20    mShareChannel = 1 == count;
21    if (!mShareChannel) {
22        mShape = [context newDeviceBuffer:3 * sizeof(int) access:CPUWriteOnly];
23    }
24    mPipeline = [context pipelineWithName:mShareChannel ? @"prelu" : @"prelu_slopes"];
25}
26
27ErrorCode MetalPReLU::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
28    auto backend = static_cast<MetalBackend *>(this->backend());
29    auto context = (__bridge MNNMetalContext *)backend->context();
30    auto output = outputs[0];
31    int w = output->width(), h = output->height(), z = UP_DIV(output->channel(), 4), b = output->batch();
32    if (mShareChannel) {
33        mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(w * h * z * b, 1, 1)];
34    } else {
35        ((int *)mShape.contents)[0] = w * h;
36        ((int *)mShape.contents)[1] = z;
37        ((int *)mShape.contents)[2] = b;
38        mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(w * h, z, b)];
39    }
40    return NO_ERROR;
41}
42
43ErrorCode MetalPReLU::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
44    auto backend = static_cast<MetalBackend *>(this->backend());
45
46    if(backend->isCommandEncoderSet()) {
47        return NO_ERROR;
48    }
49
50    auto func = [=](){
51        auto input = inputs[0], output = outputs[0];
52
53        auto encoder   = backend->encoder();
54        [encoder setComputePipelineState:mPipeline];
55        [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)input->deviceId() offset:0 atIndex:0];
56        [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)output->deviceId() offset:0 atIndex:1];
57        [encoder setBuffer:mSlope offset:0 atIndex:2];
58        if (!mShareChannel) {
59            [encoder setBuffer:mShape offset:0 atIndex:3];
60        }
61        [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
62
63        auto context = (__bridge MNNMetalContext *)backend->context();
64        if(context.isCommitEachShader) {
65            backend->flushEncoder();
66            [context commit_net];
67        }
68    };
69    func();
70    backend->addOpEncoder(func);
71    return NO_ERROR;
72}
73
74class MetalPReLUCreator : public MetalBackend::Creator {
75public:
76    virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend) const {
77        auto prelu = op->main_as_PRelu();
78        return new MetalPReLU(backend, prelu->slope()->data(), prelu->slopeCount());
79    }
80};
81REGISTER_METAL_OP_CREATOR(MetalPReLUCreator, OpType_PReLU);
82} // namespace MNN
83#endif /* MNN_METAL_ENABLED */
84