1 /*! Module analyzer.
2 
3 Figures out the following properties:
4   - control flow uniformity
5   - texture/sampler pairs
6   - expression reference counts
7 !*/
8 
9 use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
10 use crate::{
11     arena::{Arena, Handle},
12     proc::{ResolveContext, TypeResolution},
13 };
14 use std::ops;
15 
16 pub type NonUniformResult = Option<Handle<crate::Expression>>;
17 
18 bitflags::bitflags! {
19     /// Kinds of expressions that require uniform control flow.
20     #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
21     #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
22     pub struct UniformityRequirements: u8 {
23         const WORK_GROUP_BARRIER = 0x1;
24         const DERIVATIVE = 0x2;
25         const IMPLICIT_LEVEL = 0x4;
26     }
27 }
28 
29 /// Uniform control flow characteristics.
30 #[derive(Clone, Debug)]
31 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
32 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
33 #[cfg_attr(test, derive(PartialEq))]
34 pub struct Uniformity {
35     /// A child expression with non-uniform result.
36     ///
37     /// This means, when the relevant invocations are scheduled on a compute unit,
38     /// they have to use vector registers to store an individual value
39     /// per invocation.
40     ///
41     /// Whenever the control flow is conditioned on such value,
42     /// the hardware needs to keep track of the mask of invocations,
43     /// and process all branches of the control flow.
44     ///
45     /// Any operations that depend on non-uniform results also produce non-uniform.
46     pub non_uniform_result: NonUniformResult,
47     /// If this expression requires uniform control flow, store the reason here.
48     pub requirements: UniformityRequirements,
49 }
50 
51 impl Uniformity {
new() -> Self52     fn new() -> Self {
53         Uniformity {
54             non_uniform_result: None,
55             requirements: UniformityRequirements::empty(),
56         }
57     }
58 }
59 
60 bitflags::bitflags! {
61     struct ExitFlags: u8 {
62         /// Control flow may return from the function, which makes all the
63         /// subsequent statements within the current function (only!)
64         /// to be executed in a non-uniform control flow.
65         const MAY_RETURN = 0x1;
66         /// Control flow may be killed. Anything after `Statement::Kill` is
67         /// considered inside non-uniform context.
68         const MAY_KILL = 0x2;
69     }
70 }
71 
72 /// Uniformity characteristics of a function.
73 #[cfg_attr(test, derive(Debug, PartialEq))]
74 struct FunctionUniformity {
75     result: Uniformity,
76     exit: ExitFlags,
77 }
78 
79 impl ops::BitOr for FunctionUniformity {
80     type Output = Self;
bitor(self, other: Self) -> Self81     fn bitor(self, other: Self) -> Self {
82         FunctionUniformity {
83             result: Uniformity {
84                 non_uniform_result: self
85                     .result
86                     .non_uniform_result
87                     .or(other.result.non_uniform_result),
88                 requirements: self.result.requirements | other.result.requirements,
89             },
90             exit: self.exit | other.exit,
91         }
92     }
93 }
94 
95 impl FunctionUniformity {
new() -> Self96     fn new() -> Self {
97         FunctionUniformity {
98             result: Uniformity::new(),
99             exit: ExitFlags::empty(),
100         }
101     }
102 
103     /// Returns a disruptor based on the stored exit flags, if any.
exit_disruptor(&self) -> Option<UniformityDisruptor>104     fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
105         if self.exit.contains(ExitFlags::MAY_RETURN) {
106             Some(UniformityDisruptor::Return)
107         } else if self.exit.contains(ExitFlags::MAY_KILL) {
108             Some(UniformityDisruptor::Discard)
109         } else {
110             None
111         }
112     }
113 }
114 
115 bitflags::bitflags! {
116     /// Indicates how a global variable is used.
117     #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
118     #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
119     pub struct GlobalUse: u8 {
120         /// Data will be read from the variable.
121         const READ = 0x1;
122         /// Data will be written to the variable.
123         const WRITE = 0x2;
124         /// The information about the data is queried.
125         const QUERY = 0x4;
126     }
127 }
128 
129 #[derive(Clone, Debug, Eq, Hash, PartialEq)]
130 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
131 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
132 pub struct SamplingKey {
133     pub image: Handle<crate::GlobalVariable>,
134     pub sampler: Handle<crate::GlobalVariable>,
135 }
136 
137 #[derive(Clone, Debug)]
138 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
139 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
140 pub struct ExpressionInfo {
141     pub uniformity: Uniformity,
142     pub ref_count: usize,
143     assignable_global: Option<Handle<crate::GlobalVariable>>,
144     pub ty: TypeResolution,
145 }
146 
147 impl ExpressionInfo {
new() -> Self148     fn new() -> Self {
149         ExpressionInfo {
150             uniformity: Uniformity::new(),
151             ref_count: 0,
152             assignable_global: None,
153             // this doesn't matter at this point, will be overwritten
154             ty: TypeResolution::Value(crate::TypeInner::Scalar {
155                 kind: crate::ScalarKind::Bool,
156                 width: 0,
157             }),
158         }
159     }
160 }
161 
162 #[derive(Debug)]
163 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
164 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
165 pub struct FunctionInfo {
166     /// Validation flags.
167     flags: ValidationFlags,
168     /// Set of shader stages where calling this function is valid.
169     pub available_stages: ShaderStages,
170     /// Uniformity characteristics.
171     pub uniformity: Uniformity,
172     /// Function may kill the invocation.
173     pub may_kill: bool,
174     /// Set of image-sampler pais used with sampling.
175     pub sampling_set: crate::FastHashSet<SamplingKey>,
176     /// Vector of global variable usages.
177     ///
178     /// Each item corresponds to a global variable in the module.
179     global_uses: Box<[GlobalUse]>,
180     /// Vector of expression infos.
181     ///
182     /// Each item corresponds to an expression in the function.
183     expressions: Box<[ExpressionInfo]>,
184 }
185 
186 impl FunctionInfo {
global_variable_count(&self) -> usize187     pub fn global_variable_count(&self) -> usize {
188         self.global_uses.len()
189     }
expression_count(&self) -> usize190     pub fn expression_count(&self) -> usize {
191         self.expressions.len()
192     }
dominates_global_use(&self, other: &Self) -> bool193     pub fn dominates_global_use(&self, other: &Self) -> bool {
194         for (self_global_uses, other_global_uses) in
195             self.global_uses.iter().zip(other.global_uses.iter())
196         {
197             if !self_global_uses.contains(*other_global_uses) {
198                 return false;
199             }
200         }
201         true
202     }
203 }
204 
205 impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
206     type Output = GlobalUse;
index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse207     fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
208         &self.global_uses[handle.index()]
209     }
210 }
211 
212 impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
213     type Output = ExpressionInfo;
index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo214     fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
215         &self.expressions[handle.index()]
216     }
217 }
218 
219 /// Disruptor of the uniform control flow.
220 #[derive(Clone, Copy, Debug, thiserror::Error)]
221 #[cfg_attr(test, derive(PartialEq))]
222 pub enum UniformityDisruptor {
223     #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
224     Expression(Handle<crate::Expression>),
225     #[error("There is a Return earlier in the control flow of the function")]
226     Return,
227     #[error("There is a Discard earlier in the entry point across all called functions")]
228     Discard,
229 }
230 
231 impl FunctionInfo {
232     /// Adds a value-type reference to an expression.
233     #[must_use]
add_ref_impl( &mut self, handle: Handle<crate::Expression>, global_use: GlobalUse, ) -> NonUniformResult234     fn add_ref_impl(
235         &mut self,
236         handle: Handle<crate::Expression>,
237         global_use: GlobalUse,
238     ) -> NonUniformResult {
239         let info = &mut self.expressions[handle.index()];
240         info.ref_count += 1;
241         // mark the used global as read
242         if let Some(global) = info.assignable_global {
243             self.global_uses[global.index()] |= global_use;
244         }
245         info.uniformity.non_uniform_result
246     }
247 
248     /// Adds a value-type reference to an expression.
249     #[must_use]
add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult250     fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult {
251         self.add_ref_impl(handle, GlobalUse::READ)
252     }
253 
254     /// Adds a potentially assignable reference to an expression.
255     /// These are destinations for `Store` and `ImageStore` statements,
256     /// which can transit through `Access` and `AccessIndex`.
257     #[must_use]
add_assignable_ref( &mut self, handle: Handle<crate::Expression>, assignable_global: &mut Option<Handle<crate::GlobalVariable>>, ) -> NonUniformResult258     fn add_assignable_ref(
259         &mut self,
260         handle: Handle<crate::Expression>,
261         assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
262     ) -> NonUniformResult {
263         let info = &mut self.expressions[handle.index()];
264         info.ref_count += 1;
265         // propagate the assignable global up the chain, till it either hits
266         // a value-type expression, or the assignment statement.
267         if let Some(global) = info.assignable_global {
268             if let Some(_old) = assignable_global.replace(global) {
269                 unreachable!()
270             }
271         }
272         info.uniformity.non_uniform_result
273     }
274 
275     /// Inherit information from a called function.
process_call(&mut self, info: &Self) -> FunctionUniformity276     fn process_call(&mut self, info: &Self) -> FunctionUniformity {
277         for key in info.sampling_set.iter() {
278             self.sampling_set.insert(key.clone());
279         }
280         for (mine, other) in self.global_uses.iter_mut().zip(info.global_uses.iter()) {
281             *mine |= *other;
282         }
283         FunctionUniformity {
284             result: info.uniformity.clone(),
285             exit: if info.may_kill {
286                 ExitFlags::MAY_KILL
287             } else {
288                 ExitFlags::empty()
289             },
290         }
291     }
292 
293     /// Computes the expression info and stores it in `self.expressions`.
294     /// Also, bumps the reference counts on dependent expressions.
295     #[allow(clippy::or_fun_call)]
process_expression( &mut self, handle: Handle<crate::Expression>, expression: &crate::Expression, expression_arena: &Arena<crate::Expression>, other_functions: &[FunctionInfo], type_arena: &Arena<crate::Type>, resolve_context: &ResolveContext, ) -> Result<(), ExpressionError>296     fn process_expression(
297         &mut self,
298         handle: Handle<crate::Expression>,
299         expression: &crate::Expression,
300         expression_arena: &Arena<crate::Expression>,
301         other_functions: &[FunctionInfo],
302         type_arena: &Arena<crate::Type>,
303         resolve_context: &ResolveContext,
304     ) -> Result<(), ExpressionError> {
305         use crate::{Expression as E, SampleLevel as Sl};
306 
307         let mut assignable_global = None;
308         let uniformity = match *expression {
309             E::Access { base, index } => Uniformity {
310                 non_uniform_result: self
311                     .add_assignable_ref(base, &mut assignable_global)
312                     .or(self.add_ref(index)),
313                 requirements: UniformityRequirements::empty(),
314             },
315             E::AccessIndex { base, .. } => Uniformity {
316                 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
317                 requirements: UniformityRequirements::empty(),
318             },
319             // always uniform
320             E::Constant(_) => Uniformity::new(),
321             E::Splat { size: _, value } => Uniformity {
322                 non_uniform_result: self.add_ref(value),
323                 requirements: UniformityRequirements::empty(),
324             },
325             E::Swizzle { vector, .. } => Uniformity {
326                 non_uniform_result: self.add_ref(vector),
327                 requirements: UniformityRequirements::empty(),
328             },
329             E::Compose { ref components, .. } => {
330                 let non_uniform_result = components
331                     .iter()
332                     .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
333                 Uniformity {
334                     non_uniform_result,
335                     requirements: UniformityRequirements::empty(),
336                 }
337             }
338             // depends on the builtin or interpolation
339             E::FunctionArgument(index) => {
340                 let arg = &resolve_context.arguments[index as usize];
341                 let uniform = match arg.binding {
342                     Some(crate::Binding::BuiltIn(built_in)) => match built_in {
343                         // per-polygon built-ins are uniform
344                         crate::BuiltIn::FrontFacing
345                         // per-work-group built-ins are uniform
346                         | crate::BuiltIn::WorkGroupId
347                         | crate::BuiltIn::WorkGroupSize => true,
348                         _ => false,
349                     },
350                     // only flat inputs are uniform
351                     Some(crate::Binding::Location {
352                         interpolation: Some(crate::Interpolation::Flat),
353                         ..
354                     }) => true,
355                     _ => false,
356                 };
357                 Uniformity {
358                     non_uniform_result: if uniform { None } else { Some(handle) },
359                     requirements: UniformityRequirements::empty(),
360                 }
361             }
362             // depends on the storage class
363             E::GlobalVariable(gh) => {
364                 use crate::StorageClass as Sc;
365                 assignable_global = Some(gh);
366                 let var = &resolve_context.global_vars[gh];
367                 let uniform = match var.class {
368                     // local data is non-uniform
369                     Sc::Function | Sc::Private => false,
370                     // workgroup memory is exclusively accessed by the group
371                     Sc::WorkGroup => true,
372                     // uniform data
373                     Sc::Uniform | Sc::PushConstant => true,
374                     // storage data is only uniform when read-only
375                     Sc::Handle | Sc::Storage => {
376                         !var.storage_access.contains(crate::StorageAccess::STORE)
377                     }
378                 };
379                 Uniformity {
380                     non_uniform_result: if uniform { None } else { Some(handle) },
381                     requirements: UniformityRequirements::empty(),
382                 }
383             }
384             E::LocalVariable(_) => Uniformity {
385                 non_uniform_result: Some(handle),
386                 requirements: UniformityRequirements::empty(),
387             },
388             E::Load { pointer } => Uniformity {
389                 non_uniform_result: self.add_ref(pointer),
390                 requirements: UniformityRequirements::empty(),
391             },
392             E::ImageSample {
393                 image,
394                 sampler,
395                 coordinate,
396                 array_index,
397                 offset: _,
398                 level,
399                 depth_ref,
400             } => {
401                 self.sampling_set.insert(SamplingKey {
402                     image: match expression_arena[image] {
403                         crate::Expression::GlobalVariable(var) => var,
404                         _ => return Err(ExpressionError::ExpectedGlobalVariable),
405                     },
406                     sampler: match expression_arena[sampler] {
407                         crate::Expression::GlobalVariable(var) => var,
408                         _ => return Err(ExpressionError::ExpectedGlobalVariable),
409                     },
410                 });
411                 // "nur" == "Non-Uniform Result"
412                 let array_nur = array_index.and_then(|h| self.add_ref(h));
413                 let level_nur = match level {
414                     Sl::Auto | Sl::Zero => None,
415                     Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
416                     Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
417                 };
418                 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
419                 Uniformity {
420                     non_uniform_result: self
421                         .add_ref(image)
422                         .or(self.add_ref(sampler))
423                         .or(self.add_ref(coordinate))
424                         .or(array_nur)
425                         .or(level_nur)
426                         .or(dref_nur),
427                     requirements: if level.implicit_derivatives() {
428                         UniformityRequirements::IMPLICIT_LEVEL
429                     } else {
430                         UniformityRequirements::empty()
431                     },
432                 }
433             }
434             E::ImageLoad {
435                 image,
436                 coordinate,
437                 array_index,
438                 index,
439             } => {
440                 let array_nur = array_index.and_then(|h| self.add_ref(h));
441                 let index_nur = index.and_then(|h| self.add_ref(h));
442                 Uniformity {
443                     non_uniform_result: self
444                         .add_ref(image)
445                         .or(self.add_ref(coordinate))
446                         .or(array_nur)
447                         .or(index_nur),
448                     requirements: UniformityRequirements::empty(),
449                 }
450             }
451             E::ImageQuery { image, query } => {
452                 let query_nur = match query {
453                     crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
454                     _ => None,
455                 };
456                 Uniformity {
457                     non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
458                     requirements: UniformityRequirements::empty(),
459                 }
460             }
461             E::Unary { expr, .. } => Uniformity {
462                 non_uniform_result: self.add_ref(expr),
463                 requirements: UniformityRequirements::empty(),
464             },
465             E::Binary { left, right, .. } => Uniformity {
466                 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
467                 requirements: UniformityRequirements::empty(),
468             },
469             E::Select {
470                 condition,
471                 accept,
472                 reject,
473             } => Uniformity {
474                 non_uniform_result: self
475                     .add_ref(condition)
476                     .or(self.add_ref(accept))
477                     .or(self.add_ref(reject)),
478                 requirements: UniformityRequirements::empty(),
479             },
480             // explicit derivatives require uniform
481             E::Derivative { expr, .. } => Uniformity {
482                 //Note: taking a derivative of a uniform doesn't make it non-uniform
483                 non_uniform_result: self.add_ref(expr),
484                 requirements: UniformityRequirements::DERIVATIVE,
485             },
486             E::Relational { argument, .. } => Uniformity {
487                 non_uniform_result: self.add_ref(argument),
488                 requirements: UniformityRequirements::empty(),
489             },
490             E::Math {
491                 arg, arg1, arg2, ..
492             } => {
493                 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
494                 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
495                 Uniformity {
496                     non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur),
497                     requirements: UniformityRequirements::empty(),
498                 }
499             }
500             E::As { expr, .. } => Uniformity {
501                 non_uniform_result: self.add_ref(expr),
502                 requirements: UniformityRequirements::empty(),
503             },
504             E::Call(function) => {
505                 let fun = other_functions
506                     .get(function.index())
507                     .ok_or(ExpressionError::CallToUndeclaredFunction(function))?;
508                 self.process_call(fun).result
509             }
510             E::ArrayLength(expr) => Uniformity {
511                 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
512                 requirements: UniformityRequirements::empty(),
513             },
514         };
515 
516         let ty =
517             resolve_context.resolve(expression, type_arena, |h| &self.expressions[h.index()].ty)?;
518         self.expressions[handle.index()] = ExpressionInfo {
519             uniformity,
520             ref_count: 0,
521             assignable_global,
522             ty,
523         };
524         Ok(())
525     }
526 
527     /// Analyzes the uniformity requirements of a block (as a sequence of statements).
528     /// Returns the uniformity characteristics at the *function* level, i.e.
529     /// whether or not the function requires to be called in uniform control flow,
530     /// and whether the produced result is not disrupting the control flow.
531     ///
532     /// The parent control flow is uniform if `disruptor.is_none()`.
533     ///
534     /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
535     /// require uniformity, but the current flow is non-uniform.
536     #[allow(clippy::or_fun_call)]
process_block( &mut self, statements: &[crate::Statement], other_functions: &[FunctionInfo], mut disruptor: Option<UniformityDisruptor>, ) -> Result<FunctionUniformity, FunctionError>537     fn process_block(
538         &mut self,
539         statements: &[crate::Statement],
540         other_functions: &[FunctionInfo],
541         mut disruptor: Option<UniformityDisruptor>,
542     ) -> Result<FunctionUniformity, FunctionError> {
543         use crate::Statement as S;
544 
545         let mut combined_uniformity = FunctionUniformity::new();
546         for statement in statements {
547             let uniformity = match *statement {
548                 S::Emit(ref range) => {
549                     let mut requirements = UniformityRequirements::empty();
550                     for expr in range.clone() {
551                         let req = self.expressions[expr.index()].uniformity.requirements;
552                         if self
553                             .flags
554                             .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
555                             && !req.is_empty()
556                         {
557                             if let Some(cause) = disruptor {
558                                 return Err(FunctionError::NonUniformControlFlow(req, expr, cause));
559                             }
560                         }
561                         requirements |= req;
562                     }
563                     FunctionUniformity {
564                         result: Uniformity {
565                             non_uniform_result: None,
566                             requirements,
567                         },
568                         exit: ExitFlags::empty(),
569                     }
570                 }
571                 S::Break | S::Continue => FunctionUniformity::new(),
572                 S::Kill => FunctionUniformity {
573                     result: Uniformity::new(),
574                     exit: ExitFlags::MAY_KILL,
575                 },
576                 S::Barrier(_) => FunctionUniformity {
577                     result: Uniformity {
578                         non_uniform_result: None,
579                         requirements: UniformityRequirements::WORK_GROUP_BARRIER,
580                     },
581                     exit: ExitFlags::empty(),
582                 },
583                 S::Block(ref b) => self.process_block(b, other_functions, disruptor)?,
584                 S::If {
585                     condition,
586                     ref accept,
587                     ref reject,
588                 } => {
589                     let condition_nur = self.add_ref(condition);
590                     let branch_disruptor =
591                         disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
592                     let accept_uniformity =
593                         self.process_block(accept, other_functions, branch_disruptor)?;
594                     let reject_uniformity =
595                         self.process_block(reject, other_functions, branch_disruptor)?;
596                     accept_uniformity | reject_uniformity
597                 }
598                 S::Switch {
599                     selector,
600                     ref cases,
601                     ref default,
602                 } => {
603                     let selector_nur = self.add_ref(selector);
604                     let branch_disruptor =
605                         disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
606                     let mut uniformity = FunctionUniformity::new();
607                     let mut case_disruptor = branch_disruptor;
608                     for case in cases.iter() {
609                         let case_uniformity =
610                             self.process_block(&case.body, other_functions, case_disruptor)?;
611                         case_disruptor = if case.fall_through {
612                             case_disruptor.or(case_uniformity.exit_disruptor())
613                         } else {
614                             branch_disruptor
615                         };
616                         uniformity = uniformity | case_uniformity;
617                     }
618                     // using the disruptor inherited from the last fall-through chain
619                     let default_exit =
620                         self.process_block(default, other_functions, case_disruptor)?;
621                     uniformity | default_exit
622                 }
623                 S::Loop {
624                     ref body,
625                     ref continuing,
626                 } => {
627                     let body_uniformity = self.process_block(body, other_functions, disruptor)?;
628                     let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
629                     let continuing_uniformity =
630                         self.process_block(continuing, other_functions, continuing_disruptor)?;
631                     body_uniformity | continuing_uniformity
632                 }
633                 S::Return { value } => FunctionUniformity {
634                     result: Uniformity {
635                         non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
636                         requirements: UniformityRequirements::empty(),
637                     },
638                     //TODO: if we are in the uniform control flow, should this still be an exit flag?
639                     exit: ExitFlags::MAY_RETURN,
640                 },
641                 // Here and below, the used expressions are already emitted,
642                 // and their results do not affect the function return value,
643                 // so we can ignore their non-uniformity.
644                 S::Store { pointer, value } => {
645                     let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
646                     let _ = self.add_ref(value);
647                     FunctionUniformity::new()
648                 }
649                 S::ImageStore {
650                     image,
651                     coordinate,
652                     array_index,
653                     value,
654                 } => {
655                     let _ = self.add_ref_impl(image, GlobalUse::WRITE);
656                     if let Some(expr) = array_index {
657                         let _ = self.add_ref(expr);
658                     }
659                     let _ = self.add_ref(coordinate);
660                     let _ = self.add_ref(value);
661                     FunctionUniformity::new()
662                 }
663                 S::Call {
664                     function,
665                     ref arguments,
666                     result: _,
667                 } => {
668                     for &argument in arguments {
669                         let _ = self.add_ref(argument);
670                     }
671                     let info = other_functions.get(function.index()).ok_or(
672                         FunctionError::InvalidCall {
673                             function,
674                             error: CallError::ForwardDeclaredFunction,
675                         },
676                     )?;
677                     //Note: the result is validated by the Validator, not here
678                     self.process_call(info)
679                 }
680             };
681 
682             disruptor = disruptor.or(uniformity.exit_disruptor());
683             combined_uniformity = combined_uniformity | uniformity;
684         }
685         Ok(combined_uniformity)
686     }
687 }
688 
689 impl ModuleInfo {
690     /// Builds the `FunctionInfo` based on the function, and validates the
691     /// uniform control flow if required by the expressions of this function.
process_function( &self, fun: &crate::Function, module: &crate::Module, flags: ValidationFlags, ) -> Result<FunctionInfo, FunctionError>692     pub(super) fn process_function(
693         &self,
694         fun: &crate::Function,
695         module: &crate::Module,
696         flags: ValidationFlags,
697     ) -> Result<FunctionInfo, FunctionError> {
698         let mut info = FunctionInfo {
699             flags,
700             available_stages: ShaderStages::all(),
701             uniformity: Uniformity::new(),
702             may_kill: false,
703             sampling_set: crate::FastHashSet::default(),
704             global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
705             expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
706         };
707         let resolve_context = ResolveContext {
708             constants: &module.constants,
709             global_vars: &module.global_variables,
710             local_vars: &fun.local_variables,
711             functions: &module.functions,
712             arguments: &fun.arguments,
713         };
714 
715         for (handle, expr) in fun.expressions.iter() {
716             if let Err(error) = info.process_expression(
717                 handle,
718                 expr,
719                 &fun.expressions,
720                 &self.functions,
721                 &module.types,
722                 &resolve_context,
723             ) {
724                 return Err(FunctionError::Expression { handle, error });
725             }
726         }
727 
728         let uniformity = info.process_block(&fun.body, &self.functions, None)?;
729         info.uniformity = uniformity.result;
730         info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
731 
732         Ok(info)
733     }
734 
get_entry_point(&self, index: usize) -> &FunctionInfo735     pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
736         &self.entry_points[index]
737     }
738 }
739 
740 #[test]
uniform_control_flow()741 fn uniform_control_flow() {
742     use crate::{Expression as E, Statement as S};
743 
744     let mut constant_arena = Arena::new();
745     let constant = constant_arena.append(crate::Constant {
746         name: None,
747         specialization: None,
748         inner: crate::ConstantInner::Scalar {
749             width: 4,
750             value: crate::ScalarValue::Uint(0),
751         },
752     });
753     let mut type_arena = Arena::new();
754     let ty = type_arena.append(crate::Type {
755         name: None,
756         inner: crate::TypeInner::Vector {
757             size: crate::VectorSize::Bi,
758             kind: crate::ScalarKind::Float,
759             width: 4,
760         },
761     });
762     let mut global_var_arena = Arena::new();
763     let non_uniform_global = global_var_arena.append(crate::GlobalVariable {
764         name: None,
765         init: None,
766         ty,
767         class: crate::StorageClass::Handle,
768         binding: None,
769         storage_access: crate::StorageAccess::STORE,
770     });
771     let uniform_global = global_var_arena.append(crate::GlobalVariable {
772         name: None,
773         init: None,
774         ty,
775         binding: None,
776         class: crate::StorageClass::Uniform,
777         storage_access: crate::StorageAccess::empty(),
778     });
779 
780     let mut expressions = Arena::new();
781     // checks the uniform control flow
782     let constant_expr = expressions.append(E::Constant(constant));
783     // checks the non-uniform control flow
784     let derivative_expr = expressions.append(E::Derivative {
785         axis: crate::DerivativeAxis::X,
786         expr: constant_expr,
787     });
788     let emit_range_constant_derivative = expressions.range_from(0);
789     let non_uniform_global_expr = expressions.append(E::GlobalVariable(non_uniform_global));
790     let uniform_global_expr = expressions.append(E::GlobalVariable(uniform_global));
791     let emit_range_globals = expressions.range_from(2);
792 
793     // checks the QUERY flag
794     let query_expr = expressions.append(E::ArrayLength(uniform_global_expr));
795     // checks the transitive WRITE flag
796     let access_expr = expressions.append(E::AccessIndex {
797         base: non_uniform_global_expr,
798         index: 1,
799     });
800     let emit_range_query_access_globals = expressions.range_from(2);
801 
802     let mut info = FunctionInfo {
803         flags: ValidationFlags::all(),
804         available_stages: ShaderStages::all(),
805         uniformity: Uniformity::new(),
806         may_kill: false,
807         sampling_set: crate::FastHashSet::default(),
808         global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
809         expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
810     };
811     let resolve_context = ResolveContext {
812         constants: &constant_arena,
813         global_vars: &global_var_arena,
814         local_vars: &Arena::new(),
815         functions: &Arena::new(),
816         arguments: &[],
817     };
818     for (handle, expression) in expressions.iter() {
819         info.process_expression(
820             handle,
821             expression,
822             &expressions,
823             &[],
824             &type_arena,
825             &resolve_context,
826         )
827         .unwrap();
828     }
829     assert_eq!(info[non_uniform_global_expr].ref_count, 1);
830     assert_eq!(info[uniform_global_expr].ref_count, 1);
831     assert_eq!(info[query_expr].ref_count, 0);
832     assert_eq!(info[access_expr].ref_count, 0);
833     assert_eq!(info[non_uniform_global], GlobalUse::empty());
834     assert_eq!(info[uniform_global], GlobalUse::QUERY);
835 
836     let stmt_emit1 = S::Emit(emit_range_globals.clone());
837     let stmt_if_uniform = S::If {
838         condition: uniform_global_expr,
839         accept: Vec::new(),
840         reject: vec![
841             S::Emit(emit_range_constant_derivative.clone()),
842             S::Store {
843                 pointer: constant_expr,
844                 value: derivative_expr,
845             },
846         ],
847     };
848     assert_eq!(
849         info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None),
850         Ok(FunctionUniformity {
851             result: Uniformity {
852                 non_uniform_result: None,
853                 requirements: UniformityRequirements::DERIVATIVE,
854             },
855             exit: ExitFlags::empty(),
856         }),
857     );
858     assert_eq!(info[constant_expr].ref_count, 2);
859     assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
860 
861     let stmt_emit2 = S::Emit(emit_range_globals.clone());
862     let stmt_if_non_uniform = S::If {
863         condition: non_uniform_global_expr,
864         accept: vec![
865             S::Emit(emit_range_constant_derivative.clone()),
866             S::Store {
867                 pointer: constant_expr,
868                 value: derivative_expr,
869             },
870         ],
871         reject: Vec::new(),
872     };
873     assert_eq!(
874         info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None),
875         Err(FunctionError::NonUniformControlFlow(
876             UniformityRequirements::DERIVATIVE,
877             derivative_expr,
878             UniformityDisruptor::Expression(non_uniform_global_expr)
879         )),
880     );
881     assert_eq!(info[derivative_expr].ref_count, 1);
882     assert_eq!(info[non_uniform_global], GlobalUse::READ);
883 
884     let stmt_emit3 = S::Emit(emit_range_globals);
885     let stmt_return_non_uniform = S::Return {
886         value: Some(non_uniform_global_expr),
887     };
888     assert_eq!(
889         info.process_block(
890             &[stmt_emit3, stmt_return_non_uniform],
891             &[],
892             Some(UniformityDisruptor::Return)
893         ),
894         Ok(FunctionUniformity {
895             result: Uniformity {
896                 non_uniform_result: Some(non_uniform_global_expr),
897                 requirements: UniformityRequirements::empty(),
898             },
899             exit: ExitFlags::MAY_RETURN,
900         }),
901     );
902     assert_eq!(info[non_uniform_global_expr].ref_count, 3);
903 
904     // Check that uniformity requirements reach through a pointer
905     let stmt_emit4 = S::Emit(emit_range_query_access_globals);
906     let stmt_assign = S::Store {
907         pointer: access_expr,
908         value: query_expr,
909     };
910     let stmt_return_pointer = S::Return {
911         value: Some(access_expr),
912     };
913     let stmt_kill = S::Kill;
914     assert_eq!(
915         info.process_block(
916             &[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer],
917             &[],
918             Some(UniformityDisruptor::Discard)
919         ),
920         Ok(FunctionUniformity {
921             result: Uniformity {
922                 non_uniform_result: Some(non_uniform_global_expr),
923                 requirements: UniformityRequirements::empty(),
924             },
925             exit: ExitFlags::all(),
926         }),
927     );
928     assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
929 }
930