1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 //
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 
8 #include "../../precomp.hpp"
9 #include "common.hpp"
10 #include "internal.hpp"
11 #include "../include/op_relu.hpp"
12 
13 namespace cv { namespace dnn { namespace vkcom {
14 
15 #ifdef HAVE_VULKAN
16 
17 #define LOCAL_SZ_X 32
18 
19 struct ReLUParam {
20       int total;
21       float slope;
22 };
23 
OpReLU(const float slope)24 OpReLU::OpReLU(const float slope) : slope_(slope)
25 {
26     OpBase::initVulkanThing(2);
27     type_ = "ReLU";
28 }
29 
reshapeOutTensor(Tensor & in,Tensor & out)30 void OpReLU::reshapeOutTensor(Tensor& in, Tensor& out)
31 {
32     Shape shape = in.getShape();
33     out.reshape(NULL, shape);
34 }
35 
forward(std::vector<Tensor> & ins,std::vector<Tensor> & blobs,std::vector<Tensor> & outs)36 bool OpReLU::forward(std::vector<Tensor>& ins,
37                      std::vector<Tensor>& blobs,
38                      std::vector<Tensor>& outs)
39 {
40     return forward(ins[0], outs[0]);
41 }
42 
forward(Tensor & in,Tensor & out)43 bool OpReLU::forward(Tensor& in, Tensor& out)
44 {
45     if (pipeline_ == VK_NULL_HANDLE)
46     {
47         total_ = in.count();
48 #define maxComputeWorkGroupCount 65535
49         computeGroupCount();
50         createShaderModule(relu_spv, sizeof(relu_spv));
51         createPipeline(sizeof(ReLUParam));
52     }
53 
54     bindTensor(device_, in,  0, descriptor_set_);
55     bindTensor(device_, out, 1, descriptor_set_);
56     ReLUParam param = { total_, slope_ };
57     recordCommandBuffer((void *)&param, sizeof(ReLUParam));
58     runCommandBuffer();
59     return true;
60 }
61 
computeGroupCount()62 bool OpReLU::computeGroupCount()
63 {
64     group_x_ = alignSize(total_, LOCAL_SZ_X) / LOCAL_SZ_X;
65     if (group_x_ > maxComputeWorkGroupCount)
66         group_x_ = maxComputeWorkGroupCount;
67     group_y_ = 1;
68     group_z_ = 1;
69     return true;
70 }
71 
72 #endif // HAVE_VULKAN
73 
74 }}} // namespace cv::dnn::vkcom
75