1 /*! Module analyzer.
2
3 Figures out the following properties:
4 - control flow uniformity
5 - texture/sampler pairs
6 - expression reference counts
7 !*/
8
9 use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
10 use crate::{
11 arena::{Arena, Handle},
12 proc::{ResolveContext, TypeResolution},
13 };
14 use std::ops;
15
16 pub type NonUniformResult = Option<Handle<crate::Expression>>;
17
18 bitflags::bitflags! {
19 /// Kinds of expressions that require uniform control flow.
20 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
21 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
22 pub struct UniformityRequirements: u8 {
23 const WORK_GROUP_BARRIER = 0x1;
24 const DERIVATIVE = 0x2;
25 const IMPLICIT_LEVEL = 0x4;
26 }
27 }
28
29 /// Uniform control flow characteristics.
30 #[derive(Clone, Debug)]
31 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
32 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
33 #[cfg_attr(test, derive(PartialEq))]
34 pub struct Uniformity {
35 /// A child expression with non-uniform result.
36 ///
37 /// This means, when the relevant invocations are scheduled on a compute unit,
38 /// they have to use vector registers to store an individual value
39 /// per invocation.
40 ///
41 /// Whenever the control flow is conditioned on such value,
42 /// the hardware needs to keep track of the mask of invocations,
43 /// and process all branches of the control flow.
44 ///
45 /// Any operations that depend on non-uniform results also produce non-uniform.
46 pub non_uniform_result: NonUniformResult,
47 /// If this expression requires uniform control flow, store the reason here.
48 pub requirements: UniformityRequirements,
49 }
50
51 impl Uniformity {
new() -> Self52 fn new() -> Self {
53 Uniformity {
54 non_uniform_result: None,
55 requirements: UniformityRequirements::empty(),
56 }
57 }
58 }
59
60 bitflags::bitflags! {
61 struct ExitFlags: u8 {
62 /// Control flow may return from the function, which makes all the
63 /// subsequent statements within the current function (only!)
64 /// to be executed in a non-uniform control flow.
65 const MAY_RETURN = 0x1;
66 /// Control flow may be killed. Anything after `Statement::Kill` is
67 /// considered inside non-uniform context.
68 const MAY_KILL = 0x2;
69 }
70 }
71
72 /// Uniformity characteristics of a function.
73 #[cfg_attr(test, derive(Debug, PartialEq))]
74 struct FunctionUniformity {
75 result: Uniformity,
76 exit: ExitFlags,
77 }
78
79 impl ops::BitOr for FunctionUniformity {
80 type Output = Self;
bitor(self, other: Self) -> Self81 fn bitor(self, other: Self) -> Self {
82 FunctionUniformity {
83 result: Uniformity {
84 non_uniform_result: self
85 .result
86 .non_uniform_result
87 .or(other.result.non_uniform_result),
88 requirements: self.result.requirements | other.result.requirements,
89 },
90 exit: self.exit | other.exit,
91 }
92 }
93 }
94
95 impl FunctionUniformity {
new() -> Self96 fn new() -> Self {
97 FunctionUniformity {
98 result: Uniformity::new(),
99 exit: ExitFlags::empty(),
100 }
101 }
102
103 /// Returns a disruptor based on the stored exit flags, if any.
exit_disruptor(&self) -> Option<UniformityDisruptor>104 fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
105 if self.exit.contains(ExitFlags::MAY_RETURN) {
106 Some(UniformityDisruptor::Return)
107 } else if self.exit.contains(ExitFlags::MAY_KILL) {
108 Some(UniformityDisruptor::Discard)
109 } else {
110 None
111 }
112 }
113 }
114
115 bitflags::bitflags! {
116 /// Indicates how a global variable is used.
117 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
118 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
119 pub struct GlobalUse: u8 {
120 /// Data will be read from the variable.
121 const READ = 0x1;
122 /// Data will be written to the variable.
123 const WRITE = 0x2;
124 /// The information about the data is queried.
125 const QUERY = 0x4;
126 }
127 }
128
129 #[derive(Clone, Debug, Eq, Hash, PartialEq)]
130 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
131 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
132 pub struct SamplingKey {
133 pub image: Handle<crate::GlobalVariable>,
134 pub sampler: Handle<crate::GlobalVariable>,
135 }
136
137 #[derive(Clone, Debug)]
138 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
139 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
140 pub struct ExpressionInfo {
141 pub uniformity: Uniformity,
142 pub ref_count: usize,
143 assignable_global: Option<Handle<crate::GlobalVariable>>,
144 pub ty: TypeResolution,
145 }
146
147 impl ExpressionInfo {
new() -> Self148 fn new() -> Self {
149 ExpressionInfo {
150 uniformity: Uniformity::new(),
151 ref_count: 0,
152 assignable_global: None,
153 // this doesn't matter at this point, will be overwritten
154 ty: TypeResolution::Value(crate::TypeInner::Scalar {
155 kind: crate::ScalarKind::Bool,
156 width: 0,
157 }),
158 }
159 }
160 }
161
162 #[derive(Debug)]
163 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
164 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
165 pub struct FunctionInfo {
166 /// Validation flags.
167 flags: ValidationFlags,
168 /// Set of shader stages where calling this function is valid.
169 pub available_stages: ShaderStages,
170 /// Uniformity characteristics.
171 pub uniformity: Uniformity,
172 /// Function may kill the invocation.
173 pub may_kill: bool,
174 /// Set of image-sampler pais used with sampling.
175 pub sampling_set: crate::FastHashSet<SamplingKey>,
176 /// Vector of global variable usages.
177 ///
178 /// Each item corresponds to a global variable in the module.
179 global_uses: Box<[GlobalUse]>,
180 /// Vector of expression infos.
181 ///
182 /// Each item corresponds to an expression in the function.
183 expressions: Box<[ExpressionInfo]>,
184 }
185
186 impl FunctionInfo {
global_variable_count(&self) -> usize187 pub fn global_variable_count(&self) -> usize {
188 self.global_uses.len()
189 }
expression_count(&self) -> usize190 pub fn expression_count(&self) -> usize {
191 self.expressions.len()
192 }
dominates_global_use(&self, other: &Self) -> bool193 pub fn dominates_global_use(&self, other: &Self) -> bool {
194 for (self_global_uses, other_global_uses) in
195 self.global_uses.iter().zip(other.global_uses.iter())
196 {
197 if !self_global_uses.contains(*other_global_uses) {
198 return false;
199 }
200 }
201 true
202 }
203 }
204
205 impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
206 type Output = GlobalUse;
index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse207 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
208 &self.global_uses[handle.index()]
209 }
210 }
211
212 impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
213 type Output = ExpressionInfo;
index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo214 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
215 &self.expressions[handle.index()]
216 }
217 }
218
219 /// Disruptor of the uniform control flow.
220 #[derive(Clone, Copy, Debug, thiserror::Error)]
221 #[cfg_attr(test, derive(PartialEq))]
222 pub enum UniformityDisruptor {
223 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
224 Expression(Handle<crate::Expression>),
225 #[error("There is a Return earlier in the control flow of the function")]
226 Return,
227 #[error("There is a Discard earlier in the entry point across all called functions")]
228 Discard,
229 }
230
231 impl FunctionInfo {
232 /// Adds a value-type reference to an expression.
233 #[must_use]
add_ref_impl( &mut self, handle: Handle<crate::Expression>, global_use: GlobalUse, ) -> NonUniformResult234 fn add_ref_impl(
235 &mut self,
236 handle: Handle<crate::Expression>,
237 global_use: GlobalUse,
238 ) -> NonUniformResult {
239 let info = &mut self.expressions[handle.index()];
240 info.ref_count += 1;
241 // mark the used global as read
242 if let Some(global) = info.assignable_global {
243 self.global_uses[global.index()] |= global_use;
244 }
245 info.uniformity.non_uniform_result
246 }
247
248 /// Adds a value-type reference to an expression.
249 #[must_use]
add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult250 fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult {
251 self.add_ref_impl(handle, GlobalUse::READ)
252 }
253
254 /// Adds a potentially assignable reference to an expression.
255 /// These are destinations for `Store` and `ImageStore` statements,
256 /// which can transit through `Access` and `AccessIndex`.
257 #[must_use]
add_assignable_ref( &mut self, handle: Handle<crate::Expression>, assignable_global: &mut Option<Handle<crate::GlobalVariable>>, ) -> NonUniformResult258 fn add_assignable_ref(
259 &mut self,
260 handle: Handle<crate::Expression>,
261 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
262 ) -> NonUniformResult {
263 let info = &mut self.expressions[handle.index()];
264 info.ref_count += 1;
265 // propagate the assignable global up the chain, till it either hits
266 // a value-type expression, or the assignment statement.
267 if let Some(global) = info.assignable_global {
268 if let Some(_old) = assignable_global.replace(global) {
269 unreachable!()
270 }
271 }
272 info.uniformity.non_uniform_result
273 }
274
275 /// Inherit information from a called function.
process_call(&mut self, info: &Self) -> FunctionUniformity276 fn process_call(&mut self, info: &Self) -> FunctionUniformity {
277 for key in info.sampling_set.iter() {
278 self.sampling_set.insert(key.clone());
279 }
280 for (mine, other) in self.global_uses.iter_mut().zip(info.global_uses.iter()) {
281 *mine |= *other;
282 }
283 FunctionUniformity {
284 result: info.uniformity.clone(),
285 exit: if info.may_kill {
286 ExitFlags::MAY_KILL
287 } else {
288 ExitFlags::empty()
289 },
290 }
291 }
292
293 /// Computes the expression info and stores it in `self.expressions`.
294 /// Also, bumps the reference counts on dependent expressions.
295 #[allow(clippy::or_fun_call)]
process_expression( &mut self, handle: Handle<crate::Expression>, expression: &crate::Expression, expression_arena: &Arena<crate::Expression>, other_functions: &[FunctionInfo], type_arena: &Arena<crate::Type>, resolve_context: &ResolveContext, ) -> Result<(), ExpressionError>296 fn process_expression(
297 &mut self,
298 handle: Handle<crate::Expression>,
299 expression: &crate::Expression,
300 expression_arena: &Arena<crate::Expression>,
301 other_functions: &[FunctionInfo],
302 type_arena: &Arena<crate::Type>,
303 resolve_context: &ResolveContext,
304 ) -> Result<(), ExpressionError> {
305 use crate::{Expression as E, SampleLevel as Sl};
306
307 let mut assignable_global = None;
308 let uniformity = match *expression {
309 E::Access { base, index } => Uniformity {
310 non_uniform_result: self
311 .add_assignable_ref(base, &mut assignable_global)
312 .or(self.add_ref(index)),
313 requirements: UniformityRequirements::empty(),
314 },
315 E::AccessIndex { base, .. } => Uniformity {
316 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
317 requirements: UniformityRequirements::empty(),
318 },
319 // always uniform
320 E::Constant(_) => Uniformity::new(),
321 E::Splat { size: _, value } => Uniformity {
322 non_uniform_result: self.add_ref(value),
323 requirements: UniformityRequirements::empty(),
324 },
325 E::Swizzle { vector, .. } => Uniformity {
326 non_uniform_result: self.add_ref(vector),
327 requirements: UniformityRequirements::empty(),
328 },
329 E::Compose { ref components, .. } => {
330 let non_uniform_result = components
331 .iter()
332 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
333 Uniformity {
334 non_uniform_result,
335 requirements: UniformityRequirements::empty(),
336 }
337 }
338 // depends on the builtin or interpolation
339 E::FunctionArgument(index) => {
340 let arg = &resolve_context.arguments[index as usize];
341 let uniform = match arg.binding {
342 Some(crate::Binding::BuiltIn(built_in)) => match built_in {
343 // per-polygon built-ins are uniform
344 crate::BuiltIn::FrontFacing
345 // per-work-group built-ins are uniform
346 | crate::BuiltIn::WorkGroupId
347 | crate::BuiltIn::WorkGroupSize => true,
348 _ => false,
349 },
350 // only flat inputs are uniform
351 Some(crate::Binding::Location {
352 interpolation: Some(crate::Interpolation::Flat),
353 ..
354 }) => true,
355 _ => false,
356 };
357 Uniformity {
358 non_uniform_result: if uniform { None } else { Some(handle) },
359 requirements: UniformityRequirements::empty(),
360 }
361 }
362 // depends on the storage class
363 E::GlobalVariable(gh) => {
364 use crate::StorageClass as Sc;
365 assignable_global = Some(gh);
366 let var = &resolve_context.global_vars[gh];
367 let uniform = match var.class {
368 // local data is non-uniform
369 Sc::Function | Sc::Private => false,
370 // workgroup memory is exclusively accessed by the group
371 Sc::WorkGroup => true,
372 // uniform data
373 Sc::Uniform | Sc::PushConstant => true,
374 // storage data is only uniform when read-only
375 Sc::Handle | Sc::Storage => {
376 !var.storage_access.contains(crate::StorageAccess::STORE)
377 }
378 };
379 Uniformity {
380 non_uniform_result: if uniform { None } else { Some(handle) },
381 requirements: UniformityRequirements::empty(),
382 }
383 }
384 E::LocalVariable(_) => Uniformity {
385 non_uniform_result: Some(handle),
386 requirements: UniformityRequirements::empty(),
387 },
388 E::Load { pointer } => Uniformity {
389 non_uniform_result: self.add_ref(pointer),
390 requirements: UniformityRequirements::empty(),
391 },
392 E::ImageSample {
393 image,
394 sampler,
395 coordinate,
396 array_index,
397 offset: _,
398 level,
399 depth_ref,
400 } => {
401 self.sampling_set.insert(SamplingKey {
402 image: match expression_arena[image] {
403 crate::Expression::GlobalVariable(var) => var,
404 _ => return Err(ExpressionError::ExpectedGlobalVariable),
405 },
406 sampler: match expression_arena[sampler] {
407 crate::Expression::GlobalVariable(var) => var,
408 _ => return Err(ExpressionError::ExpectedGlobalVariable),
409 },
410 });
411 // "nur" == "Non-Uniform Result"
412 let array_nur = array_index.and_then(|h| self.add_ref(h));
413 let level_nur = match level {
414 Sl::Auto | Sl::Zero => None,
415 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
416 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
417 };
418 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
419 Uniformity {
420 non_uniform_result: self
421 .add_ref(image)
422 .or(self.add_ref(sampler))
423 .or(self.add_ref(coordinate))
424 .or(array_nur)
425 .or(level_nur)
426 .or(dref_nur),
427 requirements: if level.implicit_derivatives() {
428 UniformityRequirements::IMPLICIT_LEVEL
429 } else {
430 UniformityRequirements::empty()
431 },
432 }
433 }
434 E::ImageLoad {
435 image,
436 coordinate,
437 array_index,
438 index,
439 } => {
440 let array_nur = array_index.and_then(|h| self.add_ref(h));
441 let index_nur = index.and_then(|h| self.add_ref(h));
442 Uniformity {
443 non_uniform_result: self
444 .add_ref(image)
445 .or(self.add_ref(coordinate))
446 .or(array_nur)
447 .or(index_nur),
448 requirements: UniformityRequirements::empty(),
449 }
450 }
451 E::ImageQuery { image, query } => {
452 let query_nur = match query {
453 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
454 _ => None,
455 };
456 Uniformity {
457 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
458 requirements: UniformityRequirements::empty(),
459 }
460 }
461 E::Unary { expr, .. } => Uniformity {
462 non_uniform_result: self.add_ref(expr),
463 requirements: UniformityRequirements::empty(),
464 },
465 E::Binary { left, right, .. } => Uniformity {
466 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
467 requirements: UniformityRequirements::empty(),
468 },
469 E::Select {
470 condition,
471 accept,
472 reject,
473 } => Uniformity {
474 non_uniform_result: self
475 .add_ref(condition)
476 .or(self.add_ref(accept))
477 .or(self.add_ref(reject)),
478 requirements: UniformityRequirements::empty(),
479 },
480 // explicit derivatives require uniform
481 E::Derivative { expr, .. } => Uniformity {
482 //Note: taking a derivative of a uniform doesn't make it non-uniform
483 non_uniform_result: self.add_ref(expr),
484 requirements: UniformityRequirements::DERIVATIVE,
485 },
486 E::Relational { argument, .. } => Uniformity {
487 non_uniform_result: self.add_ref(argument),
488 requirements: UniformityRequirements::empty(),
489 },
490 E::Math {
491 arg, arg1, arg2, ..
492 } => {
493 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
494 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
495 Uniformity {
496 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur),
497 requirements: UniformityRequirements::empty(),
498 }
499 }
500 E::As { expr, .. } => Uniformity {
501 non_uniform_result: self.add_ref(expr),
502 requirements: UniformityRequirements::empty(),
503 },
504 E::Call(function) => {
505 let fun = other_functions
506 .get(function.index())
507 .ok_or(ExpressionError::CallToUndeclaredFunction(function))?;
508 self.process_call(fun).result
509 }
510 E::ArrayLength(expr) => Uniformity {
511 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
512 requirements: UniformityRequirements::empty(),
513 },
514 };
515
516 let ty =
517 resolve_context.resolve(expression, type_arena, |h| &self.expressions[h.index()].ty)?;
518 self.expressions[handle.index()] = ExpressionInfo {
519 uniformity,
520 ref_count: 0,
521 assignable_global,
522 ty,
523 };
524 Ok(())
525 }
526
527 /// Analyzes the uniformity requirements of a block (as a sequence of statements).
528 /// Returns the uniformity characteristics at the *function* level, i.e.
529 /// whether or not the function requires to be called in uniform control flow,
530 /// and whether the produced result is not disrupting the control flow.
531 ///
532 /// The parent control flow is uniform if `disruptor.is_none()`.
533 ///
534 /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
535 /// require uniformity, but the current flow is non-uniform.
536 #[allow(clippy::or_fun_call)]
process_block( &mut self, statements: &[crate::Statement], other_functions: &[FunctionInfo], mut disruptor: Option<UniformityDisruptor>, ) -> Result<FunctionUniformity, FunctionError>537 fn process_block(
538 &mut self,
539 statements: &[crate::Statement],
540 other_functions: &[FunctionInfo],
541 mut disruptor: Option<UniformityDisruptor>,
542 ) -> Result<FunctionUniformity, FunctionError> {
543 use crate::Statement as S;
544
545 let mut combined_uniformity = FunctionUniformity::new();
546 for statement in statements {
547 let uniformity = match *statement {
548 S::Emit(ref range) => {
549 let mut requirements = UniformityRequirements::empty();
550 for expr in range.clone() {
551 let req = self.expressions[expr.index()].uniformity.requirements;
552 if self
553 .flags
554 .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
555 && !req.is_empty()
556 {
557 if let Some(cause) = disruptor {
558 return Err(FunctionError::NonUniformControlFlow(req, expr, cause));
559 }
560 }
561 requirements |= req;
562 }
563 FunctionUniformity {
564 result: Uniformity {
565 non_uniform_result: None,
566 requirements,
567 },
568 exit: ExitFlags::empty(),
569 }
570 }
571 S::Break | S::Continue => FunctionUniformity::new(),
572 S::Kill => FunctionUniformity {
573 result: Uniformity::new(),
574 exit: ExitFlags::MAY_KILL,
575 },
576 S::Barrier(_) => FunctionUniformity {
577 result: Uniformity {
578 non_uniform_result: None,
579 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
580 },
581 exit: ExitFlags::empty(),
582 },
583 S::Block(ref b) => self.process_block(b, other_functions, disruptor)?,
584 S::If {
585 condition,
586 ref accept,
587 ref reject,
588 } => {
589 let condition_nur = self.add_ref(condition);
590 let branch_disruptor =
591 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
592 let accept_uniformity =
593 self.process_block(accept, other_functions, branch_disruptor)?;
594 let reject_uniformity =
595 self.process_block(reject, other_functions, branch_disruptor)?;
596 accept_uniformity | reject_uniformity
597 }
598 S::Switch {
599 selector,
600 ref cases,
601 ref default,
602 } => {
603 let selector_nur = self.add_ref(selector);
604 let branch_disruptor =
605 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
606 let mut uniformity = FunctionUniformity::new();
607 let mut case_disruptor = branch_disruptor;
608 for case in cases.iter() {
609 let case_uniformity =
610 self.process_block(&case.body, other_functions, case_disruptor)?;
611 case_disruptor = if case.fall_through {
612 case_disruptor.or(case_uniformity.exit_disruptor())
613 } else {
614 branch_disruptor
615 };
616 uniformity = uniformity | case_uniformity;
617 }
618 // using the disruptor inherited from the last fall-through chain
619 let default_exit =
620 self.process_block(default, other_functions, case_disruptor)?;
621 uniformity | default_exit
622 }
623 S::Loop {
624 ref body,
625 ref continuing,
626 } => {
627 let body_uniformity = self.process_block(body, other_functions, disruptor)?;
628 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
629 let continuing_uniformity =
630 self.process_block(continuing, other_functions, continuing_disruptor)?;
631 body_uniformity | continuing_uniformity
632 }
633 S::Return { value } => FunctionUniformity {
634 result: Uniformity {
635 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
636 requirements: UniformityRequirements::empty(),
637 },
638 //TODO: if we are in the uniform control flow, should this still be an exit flag?
639 exit: ExitFlags::MAY_RETURN,
640 },
641 // Here and below, the used expressions are already emitted,
642 // and their results do not affect the function return value,
643 // so we can ignore their non-uniformity.
644 S::Store { pointer, value } => {
645 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
646 let _ = self.add_ref(value);
647 FunctionUniformity::new()
648 }
649 S::ImageStore {
650 image,
651 coordinate,
652 array_index,
653 value,
654 } => {
655 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
656 if let Some(expr) = array_index {
657 let _ = self.add_ref(expr);
658 }
659 let _ = self.add_ref(coordinate);
660 let _ = self.add_ref(value);
661 FunctionUniformity::new()
662 }
663 S::Call {
664 function,
665 ref arguments,
666 result: _,
667 } => {
668 for &argument in arguments {
669 let _ = self.add_ref(argument);
670 }
671 let info = other_functions.get(function.index()).ok_or(
672 FunctionError::InvalidCall {
673 function,
674 error: CallError::ForwardDeclaredFunction,
675 },
676 )?;
677 //Note: the result is validated by the Validator, not here
678 self.process_call(info)
679 }
680 };
681
682 disruptor = disruptor.or(uniformity.exit_disruptor());
683 combined_uniformity = combined_uniformity | uniformity;
684 }
685 Ok(combined_uniformity)
686 }
687 }
688
689 impl ModuleInfo {
690 /// Builds the `FunctionInfo` based on the function, and validates the
691 /// uniform control flow if required by the expressions of this function.
process_function( &self, fun: &crate::Function, module: &crate::Module, flags: ValidationFlags, ) -> Result<FunctionInfo, FunctionError>692 pub(super) fn process_function(
693 &self,
694 fun: &crate::Function,
695 module: &crate::Module,
696 flags: ValidationFlags,
697 ) -> Result<FunctionInfo, FunctionError> {
698 let mut info = FunctionInfo {
699 flags,
700 available_stages: ShaderStages::all(),
701 uniformity: Uniformity::new(),
702 may_kill: false,
703 sampling_set: crate::FastHashSet::default(),
704 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
705 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
706 };
707 let resolve_context = ResolveContext {
708 constants: &module.constants,
709 global_vars: &module.global_variables,
710 local_vars: &fun.local_variables,
711 functions: &module.functions,
712 arguments: &fun.arguments,
713 };
714
715 for (handle, expr) in fun.expressions.iter() {
716 if let Err(error) = info.process_expression(
717 handle,
718 expr,
719 &fun.expressions,
720 &self.functions,
721 &module.types,
722 &resolve_context,
723 ) {
724 return Err(FunctionError::Expression { handle, error });
725 }
726 }
727
728 let uniformity = info.process_block(&fun.body, &self.functions, None)?;
729 info.uniformity = uniformity.result;
730 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
731
732 Ok(info)
733 }
734
get_entry_point(&self, index: usize) -> &FunctionInfo735 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
736 &self.entry_points[index]
737 }
738 }
739
740 #[test]
uniform_control_flow()741 fn uniform_control_flow() {
742 use crate::{Expression as E, Statement as S};
743
744 let mut constant_arena = Arena::new();
745 let constant = constant_arena.append(crate::Constant {
746 name: None,
747 specialization: None,
748 inner: crate::ConstantInner::Scalar {
749 width: 4,
750 value: crate::ScalarValue::Uint(0),
751 },
752 });
753 let mut type_arena = Arena::new();
754 let ty = type_arena.append(crate::Type {
755 name: None,
756 inner: crate::TypeInner::Vector {
757 size: crate::VectorSize::Bi,
758 kind: crate::ScalarKind::Float,
759 width: 4,
760 },
761 });
762 let mut global_var_arena = Arena::new();
763 let non_uniform_global = global_var_arena.append(crate::GlobalVariable {
764 name: None,
765 init: None,
766 ty,
767 class: crate::StorageClass::Handle,
768 binding: None,
769 storage_access: crate::StorageAccess::STORE,
770 });
771 let uniform_global = global_var_arena.append(crate::GlobalVariable {
772 name: None,
773 init: None,
774 ty,
775 binding: None,
776 class: crate::StorageClass::Uniform,
777 storage_access: crate::StorageAccess::empty(),
778 });
779
780 let mut expressions = Arena::new();
781 // checks the uniform control flow
782 let constant_expr = expressions.append(E::Constant(constant));
783 // checks the non-uniform control flow
784 let derivative_expr = expressions.append(E::Derivative {
785 axis: crate::DerivativeAxis::X,
786 expr: constant_expr,
787 });
788 let emit_range_constant_derivative = expressions.range_from(0);
789 let non_uniform_global_expr = expressions.append(E::GlobalVariable(non_uniform_global));
790 let uniform_global_expr = expressions.append(E::GlobalVariable(uniform_global));
791 let emit_range_globals = expressions.range_from(2);
792
793 // checks the QUERY flag
794 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr));
795 // checks the transitive WRITE flag
796 let access_expr = expressions.append(E::AccessIndex {
797 base: non_uniform_global_expr,
798 index: 1,
799 });
800 let emit_range_query_access_globals = expressions.range_from(2);
801
802 let mut info = FunctionInfo {
803 flags: ValidationFlags::all(),
804 available_stages: ShaderStages::all(),
805 uniformity: Uniformity::new(),
806 may_kill: false,
807 sampling_set: crate::FastHashSet::default(),
808 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
809 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
810 };
811 let resolve_context = ResolveContext {
812 constants: &constant_arena,
813 global_vars: &global_var_arena,
814 local_vars: &Arena::new(),
815 functions: &Arena::new(),
816 arguments: &[],
817 };
818 for (handle, expression) in expressions.iter() {
819 info.process_expression(
820 handle,
821 expression,
822 &expressions,
823 &[],
824 &type_arena,
825 &resolve_context,
826 )
827 .unwrap();
828 }
829 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
830 assert_eq!(info[uniform_global_expr].ref_count, 1);
831 assert_eq!(info[query_expr].ref_count, 0);
832 assert_eq!(info[access_expr].ref_count, 0);
833 assert_eq!(info[non_uniform_global], GlobalUse::empty());
834 assert_eq!(info[uniform_global], GlobalUse::QUERY);
835
836 let stmt_emit1 = S::Emit(emit_range_globals.clone());
837 let stmt_if_uniform = S::If {
838 condition: uniform_global_expr,
839 accept: Vec::new(),
840 reject: vec![
841 S::Emit(emit_range_constant_derivative.clone()),
842 S::Store {
843 pointer: constant_expr,
844 value: derivative_expr,
845 },
846 ],
847 };
848 assert_eq!(
849 info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None),
850 Ok(FunctionUniformity {
851 result: Uniformity {
852 non_uniform_result: None,
853 requirements: UniformityRequirements::DERIVATIVE,
854 },
855 exit: ExitFlags::empty(),
856 }),
857 );
858 assert_eq!(info[constant_expr].ref_count, 2);
859 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
860
861 let stmt_emit2 = S::Emit(emit_range_globals.clone());
862 let stmt_if_non_uniform = S::If {
863 condition: non_uniform_global_expr,
864 accept: vec![
865 S::Emit(emit_range_constant_derivative.clone()),
866 S::Store {
867 pointer: constant_expr,
868 value: derivative_expr,
869 },
870 ],
871 reject: Vec::new(),
872 };
873 assert_eq!(
874 info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None),
875 Err(FunctionError::NonUniformControlFlow(
876 UniformityRequirements::DERIVATIVE,
877 derivative_expr,
878 UniformityDisruptor::Expression(non_uniform_global_expr)
879 )),
880 );
881 assert_eq!(info[derivative_expr].ref_count, 1);
882 assert_eq!(info[non_uniform_global], GlobalUse::READ);
883
884 let stmt_emit3 = S::Emit(emit_range_globals);
885 let stmt_return_non_uniform = S::Return {
886 value: Some(non_uniform_global_expr),
887 };
888 assert_eq!(
889 info.process_block(
890 &[stmt_emit3, stmt_return_non_uniform],
891 &[],
892 Some(UniformityDisruptor::Return)
893 ),
894 Ok(FunctionUniformity {
895 result: Uniformity {
896 non_uniform_result: Some(non_uniform_global_expr),
897 requirements: UniformityRequirements::empty(),
898 },
899 exit: ExitFlags::MAY_RETURN,
900 }),
901 );
902 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
903
904 // Check that uniformity requirements reach through a pointer
905 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
906 let stmt_assign = S::Store {
907 pointer: access_expr,
908 value: query_expr,
909 };
910 let stmt_return_pointer = S::Return {
911 value: Some(access_expr),
912 };
913 let stmt_kill = S::Kill;
914 assert_eq!(
915 info.process_block(
916 &[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer],
917 &[],
918 Some(UniformityDisruptor::Discard)
919 ),
920 Ok(FunctionUniformity {
921 result: Uniformity {
922 non_uniform_result: Some(non_uniform_global_expr),
923 requirements: UniformityRequirements::empty(),
924 },
925 exit: ExitFlags::all(),
926 }),
927 );
928 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
929 }
930