1 // 2 // VulkanRelu.hpp 3 // MNN 4 // 5 // Created by MNN on 2019/01/31. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef VulkanRelu_hpp 10 #define VulkanRelu_hpp 11 12 #include <stdio.h> 13 14 #include "VulkanBasicExecution.hpp" 15 16 namespace MNN { 17 18 class VulkanRelu : public VulkanBasicExecution { 19 public: 20 VulkanRelu(Backend* bn, const Op* op); 21 virtual ~VulkanRelu(); 22 ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, 23 const VulkanCommandPool::Buffer* cmdBuffer) override; 24 25 private: 26 float mSlope[4]; 27 std::shared_ptr<VulkanBuffer> mGpuReluParam; 28 const VulkanPipeline* mReluPipeline; 29 std::shared_ptr<VulkanPipeline::DescriptorSet> mDescriptorSet; 30 }; 31 32 class VulkanPrelu : public VulkanBasicExecution { 33 public: 34 VulkanPrelu(Backend* bn, const Op* op); 35 virtual ~VulkanPrelu(); 36 ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, 37 const VulkanCommandPool::Buffer* cmdBuffer) override; 38 39 private: 40 std::shared_ptr<VulkanBuffer> mGpuPreluParam; 41 std::shared_ptr<VulkanImage> mSlope; 42 const VulkanPipeline* mPreluPipeline; 43 std::shared_ptr<VulkanPipeline::DescriptorSet> mDescriptorSet; 44 }; 45 46 } // namespace MNN 47 48 #endif /* VulkanRelu_hpp */ 49