1 //
2 //  VulkanRelu.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef VulkanRelu_hpp
10 #define VulkanRelu_hpp
11 
12 #include <stdio.h>
13 
14 #include "VulkanBasicExecution.hpp"
15 
16 namespace MNN {
17 
18 class VulkanRelu : public VulkanBasicExecution {
19 public:
20     VulkanRelu(Backend* bn, const Op* op);
21     virtual ~VulkanRelu();
22     ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
23                        const VulkanCommandPool::Buffer* cmdBuffer) override;
24 
25 private:
26     float mSlope[4];
27     std::shared_ptr<VulkanBuffer> mGpuReluParam;
28     const VulkanPipeline* mReluPipeline;
29     std::shared_ptr<VulkanPipeline::DescriptorSet> mDescriptorSet;
30 };
31 
32 class VulkanPrelu : public VulkanBasicExecution {
33 public:
34     VulkanPrelu(Backend* bn, const Op* op);
35     virtual ~VulkanPrelu();
36     ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
37                        const VulkanCommandPool::Buffer* cmdBuffer) override;
38 
39 private:
40     std::shared_ptr<VulkanBuffer> mGpuPreluParam;
41     std::shared_ptr<VulkanImage> mSlope;
42     const VulkanPipeline* mPreluPipeline;
43     std::shared_ptr<VulkanPipeline::DescriptorSet> mDescriptorSet;
44 };
45 
46 } // namespace MNN
47 
48 #endif /* VulkanRelu_hpp */
49