1//
2//  MetalReduction.mm
3//  MNN
4//
5//  Created by MNN on 2019/01/30.
6//  Copyright © 2018, Alibaba Group Holding Limited
7//
8
9#import "backend/metal/MetalReduction.hpp"
10#import "backend/metal/MNNMetalContext.h"
11#import "core/Macro.h"
12#import "core/Macro.h"
13#import "backend/metal/MetalBackend.hpp"
14#import "core/TensorUtils.hpp"
15
16#if MNN_METAL_ENABLED
17namespace MNN {
18
19MetalReduction::MetalReduction(Backend *backend, const ReductionParam *p, halide_type_t type) : Execution(backend) {
20    auto integer = type.code == halide_type_int;
21    NSString *kernel;
22    switch (p->operation()) {
23        case ReductionType_SUM:
24            kernel = integer ? @"reduce_sum_s" : @"reduce_sum_f";
25            break;
26        case ReductionType_ASUM:
27        case ReductionType_SUMSQ:
28            MNN_ASSERT(false); // both un-supported
29            break;
30        case ReductionType_MEAN:
31            kernel = integer ? @"reduce_mean_s" : @"reduce_mean_f";
32            break;
33        case ReductionType_MAXIMUM:
34            kernel = integer ? @"reduce_max_s" : @"reduce_max_f";
35            break;
36        case ReductionType_MINIMUM:
37            kernel = integer ? @"reduce_min_s" : @"reduce_min_f";
38            break;
39        case ReductionType_PROD:
40            kernel = integer ? @"reduce_prod_s" : @"reduce_prod_f";
41            break;
42        default:
43            break;
44    }
45    // The reduce after geometry compute has only one axis
46    mAxis = p->dim()->data()[0];
47    auto mkbn = static_cast<MetalBackend *>(backend);
48    auto context = (__bridge MNNMetalContext *)mkbn->context();
49    mConst = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
50    mPipeline = [context pipelineWithName:kernel];
51}
52
53ErrorCode MetalReduction::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
54    int outsideSize = 1, axisSize = 1, insideSize = 1;
55    for (int i = 0; i < mAxis; i++) {
56        outsideSize *= inputs[0]->length(i);
57    }
58    axisSize = inputs[0]->length(mAxis);
59    for (int i = mAxis + 1; i < inputs[0]->dimensions(); i++) {
60        insideSize *= inputs[0]->length(i);
61    }
62    auto backend = static_cast<MetalBackend *>(this->backend());
63    auto context = (__bridge MNNMetalContext *)backend->context();
64    ((int *)mConst.contents)[0] = outsideSize;
65    ((int *)mConst.contents)[1] = axisSize;
66    ((int *)mConst.contents)[2] = insideSize;
67    ((int *)mConst.contents)[3] = axisSize * insideSize;
68    mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(outsideSize, insideSize, 1)];
69    return NO_ERROR;
70}
71
72ErrorCode MetalReduction::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
73    auto backend = static_cast<MetalBackend *>(this->backend());
74
75    if(backend->isCommandEncoderSet()) {
76        return NO_ERROR;
77    }
78
79    auto func = [=](){
80        auto &input = inputs[0], &output = outputs[0];
81        auto encoder   = backend->encoder();
82        [encoder setComputePipelineState:mPipeline];
83        [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)input->deviceId() offset:0 atIndex:0];
84        [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)output->deviceId() offset:0 atIndex:1];
85        [encoder setBuffer:mConst offset:0 atIndex:2];
86        [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
87
88        auto context = (__bridge MNNMetalContext *)backend->context();
89        if(context.isCommitEachShader) {
90            backend->flushEncoder();
91            [context commit_net];
92        }
93    };
94    func();
95    backend->addOpEncoder(func);
96    return NO_ERROR;
97}
98
99class MetalReductionCreator : public MetalBackend::Creator {
100public:
101    virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend) const {
102        auto param = op->main_as_ReductionParam();
103        switch (param->operation()) {
104            case ReductionType_ALL:
105            case ReductionType_ANY:
106            case ReductionType_ASUM:
107            case ReductionType_SUMSQ:
108                return nullptr;
109            default:
110                break;
111        };
112
113        return new MetalReduction(backend, op->main_as_ReductionParam(), inputs[0]->getType());
114    }
115};
116REGISTER_METAL_OP_CREATOR(MetalReductionCreator, OpType_Reduction);
117} // namespace MNN
118#endif /* MNN_METAL_ENABLED */
119