1 /* Copyright (c) 2017-2018 Hans-Kristian Arntzen
2  *
3  * Permission is hereby granted, free of charge, to any person obtaining
4  * a copy of this software and associated documentation files (the
5  * "Software"), to deal in the Software without restriction, including
6  * without limitation the rights to use, copy, modify, merge, publish,
7  * distribute, sublicense, and/or sell copies of the Software, and to
8  * permit persons to whom the Software is furnished to do so, subject to
9  * the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be
12  * included in all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
17  * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
18  * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
19  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
20  * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21  */
22 
23 #include "shader.hpp"
24 #include "device.hpp"
25 #include <spirv_cross.hpp>
26 
27 #ifdef GRANITE_SPIRV_DUMP
28 #include "filesystem.hpp"
29 #endif
30 
31 using namespace std;
32 using namespace spirv_cross;
33 using namespace Util;
34 
35 namespace Vulkan
36 {
PipelineLayout(Hash hash,Device * device,const CombinedResourceLayout & layout)37 PipelineLayout::PipelineLayout(Hash hash, Device *device, const CombinedResourceLayout &layout)
38 	: IntrusiveHashMapEnabled<PipelineLayout>(hash)
39 	, device(device)
40 	, layout(layout)
41 {
42 	VkDescriptorSetLayout layouts[VULKAN_NUM_DESCRIPTOR_SETS] = {};
43 	unsigned num_sets = 0;
44 	for (unsigned i = 0; i < VULKAN_NUM_DESCRIPTOR_SETS; i++)
45 	{
46 		set_allocators[i] = device->request_descriptor_set_allocator(layout.sets[i], layout.stages_for_bindings[i]);
47 		layouts[i] = set_allocators[i]->get_layout();
48 		if (layout.descriptor_set_mask & (1u << i))
49 			num_sets = i + 1;
50 	}
51 
52 	VkPipelineLayoutCreateInfo info = { VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO };
53 	if (num_sets)
54 	{
55 		info.setLayoutCount = num_sets;
56 		info.pSetLayouts = layouts;
57 	}
58 
59 	if (layout.push_constant_range.stageFlags != 0)
60 	{
61 		info.pushConstantRangeCount = 1;
62 		info.pPushConstantRanges = &layout.push_constant_range;
63 	}
64 
65 #ifdef GRANITE_VULKAN_FOSSILIZE
66 	unsigned layout_index = device->register_pipeline_layout(get_hash(), info);
67 #endif
68 	LOGI("Creating pipeline layout.\n");
69 	if (vkCreatePipelineLayout(device->get_device(), &info, nullptr, &pipe_layout) != VK_SUCCESS)
70 		LOGE("Failed to create pipeline layout.\n");
71 #ifdef GRANITE_VULKAN_FOSSILIZE
72 	device->set_pipeline_layout_handle(layout_index, pipe_layout);
73 #endif
74 }
75 
~PipelineLayout()76 PipelineLayout::~PipelineLayout()
77 {
78 	if (pipe_layout != VK_NULL_HANDLE)
79 		vkDestroyPipelineLayout(device->get_device(), pipe_layout, nullptr);
80 }
81 
stage_to_name(ShaderStage stage)82 const char *Shader::stage_to_name(ShaderStage stage)
83 {
84 	switch (stage)
85 	{
86 	case ShaderStage::Compute:
87 		return "compute";
88 	case ShaderStage::Vertex:
89 		return "vertex";
90 	case ShaderStage::Fragment:
91 		return "fragment";
92 	case ShaderStage::Geometry:
93 		return "geometry";
94 	case ShaderStage::TessControl:
95 		return "tess_control";
96 	case ShaderStage::TessEvaluation:
97 		return "tess_evaluation";
98 	default:
99 		return "unknown";
100 	}
101 }
102 
get_stock_sampler(StockSampler & sampler,const string & name)103 static bool get_stock_sampler(StockSampler &sampler, const string &name)
104 {
105 	if (name.find("NearestClamp") != string::npos)
106 		sampler = StockSampler::NearestClamp;
107 	else if (name.find("LinearClamp") != string::npos)
108 		sampler = StockSampler::LinearClamp;
109 	else if (name.find("TrilinearClamp") != string::npos)
110 		sampler = StockSampler::TrilinearClamp;
111 	else if (name.find("NearestWrap") != string::npos)
112 		sampler = StockSampler::NearestWrap;
113 	else if (name.find("LinearWrap") != string::npos)
114 		sampler = StockSampler::LinearWrap;
115 	else if (name.find("TrilinearWrap") != string::npos)
116 		sampler = StockSampler::TrilinearWrap;
117 	else if (name.find("NearestShadow") != string::npos)
118 		sampler = StockSampler::NearestShadow;
119 	else if (name.find("LinearShadow") != string::npos)
120 		sampler = StockSampler::LinearShadow;
121 	else
122 		return false;
123 
124 	return true;
125 }
126 
Shader(Hash hash,Device * device,const uint32_t * data,size_t size)127 Shader::Shader(Hash hash, Device *device, const uint32_t *data, size_t size)
128 	: IntrusiveHashMapEnabled<Shader>(hash)
129 	, device(device)
130 {
131 #ifdef GRANITE_SPIRV_DUMP
132 	if (!Granite::Filesystem::get().write_buffer_to_file(string("cache://spirv/") + to_string(hash) + ".spv", data, size))
133 		LOGE("Failed to dump shader to file.\n");
134 #endif
135 
136 	VkShaderModuleCreateInfo info = { VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO };
137 	info.codeSize = size;
138 	info.pCode = data;
139 
140 #ifdef GRANITE_VULKAN_FOSSILIZE
141 	unsigned module_index = device->register_shader_module(get_hash(), info);
142 #endif
143 	LOGI("Creating shader module.\n");
144 	if (vkCreateShaderModule(device->get_device(), &info, nullptr, &module) != VK_SUCCESS)
145 		LOGE("Failed to create shader module.\n");
146 #ifdef GRANITE_VULKAN_FOSSILIZE
147 	device->set_shader_module_handle(module_index, module);
148 #endif
149 
150 	Compiler compiler(data, size / sizeof(uint32_t));
151 
152 	auto resources = compiler.get_shader_resources();
153 	for (auto &image : resources.sampled_images)
154 	{
155 		auto set = compiler.get_decoration(image.id, spv::DecorationDescriptorSet);
156 		auto binding = compiler.get_decoration(image.id, spv::DecorationBinding);
157 		auto &type = compiler.get_type(image.base_type_id);
158 		if (type.image.dim == spv::DimBuffer)
159 			layout.sets[set].sampled_buffer_mask |= 1u << binding;
160 		else
161 			layout.sets[set].sampled_image_mask |= 1u << binding;
162 
163 		if (compiler.get_type(type.image.type).basetype == SPIRType::BaseType::Float)
164 			layout.sets[set].fp_mask |= 1u << binding;
165 
166 		const string &name = image.name;
167 		StockSampler sampler;
168 		if (type.image.dim != spv::DimBuffer && get_stock_sampler(sampler, name))
169 		{
170 			if (has_immutable_sampler(layout.sets[set], binding))
171 			{
172 				if (sampler != get_immutable_sampler(layout.sets[set], binding))
173 					LOGE("Immutable sampler mismatch detected!\n");
174 			}
175 			else
176 				set_immutable_sampler(layout.sets[set], binding, sampler);
177 		}
178 	}
179 
180 	for (auto &image : resources.subpass_inputs)
181 	{
182 		auto set = compiler.get_decoration(image.id, spv::DecorationDescriptorSet);
183 		auto binding = compiler.get_decoration(image.id, spv::DecorationBinding);
184 		layout.sets[set].input_attachment_mask |= 1u << binding;
185 
186 		auto &type = compiler.get_type(image.base_type_id);
187 		if (compiler.get_type(type.image.type).basetype == SPIRType::BaseType::Float)
188 			layout.sets[set].fp_mask |= 1u << binding;
189 	}
190 
191 	for (auto &image : resources.separate_images)
192 	{
193 		auto set = compiler.get_decoration(image.id, spv::DecorationDescriptorSet);
194 		auto binding = compiler.get_decoration(image.id, spv::DecorationBinding);
195 
196 		auto &type = compiler.get_type(image.base_type_id);
197 		if (compiler.get_type(type.image.type).basetype == SPIRType::BaseType::Float)
198 			layout.sets[set].fp_mask |= 1u << binding;
199 
200 		if (type.image.dim == spv::DimBuffer)
201 			layout.sets[set].sampled_buffer_mask |= 1u << binding;
202 		else
203 			layout.sets[set].separate_image_mask |= 1u << binding;
204 	}
205 
206 	for (auto &image : resources.separate_samplers)
207 	{
208 		auto set = compiler.get_decoration(image.id, spv::DecorationDescriptorSet);
209 		auto binding = compiler.get_decoration(image.id, spv::DecorationBinding);
210 		layout.sets[set].sampler_mask |= 1u << binding;
211 
212 		const string &name = image.name;
213 		StockSampler sampler;
214 		if (get_stock_sampler(sampler, name))
215 		{
216 			if (has_immutable_sampler(layout.sets[set], binding))
217 			{
218 				if (sampler != get_immutable_sampler(layout.sets[set], binding))
219 					LOGE("Immutable sampler mismatch detected!\n");
220 			}
221 			else
222 				set_immutable_sampler(layout.sets[set], binding, sampler);
223 		}
224 	}
225 
226 	for (auto &image : resources.storage_images)
227 	{
228 		auto set = compiler.get_decoration(image.id, spv::DecorationDescriptorSet);
229 		auto binding = compiler.get_decoration(image.id, spv::DecorationBinding);
230 		layout.sets[set].storage_image_mask |= 1u << binding;
231 
232 		auto &type = compiler.get_type(image.base_type_id);
233 		if (compiler.get_type(type.image.type).basetype == SPIRType::BaseType::Float)
234 			layout.sets[set].fp_mask |= 1u << binding;
235 	}
236 
237 	for (auto &buffer : resources.uniform_buffers)
238 	{
239 		auto set = compiler.get_decoration(buffer.id, spv::DecorationDescriptorSet);
240 		auto binding = compiler.get_decoration(buffer.id, spv::DecorationBinding);
241 		layout.sets[set].uniform_buffer_mask |= 1u << binding;
242 	}
243 
244 	for (auto &buffer : resources.storage_buffers)
245 	{
246 		auto set = compiler.get_decoration(buffer.id, spv::DecorationDescriptorSet);
247 		auto binding = compiler.get_decoration(buffer.id, spv::DecorationBinding);
248 		layout.sets[set].storage_buffer_mask |= 1u << binding;
249 	}
250 
251 	for (auto &attrib : resources.stage_inputs)
252 	{
253 		auto location = compiler.get_decoration(attrib.id, spv::DecorationLocation);
254 		layout.input_mask |= 1u << location;
255 	}
256 
257 	for (auto &attrib : resources.stage_outputs)
258 	{
259 		auto location = compiler.get_decoration(attrib.id, spv::DecorationLocation);
260 		layout.output_mask |= 1u << location;
261 	}
262 
263 	if (!resources.push_constant_buffers.empty())
264 	{
265 		// Don't bother trying to extract which part of a push constant block we're using.
266 		// Just assume we're accessing everything. At least on older validation layers,
267 		// it did not do a static analysis to determine similar information, so we got a lot
268 		// of false positives.
269 		layout.push_constant_size =
270 		    compiler.get_declared_struct_size(compiler.get_type(resources.push_constant_buffers.front().base_type_id));
271 	}
272 
273 	auto spec_constants = compiler.get_specialization_constants();
274 	for (auto &c : spec_constants)
275 	{
276 		if (c.constant_id >= VULKAN_NUM_SPEC_CONSTANTS)
277 		{
278 			LOGE("Spec constant ID: %u is out of range, will be ignored.\n", c.constant_id);
279 			continue;
280 		}
281 
282 		layout.spec_constant_mask |= 1u << c.constant_id;
283 	}
284 }
285 
~Shader()286 Shader::~Shader()
287 {
288 	if (module)
289 		vkDestroyShaderModule(device->get_device(), module, nullptr);
290 }
291 
set_shader(ShaderStage stage,Shader * handle)292 void Program::set_shader(ShaderStage stage, Shader *handle)
293 {
294 	shaders[Util::ecast(stage)] = handle;
295 }
296 
Program(Device * device,Shader * vertex,Shader * fragment)297 Program::Program(Device *device, Shader *vertex, Shader *fragment)
298     : device(device)
299 {
300 	set_shader(ShaderStage::Vertex, vertex);
301 	set_shader(ShaderStage::Fragment, fragment);
302 	device->bake_program(*this);
303 }
304 
Program(Device * device,Shader * compute)305 Program::Program(Device *device, Shader *compute)
306     : device(device)
307 {
308 	set_shader(ShaderStage::Compute, compute);
309 	device->bake_program(*this);
310 }
311 
get_pipeline(Hash hash) const312 VkPipeline Program::get_pipeline(Hash hash) const
313 {
314 	auto *ret = pipelines.find(hash);
315 	return ret ? ret->get() : VK_NULL_HANDLE;
316 }
317 
add_pipeline(Hash hash,VkPipeline pipeline)318 VkPipeline Program::add_pipeline(Hash hash, VkPipeline pipeline)
319 {
320 	return pipelines.emplace_yield(hash, pipeline)->get();
321 }
322 
~Program()323 Program::~Program()
324 {
325 	for (auto &pipe : pipelines)
326 	{
327 		if (internal_sync)
328 			device->destroy_pipeline_nolock(pipe.get());
329 		else
330 			device->destroy_pipeline(pipe.get());
331 	}
332 }
333 }
334