1 // Copyright 2020 GFX developers
2 //
3 // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4 // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5 // http://opensource.org/licenses/MIT>, at your option. This file may not be
6 // copied, modified, or distributed except according to those terms.
7 
8 use super::*;
9 
10 use objc::runtime::{BOOL, YES};
11 
12 #[link(name = "MetalPerformanceShaders", kind = "framework")]
13 extern "C" {
MPSSupportsMTLDevice(device: *const std::ffi::c_void) -> BOOL14     fn MPSSupportsMTLDevice(device: *const std::ffi::c_void) -> BOOL;
15 }
16 
mps_supports_device(device: &DeviceRef) -> bool17 pub fn mps_supports_device(device: &DeviceRef) -> bool {
18     let b: BOOL = unsafe {
19         let ptr: *const DeviceRef = device;
20         MPSSupportsMTLDevice(ptr as _)
21     };
22     b == YES
23 }
24 
25 pub enum MPSKernel {}
26 
27 foreign_obj_type! {
28     type CType = MPSKernel;
29     pub struct Kernel;
30     pub struct KernelRef;
31 }
32 
33 pub enum MPSRayDataType {
34     OriginDirection = 0,
35     OriginMinDistanceDirectionMaxDistance = 1,
36     OriginMaskDirectionMaxDistance = 2,
37 }
38 
39 bitflags! {
40     #[allow(non_upper_case_globals)]
41     pub struct MPSRayMaskOptions: NSUInteger {
42         /// Enable primitive masks
43         const Primitive = 1;
44         /// Enable instance masks
45         const Instance = 2;
46     }
47 }
48 
49 /// Options that determine the data contained in an intersection result.
50 pub enum MPSIntersectionDataType {
51     Distance = 0,
52     DistancePrimitiveIndex = 1,
53     DistancePrimitiveIndexCoordinates = 2,
54     DistancePrimitiveIndexInstanceIndex = 3,
55     DistancePrimitiveIndexInstanceIndexCoordinates = 4,
56 }
57 
58 pub enum MPSIntersectionType {
59     /// Find the closest intersection to the ray's origin along the ray direction.
60     /// This is potentially slower than `Any` but is well suited to primary visibility rays.
61     Nearest = 0,
62     /// Find any intersection along the ray direction. This is potentially faster than `Nearest` and
63     /// is well suited to shadow and occlusion rays.
64     Any = 1,
65 }
66 
67 pub enum MPSRayMaskOperator {
68     /// Accept the intersection if `(primitive mask & ray mask) != 0`.
69     And = 0,
70     /// Accept the intersection if `~(primitive mask & ray mask) != 0`.
71     NotAnd = 1,
72     /// Accept the intersection if `(primitive mask | ray mask) != 0`.
73     Or = 2,
74     /// Accept the intersection if `~(primitive mask | ray mask) != 0`.
75     NotOr = 3,
76     /// Accept the intersection if `(primitive mask ^ ray mask) != 0`.
77     /// Note that this is equivalent to the "!=" operator.
78     Xor = 4,
79     /// Accept the intersection if `~(primitive mask ^ ray mask) != 0`.
80     /// Note that this is equivalent to the "==" operator.
81     NotXor = 5,
82     /// Accept the intersection if `(primitive mask < ray mask) != 0`.
83     LessThan = 6,
84     /// Accept the intersection if `(primitive mask <= ray mask) != 0`.
85     LessThanOrEqualTo = 7,
86     /// Accept the intersection if `(primitive mask > ray mask) != 0`.
87     GreaterThan = 8,
88     /// Accept the intersection if `(primitive mask >= ray mask) != 0`.
89     GreaterThanOrEqualTo = 9,
90 }
91 
92 pub enum MPSTriangleIntersectionTestType {
93     /// Use the default ray/triangle intersection test
94     Default = 0,
95     /// Use a watertight ray/triangle intersection test which avoids gaps along shared triangle edges.
96     /// Shared vertices may still have gaps.
97     /// This intersection test may be slower than `Default`.
98     Watertight = 1,
99 }
100 
101 #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
102 pub enum MPSAccelerationStructureStatus {
103     Unbuilt = 0,
104     Built = 1,
105 }
106 
107 bitflags! {
108     #[allow(non_upper_case_globals)]
109     pub struct MPSAccelerationStructureUsage: NSUInteger {
110         /// No usage options specified
111         const None = 0;
112         /// Option that enables support for refitting the acceleration structure after it has been built.
113         const Refit = 1;
114         /// Option indicating that the acceleration structure will be rebuilt frequently.
115         const FrequentRebuild = 2;
116         const PreferGPUBuild = 4;
117         const PreferCPUBuild = 8;
118     }
119 }
120 
121 /// A common bit for all floating point data types.
122 const MPSDataTypeFloatBit: isize = 0x10000000;
123 const MPSDataTypeSignedBit: isize = 0x20000000;
124 const MPSDataTypeNormalizedBit: isize = 0x40000000;
125 
126 pub enum MPSDataType {
127     Invalid = 0,
128 
129     Float32 = MPSDataTypeFloatBit | 32,
130     Float16 = MPSDataTypeFloatBit | 16,
131 
132     // Signed integers.
133     Int8 = MPSDataTypeSignedBit | 8,
134     Int16 = MPSDataTypeSignedBit | 16,
135     Int32 = MPSDataTypeSignedBit | 32,
136 
137     // Unsigned integers. Range: [0, UTYPE_MAX]
138     UInt8 = 8,
139     UInt16 = 16,
140     UInt32 = 32,
141 
142     // Unsigned normalized. Range: [0, 1.0]
143     Unorm1 = MPSDataTypeNormalizedBit | 1,
144     Unorm8 = MPSDataTypeNormalizedBit | 8,
145 }
146 
147 /// A kernel that performs intersection tests between rays and geometry.
148 pub enum MPSRayIntersector {}
149 
150 foreign_obj_type! {
151     type CType = MPSRayIntersector;
152     pub struct RayIntersector;
153     pub struct RayIntersectorRef;
154     type ParentType = KernelRef;
155 }
156 
157 impl RayIntersector {
from_device(device: &DeviceRef) -> Option<Self>158     pub fn from_device(device: &DeviceRef) -> Option<Self> {
159         unsafe {
160             let intersector: RayIntersector = msg_send![class!(MPSRayIntersector), alloc];
161             let ptr: *mut Object = msg_send![intersector.as_ref(), initWithDevice: device];
162             if ptr.is_null() {
163                 None
164             } else {
165                 Some(intersector)
166             }
167         }
168     }
169 }
170 
171 impl RayIntersectorRef {
set_cull_mode(&self, mode: MTLCullMode)172     pub fn set_cull_mode(&self, mode: MTLCullMode) {
173         unsafe { msg_send![self, setCullMode: mode] }
174     }
175 
set_front_facing_winding(&self, winding: MTLWinding)176     pub fn set_front_facing_winding(&self, winding: MTLWinding) {
177         unsafe { msg_send![self, setFrontFacingWinding: winding] }
178     }
179 
set_intersection_data_type(&self, options: MPSIntersectionDataType)180     pub fn set_intersection_data_type(&self, options: MPSIntersectionDataType) {
181         unsafe { msg_send![self, setIntersectionDataType: options] }
182     }
183 
set_intersection_stride(&self, stride: NSUInteger)184     pub fn set_intersection_stride(&self, stride: NSUInteger) {
185         unsafe { msg_send![self, setIntersectionStride: stride] }
186     }
187 
set_ray_data_type(&self, ty: MPSRayDataType)188     pub fn set_ray_data_type(&self, ty: MPSRayDataType) {
189         unsafe { msg_send![self, setRayDataType: ty] }
190     }
191 
set_ray_index_data_type(&self, ty: MPSDataType)192     pub fn set_ray_index_data_type(&self, ty: MPSDataType) {
193         unsafe { msg_send![self, setRayIndexDataType: ty] }
194     }
195 
set_ray_mask(&self, ray_mask: u32)196     pub fn set_ray_mask(&self, ray_mask: u32) {
197         unsafe { msg_send![self, setRayMask: ray_mask] }
198     }
199 
set_ray_mask_operator(&self, operator: MPSRayMaskOperator)200     pub fn set_ray_mask_operator(&self, operator: MPSRayMaskOperator) {
201         unsafe { msg_send![self, setRayMaskOperator: operator] }
202     }
203 
set_ray_mask_options(&self, options: MPSRayMaskOptions)204     pub fn set_ray_mask_options(&self, options: MPSRayMaskOptions) {
205         unsafe { msg_send![self, setRayMaskOptions: options] }
206     }
207 
set_ray_stride(&self, stride: NSUInteger)208     pub fn set_ray_stride(&self, stride: NSUInteger) {
209         unsafe { msg_send![self, setRayStride: stride] }
210     }
211 
set_triangle_intersection_test_type(&self, test_type: MPSTriangleIntersectionTestType)212     pub fn set_triangle_intersection_test_type(&self, test_type: MPSTriangleIntersectionTestType) {
213         unsafe { msg_send![self, setTriangleIntersectionTestType: test_type] }
214     }
215 
encode_intersection_to_command_buffer( &self, command_buffer: &CommandBufferRef, intersection_type: MPSIntersectionType, ray_buffer: &BufferRef, ray_buffer_offset: NSUInteger, intersection_buffer: &BufferRef, intersection_buffer_offset: NSUInteger, ray_count: NSUInteger, acceleration_structure: &AccelerationStructureRef, )216     pub fn encode_intersection_to_command_buffer(
217         &self,
218         command_buffer: &CommandBufferRef,
219         intersection_type: MPSIntersectionType,
220         ray_buffer: &BufferRef,
221         ray_buffer_offset: NSUInteger,
222         intersection_buffer: &BufferRef,
223         intersection_buffer_offset: NSUInteger,
224         ray_count: NSUInteger,
225         acceleration_structure: &AccelerationStructureRef,
226     ) {
227         unsafe {
228             msg_send![
229                 self,
230                 encodeIntersectionToCommandBuffer: command_buffer
231                 intersectionType: intersection_type
232                 rayBuffer: ray_buffer
233                 rayBufferOffset: ray_buffer_offset
234                 intersectionBuffer: intersection_buffer
235                 intersectionBufferOffset: intersection_buffer_offset
236                 rayCount: ray_count
237                 accelerationStructure: acceleration_structure
238             ]
239         }
240     }
241 
recommended_minimum_ray_batch_size_for_ray_count( &self, ray_count: NSUInteger, ) -> NSUInteger242     pub fn recommended_minimum_ray_batch_size_for_ray_count(
243         &self,
244         ray_count: NSUInteger,
245     ) -> NSUInteger {
246         unsafe { msg_send![self, recommendedMinimumRayBatchSizeForRayCount: ray_count] }
247     }
248 }
249 
250 /// A group of acceleration structures which may be used together in an instance acceleration structure
251 pub enum MPSAccelerationStructureGroup {}
252 
253 foreign_obj_type! {
254     type CType = MPSAccelerationStructureGroup;
255     pub struct AccelerationStructureGroup;
256     pub struct AccelerationStructureGroupRef;
257 }
258 
259 impl AccelerationStructureGroup {
new_with_device(device: &DeviceRef) -> Option<Self>260     pub fn new_with_device(device: &DeviceRef) -> Option<Self> {
261         unsafe {
262             let group: AccelerationStructureGroup =
263                 msg_send![class!(MPSAccelerationStructureGroup), alloc];
264             let ptr: *mut Object = msg_send![group.as_ref(), initWithDevice: device];
265             if ptr.is_null() {
266                 None
267             } else {
268                 Some(group)
269             }
270         }
271     }
272 }
273 
274 impl AccelerationStructureGroupRef {
device(&self) -> &DeviceRef275     pub fn device(&self) -> &DeviceRef {
276         unsafe { msg_send![self, device] }
277     }
278 }
279 
280 /// The base class for data structures that are built over geometry and used to accelerate ray tracing.
281 pub enum MPSAccelerationStructure {}
282 
283 foreign_obj_type! {
284     type CType = MPSAccelerationStructure;
285     pub struct AccelerationStructure;
286     pub struct AccelerationStructureRef;
287 }
288 
289 impl AccelerationStructureRef {
status(&self) -> MPSAccelerationStructureStatus290     pub fn status(&self) -> MPSAccelerationStructureStatus {
291         unsafe { msg_send![self, status] }
292     }
293 
usage(&self) -> MPSAccelerationStructureUsage294     pub fn usage(&self) -> MPSAccelerationStructureUsage {
295         unsafe { msg_send![self, usage] }
296     }
297 
set_usage(&self, usage: MPSAccelerationStructureUsage)298     pub fn set_usage(&self, usage: MPSAccelerationStructureUsage) {
299         unsafe { msg_send![self, setUsage: usage] }
300     }
301 
group(&self) -> &AccelerationStructureGroupRef302     pub fn group(&self) -> &AccelerationStructureGroupRef {
303         unsafe { msg_send![self, group] }
304     }
305 
encode_refit_to_command_buffer(&self, buffer: &CommandBufferRef)306     pub fn encode_refit_to_command_buffer(&self, buffer: &CommandBufferRef) {
307         unsafe { msg_send![self, encodeRefitToCommandBuffer: buffer] }
308     }
309 
rebuild(&self)310     pub fn rebuild(&self) {
311         unsafe { msg_send![self, rebuild] }
312     }
313 }
314 
315 pub enum MPSPolygonAccelerationStructure {}
316 
317 foreign_obj_type! {
318     type CType = MPSPolygonAccelerationStructure;
319     pub struct PolygonAccelerationStructure;
320     pub struct PolygonAccelerationStructureRef;
321     type ParentType = AccelerationStructureRef;
322 }
323 
324 impl PolygonAccelerationStructureRef {
set_index_buffer(&self, buffer: Option<&BufferRef>)325     pub fn set_index_buffer(&self, buffer: Option<&BufferRef>) {
326         unsafe { msg_send![self, setIndexBuffer: buffer] }
327     }
328 
set_index_buffer_offset(&self, offset: NSUInteger)329     pub fn set_index_buffer_offset(&self, offset: NSUInteger) {
330         unsafe { msg_send![self, setIndexBufferOffset: offset] }
331     }
332 
set_index_type(&self, data_type: MPSDataType)333     pub fn set_index_type(&self, data_type: MPSDataType) {
334         unsafe { msg_send![self, setIndexType: data_type] }
335     }
336 
set_mask_buffer(&self, buffer: Option<&BufferRef>)337     pub fn set_mask_buffer(&self, buffer: Option<&BufferRef>) {
338         unsafe { msg_send![self, setMaskBuffer: buffer] }
339     }
340 
set_mask_buffer_offset(&self, offset: NSUInteger)341     pub fn set_mask_buffer_offset(&self, offset: NSUInteger) {
342         unsafe { msg_send![self, setMaskBufferOffset: offset] }
343     }
344 
set_vertex_buffer(&self, buffer: Option<&BufferRef>)345     pub fn set_vertex_buffer(&self, buffer: Option<&BufferRef>) {
346         unsafe { msg_send![self, setVertexBuffer: buffer] }
347     }
348 
set_vertex_buffer_offset(&self, offset: NSUInteger)349     pub fn set_vertex_buffer_offset(&self, offset: NSUInteger) {
350         unsafe { msg_send![self, setVertexBufferOffset: offset] }
351     }
352 
set_vertex_stride(&self, stride: NSUInteger)353     pub fn set_vertex_stride(&self, stride: NSUInteger) {
354         unsafe { msg_send![self, setVertexStride: stride] }
355     }
356 }
357 
358 /// An acceleration structure built over triangles.
359 pub enum MPSTriangleAccelerationStructure {}
360 
361 foreign_obj_type! {
362     type CType = MPSTriangleAccelerationStructure;
363     pub struct TriangleAccelerationStructure;
364     pub struct TriangleAccelerationStructureRef;
365     type ParentType = PolygonAccelerationStructureRef;
366 }
367 
368 impl TriangleAccelerationStructure {
from_device(device: &DeviceRef) -> Option<Self>369     pub fn from_device(device: &DeviceRef) -> Option<Self> {
370         unsafe {
371             let structure: TriangleAccelerationStructure =
372                 msg_send![class!(MPSTriangleAccelerationStructure), alloc];
373             let ptr: *mut Object = msg_send![structure.as_ref(), initWithDevice: device];
374             if ptr.is_null() {
375                 None
376             } else {
377                 Some(structure)
378             }
379         }
380     }
381 }
382 
383 impl TriangleAccelerationStructureRef {
triangle_count(&self) -> NSUInteger384     pub fn triangle_count(&self) -> NSUInteger {
385         unsafe { msg_send![self, triangleCount] }
386     }
387 
set_triangle_count(&self, count: NSUInteger)388     pub fn set_triangle_count(&self, count: NSUInteger) {
389         unsafe { msg_send![self, setTriangleCount: count] }
390     }
391 }
392 
393 #[repr(u64)]
394 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
395 pub enum MPSTransformType {
396     Float4x4 = 0,
397     Identity = 1,
398 }
399 
400 /// An acceleration structure built over instances of other acceleration structures
401 pub enum MPSInstanceAccelerationStructure {}
402 
403 foreign_obj_type! {
404     type CType = MPSInstanceAccelerationStructure;
405     pub struct InstanceAccelerationStructure;
406     pub struct InstanceAccelerationStructureRef;
407     type ParentType = AccelerationStructureRef;
408 }
409 
410 impl InstanceAccelerationStructure {
init_with_group(group: &AccelerationStructureGroupRef) -> Option<Self>411     pub fn init_with_group(group: &AccelerationStructureGroupRef) -> Option<Self> {
412         unsafe {
413             let structure: InstanceAccelerationStructure =
414                 msg_send![class!(MPSInstanceAccelerationStructure), alloc];
415             let ptr: *mut Object = msg_send![structure.as_ref(), initWithGroup: group];
416             if ptr.is_null() {
417                 None
418             } else {
419                 Some(structure)
420             }
421         }
422     }
423 }
424 
425 impl InstanceAccelerationStructureRef {
426     /// Marshal to Rust Vec
acceleration_structures(&self) -> Vec<PolygonAccelerationStructure>427     pub fn acceleration_structures(&self) -> Vec<PolygonAccelerationStructure> {
428         unsafe {
429             let acs: *mut Object = msg_send![self, accelerationStructures];
430             let count: NSUInteger = msg_send![acs, count];
431             let ret = (0..count)
432                 .map(|i| {
433                     let ac = msg_send![acs, objectAtIndex: i];
434                     PolygonAccelerationStructure::from_ptr(ac)
435                 })
436                 .collect();
437             ret
438         }
439     }
440 
441     /// Marshal from Rust slice
set_acceleration_structures(&self, acs: &[&PolygonAccelerationStructureRef])442     pub fn set_acceleration_structures(&self, acs: &[&PolygonAccelerationStructureRef]) {
443         let ns_array = Array::<PolygonAccelerationStructure>::from_slice(acs);
444         unsafe { msg_send![self, setAccelerationStructures: ns_array] }
445     }
446 
instance_buffer(&self) -> &BufferRef447     pub fn instance_buffer(&self) -> &BufferRef {
448         unsafe { msg_send![self, instanceBuffer] }
449     }
450 
set_instance_buffer(&self, buffer: &BufferRef)451     pub fn set_instance_buffer(&self, buffer: &BufferRef) {
452         unsafe { msg_send![self, setInstanceBuffer: buffer] }
453     }
454 
instance_buffer_offset(&self) -> NSUInteger455     pub fn instance_buffer_offset(&self) -> NSUInteger {
456         unsafe { msg_send![self, instanceBufferOffset] }
457     }
458 
set_instance_buffer_offset(&self, offset: NSUInteger)459     pub fn set_instance_buffer_offset(&self, offset: NSUInteger) {
460         unsafe { msg_send![self, setInstanceBufferOffset: offset] }
461     }
462 
transform_buffer(&self) -> &BufferRef463     pub fn transform_buffer(&self) -> &BufferRef {
464         unsafe { msg_send![self, transformBuffer] }
465     }
466 
set_transform_buffer(&self, buffer: &BufferRef)467     pub fn set_transform_buffer(&self, buffer: &BufferRef) {
468         unsafe { msg_send![self, setTransformBuffer: buffer] }
469     }
470 
transform_buffer_offset(&self) -> NSUInteger471     pub fn transform_buffer_offset(&self) -> NSUInteger {
472         unsafe { msg_send![self, transformBufferOffset] }
473     }
474 
set_transform_buffer_offset(&self, offset: NSUInteger)475     pub fn set_transform_buffer_offset(&self, offset: NSUInteger) {
476         unsafe { msg_send![self, setTransformBufferOffset: offset] }
477     }
478 
transform_type(&self) -> MPSTransformType479     pub fn transform_type(&self) -> MPSTransformType {
480         unsafe { msg_send![self, transformType] }
481     }
482 
set_transform_type(&self, transform_type: MPSTransformType)483     pub fn set_transform_type(&self, transform_type: MPSTransformType) {
484         unsafe { msg_send![self, setTransformType: transform_type] }
485     }
486 
mask_buffer(&self) -> &BufferRef487     pub fn mask_buffer(&self) -> &BufferRef {
488         unsafe { msg_send![self, maskBuffer] }
489     }
490 
set_mask_buffer(&self, buffer: &BufferRef)491     pub fn set_mask_buffer(&self, buffer: &BufferRef) {
492         unsafe { msg_send![self, setMaskBuffer: buffer] }
493     }
494 
mask_buffer_offset(&self) -> NSUInteger495     pub fn mask_buffer_offset(&self) -> NSUInteger {
496         unsafe { msg_send![self, maskBufferOffset] }
497     }
498 
set_mask_buffer_offset(&self, offset: NSUInteger)499     pub fn set_mask_buffer_offset(&self, offset: NSUInteger) {
500         unsafe { msg_send![self, setMaskBufferOffset: offset] }
501     }
502 
instance_count(&self) -> NSUInteger503     pub fn instance_count(&self) -> NSUInteger {
504         unsafe { msg_send![self, instanceCount] }
505     }
506 
set_instance_count(&self, count: NSUInteger)507     pub fn set_instance_count(&self, count: NSUInteger) {
508         unsafe { msg_send![self, setInstanceCount: count] }
509     }
510 }
511 
512 #[repr(C)]
513 pub struct MPSPackedFloat3 {
514     pub elements: [f32; 3],
515 }
516 
517 /// Represents a 3D ray with an origin, a direction, and an intersection distance range from the origin.
518 #[repr(C)]
519 pub struct MPSRayOriginMinDistanceDirectionMaxDistance {
520     /// Ray origin. The intersection test will be skipped if the origin contains NaNs or infinities.
521     pub origin: MPSPackedFloat3,
522     /// Minimum intersection distance from the origin along the ray direction.
523     /// The intersection test will be skipped if the minimum distance is equal to positive infinity or NaN.
524     pub min_distance: f32,
525     /// Ray direction. Does not need to be normalized. The intersection test will be skipped if
526     /// the direction has length zero or contains NaNs or infinities.
527     pub direction: MPSPackedFloat3,
528     /// Maximum intersection distance from the origin along the ray direction. May be infinite.
529     /// The intersection test will be skipped if the maximum distance is less than zero, NaN, or
530     /// less than the minimum intersection distance.
531     pub max_distance: f32,
532 }
533 
534 /// Intersection result which contains the distance from the ray origin to the intersection point,
535 /// the index of the intersected primitive, and the first two barycentric coordinates of the intersection point.
536 #[repr(C)]
537 pub struct MPSIntersectionDistancePrimitiveIndexCoordinates {
538     /// Distance from the ray origin to the intersection point along the ray direction vector such
539     /// that `intersection = ray.origin + ray.direction * distance`.
540     /// Is negative if there is no intersection. If the intersection type is `MPSIntersectionTypeAny`,
541     /// is a positive value for a hit or a negative value for a miss.
542     pub distance: f32,
543     /// Index of the intersected primitive. Undefined if the ray does not intersect a primitive or
544     /// if the intersection type is `MPSIntersectionTypeAny`.
545     pub primitive_index: u32,
546     /// The first two barycentric coordinates `U` and `V` of the intersection point.
547     /// The third coordinate `W = 1 - U - V`. Undefined if the ray does not intersect a primitive or
548     /// if the intersection type is `MPSIntersectionTypeAny`.
549     pub coordinates: [f32; 2],
550 }
551