1 use crate::bindings as br;
2 use crate::{compiler, spirv, ErrorCode};
3 
4 use std::collections::BTreeMap;
5 use std::ffi::CStr;
6 use std::marker::PhantomData;
7 use std::ptr;
8 use std::u8;
9 
10 /// A MSL target.
11 #[derive(Debug, Clone)]
12 pub enum Target {}
13 
14 pub struct TargetData {
15     vertex_attribute_overrides: Vec<br::SPIRV_CROSS_NAMESPACE::MSLVertexAttr>,
16     resource_binding_overrides: Vec<br::SPIRV_CROSS_NAMESPACE::MSLResourceBinding>,
17     const_samplers: Vec<br::MslConstSamplerMapping>,
18 }
19 
20 impl spirv::Target for Target {
21     type Data = TargetData;
22 }
23 
24 /// Location of a vertex attribute to override
25 #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
26 pub struct VertexAttributeLocation(pub u32);
27 
28 /// Format of the vertex attribute
29 #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
30 pub enum Format {
31     Other,
32     Uint8,
33     Uint16,
34 }
35 
36 impl Format {
as_raw(&self) -> br::SPIRV_CROSS_NAMESPACE::MSLVertexFormat37     fn as_raw(&self) -> br::SPIRV_CROSS_NAMESPACE::MSLVertexFormat {
38         use self::Format::*;
39         use crate::bindings::root::SPIRV_CROSS_NAMESPACE::MSLVertexFormat as R;
40         match self {
41             Other => R::MSL_VERTEX_FORMAT_OTHER,
42             Uint8 => R::MSL_VERTEX_FORMAT_UINT8,
43             Uint16 => R::MSL_VERTEX_FORMAT_UINT16,
44         }
45     }
46 }
47 
48 /// Vertex attribute description for overriding
49 #[derive(Debug, Clone, Hash, Eq, PartialEq)]
50 pub struct VertexAttribute {
51     pub buffer_id: u32,
52     pub offset: u32,
53     pub stride: u32,
54     pub step: spirv::VertexAttributeStep,
55     pub format: Format,
56     pub built_in: Option<spirv::BuiltIn>,
57 }
58 
59 /// Location of a resource binding to override
60 #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
61 pub struct ResourceBindingLocation {
62     pub stage: spirv::ExecutionModel,
63     pub desc_set: u32,
64     pub binding: u32,
65 }
66 
67 /// Resource binding description for overriding
68 #[derive(Debug, Clone, Hash, Eq, PartialEq)]
69 pub struct ResourceBinding {
70     pub buffer_id: u32,
71     pub texture_id: u32,
72     pub sampler_id: u32,
73 }
74 
75 /// Location of a sampler binding to override
76 #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
77 pub struct SamplerLocation {
78     pub desc_set: u32,
79     pub binding: u32,
80 }
81 
82 #[repr(C)]
83 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
84 pub enum SamplerCoord {
85     Normalized = 0,
86     Pixel = 1,
87 }
88 
89 #[repr(C)]
90 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
91 pub enum SamplerFilter {
92     Nearest = 0,
93     Linear = 1,
94 }
95 
96 #[repr(C)]
97 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
98 pub enum SamplerMipFilter {
99     None = 0,
100     Nearest = 1,
101     Linear = 2,
102 }
103 
104 #[repr(C)]
105 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
106 pub enum SamplerAddress {
107     ClampToZero = 0,
108     ClampToEdge = 1,
109     ClampToBorder = 2,
110     Repeat = 3,
111     MirroredRepeat = 4,
112 }
113 
114 #[repr(C)]
115 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
116 pub enum SamplerCompareFunc {
117     Never = 0,
118     Less = 1,
119     LessEqual = 2,
120     Greater = 3,
121     GreaterEqual = 4,
122     Equal = 5,
123     NotEqual = 6,
124     Always = 7,
125 }
126 
127 #[repr(C)]
128 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
129 pub enum SamplerBorderColor {
130     TransparentBlack = 0,
131     OpaqueBlack = 1,
132     OpaqueWhite = 2,
133 }
134 
135 #[repr(transparent)]
136 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
137 pub struct LodBase16(u8);
138 
139 impl LodBase16 {
140     pub const ZERO: Self = LodBase16(0);
141     pub const MAX: Self = LodBase16(!0);
142 }
143 impl From<f32> for LodBase16 {
from(v: f32) -> Self144     fn from(v: f32) -> Self {
145         LodBase16((v * 16.0).max(0.0).min(u8::MAX as f32) as u8)
146     }
147 }
148 impl Into<f32> for LodBase16 {
into(self) -> f32149     fn into(self) -> f32 {
150         self.0 as f32 / 16.0
151     }
152 }
153 
154 /// Data fully defining a constant sampler.
155 #[derive(Debug, Clone, Hash, Eq, PartialEq)]
156 pub struct SamplerData {
157     pub coord: SamplerCoord,
158     pub min_filter: SamplerFilter,
159     pub mag_filter: SamplerFilter,
160     pub mip_filter: SamplerMipFilter,
161     pub s_address: SamplerAddress,
162     pub t_address: SamplerAddress,
163     pub r_address: SamplerAddress,
164     pub compare_func: SamplerCompareFunc,
165     pub border_color: SamplerBorderColor,
166     pub lod_clamp_min: LodBase16,
167     pub lod_clamp_max: LodBase16,
168     pub max_anisotropy: i32,
169 }
170 
171 /// A MSL shader platform.
172 #[repr(u8)]
173 #[allow(non_snake_case, non_camel_case_types)]
174 #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
175 pub enum Platform {
176     iOS = 0,
177     macOS = 1,
178 }
179 
180 /// A MSL shader model version.
181 #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
182 pub enum Version {
183     V1_0,
184     V1_1,
185     V1_2,
186     V2_0,
187     V2_1,
188     V2_2,
189 }
190 
191 impl Version {
as_raw(self) -> u32192     fn as_raw(self) -> u32 {
193         use self::Version::*;
194         match self {
195             V1_0 => 10000,
196             V1_1 => 10100,
197             V1_2 => 10200,
198             V2_0 => 20000,
199             V2_1 => 20100,
200             V2_2 => 20200,
201         }
202     }
203 }
204 
205 #[derive(Debug, Clone, Hash, Eq, PartialEq)]
206 pub struct CompilerVertexOptions {
207     pub invert_y: bool,
208     pub transform_clip_space: bool,
209 }
210 
211 impl Default for CompilerVertexOptions {
default() -> Self212     fn default() -> Self {
213         CompilerVertexOptions {
214             invert_y: false,
215             transform_clip_space: false,
216         }
217     }
218 }
219 
220 /// MSL compiler options.
221 #[derive(Debug, Clone, Hash, Eq, PartialEq)]
222 pub struct CompilerOptions {
223     /// The target platform.
224     pub platform: Platform,
225     /// The target MSL version.
226     pub version: Version,
227     /// Vertex compiler options.
228     pub vertex: CompilerVertexOptions,
229     /// The buffer index to use for swizzle.
230     pub swizzle_buffer_index: u32,
231     // The buffer index to use for indirect params.
232     pub indirect_params_buffer_index: u32,
233     /// The buffer index to use for output.
234     pub output_buffer_index: u32,
235     /// The buffer index to use for patch output.
236     pub patch_output_buffer_index: u32,
237     /// The buffer index to use for tessellation factor.
238     pub tessellation_factor_buffer_index: u32,
239     /// The buffer index to use for buffer size.
240     pub buffer_size_buffer_index: u32,
241     /// Whether the built-in point size should be enabled.
242     pub enable_point_size_builtin: bool,
243     /// Whether rasterization should be enabled.
244     pub enable_rasterization: bool,
245     /// Whether to capture output to buffer.
246     pub capture_output_to_buffer: bool,
247     /// Whether to swizzle texture samples.
248     pub swizzle_texture_samples: bool,
249     /// Whether to place the origin of tessellation domain shaders in the lower left.
250     pub tessellation_domain_origin_lower_left: bool,
251     /// Whether to enable use of argument buffers (only compatible with MSL 2.0).
252     pub enable_argument_buffers: bool,
253     /// Whether to pad fragment output to have at least the number of components as the render pass.
254     pub pad_fragment_output_components: bool,
255     /// MSL resource bindings overrides.
256     pub resource_binding_overrides: BTreeMap<ResourceBindingLocation, ResourceBinding>,
257     /// MSL vertex attribute overrides.
258     pub vertex_attribute_overrides: BTreeMap<VertexAttributeLocation, VertexAttribute>,
259     /// MSL const sampler mappings.
260     pub const_samplers: BTreeMap<SamplerLocation, SamplerData>,
261 }
262 
263 impl CompilerOptions {
as_raw(&self) -> br::ScMslCompilerOptions264     fn as_raw(&self) -> br::ScMslCompilerOptions {
265         br::ScMslCompilerOptions {
266             vertex_invert_y: self.vertex.invert_y,
267             vertex_transform_clip_space: self.vertex.transform_clip_space,
268             platform: self.platform as _,
269             version: self.version.as_raw(),
270             enable_point_size_builtin: self.enable_point_size_builtin,
271             disable_rasterization: !self.enable_rasterization,
272             swizzle_buffer_index: self.swizzle_buffer_index,
273             indirect_params_buffer_index: self.indirect_params_buffer_index,
274             shader_output_buffer_index: self.output_buffer_index,
275             shader_patch_output_buffer_index: self.patch_output_buffer_index,
276             shader_tess_factor_buffer_index: self.tessellation_factor_buffer_index,
277             buffer_size_buffer_index: self.buffer_size_buffer_index,
278             capture_output_to_buffer: self.capture_output_to_buffer,
279             swizzle_texture_samples: self.swizzle_texture_samples,
280             tess_domain_origin_lower_left: self.tessellation_domain_origin_lower_left,
281             argument_buffers: self.enable_argument_buffers,
282             pad_fragment_output_components: self.pad_fragment_output_components,
283         }
284     }
285 }
286 
287 impl Default for CompilerOptions {
default() -> Self288     fn default() -> Self {
289         CompilerOptions {
290             platform: Platform::macOS,
291             version: Version::V1_2,
292             vertex: CompilerVertexOptions::default(),
293             swizzle_buffer_index: 30,
294             indirect_params_buffer_index: 29,
295             output_buffer_index: 28,
296             patch_output_buffer_index: 27,
297             tessellation_factor_buffer_index: 26,
298             buffer_size_buffer_index: 25,
299             enable_point_size_builtin: true,
300             enable_rasterization: true,
301             capture_output_to_buffer: false,
302             swizzle_texture_samples: false,
303             tessellation_domain_origin_lower_left: false,
304             enable_argument_buffers: false,
305             pad_fragment_output_components: false,
306             resource_binding_overrides: Default::default(),
307             vertex_attribute_overrides: Default::default(),
308             const_samplers: Default::default(),
309         }
310     }
311 }
312 
313 impl<'a> spirv::Parse<Target> for spirv::Ast<Target> {
parse(module: &spirv::Module) -> Result<Self, ErrorCode>314     fn parse(module: &spirv::Module) -> Result<Self, ErrorCode> {
315         let mut sc_compiler = ptr::null_mut();
316         unsafe {
317             check!(br::sc_internal_compiler_msl_new(
318                 &mut sc_compiler,
319                 module.words.as_ptr(),
320                 module.words.len(),
321             ));
322         }
323 
324         Ok(spirv::Ast {
325             compiler: compiler::Compiler {
326                 sc_compiler,
327                 target_data: TargetData {
328                     resource_binding_overrides: Vec::new(),
329                     vertex_attribute_overrides: Vec::new(),
330                     const_samplers: Vec::new(),
331                 },
332                 has_been_compiled: false,
333             },
334             target_type: PhantomData,
335         })
336     }
337 }
338 
339 impl spirv::Compile<Target> for spirv::Ast<Target> {
340     type CompilerOptions = CompilerOptions;
341 
342     /// Set MSL compiler specific compilation settings.
set_compiler_options(&mut self, options: &CompilerOptions) -> Result<(), ErrorCode>343     fn set_compiler_options(&mut self, options: &CompilerOptions) -> Result<(), ErrorCode> {
344         let raw_options = options.as_raw();
345         unsafe {
346             check!(br::sc_internal_compiler_msl_set_options(
347                 self.compiler.sc_compiler,
348                 &raw_options,
349             ));
350         }
351 
352         self.compiler.target_data.resource_binding_overrides.clear();
353         self.compiler.target_data.resource_binding_overrides.extend(
354             options.resource_binding_overrides.iter().map(|(loc, res)| {
355                 br::SPIRV_CROSS_NAMESPACE::MSLResourceBinding {
356                     stage: loc.stage.as_raw(),
357                     desc_set: loc.desc_set,
358                     binding: loc.binding,
359                     msl_buffer: res.buffer_id,
360                     msl_texture: res.texture_id,
361                     msl_sampler: res.sampler_id,
362                 }
363             }),
364         );
365 
366         self.compiler.target_data.vertex_attribute_overrides.clear();
367         self.compiler.target_data.vertex_attribute_overrides.extend(
368             options.vertex_attribute_overrides.iter().map(|(loc, vat)| {
369                 br::SPIRV_CROSS_NAMESPACE::MSLVertexAttr {
370                     location: loc.0,
371                     msl_buffer: vat.buffer_id,
372                     msl_offset: vat.offset,
373                     msl_stride: vat.stride,
374                     per_instance: match vat.step {
375                         spirv::VertexAttributeStep::Vertex => false,
376                         spirv::VertexAttributeStep::Instance => true,
377                     },
378                     format: vat.format.as_raw(),
379                     builtin: spirv::built_in_as_raw(vat.built_in),
380                 }
381             }),
382         );
383 
384         self.compiler.target_data.const_samplers.clear();
385         self.compiler.target_data.const_samplers.extend(
386             options.const_samplers.iter().map(|(loc, data)| unsafe {
387                 use std::mem::transmute;
388                 br::MslConstSamplerMapping {
389                     desc_set: loc.desc_set,
390                     binding: loc.binding,
391                     sampler: br::SPIRV_CROSS_NAMESPACE::MSLConstexprSampler {
392                         coord: transmute(data.coord),
393                         min_filter: transmute(data.min_filter),
394                         mag_filter: transmute(data.mag_filter),
395                         mip_filter: transmute(data.mip_filter),
396                         s_address: transmute(data.s_address),
397                         t_address: transmute(data.t_address),
398                         r_address: transmute(data.r_address),
399                         compare_func: transmute(data.compare_func),
400                         border_color: transmute(data.border_color),
401                         lod_clamp_min: data.lod_clamp_min.into(),
402                         lod_clamp_max: data.lod_clamp_max.into(),
403                         max_anisotropy: data.max_anisotropy,
404                         compare_enable: data.compare_func != SamplerCompareFunc::Always,
405                         lod_clamp_enable: data.lod_clamp_min != LodBase16::ZERO ||
406                             data.lod_clamp_max != LodBase16::MAX,
407                         anisotropy_enable: data.max_anisotropy != 0,
408                     },
409                 }
410             }),
411         );
412 
413         Ok(())
414     }
415 
416     /// Generate MSL shader from the AST.
compile(&mut self) -> Result<String, ErrorCode>417     fn compile(&mut self) -> Result<String, ErrorCode> {
418         self.compile_internal()
419     }
420 }
421 
422 impl spirv::Ast<Target> {
compile_internal(&self) -> Result<String, ErrorCode>423     fn compile_internal(&self) -> Result<String, ErrorCode> {
424         let vat_overrides = &self.compiler.target_data.vertex_attribute_overrides;
425         let res_overrides = &self.compiler.target_data.resource_binding_overrides;
426         let const_samplers = &self.compiler.target_data.const_samplers;
427         unsafe {
428             let mut shader_ptr = ptr::null();
429             check!(br::sc_internal_compiler_msl_compile(
430                 self.compiler.sc_compiler,
431                 &mut shader_ptr,
432                 vat_overrides.as_ptr(),
433                 vat_overrides.len(),
434                 res_overrides.as_ptr(),
435                 res_overrides.len(),
436                 const_samplers.as_ptr(),
437                 const_samplers.len(),
438             ));
439             let shader = match CStr::from_ptr(shader_ptr).to_str() {
440                 Ok(v) => v.to_owned(),
441                 Err(_) => return Err(ErrorCode::Unhandled),
442             };
443             check!(br::sc_internal_free_pointer(
444                 shader_ptr as *mut std::os::raw::c_void
445             ));
446             Ok(shader)
447         }
448     }
449 
is_rasterization_enabled(&self) -> Result<bool, ErrorCode>450     pub fn is_rasterization_enabled(&self) -> Result<bool, ErrorCode> {
451         unsafe {
452             let mut is_disabled = false;
453             check!(br::sc_internal_compiler_msl_get_is_rasterization_disabled(
454                 self.compiler.sc_compiler,
455                 &mut is_disabled
456             ));
457             Ok(!is_disabled)
458         }
459     }
460 }
461 
462 // TODO: Generate with bindgen
463 pub const ARGUMENT_BUFFER_BINDING: u32 = !3;
464