1 use std::{ffi, ptr, slice};
2
3 use spirv_cross::{hlsl, spirv, ErrorCode as SpirvErrorCode};
4
5 use winapi::shared::winerror;
6 use winapi::um::{d3dcommon, d3dcompiler};
7 use wio::com::ComPtr;
8
9 use auxil::spirv_cross_specialize_ast;
10 use hal::{device, pso};
11
12 use {conv, Backend, PipelineLayout};
13
14 /// Emit error during shader module creation. Used if we don't expect an error
15 /// but might panic due to an exception in SPIRV-Cross.
gen_unexpected_error(err: SpirvErrorCode) -> device::ShaderError16 fn gen_unexpected_error(err: SpirvErrorCode) -> device::ShaderError {
17 let msg = match err {
18 SpirvErrorCode::CompilationError(msg) => msg,
19 SpirvErrorCode::Unhandled => "Unexpected error".into(),
20 };
21 device::ShaderError::CompilationFailed(msg)
22 }
23
24 /// Emit error during shader module creation. Used if we execute an query command.
gen_query_error(err: SpirvErrorCode) -> device::ShaderError25 fn gen_query_error(err: SpirvErrorCode) -> device::ShaderError {
26 let msg = match err {
27 SpirvErrorCode::CompilationError(msg) => msg,
28 SpirvErrorCode::Unhandled => "Unknown query error".into(),
29 };
30 device::ShaderError::CompilationFailed(msg)
31 }
32
compile_spirv_entrypoint( raw_data: &[u32], stage: pso::Stage, source: &pso::EntryPoint<Backend>, layout: &PipelineLayout, ) -> Result<Option<ComPtr<d3dcommon::ID3DBlob>>, device::ShaderError>33 pub(crate) fn compile_spirv_entrypoint(
34 raw_data: &[u32],
35 stage: pso::Stage,
36 source: &pso::EntryPoint<Backend>,
37 layout: &PipelineLayout,
38 ) -> Result<Option<ComPtr<d3dcommon::ID3DBlob>>, device::ShaderError> {
39 let mut ast = parse_spirv(raw_data)?;
40 spirv_cross_specialize_ast(&mut ast, &source.specialization)?;
41
42 patch_spirv_resources(&mut ast, layout)?;
43 let shader_model = hlsl::ShaderModel::V5_0;
44 let shader_code = translate_spirv(&mut ast, shader_model, layout, stage)?;
45
46 let real_name = ast
47 .get_cleansed_entry_point_name(source.entry, conv::map_stage(stage))
48 .map_err(gen_query_error)?;
49
50 // TODO: opt: don't query *all* entry points.
51 let entry_points = ast.get_entry_points().map_err(gen_query_error)?;
52 entry_points
53 .iter()
54 .find(|entry_point| entry_point.name == real_name)
55 .ok_or(device::ShaderError::MissingEntryPoint(source.entry.into()))
56 .and_then(|entry_point| {
57 let stage = conv::map_execution_model(entry_point.execution_model);
58 let shader = compile_hlsl_shader(
59 stage,
60 shader_model,
61 &entry_point.name,
62 shader_code.as_bytes(),
63 )?;
64 Ok(Some(unsafe { ComPtr::from_raw(shader) }))
65 })
66 }
67
compile_hlsl_shader( stage: pso::Stage, shader_model: hlsl::ShaderModel, entry: &str, code: &[u8], ) -> Result<*mut d3dcommon::ID3DBlob, device::ShaderError>68 pub(crate) fn compile_hlsl_shader(
69 stage: pso::Stage,
70 shader_model: hlsl::ShaderModel,
71 entry: &str,
72 code: &[u8],
73 ) -> Result<*mut d3dcommon::ID3DBlob, device::ShaderError> {
74 let stage_to_str = |stage, shader_model| {
75 let stage = match stage {
76 pso::Stage::Vertex => "vs",
77 pso::Stage::Fragment => "ps",
78 pso::Stage::Compute => "cs",
79 _ => unimplemented!(),
80 };
81
82 let model = match shader_model {
83 hlsl::ShaderModel::V5_0 => "5_0",
84 // TODO: >= 11.3
85 hlsl::ShaderModel::V5_1 => "5_1",
86 // TODO: >= 12?, no mention of 11 on msdn
87 hlsl::ShaderModel::V6_0 => "6_0",
88 _ => unimplemented!(),
89 };
90
91 format!("{}_{}\0", stage, model)
92 };
93
94 let mut blob = ptr::null_mut();
95 let mut error = ptr::null_mut();
96 let entry = ffi::CString::new(entry).unwrap();
97 let hr = unsafe {
98 d3dcompiler::D3DCompile(
99 code.as_ptr() as *const _,
100 code.len(),
101 ptr::null(),
102 ptr::null(),
103 ptr::null_mut(),
104 entry.as_ptr() as *const _,
105 stage_to_str(stage, shader_model).as_ptr() as *const i8,
106 1,
107 0,
108 &mut blob as *mut *mut _,
109 &mut error as *mut *mut _,
110 )
111 };
112
113 if !winerror::SUCCEEDED(hr) {
114 let error = unsafe { ComPtr::<d3dcommon::ID3DBlob>::from_raw(error) };
115 let message = unsafe {
116 let pointer = error.GetBufferPointer();
117 let size = error.GetBufferSize();
118 let slice = slice::from_raw_parts(pointer as *const u8, size as usize);
119 String::from_utf8_lossy(slice).into_owned()
120 };
121
122 Err(device::ShaderError::CompilationFailed(message))
123 } else {
124 Ok(blob)
125 }
126 }
127
parse_spirv(raw_data: &[u32]) -> Result<spirv::Ast<hlsl::Target>, device::ShaderError>128 fn parse_spirv(raw_data: &[u32]) -> Result<spirv::Ast<hlsl::Target>, device::ShaderError> {
129 let module = spirv::Module::from_words(raw_data);
130
131 spirv::Ast::parse(&module).map_err(|err| {
132 let msg = match err {
133 SpirvErrorCode::CompilationError(msg) => msg,
134 SpirvErrorCode::Unhandled => "Unknown parsing error".into(),
135 };
136 device::ShaderError::CompilationFailed(msg)
137 })
138 }
139
patch_spirv_resources( ast: &mut spirv::Ast<hlsl::Target>, layout: &PipelineLayout, ) -> Result<(), device::ShaderError>140 fn patch_spirv_resources(
141 ast: &mut spirv::Ast<hlsl::Target>,
142 layout: &PipelineLayout,
143 ) -> Result<(), device::ShaderError> {
144 // we remap all `layout(binding = n, set = n)` to a flat space which we get from our
145 // `PipelineLayout` which knows of all descriptor set layouts
146
147 let shader_resources = ast.get_shader_resources().map_err(gen_query_error)?;
148 for image in &shader_resources.separate_images {
149 let set = ast
150 .get_decoration(image.id, spirv::Decoration::DescriptorSet)
151 .map_err(gen_query_error)? as usize;
152 let binding = ast
153 .get_decoration(image.id, spirv::Decoration::Binding)
154 .map_err(gen_query_error)?;
155 let mapping = layout.set_remapping[set]
156 .mapping
157 .iter()
158 .find(|&mapping| binding == mapping.spirv_binding)
159 .unwrap();
160
161 ast.set_decoration(
162 image.id,
163 spirv::Decoration::Binding,
164 mapping.hlsl_register as u32,
165 )
166 .map_err(gen_unexpected_error)?;
167 }
168
169 for uniform_buffer in &shader_resources.uniform_buffers {
170 let set = ast
171 .get_decoration(uniform_buffer.id, spirv::Decoration::DescriptorSet)
172 .map_err(gen_query_error)? as usize;
173 let binding = ast
174 .get_decoration(uniform_buffer.id, spirv::Decoration::Binding)
175 .map_err(gen_query_error)?;
176 let mapping = layout.set_remapping[set]
177 .mapping
178 .iter()
179 .find(|&mapping| binding == mapping.spirv_binding)
180 .unwrap();
181
182 ast.set_decoration(
183 uniform_buffer.id,
184 spirv::Decoration::Binding,
185 mapping.hlsl_register as u32,
186 )
187 .map_err(gen_unexpected_error)?;
188 }
189
190 for storage_buffer in &shader_resources.storage_buffers {
191 let set = ast
192 .get_decoration(storage_buffer.id, spirv::Decoration::DescriptorSet)
193 .map_err(gen_query_error)? as usize;
194 let binding = ast
195 .get_decoration(storage_buffer.id, spirv::Decoration::Binding)
196 .map_err(gen_query_error)?;
197 let mapping = layout.set_remapping[set]
198 .mapping
199 .iter()
200 .find(|&mapping| binding == mapping.spirv_binding)
201 .unwrap();
202
203 ast.set_decoration(
204 storage_buffer.id,
205 spirv::Decoration::Binding,
206 mapping.hlsl_register as u32,
207 )
208 .map_err(gen_unexpected_error)?;
209 }
210
211 for image in &shader_resources.storage_images {
212 let set = ast
213 .get_decoration(image.id, spirv::Decoration::DescriptorSet)
214 .map_err(gen_query_error)? as usize;
215 let binding = ast
216 .get_decoration(image.id, spirv::Decoration::Binding)
217 .map_err(gen_query_error)?;
218 let mapping = layout.set_remapping[set]
219 .mapping
220 .iter()
221 .find(|&mapping| binding == mapping.spirv_binding)
222 .unwrap();
223
224 ast.set_decoration(
225 image.id,
226 spirv::Decoration::Binding,
227 mapping.hlsl_register as u32,
228 )
229 .map_err(gen_unexpected_error)?;
230 }
231
232 for sampler in &shader_resources.separate_samplers {
233 let set = ast
234 .get_decoration(sampler.id, spirv::Decoration::DescriptorSet)
235 .map_err(gen_query_error)? as usize;
236 let binding = ast
237 .get_decoration(sampler.id, spirv::Decoration::Binding)
238 .map_err(gen_query_error)?;
239 let mapping = layout.set_remapping[set]
240 .mapping
241 .iter()
242 .find(|&mapping| binding == mapping.spirv_binding)
243 .unwrap();
244
245 ast.set_decoration(
246 sampler.id,
247 spirv::Decoration::Binding,
248 mapping.hlsl_register as u32,
249 )
250 .map_err(gen_unexpected_error)?;
251 }
252
253 for image in &shader_resources.sampled_images {
254 let set = ast
255 .get_decoration(image.id, spirv::Decoration::DescriptorSet)
256 .map_err(gen_query_error)? as usize;
257 let binding = ast
258 .get_decoration(image.id, spirv::Decoration::Binding)
259 .map_err(gen_query_error)?;
260 let mapping = layout.set_remapping[set]
261 .mapping
262 .iter()
263 .find(|&mapping| binding == mapping.spirv_binding)
264 .unwrap();
265
266 ast.set_decoration(
267 image.id,
268 spirv::Decoration::Binding,
269 mapping.hlsl_register as u32,
270 )
271 .map_err(gen_unexpected_error)?;
272 }
273
274 Ok(())
275 }
276
translate_spirv( ast: &mut spirv::Ast<hlsl::Target>, shader_model: hlsl::ShaderModel, _layout: &PipelineLayout, _stage: pso::Stage, ) -> Result<String, device::ShaderError>277 fn translate_spirv(
278 ast: &mut spirv::Ast<hlsl::Target>,
279 shader_model: hlsl::ShaderModel,
280 _layout: &PipelineLayout,
281 _stage: pso::Stage,
282 ) -> Result<String, device::ShaderError> {
283 let mut compile_options = hlsl::CompilerOptions::default();
284 compile_options.shader_model = shader_model;
285 compile_options.vertex.invert_y = true;
286
287 //let stage_flag = stage.into();
288
289 // TODO:
290 /*let root_constant_layout = layout
291 .root_constants
292 .iter()
293 .filter_map(|constant| if constant.stages.contains(stage_flag) {
294 Some(hlsl::RootConstant {
295 start: constant.range.start * 4,
296 end: constant.range.end * 4,
297 binding: constant.range.start,
298 space: 0,
299 })
300 } else {
301 None
302 })
303 .collect();*/
304 ast.set_compiler_options(&compile_options)
305 .map_err(gen_unexpected_error)?;
306 //ast.set_root_constant_layout(root_constant_layout)
307 // .map_err(gen_unexpected_error)?;
308 ast.compile().map_err(|err| {
309 let msg = match err {
310 SpirvErrorCode::CompilationError(msg) => msg,
311 SpirvErrorCode::Unhandled => "Unknown compile error".into(),
312 };
313 device::ShaderError::CompilationFailed(msg)
314 })
315 }
316