1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 //
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 
8 #ifndef OPENCV_DNN_OP_VKCOM_HPP
9 #define OPENCV_DNN_OP_VKCOM_HPP
10 
11 #include <opencv2/dnn/shape_utils.hpp>
12 #ifdef HAVE_VULKAN
13 #include "vkcom/include/vkcom.hpp"
14 #endif  // HAVE_VULKAN
15 
16 namespace cv
17 {
18 namespace dnn
19 {
20 #ifdef HAVE_VULKAN
21     std::vector<vkcom::Tensor> VkComTensors(const std::vector<Ptr<BackendWrapper> >& ptrs);
22 
23     vkcom::Tensor VkComTensor(const Ptr<BackendWrapper>& ptr);
24 
25     // Data copied from/to Mat to/from Tensor. Change the shape of dst if
26     // needed to make it the same shape as src
27     void copyToTensor(vkcom::Tensor &dst, const Mat &src);
28 
29     void copyToMat(Mat &dst, const vkcom::Tensor &src);
30 
31     class VkComBackendNode : public BackendNode
32     {
33     public:
34         VkComBackendNode(const std::vector<Ptr<BackendWrapper> >& inputsWrapper,
35                          const std::shared_ptr<vkcom::OpBase> &op,
36                          const std::vector<Ptr<BackendWrapper> >& blobsWrapper =
37                          std::vector<Ptr<BackendWrapper> >());
38 
39         bool forward(std::vector<vkcom::Tensor>& outs);
40 
41     private:
42         std::vector<vkcom::Tensor> ins;
43         std::vector<vkcom::Tensor> blobs;
44         std::vector<Ptr<BackendWrapper> > inputsWrapper_;
45         std::shared_ptr<vkcom::OpBase> operation;
46     };
47 
48     class VkComBackendWrapper : public BackendWrapper
49     {
50     public:
51         VkComBackendWrapper(Mat& m);
52         VkComBackendWrapper(const Ptr<BackendWrapper>& baseBuffer, Mat& m);
53 
54         virtual void copyToHost() CV_OVERRIDE;
55         virtual void setHostDirty() CV_OVERRIDE;
56         void setDeviceDirty();
57         void copyToDevice();
58         vkcom::Tensor getTensor();
59 
60     private:
61         vkcom::Tensor tensor;
62         Mat* host;
63         bool hostDirty;
64         bool deviceDirty;
65     };
66 #endif // HAVE_VULKAN
67 
68     void forwardVkCom(std::vector<Ptr<BackendWrapper> > &outputs, const Ptr<BackendNode>& node);
69 
70     bool haveVulkan();
71 }  // namespace dnn
72 }  // namespace cv
73 
74 #endif  // OPENCV_DNN_OP_VKCOM_HPP
75