1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 //
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "../../precomp.hpp"
9 #include "common.hpp"
10 #include "internal.hpp"
11 #include "../include/op_relu.hpp"
12
13 namespace cv { namespace dnn { namespace vkcom {
14
15 #ifdef HAVE_VULKAN
16
17 #define LOCAL_SZ_X 32
18
19 struct ReLUParam {
20 int total;
21 float slope;
22 };
23
OpReLU(const float slope)24 OpReLU::OpReLU(const float slope) : slope_(slope)
25 {
26 OpBase::initVulkanThing(2);
27 type_ = "ReLU";
28 }
29
reshapeOutTensor(Tensor & in,Tensor & out)30 void OpReLU::reshapeOutTensor(Tensor& in, Tensor& out)
31 {
32 Shape shape = in.getShape();
33 out.reshape(NULL, shape);
34 }
35
forward(std::vector<Tensor> & ins,std::vector<Tensor> & blobs,std::vector<Tensor> & outs)36 bool OpReLU::forward(std::vector<Tensor>& ins,
37 std::vector<Tensor>& blobs,
38 std::vector<Tensor>& outs)
39 {
40 return forward(ins[0], outs[0]);
41 }
42
forward(Tensor & in,Tensor & out)43 bool OpReLU::forward(Tensor& in, Tensor& out)
44 {
45 if (pipeline_ == VK_NULL_HANDLE)
46 {
47 total_ = in.count();
48 #define maxComputeWorkGroupCount 65535
49 computeGroupCount();
50 createShaderModule(relu_spv, sizeof(relu_spv));
51 createPipeline(sizeof(ReLUParam));
52 }
53
54 bindTensor(device_, in, 0, descriptor_set_);
55 bindTensor(device_, out, 1, descriptor_set_);
56 ReLUParam param = { total_, slope_ };
57 recordCommandBuffer((void *)¶m, sizeof(ReLUParam));
58 runCommandBuffer();
59 return true;
60 }
61
computeGroupCount()62 bool OpReLU::computeGroupCount()
63 {
64 group_x_ = alignSize(total_, LOCAL_SZ_X) / LOCAL_SZ_X;
65 if (group_x_ > maxComputeWorkGroupCount)
66 group_x_ = maxComputeWorkGroupCount;
67 group_y_ = 1;
68 group_z_ = 1;
69 return true;
70 }
71
72 #endif // HAVE_VULKAN
73
74 }}} // namespace cv::dnn::vkcom
75