1 use crate::{
2     binding_model::{
3         BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
4     },
5     command::{
6         bind::Binder,
7         end_pipeline_statistics_query,
8         memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
9         BasePass, BasePassRef, CommandBuffer, CommandEncoderError, CommandEncoderStatus,
10         MapPassErr, PassErrorScope, QueryUseError, StateChange,
11     },
12     device::MissingDownlevelFlags,
13     error::{ErrorFormatter, PrettyError},
14     hub::{Global, GlobalIdentityHandlerFactory, HalApi, Storage, Token},
15     id,
16     init_tracker::MemoryInitKind,
17     resource::{Buffer, Texture},
18     track::{StatefulTrackerSubset, TrackerSet, UsageConflict, UseExtendError},
19     validation::{check_buffer_usage, MissingBufferUsageError},
20     Label,
21 };
22 
23 use hal::CommandEncoder as _;
24 use thiserror::Error;
25 
26 use std::{fmt, mem, str};
27 
28 #[doc(hidden)]
29 #[derive(Clone, Copy, Debug)]
30 #[cfg_attr(
31     any(feature = "serial-pass", feature = "trace"),
32     derive(serde::Serialize)
33 )]
34 #[cfg_attr(
35     any(feature = "serial-pass", feature = "replay"),
36     derive(serde::Deserialize)
37 )]
38 pub enum ComputeCommand {
39     SetBindGroup {
40         index: u8,
41         num_dynamic_offsets: u8,
42         bind_group_id: id::BindGroupId,
43     },
44     SetPipeline(id::ComputePipelineId),
45     SetPushConstant {
46         offset: u32,
47         size_bytes: u32,
48         values_offset: u32,
49     },
50     Dispatch([u32; 3]),
51     DispatchIndirect {
52         buffer_id: id::BufferId,
53         offset: wgt::BufferAddress,
54     },
55     PushDebugGroup {
56         color: u32,
57         len: usize,
58     },
59     PopDebugGroup,
60     InsertDebugMarker {
61         color: u32,
62         len: usize,
63     },
64     WriteTimestamp {
65         query_set_id: id::QuerySetId,
66         query_index: u32,
67     },
68     BeginPipelineStatisticsQuery {
69         query_set_id: id::QuerySetId,
70         query_index: u32,
71     },
72     EndPipelineStatisticsQuery,
73 }
74 
75 #[cfg_attr(feature = "serial-pass", derive(serde::Deserialize, serde::Serialize))]
76 pub struct ComputePass {
77     base: BasePass<ComputeCommand>,
78     parent_id: id::CommandEncoderId,
79 }
80 
81 impl ComputePass {
new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self82     pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
83         Self {
84             base: BasePass::new(&desc.label),
85             parent_id,
86         }
87     }
88 
parent_id(&self) -> id::CommandEncoderId89     pub fn parent_id(&self) -> id::CommandEncoderId {
90         self.parent_id
91     }
92 
93     #[cfg(feature = "trace")]
into_command(self) -> crate::device::trace::Command94     pub fn into_command(self) -> crate::device::trace::Command {
95         crate::device::trace::Command::RunComputePass { base: self.base }
96     }
97 }
98 
99 impl fmt::Debug for ComputePass {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result100     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101         write!(
102             f,
103             "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
104             self.parent_id,
105             self.base.commands.len(),
106             self.base.dynamic_offsets.len()
107         )
108     }
109 }
110 
111 #[derive(Clone, Debug, Default)]
112 pub struct ComputePassDescriptor<'a> {
113     pub label: Label<'a>,
114 }
115 
116 #[derive(Clone, Debug, Error, PartialEq)]
117 pub enum DispatchError {
118     #[error("compute pipeline must be set")]
119     MissingPipeline,
120     #[error("current compute pipeline has a layout which is incompatible with a currently set bind group, first differing at entry index {index}")]
121     IncompatibleBindGroup {
122         index: u32,
123         //expected: BindGroupLayoutId,
124         //provided: Option<(BindGroupLayoutId, BindGroupId)>,
125     },
126     #[error(
127         "each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128     )]
129     InvalidGroupSize { current: [u32; 3], limit: u32 },
130     #[error(transparent)]
131     BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132 }
133 
134 /// Error encountered when performing a compute pass.
135 #[derive(Clone, Debug, Error)]
136 pub enum ComputePassErrorInner {
137     #[error(transparent)]
138     Encoder(#[from] CommandEncoderError),
139     #[error("bind group {0:?} is invalid")]
140     InvalidBindGroup(id::BindGroupId),
141     #[error("bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
142     BindGroupIndexOutOfRange { index: u8, max: u32 },
143     #[error("compute pipeline {0:?} is invalid")]
144     InvalidPipeline(id::ComputePipelineId),
145     #[error("QuerySet {0:?} is invalid")]
146     InvalidQuerySet(id::QuerySetId),
147     #[error("indirect buffer {0:?} is invalid or destroyed")]
148     InvalidIndirectBuffer(id::BufferId),
149     #[error("indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
150     IndirectBufferOverrun {
151         offset: u64,
152         end_offset: u64,
153         buffer_size: u64,
154     },
155     #[error("buffer {0:?} is invalid or destroyed")]
156     InvalidBuffer(id::BufferId),
157     #[error(transparent)]
158     ResourceUsageConflict(#[from] UsageConflict),
159     #[error(transparent)]
160     MissingBufferUsage(#[from] MissingBufferUsageError),
161     #[error("cannot pop debug group, because number of pushed debug groups is zero")]
162     InvalidPopDebugGroup,
163     #[error(transparent)]
164     Dispatch(#[from] DispatchError),
165     #[error(transparent)]
166     Bind(#[from] BindError),
167     #[error(transparent)]
168     PushConstants(#[from] PushConstantUploadError),
169     #[error(transparent)]
170     QueryUse(#[from] QueryUseError),
171     #[error(transparent)]
172     MissingDownlevelFlags(#[from] MissingDownlevelFlags),
173 }
174 
175 impl PrettyError for ComputePassErrorInner {
fmt_pretty(&self, fmt: &mut ErrorFormatter)176     fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
177         fmt.error(self);
178         match *self {
179             Self::InvalidBindGroup(id) => {
180                 fmt.bind_group_label(&id);
181             }
182             Self::InvalidPipeline(id) => {
183                 fmt.compute_pipeline_label(&id);
184             }
185             Self::InvalidIndirectBuffer(id) => {
186                 fmt.buffer_label(&id);
187             }
188             _ => {}
189         };
190     }
191 }
192 
193 /// Error encountered when performing a compute pass.
194 #[derive(Clone, Debug, Error)]
195 #[error("{scope}")]
196 pub struct ComputePassError {
197     pub scope: PassErrorScope,
198     #[source]
199     inner: ComputePassErrorInner,
200 }
201 impl PrettyError for ComputePassError {
fmt_pretty(&self, fmt: &mut ErrorFormatter)202     fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
203         // This error is wrapper for the inner error,
204         // but the scope has useful labels
205         fmt.error(self);
206         self.scope.fmt_pretty(fmt);
207     }
208 }
209 
210 impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
211 where
212     E: Into<ComputePassErrorInner>,
213 {
map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError>214     fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
215         self.map_err(|inner| ComputePassError {
216             scope,
217             inner: inner.into(),
218         })
219     }
220 }
221 
222 #[derive(Debug)]
223 struct State {
224     binder: Binder,
225     pipeline: StateChange<id::ComputePipelineId>,
226     trackers: StatefulTrackerSubset,
227     debug_scope_depth: u32,
228 }
229 
230 impl State {
is_ready(&self) -> Result<(), DispatchError>231     fn is_ready(&self) -> Result<(), DispatchError> {
232         let bind_mask = self.binder.invalid_mask();
233         if bind_mask != 0 {
234             //let (expected, provided) = self.binder.entries[index as usize].info();
235             return Err(DispatchError::IncompatibleBindGroup {
236                 index: bind_mask.trailing_zeros(),
237             });
238         }
239         if self.pipeline.is_unset() {
240             return Err(DispatchError::MissingPipeline);
241         }
242         self.binder.check_late_buffer_bindings()?;
243 
244         Ok(())
245     }
246 
flush_states<A: HalApi>( &mut self, raw_encoder: &mut A::CommandEncoder, base_trackers: &mut TrackerSet, bind_group_guard: &Storage<BindGroup<A>, id::BindGroupId>, buffer_guard: &Storage<Buffer<A>, id::BufferId>, texture_guard: &Storage<Texture<A>, id::TextureId>, ) -> Result<(), UsageConflict>247     fn flush_states<A: HalApi>(
248         &mut self,
249         raw_encoder: &mut A::CommandEncoder,
250         base_trackers: &mut TrackerSet,
251         bind_group_guard: &Storage<BindGroup<A>, id::BindGroupId>,
252         buffer_guard: &Storage<Buffer<A>, id::BufferId>,
253         texture_guard: &Storage<Texture<A>, id::TextureId>,
254     ) -> Result<(), UsageConflict> {
255         for id in self.binder.list_active() {
256             self.trackers.merge_extend(&bind_group_guard[id].used)?;
257             //Note: stateless trackers are not merged: the lifetime reference
258             // is held to the bind group itself.
259         }
260 
261         log::trace!("Encoding dispatch barriers");
262 
263         CommandBuffer::insert_barriers(
264             raw_encoder,
265             base_trackers,
266             &self.trackers.buffers,
267             &self.trackers.textures,
268             buffer_guard,
269             texture_guard,
270         );
271 
272         self.trackers.clear();
273         Ok(())
274     }
275 }
276 
277 // Common routines between render/compute
278 
279 impl<G: GlobalIdentityHandlerFactory> Global<G> {
command_encoder_run_compute_pass<A: HalApi>( &self, encoder_id: id::CommandEncoderId, pass: &ComputePass, ) -> Result<(), ComputePassError>280     pub fn command_encoder_run_compute_pass<A: HalApi>(
281         &self,
282         encoder_id: id::CommandEncoderId,
283         pass: &ComputePass,
284     ) -> Result<(), ComputePassError> {
285         self.command_encoder_run_compute_pass_impl::<A>(encoder_id, pass.base.as_ref())
286     }
287 
288     #[doc(hidden)]
command_encoder_run_compute_pass_impl<A: HalApi>( &self, encoder_id: id::CommandEncoderId, base: BasePassRef<ComputeCommand>, ) -> Result<(), ComputePassError>289     pub fn command_encoder_run_compute_pass_impl<A: HalApi>(
290         &self,
291         encoder_id: id::CommandEncoderId,
292         base: BasePassRef<ComputeCommand>,
293     ) -> Result<(), ComputePassError> {
294         profiling::scope!("run_compute_pass", "CommandEncoder");
295         let init_scope = PassErrorScope::Pass(encoder_id);
296 
297         let hub = A::hub(self);
298         let mut token = Token::root();
299 
300         let (device_guard, mut token) = hub.devices.read(&mut token);
301 
302         let (mut cmd_buf_guard, mut token) = hub.command_buffers.write(&mut token);
303         let cmd_buf = CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id)
304             .map_pass_err(init_scope)?;
305         // will be reset to true if recording is done without errors
306         cmd_buf.status = CommandEncoderStatus::Error;
307         let raw = cmd_buf.encoder.open();
308 
309         let device = &device_guard[cmd_buf.device_id.value];
310 
311         #[cfg(feature = "trace")]
312         if let Some(ref mut list) = cmd_buf.commands {
313             list.push(crate::device::trace::Command::RunComputePass {
314                 base: BasePass::from_ref(base),
315             });
316         }
317 
318         let (_, mut token) = hub.render_bundles.read(&mut token);
319         let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
320         let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
321         let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
322         let (query_set_guard, mut token) = hub.query_sets.read(&mut token);
323         let (buffer_guard, mut token) = hub.buffers.read(&mut token);
324         let (texture_guard, _) = hub.textures.read(&mut token);
325 
326         let mut state = State {
327             binder: Binder::new(),
328             pipeline: StateChange::new(),
329             trackers: StatefulTrackerSubset::new(A::VARIANT),
330             debug_scope_depth: 0,
331         };
332         let mut temp_offsets = Vec::new();
333         let mut dynamic_offset_count = 0;
334         let mut string_offset = 0;
335         let mut active_query = None;
336 
337         let hal_desc = hal::ComputePassDescriptor { label: base.label };
338         unsafe {
339             raw.begin_compute_pass(&hal_desc);
340         }
341 
342         // Immediate texture inits required because of prior discards. Need to be inserted before texture reads.
343         let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
344 
345         for command in base.commands {
346             match *command {
347                 ComputeCommand::SetBindGroup {
348                     index,
349                     num_dynamic_offsets,
350                     bind_group_id,
351                 } => {
352                     let scope = PassErrorScope::SetBindGroup(bind_group_id);
353 
354                     let max_bind_groups = cmd_buf.limits.max_bind_groups;
355                     if (index as u32) >= max_bind_groups {
356                         return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
357                             index,
358                             max: max_bind_groups,
359                         })
360                         .map_pass_err(scope);
361                     }
362 
363                     temp_offsets.clear();
364                     temp_offsets.extend_from_slice(
365                         &base.dynamic_offsets[dynamic_offset_count
366                             ..dynamic_offset_count + (num_dynamic_offsets as usize)],
367                     );
368                     dynamic_offset_count += num_dynamic_offsets as usize;
369 
370                     let bind_group = cmd_buf
371                         .trackers
372                         .bind_groups
373                         .use_extend(&*bind_group_guard, bind_group_id, (), ())
374                         .map_err(|_| ComputePassErrorInner::InvalidBindGroup(bind_group_id))
375                         .map_pass_err(scope)?;
376                     bind_group
377                         .validate_dynamic_bindings(&temp_offsets, &cmd_buf.limits)
378                         .map_pass_err(scope)?;
379 
380                     cmd_buf.buffer_memory_init_actions.extend(
381                         bind_group.used_buffer_ranges.iter().filter_map(
382                             |action| match buffer_guard.get(action.id) {
383                                 Ok(buffer) => buffer.initialization_status.check_action(action),
384                                 Err(_) => None,
385                             },
386                         ),
387                     );
388 
389                     for action in bind_group.used_texture_ranges.iter() {
390                         pending_discard_init_fixups.extend(
391                             cmd_buf
392                                 .texture_memory_actions
393                                 .register_init_action(action, &texture_guard),
394                         );
395                     }
396 
397                     let pipeline_layout_id = state.binder.pipeline_layout_id;
398                     let entries = state.binder.assign_group(
399                         index as usize,
400                         id::Valid(bind_group_id),
401                         bind_group,
402                         &temp_offsets,
403                     );
404                     if !entries.is_empty() {
405                         let pipeline_layout =
406                             &pipeline_layout_guard[pipeline_layout_id.unwrap()].raw;
407                         for (i, e) in entries.iter().enumerate() {
408                             let raw_bg = &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
409                             unsafe {
410                                 raw.set_bind_group(
411                                     pipeline_layout,
412                                     index as u32 + i as u32,
413                                     raw_bg,
414                                     &e.dynamic_offsets,
415                                 );
416                             }
417                         }
418                     }
419                 }
420                 ComputeCommand::SetPipeline(pipeline_id) => {
421                     let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
422 
423                     if state.pipeline.set_and_check_redundant(pipeline_id) {
424                         continue;
425                     }
426 
427                     let pipeline = cmd_buf
428                         .trackers
429                         .compute_pipes
430                         .use_extend(&*pipeline_guard, pipeline_id, (), ())
431                         .map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id))
432                         .map_pass_err(scope)?;
433 
434                     unsafe {
435                         raw.set_compute_pipeline(&pipeline.raw);
436                     }
437 
438                     // Rebind resources
439                     if state.binder.pipeline_layout_id != Some(pipeline.layout_id.value) {
440                         let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id.value];
441 
442                         let (start_index, entries) = state.binder.change_pipeline_layout(
443                             &*pipeline_layout_guard,
444                             pipeline.layout_id.value,
445                             &pipeline.late_sized_buffer_groups,
446                         );
447                         if !entries.is_empty() {
448                             for (i, e) in entries.iter().enumerate() {
449                                 let raw_bg =
450                                     &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
451                                 unsafe {
452                                     raw.set_bind_group(
453                                         &pipeline_layout.raw,
454                                         start_index as u32 + i as u32,
455                                         raw_bg,
456                                         &e.dynamic_offsets,
457                                     );
458                                 }
459                             }
460                         }
461 
462                         // Clear push constant ranges
463                         let non_overlapping = super::bind::compute_nonoverlapping_ranges(
464                             &pipeline_layout.push_constant_ranges,
465                         );
466                         for range in non_overlapping {
467                             let offset = range.range.start;
468                             let size_bytes = range.range.end - offset;
469                             super::push_constant_clear(
470                                 offset,
471                                 size_bytes,
472                                 |clear_offset, clear_data| unsafe {
473                                     raw.set_push_constants(
474                                         &pipeline_layout.raw,
475                                         wgt::ShaderStages::COMPUTE,
476                                         clear_offset,
477                                         clear_data,
478                                     );
479                                 },
480                             );
481                         }
482                     }
483                 }
484                 ComputeCommand::SetPushConstant {
485                     offset,
486                     size_bytes,
487                     values_offset,
488                 } => {
489                     let scope = PassErrorScope::SetPushConstant;
490 
491                     let end_offset_bytes = offset + size_bytes;
492                     let values_end_offset =
493                         (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
494                     let data_slice =
495                         &base.push_constant_data[(values_offset as usize)..values_end_offset];
496 
497                     let pipeline_layout_id = state
498                         .binder
499                         .pipeline_layout_id
500                         //TODO: don't error here, lazily update the push constants
501                         .ok_or(ComputePassErrorInner::Dispatch(
502                             DispatchError::MissingPipeline,
503                         ))
504                         .map_pass_err(scope)?;
505                     let pipeline_layout = &pipeline_layout_guard[pipeline_layout_id];
506 
507                     pipeline_layout
508                         .validate_push_constant_ranges(
509                             wgt::ShaderStages::COMPUTE,
510                             offset,
511                             end_offset_bytes,
512                         )
513                         .map_pass_err(scope)?;
514 
515                     unsafe {
516                         raw.set_push_constants(
517                             &pipeline_layout.raw,
518                             wgt::ShaderStages::COMPUTE,
519                             offset,
520                             data_slice,
521                         );
522                     }
523                 }
524                 ComputeCommand::Dispatch(groups) => {
525                     let scope = PassErrorScope::Dispatch {
526                         indirect: false,
527                         pipeline: state.pipeline.last_state,
528                     };
529 
530                     fixup_discarded_surfaces(
531                         pending_discard_init_fixups.drain(..),
532                         raw,
533                         &texture_guard,
534                         &mut cmd_buf.trackers.textures,
535                         device,
536                     );
537 
538                     state.is_ready().map_pass_err(scope)?;
539                     state
540                         .flush_states(
541                             raw,
542                             &mut cmd_buf.trackers,
543                             &*bind_group_guard,
544                             &*buffer_guard,
545                             &*texture_guard,
546                         )
547                         .map_pass_err(scope)?;
548 
549                     let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
550 
551                     if groups[0] > groups_size_limit
552                         || groups[1] > groups_size_limit
553                         || groups[2] > groups_size_limit
554                     {
555                         return Err(ComputePassErrorInner::Dispatch(
556                             DispatchError::InvalidGroupSize {
557                                 current: groups,
558                                 limit: groups_size_limit,
559                             },
560                         ))
561                         .map_pass_err(scope);
562                     }
563 
564                     unsafe {
565                         raw.dispatch(groups);
566                     }
567                 }
568                 ComputeCommand::DispatchIndirect { buffer_id, offset } => {
569                     let scope = PassErrorScope::Dispatch {
570                         indirect: true,
571                         pipeline: state.pipeline.last_state,
572                     };
573 
574                     state.is_ready().map_pass_err(scope)?;
575 
576                     device
577                         .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
578                         .map_pass_err(scope)?;
579 
580                     let indirect_buffer = state
581                         .trackers
582                         .buffers
583                         .use_extend(&*buffer_guard, buffer_id, (), hal::BufferUses::INDIRECT)
584                         .map_err(|_| ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
585                         .map_pass_err(scope)?;
586                     check_buffer_usage(indirect_buffer.usage, wgt::BufferUsages::INDIRECT)
587                         .map_pass_err(scope)?;
588 
589                     let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
590                     if end_offset > indirect_buffer.size {
591                         return Err(ComputePassErrorInner::IndirectBufferOverrun {
592                             offset,
593                             end_offset,
594                             buffer_size: indirect_buffer.size,
595                         })
596                         .map_pass_err(scope);
597                     }
598 
599                     let buf_raw = indirect_buffer
600                         .raw
601                         .as_ref()
602                         .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
603                         .map_pass_err(scope)?;
604 
605                     let stride = 3 * 4; // 3 integers, x/y/z group size
606 
607                     cmd_buf.buffer_memory_init_actions.extend(
608                         indirect_buffer.initialization_status.create_action(
609                             buffer_id,
610                             offset..(offset + stride),
611                             MemoryInitKind::NeedsInitializedMemory,
612                         ),
613                     );
614 
615                     state
616                         .flush_states(
617                             raw,
618                             &mut cmd_buf.trackers,
619                             &*bind_group_guard,
620                             &*buffer_guard,
621                             &*texture_guard,
622                         )
623                         .map_pass_err(scope)?;
624                     unsafe {
625                         raw.dispatch_indirect(buf_raw, offset);
626                     }
627                 }
628                 ComputeCommand::PushDebugGroup { color: _, len } => {
629                     state.debug_scope_depth += 1;
630                     let label =
631                         str::from_utf8(&base.string_data[string_offset..string_offset + len])
632                             .unwrap();
633                     string_offset += len;
634                     unsafe {
635                         raw.begin_debug_marker(label);
636                     }
637                 }
638                 ComputeCommand::PopDebugGroup => {
639                     let scope = PassErrorScope::PopDebugGroup;
640 
641                     if state.debug_scope_depth == 0 {
642                         return Err(ComputePassErrorInner::InvalidPopDebugGroup)
643                             .map_pass_err(scope);
644                     }
645                     state.debug_scope_depth -= 1;
646                     unsafe {
647                         raw.end_debug_marker();
648                     }
649                 }
650                 ComputeCommand::InsertDebugMarker { color: _, len } => {
651                     let label =
652                         str::from_utf8(&base.string_data[string_offset..string_offset + len])
653                             .unwrap();
654                     string_offset += len;
655                     unsafe { raw.insert_debug_marker(label) }
656                 }
657                 ComputeCommand::WriteTimestamp {
658                     query_set_id,
659                     query_index,
660                 } => {
661                     let scope = PassErrorScope::WriteTimestamp;
662 
663                     let query_set = cmd_buf
664                         .trackers
665                         .query_sets
666                         .use_extend(&*query_set_guard, query_set_id, (), ())
667                         .map_err(|e| match e {
668                             UseExtendError::InvalidResource => {
669                                 ComputePassErrorInner::InvalidQuerySet(query_set_id)
670                             }
671                             _ => unreachable!(),
672                         })
673                         .map_pass_err(scope)?;
674 
675                     query_set
676                         .validate_and_write_timestamp(raw, query_set_id, query_index, None)
677                         .map_pass_err(scope)?;
678                 }
679                 ComputeCommand::BeginPipelineStatisticsQuery {
680                     query_set_id,
681                     query_index,
682                 } => {
683                     let scope = PassErrorScope::BeginPipelineStatisticsQuery;
684 
685                     let query_set = cmd_buf
686                         .trackers
687                         .query_sets
688                         .use_extend(&*query_set_guard, query_set_id, (), ())
689                         .map_err(|e| match e {
690                             UseExtendError::InvalidResource => {
691                                 ComputePassErrorInner::InvalidQuerySet(query_set_id)
692                             }
693                             _ => unreachable!(),
694                         })
695                         .map_pass_err(scope)?;
696 
697                     query_set
698                         .validate_and_begin_pipeline_statistics_query(
699                             raw,
700                             query_set_id,
701                             query_index,
702                             None,
703                             &mut active_query,
704                         )
705                         .map_pass_err(scope)?;
706                 }
707                 ComputeCommand::EndPipelineStatisticsQuery => {
708                     let scope = PassErrorScope::EndPipelineStatisticsQuery;
709 
710                     end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
711                         .map_pass_err(scope)?;
712                 }
713             }
714         }
715 
716         unsafe {
717             raw.end_compute_pass();
718         }
719         cmd_buf.status = CommandEncoderStatus::Recording;
720 
721         // There can be entries left in pending_discard_init_fixups if a bind group was set, but not used (i.e. no Dispatch occurred)
722         // However, we already altered the discard/init_action state on this cmd_buf, so we need to apply the promised changes.
723         fixup_discarded_surfaces(
724             pending_discard_init_fixups.into_iter(),
725             raw,
726             &texture_guard,
727             &mut cmd_buf.trackers.textures,
728             device,
729         );
730 
731         Ok(())
732     }
733 }
734 
735 pub mod compute_ffi {
736     use super::{ComputeCommand, ComputePass};
737     use crate::{id, RawString};
738     use std::{convert::TryInto, ffi, slice};
739     use wgt::{BufferAddress, DynamicOffset};
740 
741     /// # Safety
742     ///
743     /// This function is unsafe as there is no guarantee that the given pointer is
744     /// valid for `offset_length` elements.
745     #[no_mangle]
wgpu_compute_pass_set_bind_group( pass: &mut ComputePass, index: u32, bind_group_id: id::BindGroupId, offsets: *const DynamicOffset, offset_length: usize, )746     pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group(
747         pass: &mut ComputePass,
748         index: u32,
749         bind_group_id: id::BindGroupId,
750         offsets: *const DynamicOffset,
751         offset_length: usize,
752     ) {
753         pass.base.commands.push(ComputeCommand::SetBindGroup {
754             index: index.try_into().unwrap(),
755             num_dynamic_offsets: offset_length.try_into().unwrap(),
756             bind_group_id,
757         });
758         if offset_length != 0 {
759             pass.base
760                 .dynamic_offsets
761                 .extend_from_slice(slice::from_raw_parts(offsets, offset_length));
762         }
763     }
764 
765     #[no_mangle]
wgpu_compute_pass_set_pipeline( pass: &mut ComputePass, pipeline_id: id::ComputePipelineId, )766     pub extern "C" fn wgpu_compute_pass_set_pipeline(
767         pass: &mut ComputePass,
768         pipeline_id: id::ComputePipelineId,
769     ) {
770         pass.base
771             .commands
772             .push(ComputeCommand::SetPipeline(pipeline_id));
773     }
774 
775     /// # Safety
776     ///
777     /// This function is unsafe as there is no guarantee that the given pointer is
778     /// valid for `size_bytes` bytes.
779     #[no_mangle]
wgpu_compute_pass_set_push_constant( pass: &mut ComputePass, offset: u32, size_bytes: u32, data: *const u8, )780     pub unsafe extern "C" fn wgpu_compute_pass_set_push_constant(
781         pass: &mut ComputePass,
782         offset: u32,
783         size_bytes: u32,
784         data: *const u8,
785     ) {
786         assert_eq!(
787             offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
788             0,
789             "Push constant offset must be aligned to 4 bytes."
790         );
791         assert_eq!(
792             size_bytes & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
793             0,
794             "Push constant size must be aligned to 4 bytes."
795         );
796         let data_slice = slice::from_raw_parts(data, size_bytes as usize);
797         let value_offset = pass.base.push_constant_data.len().try_into().expect(
798             "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.",
799         );
800 
801         pass.base.push_constant_data.extend(
802             data_slice
803                 .chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
804                 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
805         );
806 
807         pass.base.commands.push(ComputeCommand::SetPushConstant {
808             offset,
809             size_bytes,
810             values_offset: value_offset,
811         });
812     }
813 
814     #[no_mangle]
wgpu_compute_pass_dispatch( pass: &mut ComputePass, groups_x: u32, groups_y: u32, groups_z: u32, )815     pub extern "C" fn wgpu_compute_pass_dispatch(
816         pass: &mut ComputePass,
817         groups_x: u32,
818         groups_y: u32,
819         groups_z: u32,
820     ) {
821         pass.base
822             .commands
823             .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
824     }
825 
826     #[no_mangle]
wgpu_compute_pass_dispatch_indirect( pass: &mut ComputePass, buffer_id: id::BufferId, offset: BufferAddress, )827     pub extern "C" fn wgpu_compute_pass_dispatch_indirect(
828         pass: &mut ComputePass,
829         buffer_id: id::BufferId,
830         offset: BufferAddress,
831     ) {
832         pass.base
833             .commands
834             .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
835     }
836 
837     /// # Safety
838     ///
839     /// This function is unsafe as there is no guarantee that the given `label`
840     /// is a valid null-terminated string.
841     #[no_mangle]
wgpu_compute_pass_push_debug_group( pass: &mut ComputePass, label: RawString, color: u32, )842     pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group(
843         pass: &mut ComputePass,
844         label: RawString,
845         color: u32,
846     ) {
847         let bytes = ffi::CStr::from_ptr(label).to_bytes();
848         pass.base.string_data.extend_from_slice(bytes);
849 
850         pass.base.commands.push(ComputeCommand::PushDebugGroup {
851             color,
852             len: bytes.len(),
853         });
854     }
855 
856     #[no_mangle]
wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass)857     pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
858         pass.base.commands.push(ComputeCommand::PopDebugGroup);
859     }
860 
861     /// # Safety
862     ///
863     /// This function is unsafe as there is no guarantee that the given `label`
864     /// is a valid null-terminated string.
865     #[no_mangle]
wgpu_compute_pass_insert_debug_marker( pass: &mut ComputePass, label: RawString, color: u32, )866     pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker(
867         pass: &mut ComputePass,
868         label: RawString,
869         color: u32,
870     ) {
871         let bytes = ffi::CStr::from_ptr(label).to_bytes();
872         pass.base.string_data.extend_from_slice(bytes);
873 
874         pass.base.commands.push(ComputeCommand::InsertDebugMarker {
875             color,
876             len: bytes.len(),
877         });
878     }
879 
880     #[no_mangle]
wgpu_compute_pass_write_timestamp( pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, )881     pub extern "C" fn wgpu_compute_pass_write_timestamp(
882         pass: &mut ComputePass,
883         query_set_id: id::QuerySetId,
884         query_index: u32,
885     ) {
886         pass.base.commands.push(ComputeCommand::WriteTimestamp {
887             query_set_id,
888             query_index,
889         });
890     }
891 
892     #[no_mangle]
wgpu_compute_pass_begin_pipeline_statistics_query( pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, )893     pub extern "C" fn wgpu_compute_pass_begin_pipeline_statistics_query(
894         pass: &mut ComputePass,
895         query_set_id: id::QuerySetId,
896         query_index: u32,
897     ) {
898         pass.base
899             .commands
900             .push(ComputeCommand::BeginPipelineStatisticsQuery {
901                 query_set_id,
902                 query_index,
903             });
904     }
905 
906     #[no_mangle]
wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass)907     pub extern "C" fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) {
908         pass.base
909             .commands
910             .push(ComputeCommand::EndPipelineStatisticsQuery);
911     }
912 }
913