1 use crate::{ 2 binding_model::{ 3 BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError, 4 }, 5 command::{ 6 bind::Binder, 7 end_pipeline_statistics_query, 8 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState}, 9 BasePass, BasePassRef, CommandBuffer, CommandEncoderError, CommandEncoderStatus, 10 MapPassErr, PassErrorScope, QueryUseError, StateChange, 11 }, 12 device::MissingDownlevelFlags, 13 error::{ErrorFormatter, PrettyError}, 14 hub::{Global, GlobalIdentityHandlerFactory, HalApi, Storage, Token}, 15 id, 16 init_tracker::MemoryInitKind, 17 resource::{Buffer, Texture}, 18 track::{StatefulTrackerSubset, TrackerSet, UsageConflict, UseExtendError}, 19 validation::{check_buffer_usage, MissingBufferUsageError}, 20 Label, 21 }; 22 23 use hal::CommandEncoder as _; 24 use thiserror::Error; 25 26 use std::{fmt, mem, str}; 27 28 #[doc(hidden)] 29 #[derive(Clone, Copy, Debug)] 30 #[cfg_attr( 31 any(feature = "serial-pass", feature = "trace"), 32 derive(serde::Serialize) 33 )] 34 #[cfg_attr( 35 any(feature = "serial-pass", feature = "replay"), 36 derive(serde::Deserialize) 37 )] 38 pub enum ComputeCommand { 39 SetBindGroup { 40 index: u8, 41 num_dynamic_offsets: u8, 42 bind_group_id: id::BindGroupId, 43 }, 44 SetPipeline(id::ComputePipelineId), 45 SetPushConstant { 46 offset: u32, 47 size_bytes: u32, 48 values_offset: u32, 49 }, 50 Dispatch([u32; 3]), 51 DispatchIndirect { 52 buffer_id: id::BufferId, 53 offset: wgt::BufferAddress, 54 }, 55 PushDebugGroup { 56 color: u32, 57 len: usize, 58 }, 59 PopDebugGroup, 60 InsertDebugMarker { 61 color: u32, 62 len: usize, 63 }, 64 WriteTimestamp { 65 query_set_id: id::QuerySetId, 66 query_index: u32, 67 }, 68 BeginPipelineStatisticsQuery { 69 query_set_id: id::QuerySetId, 70 query_index: u32, 71 }, 72 EndPipelineStatisticsQuery, 73 } 74 75 #[cfg_attr(feature = "serial-pass", derive(serde::Deserialize, serde::Serialize))] 76 pub struct ComputePass { 77 base: BasePass<ComputeCommand>, 78 parent_id: id::CommandEncoderId, 79 } 80 81 impl ComputePass { new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self82 pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self { 83 Self { 84 base: BasePass::new(&desc.label), 85 parent_id, 86 } 87 } 88 parent_id(&self) -> id::CommandEncoderId89 pub fn parent_id(&self) -> id::CommandEncoderId { 90 self.parent_id 91 } 92 93 #[cfg(feature = "trace")] into_command(self) -> crate::device::trace::Command94 pub fn into_command(self) -> crate::device::trace::Command { 95 crate::device::trace::Command::RunComputePass { base: self.base } 96 } 97 } 98 99 impl fmt::Debug for ComputePass { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 101 write!( 102 f, 103 "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}", 104 self.parent_id, 105 self.base.commands.len(), 106 self.base.dynamic_offsets.len() 107 ) 108 } 109 } 110 111 #[derive(Clone, Debug, Default)] 112 pub struct ComputePassDescriptor<'a> { 113 pub label: Label<'a>, 114 } 115 116 #[derive(Clone, Debug, Error, PartialEq)] 117 pub enum DispatchError { 118 #[error("compute pipeline must be set")] 119 MissingPipeline, 120 #[error("current compute pipeline has a layout which is incompatible with a currently set bind group, first differing at entry index {index}")] 121 IncompatibleBindGroup { 122 index: u32, 123 //expected: BindGroupLayoutId, 124 //provided: Option<(BindGroupLayoutId, BindGroupId)>, 125 }, 126 #[error( 127 "each current dispatch group size dimension ({current:?}) must be less or equal to {limit}" 128 )] 129 InvalidGroupSize { current: [u32; 3], limit: u32 }, 130 #[error(transparent)] 131 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch), 132 } 133 134 /// Error encountered when performing a compute pass. 135 #[derive(Clone, Debug, Error)] 136 pub enum ComputePassErrorInner { 137 #[error(transparent)] 138 Encoder(#[from] CommandEncoderError), 139 #[error("bind group {0:?} is invalid")] 140 InvalidBindGroup(id::BindGroupId), 141 #[error("bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")] 142 BindGroupIndexOutOfRange { index: u8, max: u32 }, 143 #[error("compute pipeline {0:?} is invalid")] 144 InvalidPipeline(id::ComputePipelineId), 145 #[error("QuerySet {0:?} is invalid")] 146 InvalidQuerySet(id::QuerySetId), 147 #[error("indirect buffer {0:?} is invalid or destroyed")] 148 InvalidIndirectBuffer(id::BufferId), 149 #[error("indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")] 150 IndirectBufferOverrun { 151 offset: u64, 152 end_offset: u64, 153 buffer_size: u64, 154 }, 155 #[error("buffer {0:?} is invalid or destroyed")] 156 InvalidBuffer(id::BufferId), 157 #[error(transparent)] 158 ResourceUsageConflict(#[from] UsageConflict), 159 #[error(transparent)] 160 MissingBufferUsage(#[from] MissingBufferUsageError), 161 #[error("cannot pop debug group, because number of pushed debug groups is zero")] 162 InvalidPopDebugGroup, 163 #[error(transparent)] 164 Dispatch(#[from] DispatchError), 165 #[error(transparent)] 166 Bind(#[from] BindError), 167 #[error(transparent)] 168 PushConstants(#[from] PushConstantUploadError), 169 #[error(transparent)] 170 QueryUse(#[from] QueryUseError), 171 #[error(transparent)] 172 MissingDownlevelFlags(#[from] MissingDownlevelFlags), 173 } 174 175 impl PrettyError for ComputePassErrorInner { fmt_pretty(&self, fmt: &mut ErrorFormatter)176 fn fmt_pretty(&self, fmt: &mut ErrorFormatter) { 177 fmt.error(self); 178 match *self { 179 Self::InvalidBindGroup(id) => { 180 fmt.bind_group_label(&id); 181 } 182 Self::InvalidPipeline(id) => { 183 fmt.compute_pipeline_label(&id); 184 } 185 Self::InvalidIndirectBuffer(id) => { 186 fmt.buffer_label(&id); 187 } 188 _ => {} 189 }; 190 } 191 } 192 193 /// Error encountered when performing a compute pass. 194 #[derive(Clone, Debug, Error)] 195 #[error("{scope}")] 196 pub struct ComputePassError { 197 pub scope: PassErrorScope, 198 #[source] 199 inner: ComputePassErrorInner, 200 } 201 impl PrettyError for ComputePassError { fmt_pretty(&self, fmt: &mut ErrorFormatter)202 fn fmt_pretty(&self, fmt: &mut ErrorFormatter) { 203 // This error is wrapper for the inner error, 204 // but the scope has useful labels 205 fmt.error(self); 206 self.scope.fmt_pretty(fmt); 207 } 208 } 209 210 impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E> 211 where 212 E: Into<ComputePassErrorInner>, 213 { map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError>214 fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> { 215 self.map_err(|inner| ComputePassError { 216 scope, 217 inner: inner.into(), 218 }) 219 } 220 } 221 222 #[derive(Debug)] 223 struct State { 224 binder: Binder, 225 pipeline: StateChange<id::ComputePipelineId>, 226 trackers: StatefulTrackerSubset, 227 debug_scope_depth: u32, 228 } 229 230 impl State { is_ready(&self) -> Result<(), DispatchError>231 fn is_ready(&self) -> Result<(), DispatchError> { 232 let bind_mask = self.binder.invalid_mask(); 233 if bind_mask != 0 { 234 //let (expected, provided) = self.binder.entries[index as usize].info(); 235 return Err(DispatchError::IncompatibleBindGroup { 236 index: bind_mask.trailing_zeros(), 237 }); 238 } 239 if self.pipeline.is_unset() { 240 return Err(DispatchError::MissingPipeline); 241 } 242 self.binder.check_late_buffer_bindings()?; 243 244 Ok(()) 245 } 246 flush_states<A: HalApi>( &mut self, raw_encoder: &mut A::CommandEncoder, base_trackers: &mut TrackerSet, bind_group_guard: &Storage<BindGroup<A>, id::BindGroupId>, buffer_guard: &Storage<Buffer<A>, id::BufferId>, texture_guard: &Storage<Texture<A>, id::TextureId>, ) -> Result<(), UsageConflict>247 fn flush_states<A: HalApi>( 248 &mut self, 249 raw_encoder: &mut A::CommandEncoder, 250 base_trackers: &mut TrackerSet, 251 bind_group_guard: &Storage<BindGroup<A>, id::BindGroupId>, 252 buffer_guard: &Storage<Buffer<A>, id::BufferId>, 253 texture_guard: &Storage<Texture<A>, id::TextureId>, 254 ) -> Result<(), UsageConflict> { 255 for id in self.binder.list_active() { 256 self.trackers.merge_extend(&bind_group_guard[id].used)?; 257 //Note: stateless trackers are not merged: the lifetime reference 258 // is held to the bind group itself. 259 } 260 261 log::trace!("Encoding dispatch barriers"); 262 263 CommandBuffer::insert_barriers( 264 raw_encoder, 265 base_trackers, 266 &self.trackers.buffers, 267 &self.trackers.textures, 268 buffer_guard, 269 texture_guard, 270 ); 271 272 self.trackers.clear(); 273 Ok(()) 274 } 275 } 276 277 // Common routines between render/compute 278 279 impl<G: GlobalIdentityHandlerFactory> Global<G> { command_encoder_run_compute_pass<A: HalApi>( &self, encoder_id: id::CommandEncoderId, pass: &ComputePass, ) -> Result<(), ComputePassError>280 pub fn command_encoder_run_compute_pass<A: HalApi>( 281 &self, 282 encoder_id: id::CommandEncoderId, 283 pass: &ComputePass, 284 ) -> Result<(), ComputePassError> { 285 self.command_encoder_run_compute_pass_impl::<A>(encoder_id, pass.base.as_ref()) 286 } 287 288 #[doc(hidden)] command_encoder_run_compute_pass_impl<A: HalApi>( &self, encoder_id: id::CommandEncoderId, base: BasePassRef<ComputeCommand>, ) -> Result<(), ComputePassError>289 pub fn command_encoder_run_compute_pass_impl<A: HalApi>( 290 &self, 291 encoder_id: id::CommandEncoderId, 292 base: BasePassRef<ComputeCommand>, 293 ) -> Result<(), ComputePassError> { 294 profiling::scope!("run_compute_pass", "CommandEncoder"); 295 let init_scope = PassErrorScope::Pass(encoder_id); 296 297 let hub = A::hub(self); 298 let mut token = Token::root(); 299 300 let (device_guard, mut token) = hub.devices.read(&mut token); 301 302 let (mut cmd_buf_guard, mut token) = hub.command_buffers.write(&mut token); 303 let cmd_buf = CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id) 304 .map_pass_err(init_scope)?; 305 // will be reset to true if recording is done without errors 306 cmd_buf.status = CommandEncoderStatus::Error; 307 let raw = cmd_buf.encoder.open(); 308 309 let device = &device_guard[cmd_buf.device_id.value]; 310 311 #[cfg(feature = "trace")] 312 if let Some(ref mut list) = cmd_buf.commands { 313 list.push(crate::device::trace::Command::RunComputePass { 314 base: BasePass::from_ref(base), 315 }); 316 } 317 318 let (_, mut token) = hub.render_bundles.read(&mut token); 319 let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token); 320 let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token); 321 let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token); 322 let (query_set_guard, mut token) = hub.query_sets.read(&mut token); 323 let (buffer_guard, mut token) = hub.buffers.read(&mut token); 324 let (texture_guard, _) = hub.textures.read(&mut token); 325 326 let mut state = State { 327 binder: Binder::new(), 328 pipeline: StateChange::new(), 329 trackers: StatefulTrackerSubset::new(A::VARIANT), 330 debug_scope_depth: 0, 331 }; 332 let mut temp_offsets = Vec::new(); 333 let mut dynamic_offset_count = 0; 334 let mut string_offset = 0; 335 let mut active_query = None; 336 337 let hal_desc = hal::ComputePassDescriptor { label: base.label }; 338 unsafe { 339 raw.begin_compute_pass(&hal_desc); 340 } 341 342 // Immediate texture inits required because of prior discards. Need to be inserted before texture reads. 343 let mut pending_discard_init_fixups = SurfacesInDiscardState::new(); 344 345 for command in base.commands { 346 match *command { 347 ComputeCommand::SetBindGroup { 348 index, 349 num_dynamic_offsets, 350 bind_group_id, 351 } => { 352 let scope = PassErrorScope::SetBindGroup(bind_group_id); 353 354 let max_bind_groups = cmd_buf.limits.max_bind_groups; 355 if (index as u32) >= max_bind_groups { 356 return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { 357 index, 358 max: max_bind_groups, 359 }) 360 .map_pass_err(scope); 361 } 362 363 temp_offsets.clear(); 364 temp_offsets.extend_from_slice( 365 &base.dynamic_offsets[dynamic_offset_count 366 ..dynamic_offset_count + (num_dynamic_offsets as usize)], 367 ); 368 dynamic_offset_count += num_dynamic_offsets as usize; 369 370 let bind_group = cmd_buf 371 .trackers 372 .bind_groups 373 .use_extend(&*bind_group_guard, bind_group_id, (), ()) 374 .map_err(|_| ComputePassErrorInner::InvalidBindGroup(bind_group_id)) 375 .map_pass_err(scope)?; 376 bind_group 377 .validate_dynamic_bindings(&temp_offsets, &cmd_buf.limits) 378 .map_pass_err(scope)?; 379 380 cmd_buf.buffer_memory_init_actions.extend( 381 bind_group.used_buffer_ranges.iter().filter_map( 382 |action| match buffer_guard.get(action.id) { 383 Ok(buffer) => buffer.initialization_status.check_action(action), 384 Err(_) => None, 385 }, 386 ), 387 ); 388 389 for action in bind_group.used_texture_ranges.iter() { 390 pending_discard_init_fixups.extend( 391 cmd_buf 392 .texture_memory_actions 393 .register_init_action(action, &texture_guard), 394 ); 395 } 396 397 let pipeline_layout_id = state.binder.pipeline_layout_id; 398 let entries = state.binder.assign_group( 399 index as usize, 400 id::Valid(bind_group_id), 401 bind_group, 402 &temp_offsets, 403 ); 404 if !entries.is_empty() { 405 let pipeline_layout = 406 &pipeline_layout_guard[pipeline_layout_id.unwrap()].raw; 407 for (i, e) in entries.iter().enumerate() { 408 let raw_bg = &bind_group_guard[e.group_id.as_ref().unwrap().value].raw; 409 unsafe { 410 raw.set_bind_group( 411 pipeline_layout, 412 index as u32 + i as u32, 413 raw_bg, 414 &e.dynamic_offsets, 415 ); 416 } 417 } 418 } 419 } 420 ComputeCommand::SetPipeline(pipeline_id) => { 421 let scope = PassErrorScope::SetPipelineCompute(pipeline_id); 422 423 if state.pipeline.set_and_check_redundant(pipeline_id) { 424 continue; 425 } 426 427 let pipeline = cmd_buf 428 .trackers 429 .compute_pipes 430 .use_extend(&*pipeline_guard, pipeline_id, (), ()) 431 .map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id)) 432 .map_pass_err(scope)?; 433 434 unsafe { 435 raw.set_compute_pipeline(&pipeline.raw); 436 } 437 438 // Rebind resources 439 if state.binder.pipeline_layout_id != Some(pipeline.layout_id.value) { 440 let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id.value]; 441 442 let (start_index, entries) = state.binder.change_pipeline_layout( 443 &*pipeline_layout_guard, 444 pipeline.layout_id.value, 445 &pipeline.late_sized_buffer_groups, 446 ); 447 if !entries.is_empty() { 448 for (i, e) in entries.iter().enumerate() { 449 let raw_bg = 450 &bind_group_guard[e.group_id.as_ref().unwrap().value].raw; 451 unsafe { 452 raw.set_bind_group( 453 &pipeline_layout.raw, 454 start_index as u32 + i as u32, 455 raw_bg, 456 &e.dynamic_offsets, 457 ); 458 } 459 } 460 } 461 462 // Clear push constant ranges 463 let non_overlapping = super::bind::compute_nonoverlapping_ranges( 464 &pipeline_layout.push_constant_ranges, 465 ); 466 for range in non_overlapping { 467 let offset = range.range.start; 468 let size_bytes = range.range.end - offset; 469 super::push_constant_clear( 470 offset, 471 size_bytes, 472 |clear_offset, clear_data| unsafe { 473 raw.set_push_constants( 474 &pipeline_layout.raw, 475 wgt::ShaderStages::COMPUTE, 476 clear_offset, 477 clear_data, 478 ); 479 }, 480 ); 481 } 482 } 483 } 484 ComputeCommand::SetPushConstant { 485 offset, 486 size_bytes, 487 values_offset, 488 } => { 489 let scope = PassErrorScope::SetPushConstant; 490 491 let end_offset_bytes = offset + size_bytes; 492 let values_end_offset = 493 (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; 494 let data_slice = 495 &base.push_constant_data[(values_offset as usize)..values_end_offset]; 496 497 let pipeline_layout_id = state 498 .binder 499 .pipeline_layout_id 500 //TODO: don't error here, lazily update the push constants 501 .ok_or(ComputePassErrorInner::Dispatch( 502 DispatchError::MissingPipeline, 503 )) 504 .map_pass_err(scope)?; 505 let pipeline_layout = &pipeline_layout_guard[pipeline_layout_id]; 506 507 pipeline_layout 508 .validate_push_constant_ranges( 509 wgt::ShaderStages::COMPUTE, 510 offset, 511 end_offset_bytes, 512 ) 513 .map_pass_err(scope)?; 514 515 unsafe { 516 raw.set_push_constants( 517 &pipeline_layout.raw, 518 wgt::ShaderStages::COMPUTE, 519 offset, 520 data_slice, 521 ); 522 } 523 } 524 ComputeCommand::Dispatch(groups) => { 525 let scope = PassErrorScope::Dispatch { 526 indirect: false, 527 pipeline: state.pipeline.last_state, 528 }; 529 530 fixup_discarded_surfaces( 531 pending_discard_init_fixups.drain(..), 532 raw, 533 &texture_guard, 534 &mut cmd_buf.trackers.textures, 535 device, 536 ); 537 538 state.is_ready().map_pass_err(scope)?; 539 state 540 .flush_states( 541 raw, 542 &mut cmd_buf.trackers, 543 &*bind_group_guard, 544 &*buffer_guard, 545 &*texture_guard, 546 ) 547 .map_pass_err(scope)?; 548 549 let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension; 550 551 if groups[0] > groups_size_limit 552 || groups[1] > groups_size_limit 553 || groups[2] > groups_size_limit 554 { 555 return Err(ComputePassErrorInner::Dispatch( 556 DispatchError::InvalidGroupSize { 557 current: groups, 558 limit: groups_size_limit, 559 }, 560 )) 561 .map_pass_err(scope); 562 } 563 564 unsafe { 565 raw.dispatch(groups); 566 } 567 } 568 ComputeCommand::DispatchIndirect { buffer_id, offset } => { 569 let scope = PassErrorScope::Dispatch { 570 indirect: true, 571 pipeline: state.pipeline.last_state, 572 }; 573 574 state.is_ready().map_pass_err(scope)?; 575 576 device 577 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION) 578 .map_pass_err(scope)?; 579 580 let indirect_buffer = state 581 .trackers 582 .buffers 583 .use_extend(&*buffer_guard, buffer_id, (), hal::BufferUses::INDIRECT) 584 .map_err(|_| ComputePassErrorInner::InvalidIndirectBuffer(buffer_id)) 585 .map_pass_err(scope)?; 586 check_buffer_usage(indirect_buffer.usage, wgt::BufferUsages::INDIRECT) 587 .map_pass_err(scope)?; 588 589 let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64; 590 if end_offset > indirect_buffer.size { 591 return Err(ComputePassErrorInner::IndirectBufferOverrun { 592 offset, 593 end_offset, 594 buffer_size: indirect_buffer.size, 595 }) 596 .map_pass_err(scope); 597 } 598 599 let buf_raw = indirect_buffer 600 .raw 601 .as_ref() 602 .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id)) 603 .map_pass_err(scope)?; 604 605 let stride = 3 * 4; // 3 integers, x/y/z group size 606 607 cmd_buf.buffer_memory_init_actions.extend( 608 indirect_buffer.initialization_status.create_action( 609 buffer_id, 610 offset..(offset + stride), 611 MemoryInitKind::NeedsInitializedMemory, 612 ), 613 ); 614 615 state 616 .flush_states( 617 raw, 618 &mut cmd_buf.trackers, 619 &*bind_group_guard, 620 &*buffer_guard, 621 &*texture_guard, 622 ) 623 .map_pass_err(scope)?; 624 unsafe { 625 raw.dispatch_indirect(buf_raw, offset); 626 } 627 } 628 ComputeCommand::PushDebugGroup { color: _, len } => { 629 state.debug_scope_depth += 1; 630 let label = 631 str::from_utf8(&base.string_data[string_offset..string_offset + len]) 632 .unwrap(); 633 string_offset += len; 634 unsafe { 635 raw.begin_debug_marker(label); 636 } 637 } 638 ComputeCommand::PopDebugGroup => { 639 let scope = PassErrorScope::PopDebugGroup; 640 641 if state.debug_scope_depth == 0 { 642 return Err(ComputePassErrorInner::InvalidPopDebugGroup) 643 .map_pass_err(scope); 644 } 645 state.debug_scope_depth -= 1; 646 unsafe { 647 raw.end_debug_marker(); 648 } 649 } 650 ComputeCommand::InsertDebugMarker { color: _, len } => { 651 let label = 652 str::from_utf8(&base.string_data[string_offset..string_offset + len]) 653 .unwrap(); 654 string_offset += len; 655 unsafe { raw.insert_debug_marker(label) } 656 } 657 ComputeCommand::WriteTimestamp { 658 query_set_id, 659 query_index, 660 } => { 661 let scope = PassErrorScope::WriteTimestamp; 662 663 let query_set = cmd_buf 664 .trackers 665 .query_sets 666 .use_extend(&*query_set_guard, query_set_id, (), ()) 667 .map_err(|e| match e { 668 UseExtendError::InvalidResource => { 669 ComputePassErrorInner::InvalidQuerySet(query_set_id) 670 } 671 _ => unreachable!(), 672 }) 673 .map_pass_err(scope)?; 674 675 query_set 676 .validate_and_write_timestamp(raw, query_set_id, query_index, None) 677 .map_pass_err(scope)?; 678 } 679 ComputeCommand::BeginPipelineStatisticsQuery { 680 query_set_id, 681 query_index, 682 } => { 683 let scope = PassErrorScope::BeginPipelineStatisticsQuery; 684 685 let query_set = cmd_buf 686 .trackers 687 .query_sets 688 .use_extend(&*query_set_guard, query_set_id, (), ()) 689 .map_err(|e| match e { 690 UseExtendError::InvalidResource => { 691 ComputePassErrorInner::InvalidQuerySet(query_set_id) 692 } 693 _ => unreachable!(), 694 }) 695 .map_pass_err(scope)?; 696 697 query_set 698 .validate_and_begin_pipeline_statistics_query( 699 raw, 700 query_set_id, 701 query_index, 702 None, 703 &mut active_query, 704 ) 705 .map_pass_err(scope)?; 706 } 707 ComputeCommand::EndPipelineStatisticsQuery => { 708 let scope = PassErrorScope::EndPipelineStatisticsQuery; 709 710 end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query) 711 .map_pass_err(scope)?; 712 } 713 } 714 } 715 716 unsafe { 717 raw.end_compute_pass(); 718 } 719 cmd_buf.status = CommandEncoderStatus::Recording; 720 721 // There can be entries left in pending_discard_init_fixups if a bind group was set, but not used (i.e. no Dispatch occurred) 722 // However, we already altered the discard/init_action state on this cmd_buf, so we need to apply the promised changes. 723 fixup_discarded_surfaces( 724 pending_discard_init_fixups.into_iter(), 725 raw, 726 &texture_guard, 727 &mut cmd_buf.trackers.textures, 728 device, 729 ); 730 731 Ok(()) 732 } 733 } 734 735 pub mod compute_ffi { 736 use super::{ComputeCommand, ComputePass}; 737 use crate::{id, RawString}; 738 use std::{convert::TryInto, ffi, slice}; 739 use wgt::{BufferAddress, DynamicOffset}; 740 741 /// # Safety 742 /// 743 /// This function is unsafe as there is no guarantee that the given pointer is 744 /// valid for `offset_length` elements. 745 #[no_mangle] wgpu_compute_pass_set_bind_group( pass: &mut ComputePass, index: u32, bind_group_id: id::BindGroupId, offsets: *const DynamicOffset, offset_length: usize, )746 pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group( 747 pass: &mut ComputePass, 748 index: u32, 749 bind_group_id: id::BindGroupId, 750 offsets: *const DynamicOffset, 751 offset_length: usize, 752 ) { 753 pass.base.commands.push(ComputeCommand::SetBindGroup { 754 index: index.try_into().unwrap(), 755 num_dynamic_offsets: offset_length.try_into().unwrap(), 756 bind_group_id, 757 }); 758 if offset_length != 0 { 759 pass.base 760 .dynamic_offsets 761 .extend_from_slice(slice::from_raw_parts(offsets, offset_length)); 762 } 763 } 764 765 #[no_mangle] wgpu_compute_pass_set_pipeline( pass: &mut ComputePass, pipeline_id: id::ComputePipelineId, )766 pub extern "C" fn wgpu_compute_pass_set_pipeline( 767 pass: &mut ComputePass, 768 pipeline_id: id::ComputePipelineId, 769 ) { 770 pass.base 771 .commands 772 .push(ComputeCommand::SetPipeline(pipeline_id)); 773 } 774 775 /// # Safety 776 /// 777 /// This function is unsafe as there is no guarantee that the given pointer is 778 /// valid for `size_bytes` bytes. 779 #[no_mangle] wgpu_compute_pass_set_push_constant( pass: &mut ComputePass, offset: u32, size_bytes: u32, data: *const u8, )780 pub unsafe extern "C" fn wgpu_compute_pass_set_push_constant( 781 pass: &mut ComputePass, 782 offset: u32, 783 size_bytes: u32, 784 data: *const u8, 785 ) { 786 assert_eq!( 787 offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1), 788 0, 789 "Push constant offset must be aligned to 4 bytes." 790 ); 791 assert_eq!( 792 size_bytes & (wgt::PUSH_CONSTANT_ALIGNMENT - 1), 793 0, 794 "Push constant size must be aligned to 4 bytes." 795 ); 796 let data_slice = slice::from_raw_parts(data, size_bytes as usize); 797 let value_offset = pass.base.push_constant_data.len().try_into().expect( 798 "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.", 799 ); 800 801 pass.base.push_constant_data.extend( 802 data_slice 803 .chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize) 804 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])), 805 ); 806 807 pass.base.commands.push(ComputeCommand::SetPushConstant { 808 offset, 809 size_bytes, 810 values_offset: value_offset, 811 }); 812 } 813 814 #[no_mangle] wgpu_compute_pass_dispatch( pass: &mut ComputePass, groups_x: u32, groups_y: u32, groups_z: u32, )815 pub extern "C" fn wgpu_compute_pass_dispatch( 816 pass: &mut ComputePass, 817 groups_x: u32, 818 groups_y: u32, 819 groups_z: u32, 820 ) { 821 pass.base 822 .commands 823 .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z])); 824 } 825 826 #[no_mangle] wgpu_compute_pass_dispatch_indirect( pass: &mut ComputePass, buffer_id: id::BufferId, offset: BufferAddress, )827 pub extern "C" fn wgpu_compute_pass_dispatch_indirect( 828 pass: &mut ComputePass, 829 buffer_id: id::BufferId, 830 offset: BufferAddress, 831 ) { 832 pass.base 833 .commands 834 .push(ComputeCommand::DispatchIndirect { buffer_id, offset }); 835 } 836 837 /// # Safety 838 /// 839 /// This function is unsafe as there is no guarantee that the given `label` 840 /// is a valid null-terminated string. 841 #[no_mangle] wgpu_compute_pass_push_debug_group( pass: &mut ComputePass, label: RawString, color: u32, )842 pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group( 843 pass: &mut ComputePass, 844 label: RawString, 845 color: u32, 846 ) { 847 let bytes = ffi::CStr::from_ptr(label).to_bytes(); 848 pass.base.string_data.extend_from_slice(bytes); 849 850 pass.base.commands.push(ComputeCommand::PushDebugGroup { 851 color, 852 len: bytes.len(), 853 }); 854 } 855 856 #[no_mangle] wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass)857 pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) { 858 pass.base.commands.push(ComputeCommand::PopDebugGroup); 859 } 860 861 /// # Safety 862 /// 863 /// This function is unsafe as there is no guarantee that the given `label` 864 /// is a valid null-terminated string. 865 #[no_mangle] wgpu_compute_pass_insert_debug_marker( pass: &mut ComputePass, label: RawString, color: u32, )866 pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker( 867 pass: &mut ComputePass, 868 label: RawString, 869 color: u32, 870 ) { 871 let bytes = ffi::CStr::from_ptr(label).to_bytes(); 872 pass.base.string_data.extend_from_slice(bytes); 873 874 pass.base.commands.push(ComputeCommand::InsertDebugMarker { 875 color, 876 len: bytes.len(), 877 }); 878 } 879 880 #[no_mangle] wgpu_compute_pass_write_timestamp( pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, )881 pub extern "C" fn wgpu_compute_pass_write_timestamp( 882 pass: &mut ComputePass, 883 query_set_id: id::QuerySetId, 884 query_index: u32, 885 ) { 886 pass.base.commands.push(ComputeCommand::WriteTimestamp { 887 query_set_id, 888 query_index, 889 }); 890 } 891 892 #[no_mangle] wgpu_compute_pass_begin_pipeline_statistics_query( pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, )893 pub extern "C" fn wgpu_compute_pass_begin_pipeline_statistics_query( 894 pass: &mut ComputePass, 895 query_set_id: id::QuerySetId, 896 query_index: u32, 897 ) { 898 pass.base 899 .commands 900 .push(ComputeCommand::BeginPipelineStatisticsQuery { 901 query_set_id, 902 query_index, 903 }); 904 } 905 906 #[no_mangle] wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass)907 pub extern "C" fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) { 908 pass.base 909 .commands 910 .push(ComputeCommand::EndPipelineStatisticsQuery); 911 } 912 } 913