1 #include "GSH_VulkanTransferLocal.h"
2 #include "GSH_VulkanMemoryUtils.h"
3 #include "MemStream.h"
4 #include "vulkan/StructDefs.h"
5 #include "nuanceur/generators/SpirvShaderGenerator.h"
6
7 using namespace GSH_Vulkan;
8
9 #define DESCRIPTOR_LOCATION_MEMORY 0
10 #define DESCRIPTOR_LOCATION_SWIZZLETABLE_SRC 1
11 #define DESCRIPTOR_LOCATION_SWIZZLETABLE_DST 2
12
13 #define LOCAL_SIZE_X 32
14 #define LOCAL_SIZE_Y 32
15
CTransferLocal(const ContextPtr & context,const FrameCommandBufferPtr & frameCommandBuffer)16 CTransferLocal::CTransferLocal(const ContextPtr& context, const FrameCommandBufferPtr& frameCommandBuffer)
17 : m_context(context)
18 , m_frameCommandBuffer(frameCommandBuffer)
19 , m_pipelineCache(context->device)
20 {
21 m_pipelineCaps <<= 0;
22 }
23
SetPipelineCaps(const PIPELINE_CAPS & pipelineCaps)24 void CTransferLocal::SetPipelineCaps(const PIPELINE_CAPS& pipelineCaps)
25 {
26 m_pipelineCaps = pipelineCaps;
27 }
28
DoTransfer()29 void CTransferLocal::DoTransfer()
30 {
31 //Find pipeline and create it if we've never encountered it before
32 auto xferPipeline = m_pipelineCache.TryGetPipeline(m_pipelineCaps);
33 if(!xferPipeline)
34 {
35 xferPipeline = m_pipelineCache.RegisterPipeline(m_pipelineCaps, CreatePipeline(m_pipelineCaps));
36 }
37
38 auto descriptorSetCaps = make_convertible<DESCRIPTORSET_CAPS>(0);
39 descriptorSetCaps.srcPsm = m_pipelineCaps.srcFormat;
40 descriptorSetCaps.dstPsm = m_pipelineCaps.dstFormat;
41
42 auto descriptorSet = PrepareDescriptorSet(xferPipeline->descriptorSetLayout, descriptorSetCaps);
43 auto commandBuffer = m_frameCommandBuffer->GetCommandBuffer();
44
45 uint32 workUnitsX = (Params.rrw + LOCAL_SIZE_X - 1) / LOCAL_SIZE_X;
46 uint32 workUnitsY = (Params.rrh + LOCAL_SIZE_Y - 1) / LOCAL_SIZE_Y;
47
48 //Add a barrier to ensure reads are complete before writing to GS memory
49 if(false)
50 {
51 auto memoryBarrier = Framework::Vulkan::MemoryBarrier();
52 memoryBarrier.srcAccessMask = VK_ACCESS_SHADER_READ_BIT;
53 memoryBarrier.dstAccessMask = VK_ACCESS_SHADER_WRITE_BIT;
54
55 m_context->device.vkCmdPipelineBarrier(commandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
56 0, 1, &memoryBarrier, 0, nullptr, 0, nullptr);
57 }
58
59 m_context->device.vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, xferPipeline->pipelineLayout, 0, 1, &descriptorSet, 0, nullptr);
60 m_context->device.vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, xferPipeline->pipeline);
61 m_context->device.vkCmdPushConstants(commandBuffer, xferPipeline->pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(XFERPARAMS), &Params);
62 m_context->device.vkCmdDispatch(commandBuffer, workUnitsX, workUnitsY, 1);
63 }
64
PrepareDescriptorSet(VkDescriptorSetLayout descriptorSetLayout,const DESCRIPTORSET_CAPS & caps)65 VkDescriptorSet CTransferLocal::PrepareDescriptorSet(VkDescriptorSetLayout descriptorSetLayout, const DESCRIPTORSET_CAPS& caps)
66 {
67 auto descriptorSetIterator = m_descriptorSetCache.find(caps);
68 if(descriptorSetIterator != std::end(m_descriptorSetCache))
69 {
70 return descriptorSetIterator->second;
71 }
72
73 VkResult result = VK_SUCCESS;
74 VkDescriptorSet descriptorSet = VK_NULL_HANDLE;
75
76 //Allocate descriptor set
77 {
78 auto setAllocateInfo = Framework::Vulkan::DescriptorSetAllocateInfo();
79 setAllocateInfo.descriptorPool = m_context->descriptorPool;
80 setAllocateInfo.descriptorSetCount = 1;
81 setAllocateInfo.pSetLayouts = &descriptorSetLayout;
82
83 result = m_context->device.vkAllocateDescriptorSets(m_context->device, &setAllocateInfo, &descriptorSet);
84 CHECKVULKANERROR(result);
85 }
86
87 //Update descriptor set
88 {
89 std::vector<VkWriteDescriptorSet> writes;
90
91 VkDescriptorBufferInfo descriptorMemoryBufferInfo = {};
92 descriptorMemoryBufferInfo.buffer = m_context->memoryBuffer;
93 descriptorMemoryBufferInfo.range = VK_WHOLE_SIZE;
94
95 VkDescriptorImageInfo descriptorSrcSwizzleTableInfo = {};
96 descriptorSrcSwizzleTableInfo.imageView = m_context->GetSwizzleTable(caps.srcPsm);
97 descriptorSrcSwizzleTableInfo.imageLayout = VK_IMAGE_LAYOUT_GENERAL;
98
99 VkDescriptorImageInfo descriptorDstSwizzleTableInfo = {};
100 descriptorDstSwizzleTableInfo.imageView = m_context->GetSwizzleTable(caps.dstPsm);
101 descriptorDstSwizzleTableInfo.imageLayout = VK_IMAGE_LAYOUT_GENERAL;
102
103 //Memory Image Descriptor
104 {
105 auto writeSet = Framework::Vulkan::WriteDescriptorSet();
106 writeSet.dstSet = descriptorSet;
107 writeSet.dstBinding = DESCRIPTOR_LOCATION_MEMORY;
108 writeSet.descriptorCount = 1;
109 writeSet.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
110 writeSet.pBufferInfo = &descriptorMemoryBufferInfo;
111 writes.push_back(writeSet);
112 }
113
114 //Src Swizzle Table
115 {
116 auto writeSet = Framework::Vulkan::WriteDescriptorSet();
117 writeSet.dstSet = descriptorSet;
118 writeSet.dstBinding = DESCRIPTOR_LOCATION_SWIZZLETABLE_SRC;
119 writeSet.descriptorCount = 1;
120 writeSet.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
121 writeSet.pImageInfo = &descriptorSrcSwizzleTableInfo;
122 writes.push_back(writeSet);
123 }
124
125 //Dst Swizzle Table
126 {
127 auto writeSet = Framework::Vulkan::WriteDescriptorSet();
128 writeSet.dstSet = descriptorSet;
129 writeSet.dstBinding = DESCRIPTOR_LOCATION_SWIZZLETABLE_DST;
130 writeSet.descriptorCount = 1;
131 writeSet.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
132 writeSet.pImageInfo = &descriptorDstSwizzleTableInfo;
133 writes.push_back(writeSet);
134 }
135
136 m_context->device.vkUpdateDescriptorSets(m_context->device, writes.size(), writes.data(), 0, nullptr);
137 }
138
139 m_descriptorSetCache.insert(std::make_pair(caps, descriptorSet));
140
141 return descriptorSet;
142 }
143
CreateShader(const PIPELINE_CAPS & caps)144 Framework::Vulkan::CShaderModule CTransferLocal::CreateShader(const PIPELINE_CAPS& caps)
145 {
146 using namespace Nuanceur;
147
148 auto b = CShaderBuilder();
149
150 b.SetMetadata(CShaderBuilder::METADATA_LOCALSIZE_X, LOCAL_SIZE_X);
151 b.SetMetadata(CShaderBuilder::METADATA_LOCALSIZE_Y, LOCAL_SIZE_Y);
152
153 {
154 auto inputInvocationId = CInt4Lvalue(b.CreateInputInt(Nuanceur::SEMANTIC_SYSTEM_GIID));
155 auto memoryBuffer = CArrayUintValue(b.CreateUniformArrayUint("memoryBuffer", DESCRIPTOR_LOCATION_MEMORY));
156 auto srcSwizzleTable = CImageUint2DValue(b.CreateImage2DUint(DESCRIPTOR_LOCATION_SWIZZLETABLE_SRC));
157 auto dstSwizzleTable = CImageUint2DValue(b.CreateImage2DUint(DESCRIPTOR_LOCATION_SWIZZLETABLE_DST));
158
159 auto bufferParams = CInt4Lvalue(b.CreateUniformInt4("bufferParams", Nuanceur::UNIFORM_UNIT_PUSHCONSTANT));
160 auto offsetParams = CInt4Lvalue(b.CreateUniformInt4("offsetParams", Nuanceur::UNIFORM_UNIT_PUSHCONSTANT));
161 auto sizeParams = CInt4Lvalue(b.CreateUniformInt4("sizeParams", Nuanceur::UNIFORM_UNIT_PUSHCONSTANT));
162
163 auto srcBufAddress = bufferParams->x();
164 auto srcBufWidth = bufferParams->y();
165 auto dstBufAddress = bufferParams->z();
166 auto dstBufWidth = bufferParams->w();
167
168 auto srcOffset = offsetParams->xy();
169 auto dstOffset = offsetParams->zw();
170
171 auto size = sizeParams->xy();
172
173 BeginIf(b, inputInvocationId->x() >= size->x());
174 {
175 Return(b);
176 }
177 EndIf(b);
178
179 BeginIf(b, inputInvocationId->y() >= size->y());
180 {
181 Return(b);
182 }
183 EndIf(b);
184
185 auto srcPos = inputInvocationId->xy() + srcOffset;
186 auto dstPos = inputInvocationId->xy() + dstOffset;
187
188 auto pixel = CUintLvalue(b.CreateTemporaryUint());
189
190 switch(caps.srcFormat)
191 {
192 case CGSHandler::PSMCT32:
193 {
194 auto address = CMemoryUtils::GetPixelAddress<CGsPixelFormats::STORAGEPSMCT32>(
195 b, srcSwizzleTable, srcBufAddress, srcBufWidth, srcPos);
196 pixel = CMemoryUtils::Memory_Read32(b, memoryBuffer, address);
197 }
198 break;
199 case CGSHandler::PSMCT16:
200 {
201 auto address = CMemoryUtils::GetPixelAddress<CGsPixelFormats::STORAGEPSMCT16>(
202 b, srcSwizzleTable, srcBufAddress, srcBufWidth, srcPos);
203 pixel = CMemoryUtils::Memory_Read16(b, memoryBuffer, address);
204 }
205 break;
206 case CGSHandler::PSMT8:
207 {
208 auto address = CMemoryUtils::GetPixelAddress<CGsPixelFormats::STORAGEPSMT8>(
209 b, srcSwizzleTable, srcBufAddress, srcBufWidth, srcPos);
210 pixel = CMemoryUtils::Memory_Read8(b, memoryBuffer, address);
211 }
212 break;
213 default:
214 assert(false);
215 break;
216 }
217
218 switch(caps.dstFormat)
219 {
220 case CGSHandler::PSMCT32:
221 {
222 auto address = CMemoryUtils::GetPixelAddress<CGsPixelFormats::STORAGEPSMCT32>(
223 b, dstSwizzleTable, dstBufAddress, dstBufWidth, dstPos);
224 CMemoryUtils::Memory_Write32(b, memoryBuffer, address, pixel);
225 }
226 break;
227 case CGSHandler::PSMCT16:
228 {
229 auto address = CMemoryUtils::GetPixelAddress<CGsPixelFormats::STORAGEPSMCT16>(
230 b, dstSwizzleTable, dstBufAddress, dstBufWidth, dstPos);
231 CMemoryUtils::Memory_Write16(b, memoryBuffer, address, pixel);
232 }
233 break;
234 case CGSHandler::PSMT8:
235 {
236 auto address = CMemoryUtils::GetPixelAddress<CGsPixelFormats::STORAGEPSMT8>(
237 b, dstSwizzleTable, dstBufAddress, dstBufWidth, dstPos);
238 CMemoryUtils::Memory_Write8(b, memoryBuffer, address, pixel);
239 }
240 break;
241 default:
242 assert(false);
243 break;
244 }
245 }
246
247 Framework::CMemStream shaderStream;
248 Nuanceur::CSpirvShaderGenerator::Generate(shaderStream, b, Nuanceur::CSpirvShaderGenerator::SHADER_TYPE_COMPUTE);
249 shaderStream.Seek(0, Framework::STREAM_SEEK_SET);
250 return Framework::Vulkan::CShaderModule(m_context->device, shaderStream);
251 }
252
CreatePipeline(const PIPELINE_CAPS & caps)253 PIPELINE CTransferLocal::CreatePipeline(const PIPELINE_CAPS& caps)
254 {
255 PIPELINE xferPipeline;
256
257 auto xferShader = CreateShader(caps);
258
259 VkResult result = VK_SUCCESS;
260
261 {
262 std::vector<VkDescriptorSetLayoutBinding> bindings;
263
264 //GS memory
265 {
266 VkDescriptorSetLayoutBinding binding = {};
267 binding.binding = DESCRIPTOR_LOCATION_MEMORY;
268 binding.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
269 binding.descriptorCount = 1;
270 binding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
271 bindings.push_back(binding);
272 }
273
274 //Src Swizzle Table
275 {
276 VkDescriptorSetLayoutBinding binding = {};
277 binding.binding = DESCRIPTOR_LOCATION_SWIZZLETABLE_SRC;
278 binding.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
279 binding.descriptorCount = 1;
280 binding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
281 bindings.push_back(binding);
282 }
283
284 //Dst Swizzle Table
285 {
286 VkDescriptorSetLayoutBinding binding = {};
287 binding.binding = DESCRIPTOR_LOCATION_SWIZZLETABLE_DST;
288 binding.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
289 binding.descriptorCount = 1;
290 binding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
291 bindings.push_back(binding);
292 }
293
294 auto createInfo = Framework::Vulkan::DescriptorSetLayoutCreateInfo();
295 createInfo.bindingCount = bindings.size();
296 createInfo.pBindings = bindings.data();
297 result = m_context->device.vkCreateDescriptorSetLayout(m_context->device, &createInfo, nullptr, &xferPipeline.descriptorSetLayout);
298 CHECKVULKANERROR(result);
299 }
300
301 {
302 VkPushConstantRange pushConstantInfo = {};
303 pushConstantInfo.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
304 pushConstantInfo.offset = 0;
305 pushConstantInfo.size = sizeof(XFERPARAMS);
306
307 auto pipelineLayoutCreateInfo = Framework::Vulkan::PipelineLayoutCreateInfo();
308 pipelineLayoutCreateInfo.pushConstantRangeCount = 1;
309 pipelineLayoutCreateInfo.pPushConstantRanges = &pushConstantInfo;
310 pipelineLayoutCreateInfo.setLayoutCount = 1;
311 pipelineLayoutCreateInfo.pSetLayouts = &xferPipeline.descriptorSetLayout;
312
313 result = m_context->device.vkCreatePipelineLayout(m_context->device, &pipelineLayoutCreateInfo, nullptr, &xferPipeline.pipelineLayout);
314 CHECKVULKANERROR(result);
315 }
316
317 {
318 auto createInfo = Framework::Vulkan::ComputePipelineCreateInfo();
319 createInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
320 createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
321 createInfo.stage.pName = "main";
322 createInfo.stage.module = xferShader;
323 createInfo.layout = xferPipeline.pipelineLayout;
324
325 result = m_context->device.vkCreateComputePipelines(m_context->device, VK_NULL_HANDLE, 1, &createInfo, nullptr, &xferPipeline.pipeline);
326 CHECKVULKANERROR(result);
327 }
328
329 return xferPipeline;
330 }
331