1 use crate::{conversions as conv, PrivateCapabilities};
2 
3 use auxil::FastHashMap;
4 use hal::{
5     command::ClearColor,
6     format::{Aspects, ChannelType},
7     image::Filter,
8     pso,
9 };
10 
11 use metal;
12 use parking_lot::{Mutex, RawRwLock};
13 use storage_map::{StorageMap, StorageMapGuard};
14 
15 use std::mem;
16 
17 
18 pub type FastStorageMap<K, V> = StorageMap<RawRwLock, FastHashMap<K, V>>;
19 pub type FastStorageGuard<'a, V> = StorageMapGuard<'a, RawRwLock, V>;
20 
21 #[derive(Clone, Debug)]
22 pub struct ClearVertex {
23     pub pos: [f32; 4],
24 }
25 
26 #[derive(Clone, Debug)]
27 pub struct BlitVertex {
28     pub uv: [f32; 4],
29     pub pos: [f32; 4],
30 }
31 
32 #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
33 pub enum Channel {
34     Float,
35     Int,
36     Uint,
37 }
38 
39 impl From<ChannelType> for Channel {
from(channel_type: ChannelType) -> Self40     fn from(channel_type: ChannelType) -> Self {
41         match channel_type {
42             ChannelType::Unorm
43             | ChannelType::Snorm
44             | ChannelType::Ufloat
45             | ChannelType::Sfloat
46             | ChannelType::Uscaled
47             | ChannelType::Sscaled
48             | ChannelType::Srgb => Channel::Float,
49             ChannelType::Uint => Channel::Uint,
50             ChannelType::Sint => Channel::Int,
51         }
52     }
53 }
54 
55 impl Channel {
interpret(self, raw: ClearColor) -> metal::MTLClearColor56     pub fn interpret(self, raw: ClearColor) -> metal::MTLClearColor {
57         unsafe {
58             match self {
59                 Channel::Float => metal::MTLClearColor::new(
60                     raw.float32[0] as _,
61                     raw.float32[1] as _,
62                     raw.float32[2] as _,
63                     raw.float32[3] as _,
64                 ),
65                 Channel::Int => metal::MTLClearColor::new(
66                     raw.sint32[0] as _,
67                     raw.sint32[1] as _,
68                     raw.sint32[2] as _,
69                     raw.sint32[3] as _,
70                 ),
71                 Channel::Uint => metal::MTLClearColor::new(
72                     raw.uint32[0] as _,
73                     raw.uint32[1] as _,
74                     raw.uint32[2] as _,
75                     raw.uint32[3] as _,
76                 ),
77             }
78         }
79     }
80 }
81 
82 #[derive(Debug)]
83 pub struct SamplerStates {
84     nearest: metal::SamplerState,
85     linear: metal::SamplerState,
86 }
87 
88 impl SamplerStates {
new(device: &metal::DeviceRef) -> Self89     fn new(device: &metal::DeviceRef) -> Self {
90         let desc = metal::SamplerDescriptor::new();
91         desc.set_min_filter(metal::MTLSamplerMinMagFilter::Nearest);
92         desc.set_mag_filter(metal::MTLSamplerMinMagFilter::Nearest);
93         desc.set_mip_filter(metal::MTLSamplerMipFilter::Nearest);
94         let nearest = device.new_sampler(&desc);
95         desc.set_min_filter(metal::MTLSamplerMinMagFilter::Linear);
96         desc.set_mag_filter(metal::MTLSamplerMinMagFilter::Linear);
97         let linear = device.new_sampler(&desc);
98 
99         SamplerStates { nearest, linear }
100     }
101 
get(&self, filter: Filter) -> &metal::SamplerStateRef102     pub fn get(&self, filter: Filter) -> &metal::SamplerStateRef {
103         match filter {
104             Filter::Nearest => &self.nearest,
105             Filter::Linear => &self.linear,
106         }
107     }
108 }
109 
110 #[derive(Debug)]
111 pub struct DepthStencilStates {
112     map: FastStorageMap<pso::DepthStencilDesc, metal::DepthStencilState>,
113     write_none: pso::DepthStencilDesc,
114     write_depth: pso::DepthStencilDesc,
115     write_stencil: pso::DepthStencilDesc,
116     write_all: pso::DepthStencilDesc,
117 }
118 
119 impl DepthStencilStates {
new(device: &metal::DeviceRef) -> Self120     fn new(device: &metal::DeviceRef) -> Self {
121         let write_none = pso::DepthStencilDesc {
122             depth: None,
123             depth_bounds: false,
124             stencil: None,
125         };
126         let write_depth = pso::DepthStencilDesc {
127             depth: Some(pso::DepthTest {
128                 fun: pso::Comparison::Always,
129                 write: true,
130             }),
131             depth_bounds: false,
132             stencil: None,
133         };
134         let face = pso::StencilFace {
135             fun: pso::Comparison::Always,
136             op_fail: pso::StencilOp::Replace,
137             op_depth_fail: pso::StencilOp::Replace,
138             op_pass: pso::StencilOp::Replace,
139         };
140         let write_stencil = pso::DepthStencilDesc {
141             depth: None,
142             depth_bounds: false,
143             stencil: Some(pso::StencilTest {
144                 faces: pso::Sided::new(face),
145                 ..pso::StencilTest::default()
146             }),
147         };
148         let write_all = pso::DepthStencilDesc {
149             depth: Some(pso::DepthTest {
150                 fun: pso::Comparison::Always,
151                 write: true,
152             }),
153             depth_bounds: false,
154             stencil: Some(pso::StencilTest {
155                 faces: pso::Sided::new(face),
156                 ..pso::StencilTest::default()
157             }),
158         };
159 
160         let map = FastStorageMap::default();
161         for desc in &[&write_none, &write_depth, &write_stencil, &write_all] {
162             map.get_or_create_with(*desc, || {
163                 let raw_desc = Self::create_desc(desc).unwrap();
164                 device.new_depth_stencil_state(&raw_desc)
165             });
166         }
167 
168         DepthStencilStates {
169             map,
170             write_none,
171             write_depth,
172             write_stencil,
173             write_all,
174         }
175     }
176 
get_write(&self, aspects: Aspects) -> FastStorageGuard<metal::DepthStencilState>177     pub fn get_write(&self, aspects: Aspects) -> FastStorageGuard<metal::DepthStencilState> {
178         let key = if aspects.contains(Aspects::DEPTH | Aspects::STENCIL) {
179             &self.write_all
180         } else if aspects.contains(Aspects::DEPTH) {
181             &self.write_depth
182         } else if aspects.contains(Aspects::STENCIL) {
183             &self.write_stencil
184         } else {
185             &self.write_none
186         };
187         self.map.get_or_create_with(key, || unreachable!())
188     }
189 
prepare(&self, desc: &pso::DepthStencilDesc, device: &metal::DeviceRef)190     pub fn prepare(&self, desc: &pso::DepthStencilDesc, device: &metal::DeviceRef) {
191         self.map.prepare_maybe(desc, || {
192             Self::create_desc(desc).map(|raw_desc| device.new_depth_stencil_state(&raw_desc))
193         });
194     }
195 
196     // TODO: avoid locking for writes every time
get( &self, desc: pso::DepthStencilDesc, device: &Mutex<metal::Device>, ) -> FastStorageGuard<metal::DepthStencilState>197     pub fn get(
198         &self,
199         desc: pso::DepthStencilDesc,
200         device: &Mutex<metal::Device>,
201     ) -> FastStorageGuard<metal::DepthStencilState> {
202         self.map.get_or_create_with(&desc, || {
203             let raw_desc = Self::create_desc(&desc).expect("Incomplete descriptor provided");
204             device.lock().new_depth_stencil_state(&raw_desc)
205         })
206     }
207 
create_stencil( face: &pso::StencilFace, read_mask: pso::StencilValue, write_mask: pso::StencilValue, ) -> metal::StencilDescriptor208     fn create_stencil(
209         face: &pso::StencilFace,
210         read_mask: pso::StencilValue,
211         write_mask: pso::StencilValue,
212     ) -> metal::StencilDescriptor {
213         let desc = metal::StencilDescriptor::new();
214         desc.set_stencil_compare_function(conv::map_compare_function(face.fun));
215         desc.set_read_mask(read_mask);
216         desc.set_write_mask(write_mask);
217         desc.set_stencil_failure_operation(conv::map_stencil_op(face.op_fail));
218         desc.set_depth_failure_operation(conv::map_stencil_op(face.op_depth_fail));
219         desc.set_depth_stencil_pass_operation(conv::map_stencil_op(face.op_pass));
220         desc
221     }
222 
create_desc(desc: &pso::DepthStencilDesc) -> Option<metal::DepthStencilDescriptor>223     fn create_desc(desc: &pso::DepthStencilDesc) -> Option<metal::DepthStencilDescriptor> {
224         let raw = metal::DepthStencilDescriptor::new();
225 
226         if let Some(ref stencil) = desc.stencil {
227             let read_masks = match stencil.read_masks {
228                 pso::State::Static(value) => value,
229                 pso::State::Dynamic => return None,
230             };
231             let write_masks = match stencil.write_masks {
232                 pso::State::Static(value) => value,
233                 pso::State::Dynamic => return None,
234             };
235             let front_desc =
236                 Self::create_stencil(&stencil.faces.front, read_masks.front, write_masks.front);
237             raw.set_front_face_stencil(Some(&front_desc));
238             let back_desc = if stencil.faces.front == stencil.faces.back
239                 && read_masks.front == read_masks.back
240                 && write_masks.front == write_masks.back
241             {
242                 front_desc
243             } else {
244                 Self::create_stencil(&stencil.faces.back, read_masks.back, write_masks.back)
245             };
246             raw.set_back_face_stencil(Some(&back_desc));
247         }
248 
249         if let Some(ref depth) = desc.depth {
250             raw.set_depth_compare_function(conv::map_compare_function(depth.fun));
251             raw.set_depth_write_enabled(depth.write);
252         }
253 
254         Some(raw)
255     }
256 }
257 
258 #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
259 pub struct ClearKey {
260     pub framebuffer_aspects: Aspects,
261     pub color_formats: [metal::MTLPixelFormat; 1],
262     pub depth_stencil_format: metal::MTLPixelFormat,
263     pub target_index: Option<(u8, Channel)>,
264 }
265 
266 #[derive(Debug)]
267 pub struct ImageClearPipes {
268     map: FastStorageMap<ClearKey, metal::RenderPipelineState>,
269 }
270 
271 impl ImageClearPipes {
get( &self, key: ClearKey, library: &Mutex<metal::Library>, device: &Mutex<metal::Device>, private_caps: &PrivateCapabilities, ) -> FastStorageGuard<metal::RenderPipelineState>272     pub(crate) fn get(
273         &self,
274         key: ClearKey,
275         library: &Mutex<metal::Library>,
276         device: &Mutex<metal::Device>,
277         private_caps: &PrivateCapabilities,
278     ) -> FastStorageGuard<metal::RenderPipelineState> {
279         self.map.get_or_create_with(&key, || {
280             Self::create(key, &*library.lock(), &*device.lock(), private_caps)
281         })
282     }
283 
create( key: ClearKey, library: &metal::LibraryRef, device: &metal::DeviceRef, private_caps: &PrivateCapabilities, ) -> metal::RenderPipelineState284     fn create(
285         key: ClearKey,
286         library: &metal::LibraryRef,
287         device: &metal::DeviceRef,
288         private_caps: &PrivateCapabilities,
289     ) -> metal::RenderPipelineState {
290         let pipeline = metal::RenderPipelineDescriptor::new();
291         if private_caps.layered_rendering {
292             pipeline.set_input_primitive_topology(metal::MTLPrimitiveTopologyClass::Triangle);
293         }
294 
295         let vs_clear = library.get_function("vs_clear", None).unwrap();
296         pipeline.set_vertex_function(Some(&vs_clear));
297 
298         if key.framebuffer_aspects.contains(Aspects::COLOR) {
299             for (i, &format) in key.color_formats.iter().enumerate() {
300                 pipeline
301                     .color_attachments()
302                     .object_at(i)
303                     .unwrap()
304                     .set_pixel_format(format);
305             }
306         }
307         if key.framebuffer_aspects.contains(Aspects::DEPTH) {
308             pipeline.set_depth_attachment_pixel_format(key.depth_stencil_format);
309         }
310         if key.framebuffer_aspects.contains(Aspects::STENCIL) {
311             pipeline.set_stencil_attachment_pixel_format(key.depth_stencil_format);
312         }
313 
314         if let Some((index, channel)) = key.target_index {
315             assert!(key.framebuffer_aspects.contains(Aspects::COLOR));
316             let s_channel = match channel {
317                 Channel::Float => "float",
318                 Channel::Int => "int",
319                 Channel::Uint => "uint",
320             };
321             let ps_name = format!("ps_clear{}_{}", index, s_channel);
322             let ps_fun = library.get_function(&ps_name, None).unwrap();
323             pipeline.set_fragment_function(Some(&ps_fun));
324         }
325 
326         // Vertex buffers
327         let vertex_descriptor = metal::VertexDescriptor::new();
328         let mtl_buffer_desc = vertex_descriptor.layouts().object_at(0).unwrap();
329         mtl_buffer_desc.set_stride(mem::size_of::<ClearVertex>() as _);
330         for i in 0 .. 1 {
331             let mtl_attribute_desc = vertex_descriptor
332                 .attributes()
333                 .object_at(i)
334                 .expect("too many vertex attributes");
335             mtl_attribute_desc.set_buffer_index(0);
336             mtl_attribute_desc.set_offset((i * mem::size_of::<[f32; 4]>()) as _);
337             mtl_attribute_desc.set_format(metal::MTLVertexFormat::Float4);
338         }
339         pipeline.set_vertex_descriptor(Some(&vertex_descriptor));
340 
341         device.new_render_pipeline_state(&pipeline).unwrap()
342     }
343 }
344 
345 pub type BlitKey = (
346     metal::MTLTextureType,
347     metal::MTLPixelFormat,
348     Aspects,
349     Channel,
350 );
351 
352 #[derive(Debug)]
353 pub struct ImageBlitPipes {
354     map: FastStorageMap<BlitKey, metal::RenderPipelineState>,
355 }
356 
357 impl ImageBlitPipes {
get( &self, key: BlitKey, library: &Mutex<metal::Library>, device: &Mutex<metal::Device>, private_caps: &PrivateCapabilities, ) -> FastStorageGuard<metal::RenderPipelineState>358     pub(crate) fn get(
359         &self,
360         key: BlitKey,
361         library: &Mutex<metal::Library>,
362         device: &Mutex<metal::Device>,
363         private_caps: &PrivateCapabilities,
364     ) -> FastStorageGuard<metal::RenderPipelineState> {
365         self.map.get_or_create_with(&key, || {
366             Self::create(key, &*library.lock(), &*device.lock(), private_caps)
367         })
368     }
369 
create( key: BlitKey, library: &metal::LibraryRef, device: &metal::DeviceRef, private_caps: &PrivateCapabilities, ) -> metal::RenderPipelineState370     fn create(
371         key: BlitKey,
372         library: &metal::LibraryRef,
373         device: &metal::DeviceRef,
374         private_caps: &PrivateCapabilities,
375     ) -> metal::RenderPipelineState {
376         use metal::MTLTextureType as Tt;
377 
378         let pipeline = metal::RenderPipelineDescriptor::new();
379         if private_caps.layered_rendering {
380             pipeline.set_input_primitive_topology(metal::MTLPrimitiveTopologyClass::Triangle);
381         }
382 
383         let s_type = match key.0 {
384             Tt::D1 => "1d",
385             Tt::D1Array => "1d_array",
386             Tt::D2 => "2d",
387             Tt::D2Array => "2d_array",
388             Tt::D3 => "3d",
389             Tt::D2Multisample => panic!("Can't blit MSAA surfaces"),
390             Tt::Cube | Tt::CubeArray => unimplemented!(),
391         };
392         let s_channel = if key.2.contains(Aspects::COLOR) {
393             match key.3 {
394                 Channel::Float => "float",
395                 Channel::Int => "int",
396                 Channel::Uint => "uint",
397             }
398         } else {
399             "depth" //TODO: stencil
400         };
401         let ps_name = format!("ps_blit_{}_{}", s_type, s_channel);
402 
403         let vs_blit = library.get_function("vs_blit", None).unwrap();
404         let ps_blit = library.get_function(&ps_name, None).unwrap();
405         pipeline.set_vertex_function(Some(&vs_blit));
406         pipeline.set_fragment_function(Some(&ps_blit));
407 
408         if key.2.contains(Aspects::COLOR) {
409             pipeline
410                 .color_attachments()
411                 .object_at(0)
412                 .unwrap()
413                 .set_pixel_format(key.1);
414         }
415         if key.2.contains(Aspects::DEPTH) {
416             pipeline.set_depth_attachment_pixel_format(key.1);
417         }
418         if key.2.contains(Aspects::STENCIL) {
419             pipeline.set_stencil_attachment_pixel_format(key.1);
420         }
421 
422         // Vertex buffers
423         let vertex_descriptor = metal::VertexDescriptor::new();
424         let mtl_buffer_desc = vertex_descriptor.layouts().object_at(0).unwrap();
425         mtl_buffer_desc.set_stride(mem::size_of::<BlitVertex>() as _);
426         for i in 0 .. 2 {
427             let mtl_attribute_desc = vertex_descriptor
428                 .attributes()
429                 .object_at(i)
430                 .expect("too many vertex attributes");
431             mtl_attribute_desc.set_buffer_index(0);
432             mtl_attribute_desc.set_offset((i * mem::size_of::<[f32; 4]>()) as _);
433             mtl_attribute_desc.set_format(metal::MTLVertexFormat::Float4);
434         }
435         pipeline.set_vertex_descriptor(Some(&vertex_descriptor));
436 
437         device.new_render_pipeline_state(&pipeline).unwrap()
438     }
439 }
440 
441 #[derive(Debug)]
442 pub struct ServicePipes {
443     pub library: Mutex<metal::Library>,
444     pub sampler_states: SamplerStates,
445     pub depth_stencil_states: DepthStencilStates,
446     pub clears: ImageClearPipes,
447     pub blits: ImageBlitPipes,
448     pub copy_buffer: metal::ComputePipelineState,
449     pub fill_buffer: metal::ComputePipelineState,
450 }
451 
452 impl ServicePipes {
new(device: &metal::DeviceRef) -> Self453     pub fn new(device: &metal::DeviceRef) -> Self {
454         let data = include_bytes!("./../shaders/gfx_shaders.metallib");
455         let library = device.new_library_with_data(data).unwrap();
456 
457         let copy_buffer = Self::create_copy_buffer(&library, device);
458         let fill_buffer = Self::create_fill_buffer(&library, device);
459 
460         ServicePipes {
461             library: Mutex::new(library),
462             sampler_states: SamplerStates::new(device),
463             depth_stencil_states: DepthStencilStates::new(device),
464             clears: ImageClearPipes {
465                 map: FastStorageMap::default(),
466             },
467             blits: ImageBlitPipes {
468                 map: FastStorageMap::default(),
469             },
470             copy_buffer,
471             fill_buffer,
472         }
473     }
474 
create_copy_buffer( library: &metal::LibraryRef, device: &metal::DeviceRef, ) -> metal::ComputePipelineState475     fn create_copy_buffer(
476         library: &metal::LibraryRef,
477         device: &metal::DeviceRef,
478     ) -> metal::ComputePipelineState {
479         let pipeline = metal::ComputePipelineDescriptor::new();
480 
481         let cs_copy_buffer = library.get_function("cs_copy_buffer", None).unwrap();
482         pipeline.set_compute_function(Some(&cs_copy_buffer));
483         pipeline.set_thread_group_size_is_multiple_of_thread_execution_width(true);
484 
485         /*TODO: check MacOS version
486         if let Some(buffers) = pipeline.buffers() {
487             buffers.object_at(0).unwrap().set_mutability(metal::MTLMutability::Mutable);
488             buffers.object_at(1).unwrap().set_mutability(metal::MTLMutability::Immutable);
489             buffers.object_at(2).unwrap().set_mutability(metal::MTLMutability::Immutable);
490         }*/
491 
492         unsafe { device.new_compute_pipeline_state(&pipeline) }.unwrap()
493     }
494 
create_fill_buffer( library: &metal::LibraryRef, device: &metal::DeviceRef, ) -> metal::ComputePipelineState495     fn create_fill_buffer(
496         library: &metal::LibraryRef,
497         device: &metal::DeviceRef,
498     ) -> metal::ComputePipelineState {
499         let pipeline = metal::ComputePipelineDescriptor::new();
500 
501         let cs_fill_buffer = library.get_function("cs_fill_buffer", None).unwrap();
502         pipeline.set_compute_function(Some(&cs_fill_buffer));
503         pipeline.set_thread_group_size_is_multiple_of_thread_execution_width(true);
504 
505         /*TODO: check MacOS version
506         if let Some(buffers) = pipeline.buffers() {
507             buffers.object_at(0).unwrap().set_mutability(metal::MTLMutability::Mutable);
508             buffers.object_at(1).unwrap().set_mutability(metal::MTLMutability::Immutable);
509         }*/
510 
511         unsafe { device.new_compute_pipeline_state(&pipeline) }.unwrap()
512     }
513 
simple_blit( &self, device: &Mutex<metal::Device>, cmd_buffer: &metal::CommandBufferRef, src: &metal::TextureRef, dst: &metal::TextureRef, private_caps: &PrivateCapabilities, )514     pub(crate) fn simple_blit(
515         &self,
516         device: &Mutex<metal::Device>,
517         cmd_buffer: &metal::CommandBufferRef,
518         src: &metal::TextureRef,
519         dst: &metal::TextureRef,
520         private_caps: &PrivateCapabilities,
521     ) {
522         let key = (
523             metal::MTLTextureType::D2,
524             dst.pixel_format(),
525             Aspects::COLOR,
526             Channel::Float,
527         );
528         let pso = self.blits.get(key, &self.library, device, private_caps);
529         let vertices = [
530             BlitVertex {
531                 uv: [0.0, 1.0, 0.0, 0.0],
532                 pos: [0.0, 0.0, 0.0, 0.0],
533             },
534             BlitVertex {
535                 uv: [0.0, 0.0, 0.0, 0.0],
536                 pos: [0.0, 1.0, 0.0, 0.0],
537             },
538             BlitVertex {
539                 uv: [1.0, 1.0, 0.0, 0.0],
540                 pos: [1.0, 0.0, 0.0, 0.0],
541             },
542             BlitVertex {
543                 uv: [1.0, 0.0, 0.0, 0.0],
544                 pos: [1.0, 1.0, 0.0, 0.0],
545             },
546         ];
547 
548         let descriptor = metal::RenderPassDescriptor::new();
549         if private_caps.layered_rendering {
550             descriptor.set_render_target_array_length(1);
551         }
552         let attachment = descriptor.color_attachments().object_at(0).unwrap();
553         attachment.set_texture(Some(dst));
554         attachment.set_load_action(metal::MTLLoadAction::DontCare);
555         attachment.set_store_action(metal::MTLStoreAction::Store);
556 
557         let encoder = cmd_buffer.new_render_command_encoder(descriptor);
558         encoder.set_render_pipeline_state(pso.as_ref());
559         encoder.set_fragment_sampler_state(0, Some(&self.sampler_states.linear));
560         encoder.set_fragment_texture(0, Some(src));
561         encoder.set_vertex_bytes(
562             0,
563             (vertices.len() * mem::size_of::<BlitVertex>()) as u64,
564             vertices.as_ptr() as *const _,
565         );
566         encoder.draw_primitives(metal::MTLPrimitiveType::TriangleStrip, 0, 4);
567         encoder.end_encoding();
568     }
569 }
570