1 // Copyright 2020 GFX developers
2 //
3 // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4 // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5 // http://opensource.org/licenses/MIT>, at your option. This file may not be
6 // copied, modified, or distributed except according to those terms.
7
8 use super::*;
9
10 use objc::runtime::{BOOL, YES};
11
12 #[link(name = "MetalPerformanceShaders", kind = "framework")]
13 extern "C" {
MPSSupportsMTLDevice(device: *const std::ffi::c_void) -> BOOL14 fn MPSSupportsMTLDevice(device: *const std::ffi::c_void) -> BOOL;
15 }
16
mps_supports_device(device: &DeviceRef) -> bool17 pub fn mps_supports_device(device: &DeviceRef) -> bool {
18 let b: BOOL = unsafe {
19 let ptr: *const DeviceRef = device;
20 MPSSupportsMTLDevice(ptr as _)
21 };
22 b == YES
23 }
24
25 pub enum MPSKernel {}
26
27 foreign_obj_type! {
28 type CType = MPSKernel;
29 pub struct Kernel;
30 pub struct KernelRef;
31 }
32
33 pub enum MPSRayDataType {
34 OriginDirection = 0,
35 OriginMinDistanceDirectionMaxDistance = 1,
36 OriginMaskDirectionMaxDistance = 2,
37 }
38
39 bitflags! {
40 #[allow(non_upper_case_globals)]
41 pub struct MPSRayMaskOptions: NSUInteger {
42 /// Enable primitive masks
43 const Primitive = 1;
44 /// Enable instance masks
45 const Instance = 2;
46 }
47 }
48
49 /// Options that determine the data contained in an intersection result.
50 pub enum MPSIntersectionDataType {
51 Distance = 0,
52 DistancePrimitiveIndex = 1,
53 DistancePrimitiveIndexCoordinates = 2,
54 DistancePrimitiveIndexInstanceIndex = 3,
55 DistancePrimitiveIndexInstanceIndexCoordinates = 4,
56 }
57
58 pub enum MPSIntersectionType {
59 /// Find the closest intersection to the ray's origin along the ray direction.
60 /// This is potentially slower than `Any` but is well suited to primary visibility rays.
61 Nearest = 0,
62 /// Find any intersection along the ray direction. This is potentially faster than `Nearest` and
63 /// is well suited to shadow and occlusion rays.
64 Any = 1,
65 }
66
67 pub enum MPSRayMaskOperator {
68 /// Accept the intersection if `(primitive mask & ray mask) != 0`.
69 And = 0,
70 /// Accept the intersection if `~(primitive mask & ray mask) != 0`.
71 NotAnd = 1,
72 /// Accept the intersection if `(primitive mask | ray mask) != 0`.
73 Or = 2,
74 /// Accept the intersection if `~(primitive mask | ray mask) != 0`.
75 NotOr = 3,
76 /// Accept the intersection if `(primitive mask ^ ray mask) != 0`.
77 /// Note that this is equivalent to the "!=" operator.
78 Xor = 4,
79 /// Accept the intersection if `~(primitive mask ^ ray mask) != 0`.
80 /// Note that this is equivalent to the "==" operator.
81 NotXor = 5,
82 /// Accept the intersection if `(primitive mask < ray mask) != 0`.
83 LessThan = 6,
84 /// Accept the intersection if `(primitive mask <= ray mask) != 0`.
85 LessThanOrEqualTo = 7,
86 /// Accept the intersection if `(primitive mask > ray mask) != 0`.
87 GreaterThan = 8,
88 /// Accept the intersection if `(primitive mask >= ray mask) != 0`.
89 GreaterThanOrEqualTo = 9,
90 }
91
92 pub enum MPSTriangleIntersectionTestType {
93 /// Use the default ray/triangle intersection test
94 Default = 0,
95 /// Use a watertight ray/triangle intersection test which avoids gaps along shared triangle edges.
96 /// Shared vertices may still have gaps.
97 /// This intersection test may be slower than `Default`.
98 Watertight = 1,
99 }
100
101 #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
102 pub enum MPSAccelerationStructureStatus {
103 Unbuilt = 0,
104 Built = 1,
105 }
106
107 bitflags! {
108 #[allow(non_upper_case_globals)]
109 pub struct MPSAccelerationStructureUsage: NSUInteger {
110 /// No usage options specified
111 const None = 0;
112 /// Option that enables support for refitting the acceleration structure after it has been built.
113 const Refit = 1;
114 /// Option indicating that the acceleration structure will be rebuilt frequently.
115 const FrequentRebuild = 2;
116 const PreferGPUBuild = 4;
117 const PreferCPUBuild = 8;
118 }
119 }
120
121 /// A common bit for all floating point data types.
122 const MPSDataTypeFloatBit: isize = 0x10000000;
123 const MPSDataTypeSignedBit: isize = 0x20000000;
124 const MPSDataTypeNormalizedBit: isize = 0x40000000;
125
126 pub enum MPSDataType {
127 Invalid = 0,
128
129 Float32 = MPSDataTypeFloatBit | 32,
130 Float16 = MPSDataTypeFloatBit | 16,
131
132 // Signed integers.
133 Int8 = MPSDataTypeSignedBit | 8,
134 Int16 = MPSDataTypeSignedBit | 16,
135 Int32 = MPSDataTypeSignedBit | 32,
136
137 // Unsigned integers. Range: [0, UTYPE_MAX]
138 UInt8 = 8,
139 UInt16 = 16,
140 UInt32 = 32,
141
142 // Unsigned normalized. Range: [0, 1.0]
143 Unorm1 = MPSDataTypeNormalizedBit | 1,
144 Unorm8 = MPSDataTypeNormalizedBit | 8,
145 }
146
147 /// A kernel that performs intersection tests between rays and geometry.
148 pub enum MPSRayIntersector {}
149
150 foreign_obj_type! {
151 type CType = MPSRayIntersector;
152 pub struct RayIntersector;
153 pub struct RayIntersectorRef;
154 type ParentType = KernelRef;
155 }
156
157 impl RayIntersector {
from_device(device: &DeviceRef) -> Option<Self>158 pub fn from_device(device: &DeviceRef) -> Option<Self> {
159 unsafe {
160 let intersector: RayIntersector = msg_send![class!(MPSRayIntersector), alloc];
161 let ptr: *mut Object = msg_send![intersector.as_ref(), initWithDevice: device];
162 if ptr.is_null() {
163 None
164 } else {
165 Some(intersector)
166 }
167 }
168 }
169 }
170
171 impl RayIntersectorRef {
set_cull_mode(&self, mode: MTLCullMode)172 pub fn set_cull_mode(&self, mode: MTLCullMode) {
173 unsafe { msg_send![self, setCullMode: mode] }
174 }
175
set_front_facing_winding(&self, winding: MTLWinding)176 pub fn set_front_facing_winding(&self, winding: MTLWinding) {
177 unsafe { msg_send![self, setFrontFacingWinding: winding] }
178 }
179
set_intersection_data_type(&self, options: MPSIntersectionDataType)180 pub fn set_intersection_data_type(&self, options: MPSIntersectionDataType) {
181 unsafe { msg_send![self, setIntersectionDataType: options] }
182 }
183
set_intersection_stride(&self, stride: NSUInteger)184 pub fn set_intersection_stride(&self, stride: NSUInteger) {
185 unsafe { msg_send![self, setIntersectionStride: stride] }
186 }
187
set_ray_data_type(&self, ty: MPSRayDataType)188 pub fn set_ray_data_type(&self, ty: MPSRayDataType) {
189 unsafe { msg_send![self, setRayDataType: ty] }
190 }
191
set_ray_index_data_type(&self, ty: MPSDataType)192 pub fn set_ray_index_data_type(&self, ty: MPSDataType) {
193 unsafe { msg_send![self, setRayIndexDataType: ty] }
194 }
195
set_ray_mask(&self, ray_mask: u32)196 pub fn set_ray_mask(&self, ray_mask: u32) {
197 unsafe { msg_send![self, setRayMask: ray_mask] }
198 }
199
set_ray_mask_operator(&self, operator: MPSRayMaskOperator)200 pub fn set_ray_mask_operator(&self, operator: MPSRayMaskOperator) {
201 unsafe { msg_send![self, setRayMaskOperator: operator] }
202 }
203
set_ray_mask_options(&self, options: MPSRayMaskOptions)204 pub fn set_ray_mask_options(&self, options: MPSRayMaskOptions) {
205 unsafe { msg_send![self, setRayMaskOptions: options] }
206 }
207
set_ray_stride(&self, stride: NSUInteger)208 pub fn set_ray_stride(&self, stride: NSUInteger) {
209 unsafe { msg_send![self, setRayStride: stride] }
210 }
211
set_triangle_intersection_test_type(&self, test_type: MPSTriangleIntersectionTestType)212 pub fn set_triangle_intersection_test_type(&self, test_type: MPSTriangleIntersectionTestType) {
213 unsafe { msg_send![self, setTriangleIntersectionTestType: test_type] }
214 }
215
encode_intersection_to_command_buffer( &self, command_buffer: &CommandBufferRef, intersection_type: MPSIntersectionType, ray_buffer: &BufferRef, ray_buffer_offset: NSUInteger, intersection_buffer: &BufferRef, intersection_buffer_offset: NSUInteger, ray_count: NSUInteger, acceleration_structure: &AccelerationStructureRef, )216 pub fn encode_intersection_to_command_buffer(
217 &self,
218 command_buffer: &CommandBufferRef,
219 intersection_type: MPSIntersectionType,
220 ray_buffer: &BufferRef,
221 ray_buffer_offset: NSUInteger,
222 intersection_buffer: &BufferRef,
223 intersection_buffer_offset: NSUInteger,
224 ray_count: NSUInteger,
225 acceleration_structure: &AccelerationStructureRef,
226 ) {
227 unsafe {
228 msg_send![
229 self,
230 encodeIntersectionToCommandBuffer: command_buffer
231 intersectionType: intersection_type
232 rayBuffer: ray_buffer
233 rayBufferOffset: ray_buffer_offset
234 intersectionBuffer: intersection_buffer
235 intersectionBufferOffset: intersection_buffer_offset
236 rayCount: ray_count
237 accelerationStructure: acceleration_structure
238 ]
239 }
240 }
241
recommended_minimum_ray_batch_size_for_ray_count( &self, ray_count: NSUInteger, ) -> NSUInteger242 pub fn recommended_minimum_ray_batch_size_for_ray_count(
243 &self,
244 ray_count: NSUInteger,
245 ) -> NSUInteger {
246 unsafe { msg_send![self, recommendedMinimumRayBatchSizeForRayCount: ray_count] }
247 }
248 }
249
250 /// A group of acceleration structures which may be used together in an instance acceleration structure
251 pub enum MPSAccelerationStructureGroup {}
252
253 foreign_obj_type! {
254 type CType = MPSAccelerationStructureGroup;
255 pub struct AccelerationStructureGroup;
256 pub struct AccelerationStructureGroupRef;
257 }
258
259 impl AccelerationStructureGroup {
new_with_device(device: &DeviceRef) -> Option<Self>260 pub fn new_with_device(device: &DeviceRef) -> Option<Self> {
261 unsafe {
262 let group: AccelerationStructureGroup =
263 msg_send![class!(MPSAccelerationStructureGroup), alloc];
264 let ptr: *mut Object = msg_send![group.as_ref(), initWithDevice: device];
265 if ptr.is_null() {
266 None
267 } else {
268 Some(group)
269 }
270 }
271 }
272 }
273
274 impl AccelerationStructureGroupRef {
device(&self) -> &DeviceRef275 pub fn device(&self) -> &DeviceRef {
276 unsafe { msg_send![self, device] }
277 }
278 }
279
280 /// The base class for data structures that are built over geometry and used to accelerate ray tracing.
281 pub enum MPSAccelerationStructure {}
282
283 foreign_obj_type! {
284 type CType = MPSAccelerationStructure;
285 pub struct AccelerationStructure;
286 pub struct AccelerationStructureRef;
287 }
288
289 impl AccelerationStructureRef {
status(&self) -> MPSAccelerationStructureStatus290 pub fn status(&self) -> MPSAccelerationStructureStatus {
291 unsafe { msg_send![self, status] }
292 }
293
usage(&self) -> MPSAccelerationStructureUsage294 pub fn usage(&self) -> MPSAccelerationStructureUsage {
295 unsafe { msg_send![self, usage] }
296 }
297
set_usage(&self, usage: MPSAccelerationStructureUsage)298 pub fn set_usage(&self, usage: MPSAccelerationStructureUsage) {
299 unsafe { msg_send![self, setUsage: usage] }
300 }
301
group(&self) -> &AccelerationStructureGroupRef302 pub fn group(&self) -> &AccelerationStructureGroupRef {
303 unsafe { msg_send![self, group] }
304 }
305
encode_refit_to_command_buffer(&self, buffer: &CommandBufferRef)306 pub fn encode_refit_to_command_buffer(&self, buffer: &CommandBufferRef) {
307 unsafe { msg_send![self, encodeRefitToCommandBuffer: buffer] }
308 }
309
rebuild(&self)310 pub fn rebuild(&self) {
311 unsafe { msg_send![self, rebuild] }
312 }
313 }
314
315 pub enum MPSPolygonAccelerationStructure {}
316
317 foreign_obj_type! {
318 type CType = MPSPolygonAccelerationStructure;
319 pub struct PolygonAccelerationStructure;
320 pub struct PolygonAccelerationStructureRef;
321 type ParentType = AccelerationStructureRef;
322 }
323
324 impl PolygonAccelerationStructureRef {
set_index_buffer(&self, buffer: Option<&BufferRef>)325 pub fn set_index_buffer(&self, buffer: Option<&BufferRef>) {
326 unsafe { msg_send![self, setIndexBuffer: buffer] }
327 }
328
set_index_buffer_offset(&self, offset: NSUInteger)329 pub fn set_index_buffer_offset(&self, offset: NSUInteger) {
330 unsafe { msg_send![self, setIndexBufferOffset: offset] }
331 }
332
set_index_type(&self, data_type: MPSDataType)333 pub fn set_index_type(&self, data_type: MPSDataType) {
334 unsafe { msg_send![self, setIndexType: data_type] }
335 }
336
set_mask_buffer(&self, buffer: Option<&BufferRef>)337 pub fn set_mask_buffer(&self, buffer: Option<&BufferRef>) {
338 unsafe { msg_send![self, setMaskBuffer: buffer] }
339 }
340
set_mask_buffer_offset(&self, offset: NSUInteger)341 pub fn set_mask_buffer_offset(&self, offset: NSUInteger) {
342 unsafe { msg_send![self, setMaskBufferOffset: offset] }
343 }
344
set_vertex_buffer(&self, buffer: Option<&BufferRef>)345 pub fn set_vertex_buffer(&self, buffer: Option<&BufferRef>) {
346 unsafe { msg_send![self, setVertexBuffer: buffer] }
347 }
348
set_vertex_buffer_offset(&self, offset: NSUInteger)349 pub fn set_vertex_buffer_offset(&self, offset: NSUInteger) {
350 unsafe { msg_send![self, setVertexBufferOffset: offset] }
351 }
352
set_vertex_stride(&self, stride: NSUInteger)353 pub fn set_vertex_stride(&self, stride: NSUInteger) {
354 unsafe { msg_send![self, setVertexStride: stride] }
355 }
356 }
357
358 /// An acceleration structure built over triangles.
359 pub enum MPSTriangleAccelerationStructure {}
360
361 foreign_obj_type! {
362 type CType = MPSTriangleAccelerationStructure;
363 pub struct TriangleAccelerationStructure;
364 pub struct TriangleAccelerationStructureRef;
365 type ParentType = PolygonAccelerationStructureRef;
366 }
367
368 impl TriangleAccelerationStructure {
from_device(device: &DeviceRef) -> Option<Self>369 pub fn from_device(device: &DeviceRef) -> Option<Self> {
370 unsafe {
371 let structure: TriangleAccelerationStructure =
372 msg_send![class!(MPSTriangleAccelerationStructure), alloc];
373 let ptr: *mut Object = msg_send![structure.as_ref(), initWithDevice: device];
374 if ptr.is_null() {
375 None
376 } else {
377 Some(structure)
378 }
379 }
380 }
381 }
382
383 impl TriangleAccelerationStructureRef {
triangle_count(&self) -> NSUInteger384 pub fn triangle_count(&self) -> NSUInteger {
385 unsafe { msg_send![self, triangleCount] }
386 }
387
set_triangle_count(&self, count: NSUInteger)388 pub fn set_triangle_count(&self, count: NSUInteger) {
389 unsafe { msg_send![self, setTriangleCount: count] }
390 }
391 }
392
393 #[repr(u64)]
394 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
395 pub enum MPSTransformType {
396 Float4x4 = 0,
397 Identity = 1,
398 }
399
400 /// An acceleration structure built over instances of other acceleration structures
401 pub enum MPSInstanceAccelerationStructure {}
402
403 foreign_obj_type! {
404 type CType = MPSInstanceAccelerationStructure;
405 pub struct InstanceAccelerationStructure;
406 pub struct InstanceAccelerationStructureRef;
407 type ParentType = AccelerationStructureRef;
408 }
409
410 impl InstanceAccelerationStructure {
init_with_group(group: &AccelerationStructureGroupRef) -> Option<Self>411 pub fn init_with_group(group: &AccelerationStructureGroupRef) -> Option<Self> {
412 unsafe {
413 let structure: InstanceAccelerationStructure =
414 msg_send![class!(MPSInstanceAccelerationStructure), alloc];
415 let ptr: *mut Object = msg_send![structure.as_ref(), initWithGroup: group];
416 if ptr.is_null() {
417 None
418 } else {
419 Some(structure)
420 }
421 }
422 }
423 }
424
425 impl InstanceAccelerationStructureRef {
426 /// Marshal to Rust Vec
acceleration_structures(&self) -> Vec<PolygonAccelerationStructure>427 pub fn acceleration_structures(&self) -> Vec<PolygonAccelerationStructure> {
428 unsafe {
429 let acs: *mut Object = msg_send![self, accelerationStructures];
430 let count: NSUInteger = msg_send![acs, count];
431 let ret = (0..count)
432 .map(|i| {
433 let ac = msg_send![acs, objectAtIndex: i];
434 PolygonAccelerationStructure::from_ptr(ac)
435 })
436 .collect();
437 ret
438 }
439 }
440
441 /// Marshal from Rust slice
set_acceleration_structures(&self, acs: &[&PolygonAccelerationStructureRef])442 pub fn set_acceleration_structures(&self, acs: &[&PolygonAccelerationStructureRef]) {
443 let ns_array = Array::<PolygonAccelerationStructure>::from_slice(acs);
444 unsafe { msg_send![self, setAccelerationStructures: ns_array] }
445 }
446
instance_buffer(&self) -> &BufferRef447 pub fn instance_buffer(&self) -> &BufferRef {
448 unsafe { msg_send![self, instanceBuffer] }
449 }
450
set_instance_buffer(&self, buffer: &BufferRef)451 pub fn set_instance_buffer(&self, buffer: &BufferRef) {
452 unsafe { msg_send![self, setInstanceBuffer: buffer] }
453 }
454
instance_buffer_offset(&self) -> NSUInteger455 pub fn instance_buffer_offset(&self) -> NSUInteger {
456 unsafe { msg_send![self, instanceBufferOffset] }
457 }
458
set_instance_buffer_offset(&self, offset: NSUInteger)459 pub fn set_instance_buffer_offset(&self, offset: NSUInteger) {
460 unsafe { msg_send![self, setInstanceBufferOffset: offset] }
461 }
462
transform_buffer(&self) -> &BufferRef463 pub fn transform_buffer(&self) -> &BufferRef {
464 unsafe { msg_send![self, transformBuffer] }
465 }
466
set_transform_buffer(&self, buffer: &BufferRef)467 pub fn set_transform_buffer(&self, buffer: &BufferRef) {
468 unsafe { msg_send![self, setTransformBuffer: buffer] }
469 }
470
transform_buffer_offset(&self) -> NSUInteger471 pub fn transform_buffer_offset(&self) -> NSUInteger {
472 unsafe { msg_send![self, transformBufferOffset] }
473 }
474
set_transform_buffer_offset(&self, offset: NSUInteger)475 pub fn set_transform_buffer_offset(&self, offset: NSUInteger) {
476 unsafe { msg_send![self, setTransformBufferOffset: offset] }
477 }
478
transform_type(&self) -> MPSTransformType479 pub fn transform_type(&self) -> MPSTransformType {
480 unsafe { msg_send![self, transformType] }
481 }
482
set_transform_type(&self, transform_type: MPSTransformType)483 pub fn set_transform_type(&self, transform_type: MPSTransformType) {
484 unsafe { msg_send![self, setTransformType: transform_type] }
485 }
486
mask_buffer(&self) -> &BufferRef487 pub fn mask_buffer(&self) -> &BufferRef {
488 unsafe { msg_send![self, maskBuffer] }
489 }
490
set_mask_buffer(&self, buffer: &BufferRef)491 pub fn set_mask_buffer(&self, buffer: &BufferRef) {
492 unsafe { msg_send![self, setMaskBuffer: buffer] }
493 }
494
mask_buffer_offset(&self) -> NSUInteger495 pub fn mask_buffer_offset(&self) -> NSUInteger {
496 unsafe { msg_send![self, maskBufferOffset] }
497 }
498
set_mask_buffer_offset(&self, offset: NSUInteger)499 pub fn set_mask_buffer_offset(&self, offset: NSUInteger) {
500 unsafe { msg_send![self, setMaskBufferOffset: offset] }
501 }
502
instance_count(&self) -> NSUInteger503 pub fn instance_count(&self) -> NSUInteger {
504 unsafe { msg_send![self, instanceCount] }
505 }
506
set_instance_count(&self, count: NSUInteger)507 pub fn set_instance_count(&self, count: NSUInteger) {
508 unsafe { msg_send![self, setInstanceCount: count] }
509 }
510 }
511
512 #[repr(C)]
513 pub struct MPSPackedFloat3 {
514 pub elements: [f32; 3],
515 }
516
517 /// Represents a 3D ray with an origin, a direction, and an intersection distance range from the origin.
518 #[repr(C)]
519 pub struct MPSRayOriginMinDistanceDirectionMaxDistance {
520 /// Ray origin. The intersection test will be skipped if the origin contains NaNs or infinities.
521 pub origin: MPSPackedFloat3,
522 /// Minimum intersection distance from the origin along the ray direction.
523 /// The intersection test will be skipped if the minimum distance is equal to positive infinity or NaN.
524 pub min_distance: f32,
525 /// Ray direction. Does not need to be normalized. The intersection test will be skipped if
526 /// the direction has length zero or contains NaNs or infinities.
527 pub direction: MPSPackedFloat3,
528 /// Maximum intersection distance from the origin along the ray direction. May be infinite.
529 /// The intersection test will be skipped if the maximum distance is less than zero, NaN, or
530 /// less than the minimum intersection distance.
531 pub max_distance: f32,
532 }
533
534 /// Intersection result which contains the distance from the ray origin to the intersection point,
535 /// the index of the intersected primitive, and the first two barycentric coordinates of the intersection point.
536 #[repr(C)]
537 pub struct MPSIntersectionDistancePrimitiveIndexCoordinates {
538 /// Distance from the ray origin to the intersection point along the ray direction vector such
539 /// that `intersection = ray.origin + ray.direction * distance`.
540 /// Is negative if there is no intersection. If the intersection type is `MPSIntersectionTypeAny`,
541 /// is a positive value for a hit or a negative value for a miss.
542 pub distance: f32,
543 /// Index of the intersected primitive. Undefined if the ray does not intersect a primitive or
544 /// if the intersection type is `MPSIntersectionTypeAny`.
545 pub primitive_index: u32,
546 /// The first two barycentric coordinates `U` and `V` of the intersection point.
547 /// The third coordinate `W = 1 - U - V`. Undefined if the ray does not intersect a primitive or
548 /// if the intersection type is `MPSIntersectionTypeAny`.
549 pub coordinates: [f32; 2],
550 }
551