1 // This file is part of OpenCV project. 2 // It is subject to the license terms in the LICENSE file found in the top-level directory 3 // of this distribution and at http://opencv.org/license.html. 4 // 5 // Copyright (C) 2018, Intel Corporation, all rights reserved. 6 // Third party copyrights are property of their respective owners. 7 8 #ifndef OPENCV_DNN_OP_VKCOM_HPP 9 #define OPENCV_DNN_OP_VKCOM_HPP 10 11 #include <opencv2/dnn/shape_utils.hpp> 12 #ifdef HAVE_VULKAN 13 #include "vkcom/include/vkcom.hpp" 14 #endif // HAVE_VULKAN 15 16 namespace cv 17 { 18 namespace dnn 19 { 20 #ifdef HAVE_VULKAN 21 std::vector<vkcom::Tensor> VkComTensors(const std::vector<Ptr<BackendWrapper> >& ptrs); 22 23 vkcom::Tensor VkComTensor(const Ptr<BackendWrapper>& ptr); 24 25 // Data copied from/to Mat to/from Tensor. Change the shape of dst if 26 // needed to make it the same shape as src 27 void copyToTensor(vkcom::Tensor &dst, const Mat &src); 28 29 void copyToMat(Mat &dst, const vkcom::Tensor &src); 30 31 class VkComBackendNode : public BackendNode 32 { 33 public: 34 VkComBackendNode(const std::vector<Ptr<BackendWrapper> >& inputsWrapper, 35 const std::shared_ptr<vkcom::OpBase> &op, 36 const std::vector<Ptr<BackendWrapper> >& blobsWrapper = 37 std::vector<Ptr<BackendWrapper> >()); 38 39 bool forward(std::vector<vkcom::Tensor>& outs); 40 41 private: 42 std::vector<vkcom::Tensor> ins; 43 std::vector<vkcom::Tensor> blobs; 44 std::vector<Ptr<BackendWrapper> > inputsWrapper_; 45 std::shared_ptr<vkcom::OpBase> operation; 46 }; 47 48 class VkComBackendWrapper : public BackendWrapper 49 { 50 public: 51 VkComBackendWrapper(Mat& m); 52 VkComBackendWrapper(const Ptr<BackendWrapper>& baseBuffer, Mat& m); 53 54 virtual void copyToHost() CV_OVERRIDE; 55 virtual void setHostDirty() CV_OVERRIDE; 56 void setDeviceDirty(); 57 void copyToDevice(); 58 vkcom::Tensor getTensor(); 59 60 private: 61 vkcom::Tensor tensor; 62 Mat* host; 63 bool hostDirty; 64 bool deviceDirty; 65 }; 66 #endif // HAVE_VULKAN 67 68 void forwardVkCom(std::vector<Ptr<BackendWrapper> > &outputs, const Ptr<BackendNode>& node); 69 70 bool haveVulkan(); 71 } // namespace dnn 72 } // namespace cv 73 74 #endif // OPENCV_DNN_OP_VKCOM_HPP 75