1 //===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements C runtime wrappers around the VulkanRuntime.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <iostream>
14 #include <mutex>
15 #include <numeric>
16 
17 #include "VulkanRuntime.h"
18 
19 // Explicitly export entry points to the vulkan-runtime-wrapper.
20 
21 #ifdef _WIN32
22 #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport)
23 #else
24 #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
25 #endif // _WIN32
26 
27 namespace {
28 
29 class VulkanRuntimeManager {
30 public:
31   VulkanRuntimeManager() = default;
32   VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
33   VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
34   ~VulkanRuntimeManager() = default;
35 
setResourceData(DescriptorSetIndex setIndex,BindingIndex bindIndex,const VulkanHostMemoryBuffer & memBuffer)36   void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
37                        const VulkanHostMemoryBuffer &memBuffer) {
38     std::lock_guard<std::mutex> lock(mutex);
39     vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
40   }
41 
setEntryPoint(const char * entryPoint)42   void setEntryPoint(const char *entryPoint) {
43     std::lock_guard<std::mutex> lock(mutex);
44     vulkanRuntime.setEntryPoint(entryPoint);
45   }
46 
setNumWorkGroups(NumWorkGroups numWorkGroups)47   void setNumWorkGroups(NumWorkGroups numWorkGroups) {
48     std::lock_guard<std::mutex> lock(mutex);
49     vulkanRuntime.setNumWorkGroups(numWorkGroups);
50   }
51 
setShaderModule(uint8_t * shader,uint32_t size)52   void setShaderModule(uint8_t *shader, uint32_t size) {
53     std::lock_guard<std::mutex> lock(mutex);
54     vulkanRuntime.setShaderModule(shader, size);
55   }
56 
runOnVulkan()57   void runOnVulkan() {
58     std::lock_guard<std::mutex> lock(mutex);
59     if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
60         failed(vulkanRuntime.updateHostMemoryBuffers()) ||
61         failed(vulkanRuntime.destroy())) {
62       std::cerr << "runOnVulkan failed";
63     }
64   }
65 
66 private:
67   VulkanRuntime vulkanRuntime;
68   std::mutex mutex;
69 };
70 
71 } // namespace
72 
73 template <typename T, int N>
74 struct MemRefDescriptor {
75   T *allocated;
76   T *aligned;
77   int64_t offset;
78   int64_t sizes[N];
79   int64_t strides[N];
80 };
81 
82 template <typename T, uint32_t S>
bindMemRef(void * vkRuntimeManager,DescriptorSetIndex setIndex,BindingIndex bindIndex,MemRefDescriptor<T,S> * ptr)83 void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
84                 BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) {
85   uint32_t size = sizeof(T);
86   for (unsigned i = 0; i < S; i++)
87     size *= ptr->sizes[i];
88   VulkanHostMemoryBuffer memBuffer{ptr->allocated, size};
89   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
90       ->setResourceData(setIndex, bindIndex, memBuffer);
91 }
92 
93 extern "C" {
94 /// Initializes `VulkanRuntimeManager` and returns a pointer to it.
initVulkan()95 VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() {
96   return new VulkanRuntimeManager();
97 }
98 
99 /// Deinitializes `VulkanRuntimeManager` by the given pointer.
deinitVulkan(void * vkRuntimeManager)100 VULKAN_WRAPPER_SYMBOL_EXPORT void deinitVulkan(void *vkRuntimeManager) {
101   delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
102 }
103 
runOnVulkan(void * vkRuntimeManager)104 VULKAN_WRAPPER_SYMBOL_EXPORT void runOnVulkan(void *vkRuntimeManager) {
105   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
106 }
107 
setEntryPoint(void * vkRuntimeManager,const char * entryPoint)108 VULKAN_WRAPPER_SYMBOL_EXPORT void setEntryPoint(void *vkRuntimeManager,
109                                                 const char *entryPoint) {
110   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
111       ->setEntryPoint(entryPoint);
112 }
113 
114 VULKAN_WRAPPER_SYMBOL_EXPORT void
setNumWorkGroups(void * vkRuntimeManager,uint32_t x,uint32_t y,uint32_t z)115 setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, uint32_t z) {
116   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
117       ->setNumWorkGroups({x, y, z});
118 }
119 
120 VULKAN_WRAPPER_SYMBOL_EXPORT void
setBinaryShader(void * vkRuntimeManager,uint8_t * shader,uint32_t size)121 setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
122   reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
123       ->setShaderModule(shader, size);
124 }
125 
126 /// Binds the given memref to the given descriptor set and descriptor
127 /// index.
128 #define DECLARE_BIND_MEMREF(size, type, typeName)                              \
129   VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName(             \
130       void *vkRuntimeManager, DescriptorSetIndex setIndex,                     \
131       BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) {             \
132     bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr);        \
133   }
134 
135 DECLARE_BIND_MEMREF(1, float, Float)
136 DECLARE_BIND_MEMREF(2, float, Float)
137 DECLARE_BIND_MEMREF(3, float, Float)
138 DECLARE_BIND_MEMREF(1, int32_t, Int32)
139 DECLARE_BIND_MEMREF(2, int32_t, Int32)
140 DECLARE_BIND_MEMREF(3, int32_t, Int32)
141 DECLARE_BIND_MEMREF(1, int16_t, Int16)
142 DECLARE_BIND_MEMREF(2, int16_t, Int16)
143 DECLARE_BIND_MEMREF(3, int16_t, Int16)
144 DECLARE_BIND_MEMREF(1, int8_t, Int8)
145 DECLARE_BIND_MEMREF(2, int8_t, Int8)
146 DECLARE_BIND_MEMREF(3, int8_t, Int8)
147 DECLARE_BIND_MEMREF(1, int16_t, Half)
148 DECLARE_BIND_MEMREF(2, int16_t, Half)
149 DECLARE_BIND_MEMREF(3, int16_t, Half)
150 
151 /// Fills the given 1D float memref with the given float value.
152 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DFloat(MemRefDescriptor<float,1> * ptr,float value)153 _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
154                                  float value) {
155   std::fill_n(ptr->allocated, ptr->sizes[0], value);
156 }
157 
158 /// Fills the given 2D float memref with the given float value.
159 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DFloat(MemRefDescriptor<float,2> * ptr,float value)160 _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
161                                  float value) {
162   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
163 }
164 
165 /// Fills the given 3D float memref with the given float value.
166 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DFloat(MemRefDescriptor<float,3> * ptr,float value)167 _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
168                                  float value) {
169   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
170               value);
171 }
172 
173 /// Fills the given 1D int memref with the given int value.
174 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t,1> * ptr,int32_t value)175 _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
176                                int32_t value) {
177   std::fill_n(ptr->allocated, ptr->sizes[0], value);
178 }
179 
180 /// Fills the given 2D int memref with the given int value.
181 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t,2> * ptr,int32_t value)182 _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
183                                int32_t value) {
184   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
185 }
186 
187 /// Fills the given 3D int memref with the given int value.
188 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t,3> * ptr,int32_t value)189 _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
190                                int32_t value) {
191   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
192               value);
193 }
194 
195 /// Fills the given 1D int memref with the given int8 value.
196 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t,1> * ptr,int8_t value)197 _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
198                                 int8_t value) {
199   std::fill_n(ptr->allocated, ptr->sizes[0], value);
200 }
201 
202 /// Fills the given 2D int memref with the given int8 value.
203 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t,2> * ptr,int8_t value)204 _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
205                                 int8_t value) {
206   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
207 }
208 
209 /// Fills the given 3D int memref with the given int8 value.
210 VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t,3> * ptr,int8_t value)211 _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
212                                 int8_t value) {
213   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
214               value);
215 }
216 }
217