1 use crate::bindings as br; 2 use crate::{compiler, spirv, ErrorCode}; 3 4 use std::collections::BTreeMap; 5 use std::ffi::CStr; 6 use std::marker::PhantomData; 7 use std::ptr; 8 use std::u8; 9 10 /// A MSL target. 11 #[derive(Debug, Clone)] 12 pub enum Target {} 13 14 pub struct TargetData { 15 vertex_attribute_overrides: Vec<br::SPIRV_CROSS_NAMESPACE::MSLVertexAttr>, 16 resource_binding_overrides: Vec<br::SPIRV_CROSS_NAMESPACE::MSLResourceBinding>, 17 const_samplers: Vec<br::MslConstSamplerMapping>, 18 } 19 20 impl spirv::Target for Target { 21 type Data = TargetData; 22 } 23 24 /// Location of a vertex attribute to override 25 #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] 26 pub struct VertexAttributeLocation(pub u32); 27 28 /// Format of the vertex attribute 29 #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] 30 pub enum Format { 31 Other, 32 Uint8, 33 Uint16, 34 } 35 36 impl Format { as_raw(&self) -> br::SPIRV_CROSS_NAMESPACE::MSLVertexFormat37 fn as_raw(&self) -> br::SPIRV_CROSS_NAMESPACE::MSLVertexFormat { 38 use self::Format::*; 39 use crate::bindings::root::SPIRV_CROSS_NAMESPACE::MSLVertexFormat as R; 40 match self { 41 Other => R::MSL_VERTEX_FORMAT_OTHER, 42 Uint8 => R::MSL_VERTEX_FORMAT_UINT8, 43 Uint16 => R::MSL_VERTEX_FORMAT_UINT16, 44 } 45 } 46 } 47 48 /// Vertex attribute description for overriding 49 #[derive(Debug, Clone, Hash, Eq, PartialEq)] 50 pub struct VertexAttribute { 51 pub buffer_id: u32, 52 pub offset: u32, 53 pub stride: u32, 54 pub step: spirv::VertexAttributeStep, 55 pub format: Format, 56 pub built_in: Option<spirv::BuiltIn>, 57 } 58 59 /// Location of a resource binding to override 60 #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] 61 pub struct ResourceBindingLocation { 62 pub stage: spirv::ExecutionModel, 63 pub desc_set: u32, 64 pub binding: u32, 65 } 66 67 /// Resource binding description for overriding 68 #[derive(Debug, Clone, Hash, Eq, PartialEq)] 69 pub struct ResourceBinding { 70 pub buffer_id: u32, 71 pub texture_id: u32, 72 pub sampler_id: u32, 73 } 74 75 /// Location of a sampler binding to override 76 #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] 77 pub struct SamplerLocation { 78 pub desc_set: u32, 79 pub binding: u32, 80 } 81 82 #[repr(C)] 83 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] 84 pub enum SamplerCoord { 85 Normalized = 0, 86 Pixel = 1, 87 } 88 89 #[repr(C)] 90 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] 91 pub enum SamplerFilter { 92 Nearest = 0, 93 Linear = 1, 94 } 95 96 #[repr(C)] 97 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] 98 pub enum SamplerMipFilter { 99 None = 0, 100 Nearest = 1, 101 Linear = 2, 102 } 103 104 #[repr(C)] 105 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] 106 pub enum SamplerAddress { 107 ClampToZero = 0, 108 ClampToEdge = 1, 109 ClampToBorder = 2, 110 Repeat = 3, 111 MirroredRepeat = 4, 112 } 113 114 #[repr(C)] 115 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] 116 pub enum SamplerCompareFunc { 117 Never = 0, 118 Less = 1, 119 LessEqual = 2, 120 Greater = 3, 121 GreaterEqual = 4, 122 Equal = 5, 123 NotEqual = 6, 124 Always = 7, 125 } 126 127 #[repr(C)] 128 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] 129 pub enum SamplerBorderColor { 130 TransparentBlack = 0, 131 OpaqueBlack = 1, 132 OpaqueWhite = 2, 133 } 134 135 #[repr(transparent)] 136 #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] 137 pub struct LodBase16(u8); 138 139 impl LodBase16 { 140 pub const ZERO: Self = LodBase16(0); 141 pub const MAX: Self = LodBase16(!0); 142 } 143 impl From<f32> for LodBase16 { from(v: f32) -> Self144 fn from(v: f32) -> Self { 145 LodBase16((v * 16.0).max(0.0).min(u8::MAX as f32) as u8) 146 } 147 } 148 impl Into<f32> for LodBase16 { into(self) -> f32149 fn into(self) -> f32 { 150 self.0 as f32 / 16.0 151 } 152 } 153 154 /// Data fully defining a constant sampler. 155 #[derive(Debug, Clone, Hash, Eq, PartialEq)] 156 pub struct SamplerData { 157 pub coord: SamplerCoord, 158 pub min_filter: SamplerFilter, 159 pub mag_filter: SamplerFilter, 160 pub mip_filter: SamplerMipFilter, 161 pub s_address: SamplerAddress, 162 pub t_address: SamplerAddress, 163 pub r_address: SamplerAddress, 164 pub compare_func: SamplerCompareFunc, 165 pub border_color: SamplerBorderColor, 166 pub lod_clamp_min: LodBase16, 167 pub lod_clamp_max: LodBase16, 168 pub max_anisotropy: i32, 169 } 170 171 /// A MSL shader platform. 172 #[repr(u8)] 173 #[allow(non_snake_case, non_camel_case_types)] 174 #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] 175 pub enum Platform { 176 iOS = 0, 177 macOS = 1, 178 } 179 180 /// A MSL shader model version. 181 #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] 182 pub enum Version { 183 V1_0, 184 V1_1, 185 V1_2, 186 V2_0, 187 V2_1, 188 V2_2, 189 } 190 191 impl Version { as_raw(self) -> u32192 fn as_raw(self) -> u32 { 193 use self::Version::*; 194 match self { 195 V1_0 => 10000, 196 V1_1 => 10100, 197 V1_2 => 10200, 198 V2_0 => 20000, 199 V2_1 => 20100, 200 V2_2 => 20200, 201 } 202 } 203 } 204 205 #[derive(Debug, Clone, Hash, Eq, PartialEq)] 206 pub struct CompilerVertexOptions { 207 pub invert_y: bool, 208 pub transform_clip_space: bool, 209 } 210 211 impl Default for CompilerVertexOptions { default() -> Self212 fn default() -> Self { 213 CompilerVertexOptions { 214 invert_y: false, 215 transform_clip_space: false, 216 } 217 } 218 } 219 220 /// MSL compiler options. 221 #[derive(Debug, Clone, Hash, Eq, PartialEq)] 222 pub struct CompilerOptions { 223 /// The target platform. 224 pub platform: Platform, 225 /// The target MSL version. 226 pub version: Version, 227 /// Vertex compiler options. 228 pub vertex: CompilerVertexOptions, 229 /// The buffer index to use for swizzle. 230 pub swizzle_buffer_index: u32, 231 // The buffer index to use for indirect params. 232 pub indirect_params_buffer_index: u32, 233 /// The buffer index to use for output. 234 pub output_buffer_index: u32, 235 /// The buffer index to use for patch output. 236 pub patch_output_buffer_index: u32, 237 /// The buffer index to use for tessellation factor. 238 pub tessellation_factor_buffer_index: u32, 239 /// The buffer index to use for buffer size. 240 pub buffer_size_buffer_index: u32, 241 /// Whether the built-in point size should be enabled. 242 pub enable_point_size_builtin: bool, 243 /// Whether rasterization should be enabled. 244 pub enable_rasterization: bool, 245 /// Whether to capture output to buffer. 246 pub capture_output_to_buffer: bool, 247 /// Whether to swizzle texture samples. 248 pub swizzle_texture_samples: bool, 249 /// Whether to place the origin of tessellation domain shaders in the lower left. 250 pub tessellation_domain_origin_lower_left: bool, 251 /// Whether to enable use of argument buffers (only compatible with MSL 2.0). 252 pub enable_argument_buffers: bool, 253 /// Whether to pad fragment output to have at least the number of components as the render pass. 254 pub pad_fragment_output_components: bool, 255 /// MSL resource bindings overrides. 256 pub resource_binding_overrides: BTreeMap<ResourceBindingLocation, ResourceBinding>, 257 /// MSL vertex attribute overrides. 258 pub vertex_attribute_overrides: BTreeMap<VertexAttributeLocation, VertexAttribute>, 259 /// MSL const sampler mappings. 260 pub const_samplers: BTreeMap<SamplerLocation, SamplerData>, 261 } 262 263 impl CompilerOptions { as_raw(&self) -> br::ScMslCompilerOptions264 fn as_raw(&self) -> br::ScMslCompilerOptions { 265 br::ScMslCompilerOptions { 266 vertex_invert_y: self.vertex.invert_y, 267 vertex_transform_clip_space: self.vertex.transform_clip_space, 268 platform: self.platform as _, 269 version: self.version.as_raw(), 270 enable_point_size_builtin: self.enable_point_size_builtin, 271 disable_rasterization: !self.enable_rasterization, 272 swizzle_buffer_index: self.swizzle_buffer_index, 273 indirect_params_buffer_index: self.indirect_params_buffer_index, 274 shader_output_buffer_index: self.output_buffer_index, 275 shader_patch_output_buffer_index: self.patch_output_buffer_index, 276 shader_tess_factor_buffer_index: self.tessellation_factor_buffer_index, 277 buffer_size_buffer_index: self.buffer_size_buffer_index, 278 capture_output_to_buffer: self.capture_output_to_buffer, 279 swizzle_texture_samples: self.swizzle_texture_samples, 280 tess_domain_origin_lower_left: self.tessellation_domain_origin_lower_left, 281 argument_buffers: self.enable_argument_buffers, 282 pad_fragment_output_components: self.pad_fragment_output_components, 283 } 284 } 285 } 286 287 impl Default for CompilerOptions { default() -> Self288 fn default() -> Self { 289 CompilerOptions { 290 platform: Platform::macOS, 291 version: Version::V1_2, 292 vertex: CompilerVertexOptions::default(), 293 swizzle_buffer_index: 30, 294 indirect_params_buffer_index: 29, 295 output_buffer_index: 28, 296 patch_output_buffer_index: 27, 297 tessellation_factor_buffer_index: 26, 298 buffer_size_buffer_index: 25, 299 enable_point_size_builtin: true, 300 enable_rasterization: true, 301 capture_output_to_buffer: false, 302 swizzle_texture_samples: false, 303 tessellation_domain_origin_lower_left: false, 304 enable_argument_buffers: false, 305 pad_fragment_output_components: false, 306 resource_binding_overrides: Default::default(), 307 vertex_attribute_overrides: Default::default(), 308 const_samplers: Default::default(), 309 } 310 } 311 } 312 313 impl<'a> spirv::Parse<Target> for spirv::Ast<Target> { parse(module: &spirv::Module) -> Result<Self, ErrorCode>314 fn parse(module: &spirv::Module) -> Result<Self, ErrorCode> { 315 let mut sc_compiler = ptr::null_mut(); 316 unsafe { 317 check!(br::sc_internal_compiler_msl_new( 318 &mut sc_compiler, 319 module.words.as_ptr(), 320 module.words.len(), 321 )); 322 } 323 324 Ok(spirv::Ast { 325 compiler: compiler::Compiler { 326 sc_compiler, 327 target_data: TargetData { 328 resource_binding_overrides: Vec::new(), 329 vertex_attribute_overrides: Vec::new(), 330 const_samplers: Vec::new(), 331 }, 332 has_been_compiled: false, 333 }, 334 target_type: PhantomData, 335 }) 336 } 337 } 338 339 impl spirv::Compile<Target> for spirv::Ast<Target> { 340 type CompilerOptions = CompilerOptions; 341 342 /// Set MSL compiler specific compilation settings. set_compiler_options(&mut self, options: &CompilerOptions) -> Result<(), ErrorCode>343 fn set_compiler_options(&mut self, options: &CompilerOptions) -> Result<(), ErrorCode> { 344 let raw_options = options.as_raw(); 345 unsafe { 346 check!(br::sc_internal_compiler_msl_set_options( 347 self.compiler.sc_compiler, 348 &raw_options, 349 )); 350 } 351 352 self.compiler.target_data.resource_binding_overrides.clear(); 353 self.compiler.target_data.resource_binding_overrides.extend( 354 options.resource_binding_overrides.iter().map(|(loc, res)| { 355 br::SPIRV_CROSS_NAMESPACE::MSLResourceBinding { 356 stage: loc.stage.as_raw(), 357 desc_set: loc.desc_set, 358 binding: loc.binding, 359 msl_buffer: res.buffer_id, 360 msl_texture: res.texture_id, 361 msl_sampler: res.sampler_id, 362 } 363 }), 364 ); 365 366 self.compiler.target_data.vertex_attribute_overrides.clear(); 367 self.compiler.target_data.vertex_attribute_overrides.extend( 368 options.vertex_attribute_overrides.iter().map(|(loc, vat)| { 369 br::SPIRV_CROSS_NAMESPACE::MSLVertexAttr { 370 location: loc.0, 371 msl_buffer: vat.buffer_id, 372 msl_offset: vat.offset, 373 msl_stride: vat.stride, 374 per_instance: match vat.step { 375 spirv::VertexAttributeStep::Vertex => false, 376 spirv::VertexAttributeStep::Instance => true, 377 }, 378 format: vat.format.as_raw(), 379 builtin: spirv::built_in_as_raw(vat.built_in), 380 } 381 }), 382 ); 383 384 self.compiler.target_data.const_samplers.clear(); 385 self.compiler.target_data.const_samplers.extend( 386 options.const_samplers.iter().map(|(loc, data)| unsafe { 387 use std::mem::transmute; 388 br::MslConstSamplerMapping { 389 desc_set: loc.desc_set, 390 binding: loc.binding, 391 sampler: br::SPIRV_CROSS_NAMESPACE::MSLConstexprSampler { 392 coord: transmute(data.coord), 393 min_filter: transmute(data.min_filter), 394 mag_filter: transmute(data.mag_filter), 395 mip_filter: transmute(data.mip_filter), 396 s_address: transmute(data.s_address), 397 t_address: transmute(data.t_address), 398 r_address: transmute(data.r_address), 399 compare_func: transmute(data.compare_func), 400 border_color: transmute(data.border_color), 401 lod_clamp_min: data.lod_clamp_min.into(), 402 lod_clamp_max: data.lod_clamp_max.into(), 403 max_anisotropy: data.max_anisotropy, 404 compare_enable: data.compare_func != SamplerCompareFunc::Always, 405 lod_clamp_enable: data.lod_clamp_min != LodBase16::ZERO || 406 data.lod_clamp_max != LodBase16::MAX, 407 anisotropy_enable: data.max_anisotropy != 0, 408 }, 409 } 410 }), 411 ); 412 413 Ok(()) 414 } 415 416 /// Generate MSL shader from the AST. compile(&mut self) -> Result<String, ErrorCode>417 fn compile(&mut self) -> Result<String, ErrorCode> { 418 self.compile_internal() 419 } 420 } 421 422 impl spirv::Ast<Target> { compile_internal(&self) -> Result<String, ErrorCode>423 fn compile_internal(&self) -> Result<String, ErrorCode> { 424 let vat_overrides = &self.compiler.target_data.vertex_attribute_overrides; 425 let res_overrides = &self.compiler.target_data.resource_binding_overrides; 426 let const_samplers = &self.compiler.target_data.const_samplers; 427 unsafe { 428 let mut shader_ptr = ptr::null(); 429 check!(br::sc_internal_compiler_msl_compile( 430 self.compiler.sc_compiler, 431 &mut shader_ptr, 432 vat_overrides.as_ptr(), 433 vat_overrides.len(), 434 res_overrides.as_ptr(), 435 res_overrides.len(), 436 const_samplers.as_ptr(), 437 const_samplers.len(), 438 )); 439 let shader = match CStr::from_ptr(shader_ptr).to_str() { 440 Ok(v) => v.to_owned(), 441 Err(_) => return Err(ErrorCode::Unhandled), 442 }; 443 check!(br::sc_internal_free_pointer( 444 shader_ptr as *mut std::os::raw::c_void 445 )); 446 Ok(shader) 447 } 448 } 449 is_rasterization_enabled(&self) -> Result<bool, ErrorCode>450 pub fn is_rasterization_enabled(&self) -> Result<bool, ErrorCode> { 451 unsafe { 452 let mut is_disabled = false; 453 check!(br::sc_internal_compiler_msl_get_is_rasterization_disabled( 454 self.compiler.sc_compiler, 455 &mut is_disabled 456 )); 457 Ok(!is_disabled) 458 } 459 } 460 } 461 462 // TODO: Generate with bindgen 463 pub const ARGUMENT_BUFFER_BINDING: u32 = !3; 464