1 use std::{ffi, ptr, slice};
2 
3 use spirv_cross::{hlsl, spirv, ErrorCode as SpirvErrorCode};
4 
5 use winapi::shared::winerror;
6 use winapi::um::{d3dcommon, d3dcompiler};
7 use wio::com::ComPtr;
8 
9 use auxil::spirv_cross_specialize_ast;
10 use hal::{device, pso};
11 
12 use {conv, Backend, PipelineLayout};
13 
14 /// Emit error during shader module creation. Used if we don't expect an error
15 /// but might panic due to an exception in SPIRV-Cross.
gen_unexpected_error(err: SpirvErrorCode) -> device::ShaderError16 fn gen_unexpected_error(err: SpirvErrorCode) -> device::ShaderError {
17     let msg = match err {
18         SpirvErrorCode::CompilationError(msg) => msg,
19         SpirvErrorCode::Unhandled => "Unexpected error".into(),
20     };
21     device::ShaderError::CompilationFailed(msg)
22 }
23 
24 /// Emit error during shader module creation. Used if we execute an query command.
gen_query_error(err: SpirvErrorCode) -> device::ShaderError25 fn gen_query_error(err: SpirvErrorCode) -> device::ShaderError {
26     let msg = match err {
27         SpirvErrorCode::CompilationError(msg) => msg,
28         SpirvErrorCode::Unhandled => "Unknown query error".into(),
29     };
30     device::ShaderError::CompilationFailed(msg)
31 }
32 
compile_spirv_entrypoint( raw_data: &[u32], stage: pso::Stage, source: &pso::EntryPoint<Backend>, layout: &PipelineLayout, ) -> Result<Option<ComPtr<d3dcommon::ID3DBlob>>, device::ShaderError>33 pub(crate) fn compile_spirv_entrypoint(
34     raw_data: &[u32],
35     stage: pso::Stage,
36     source: &pso::EntryPoint<Backend>,
37     layout: &PipelineLayout,
38 ) -> Result<Option<ComPtr<d3dcommon::ID3DBlob>>, device::ShaderError> {
39     let mut ast = parse_spirv(raw_data)?;
40     spirv_cross_specialize_ast(&mut ast, &source.specialization)?;
41 
42     patch_spirv_resources(&mut ast, layout)?;
43     let shader_model = hlsl::ShaderModel::V5_0;
44     let shader_code = translate_spirv(&mut ast, shader_model, layout, stage)?;
45 
46     let real_name = ast
47         .get_cleansed_entry_point_name(source.entry, conv::map_stage(stage))
48         .map_err(gen_query_error)?;
49 
50     // TODO: opt: don't query *all* entry points.
51     let entry_points = ast.get_entry_points().map_err(gen_query_error)?;
52     entry_points
53         .iter()
54         .find(|entry_point| entry_point.name == real_name)
55         .ok_or(device::ShaderError::MissingEntryPoint(source.entry.into()))
56         .and_then(|entry_point| {
57             let stage = conv::map_execution_model(entry_point.execution_model);
58             let shader = compile_hlsl_shader(
59                 stage,
60                 shader_model,
61                 &entry_point.name,
62                 shader_code.as_bytes(),
63             )?;
64             Ok(Some(unsafe { ComPtr::from_raw(shader) }))
65         })
66 }
67 
compile_hlsl_shader( stage: pso::Stage, shader_model: hlsl::ShaderModel, entry: &str, code: &[u8], ) -> Result<*mut d3dcommon::ID3DBlob, device::ShaderError>68 pub(crate) fn compile_hlsl_shader(
69     stage: pso::Stage,
70     shader_model: hlsl::ShaderModel,
71     entry: &str,
72     code: &[u8],
73 ) -> Result<*mut d3dcommon::ID3DBlob, device::ShaderError> {
74     let stage_to_str = |stage, shader_model| {
75         let stage = match stage {
76             pso::Stage::Vertex => "vs",
77             pso::Stage::Fragment => "ps",
78             pso::Stage::Compute => "cs",
79             _ => unimplemented!(),
80         };
81 
82         let model = match shader_model {
83             hlsl::ShaderModel::V5_0 => "5_0",
84             // TODO: >= 11.3
85             hlsl::ShaderModel::V5_1 => "5_1",
86             // TODO: >= 12?, no mention of 11 on msdn
87             hlsl::ShaderModel::V6_0 => "6_0",
88             _ => unimplemented!(),
89         };
90 
91         format!("{}_{}\0", stage, model)
92     };
93 
94     let mut blob = ptr::null_mut();
95     let mut error = ptr::null_mut();
96     let entry = ffi::CString::new(entry).unwrap();
97     let hr = unsafe {
98         d3dcompiler::D3DCompile(
99             code.as_ptr() as *const _,
100             code.len(),
101             ptr::null(),
102             ptr::null(),
103             ptr::null_mut(),
104             entry.as_ptr() as *const _,
105             stage_to_str(stage, shader_model).as_ptr() as *const i8,
106             1,
107             0,
108             &mut blob as *mut *mut _,
109             &mut error as *mut *mut _,
110         )
111     };
112 
113     if !winerror::SUCCEEDED(hr) {
114         let error = unsafe { ComPtr::<d3dcommon::ID3DBlob>::from_raw(error) };
115         let message = unsafe {
116             let pointer = error.GetBufferPointer();
117             let size = error.GetBufferSize();
118             let slice = slice::from_raw_parts(pointer as *const u8, size as usize);
119             String::from_utf8_lossy(slice).into_owned()
120         };
121 
122         Err(device::ShaderError::CompilationFailed(message))
123     } else {
124         Ok(blob)
125     }
126 }
127 
parse_spirv(raw_data: &[u32]) -> Result<spirv::Ast<hlsl::Target>, device::ShaderError>128 fn parse_spirv(raw_data: &[u32]) -> Result<spirv::Ast<hlsl::Target>, device::ShaderError> {
129     let module = spirv::Module::from_words(raw_data);
130 
131     spirv::Ast::parse(&module).map_err(|err| {
132         let msg = match err {
133             SpirvErrorCode::CompilationError(msg) => msg,
134             SpirvErrorCode::Unhandled => "Unknown parsing error".into(),
135         };
136         device::ShaderError::CompilationFailed(msg)
137     })
138 }
139 
patch_spirv_resources( ast: &mut spirv::Ast<hlsl::Target>, layout: &PipelineLayout, ) -> Result<(), device::ShaderError>140 fn patch_spirv_resources(
141     ast: &mut spirv::Ast<hlsl::Target>,
142     layout: &PipelineLayout,
143 ) -> Result<(), device::ShaderError> {
144     // we remap all `layout(binding = n, set = n)` to a flat space which we get from our
145     // `PipelineLayout` which knows of all descriptor set layouts
146 
147     let shader_resources = ast.get_shader_resources().map_err(gen_query_error)?;
148     for image in &shader_resources.separate_images {
149         let set = ast
150             .get_decoration(image.id, spirv::Decoration::DescriptorSet)
151             .map_err(gen_query_error)? as usize;
152         let binding = ast
153             .get_decoration(image.id, spirv::Decoration::Binding)
154             .map_err(gen_query_error)?;
155         let mapping = layout.set_remapping[set]
156             .mapping
157             .iter()
158             .find(|&mapping| binding == mapping.spirv_binding)
159             .unwrap();
160 
161         ast.set_decoration(
162             image.id,
163             spirv::Decoration::Binding,
164             mapping.hlsl_register as u32,
165         )
166         .map_err(gen_unexpected_error)?;
167     }
168 
169     for uniform_buffer in &shader_resources.uniform_buffers {
170         let set = ast
171             .get_decoration(uniform_buffer.id, spirv::Decoration::DescriptorSet)
172             .map_err(gen_query_error)? as usize;
173         let binding = ast
174             .get_decoration(uniform_buffer.id, spirv::Decoration::Binding)
175             .map_err(gen_query_error)?;
176         let mapping = layout.set_remapping[set]
177             .mapping
178             .iter()
179             .find(|&mapping| binding == mapping.spirv_binding)
180             .unwrap();
181 
182         ast.set_decoration(
183             uniform_buffer.id,
184             spirv::Decoration::Binding,
185             mapping.hlsl_register as u32,
186         )
187         .map_err(gen_unexpected_error)?;
188     }
189 
190     for storage_buffer in &shader_resources.storage_buffers {
191         let set = ast
192             .get_decoration(storage_buffer.id, spirv::Decoration::DescriptorSet)
193             .map_err(gen_query_error)? as usize;
194         let binding = ast
195             .get_decoration(storage_buffer.id, spirv::Decoration::Binding)
196             .map_err(gen_query_error)?;
197         let mapping = layout.set_remapping[set]
198             .mapping
199             .iter()
200             .find(|&mapping| binding == mapping.spirv_binding)
201             .unwrap();
202 
203         ast.set_decoration(
204             storage_buffer.id,
205             spirv::Decoration::Binding,
206             mapping.hlsl_register as u32,
207         )
208         .map_err(gen_unexpected_error)?;
209     }
210 
211     for image in &shader_resources.storage_images {
212         let set = ast
213             .get_decoration(image.id, spirv::Decoration::DescriptorSet)
214             .map_err(gen_query_error)? as usize;
215         let binding = ast
216             .get_decoration(image.id, spirv::Decoration::Binding)
217             .map_err(gen_query_error)?;
218         let mapping = layout.set_remapping[set]
219             .mapping
220             .iter()
221             .find(|&mapping| binding == mapping.spirv_binding)
222             .unwrap();
223 
224         ast.set_decoration(
225             image.id,
226             spirv::Decoration::Binding,
227             mapping.hlsl_register as u32,
228         )
229         .map_err(gen_unexpected_error)?;
230     }
231 
232     for sampler in &shader_resources.separate_samplers {
233         let set = ast
234             .get_decoration(sampler.id, spirv::Decoration::DescriptorSet)
235             .map_err(gen_query_error)? as usize;
236         let binding = ast
237             .get_decoration(sampler.id, spirv::Decoration::Binding)
238             .map_err(gen_query_error)?;
239         let mapping = layout.set_remapping[set]
240             .mapping
241             .iter()
242             .find(|&mapping| binding == mapping.spirv_binding)
243             .unwrap();
244 
245         ast.set_decoration(
246             sampler.id,
247             spirv::Decoration::Binding,
248             mapping.hlsl_register as u32,
249         )
250         .map_err(gen_unexpected_error)?;
251     }
252 
253     for image in &shader_resources.sampled_images {
254         let set = ast
255             .get_decoration(image.id, spirv::Decoration::DescriptorSet)
256             .map_err(gen_query_error)? as usize;
257         let binding = ast
258             .get_decoration(image.id, spirv::Decoration::Binding)
259             .map_err(gen_query_error)?;
260         let mapping = layout.set_remapping[set]
261             .mapping
262             .iter()
263             .find(|&mapping| binding == mapping.spirv_binding)
264             .unwrap();
265 
266         ast.set_decoration(
267             image.id,
268             spirv::Decoration::Binding,
269             mapping.hlsl_register as u32,
270         )
271         .map_err(gen_unexpected_error)?;
272     }
273 
274     Ok(())
275 }
276 
translate_spirv( ast: &mut spirv::Ast<hlsl::Target>, shader_model: hlsl::ShaderModel, _layout: &PipelineLayout, _stage: pso::Stage, ) -> Result<String, device::ShaderError>277 fn translate_spirv(
278     ast: &mut spirv::Ast<hlsl::Target>,
279     shader_model: hlsl::ShaderModel,
280     _layout: &PipelineLayout,
281     _stage: pso::Stage,
282 ) -> Result<String, device::ShaderError> {
283     let mut compile_options = hlsl::CompilerOptions::default();
284     compile_options.shader_model = shader_model;
285     compile_options.vertex.invert_y = true;
286 
287     //let stage_flag = stage.into();
288 
289     // TODO:
290     /*let root_constant_layout = layout
291     .root_constants
292     .iter()
293     .filter_map(|constant| if constant.stages.contains(stage_flag) {
294         Some(hlsl::RootConstant {
295             start: constant.range.start * 4,
296             end: constant.range.end * 4,
297             binding: constant.range.start,
298             space: 0,
299         })
300     } else {
301         None
302     })
303     .collect();*/
304     ast.set_compiler_options(&compile_options)
305         .map_err(gen_unexpected_error)?;
306     //ast.set_root_constant_layout(root_constant_layout)
307     //    .map_err(gen_unexpected_error)?;
308     ast.compile().map_err(|err| {
309         let msg = match err {
310             SpirvErrorCode::CompilationError(msg) => msg,
311             SpirvErrorCode::Unhandled => "Unknown compile error".into(),
312         };
313         device::ShaderError::CompilationFailed(msg)
314     })
315 }
316