1 mod analyzer;
2 mod compose;
3 mod expression;
4 mod function;
5 mod interface;
6 mod r#type;
7 
8 #[cfg(feature = "validate")]
9 use crate::arena::{Arena, UniqueArena};
10 
11 use crate::{
12     arena::Handle,
13     proc::{InvalidBaseType, Layouter},
14     FastHashSet,
15 };
16 use bit_set::BitSet;
17 use std::ops;
18 
19 //TODO: analyze the model at the same time as we validate it,
20 // merge the corresponding matches over expressions and statements.
21 
22 use crate::span::{AddSpan as _, WithSpan};
23 pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
24 pub use compose::ComposeError;
25 pub use expression::ExpressionError;
26 pub use function::{CallError, FunctionError, LocalVariableError};
27 pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
28 pub use r#type::{Disalignment, TypeError, TypeFlags};
29 
30 bitflags::bitflags! {
31     /// Validation flags.
32     #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
33     #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
34     pub struct ValidationFlags: u8 {
35         /// Expressions.
36         #[cfg(feature = "validate")]
37         const EXPRESSIONS = 0x1;
38         /// Statements and blocks of them.
39         #[cfg(feature = "validate")]
40         const BLOCKS = 0x2;
41         /// Uniformity of control flow for operations that require it.
42         #[cfg(feature = "validate")]
43         const CONTROL_FLOW_UNIFORMITY = 0x4;
44         /// Host-shareable structure layouts.
45         #[cfg(feature = "validate")]
46         const STRUCT_LAYOUTS = 0x8;
47         /// Constants.
48         #[cfg(feature = "validate")]
49         const CONSTANTS = 0x10;
50     }
51 }
52 
53 impl Default for ValidationFlags {
default() -> Self54     fn default() -> Self {
55         Self::all()
56     }
57 }
58 
59 bitflags::bitflags! {
60     /// Allowed IR capabilities.
61     #[must_use]
62     #[derive(Default)]
63     #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
64     #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
65     pub struct Capabilities: u8 {
66         /// Support for `StorageClass:PushConstant`.
67         const PUSH_CONSTANT = 0x1;
68         /// Float values with width = 8.
69         const FLOAT64 = 0x2;
70         /// Support for `Builtin:PrimitiveIndex`.
71         const PRIMITIVE_INDEX = 0x4;
72     }
73 }
74 
75 bitflags::bitflags! {
76     /// Validation flags.
77     #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
78     #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
79     pub struct ShaderStages: u8 {
80         const VERTEX = 0x1;
81         const FRAGMENT = 0x2;
82         const COMPUTE = 0x4;
83     }
84 }
85 
86 #[derive(Debug)]
87 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
88 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
89 pub struct ModuleInfo {
90     functions: Vec<FunctionInfo>,
91     entry_points: Vec<FunctionInfo>,
92 }
93 
94 impl ops::Index<Handle<crate::Function>> for ModuleInfo {
95     type Output = FunctionInfo;
index(&self, handle: Handle<crate::Function>) -> &Self::Output96     fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
97         &self.functions[handle.index()]
98     }
99 }
100 
101 #[derive(Debug)]
102 pub struct Validator {
103     flags: ValidationFlags,
104     capabilities: Capabilities,
105     types: Vec<r#type::TypeInfo>,
106     layouter: Layouter,
107     location_mask: BitSet,
108     bind_group_masks: Vec<BitSet>,
109     #[allow(dead_code)]
110     select_cases: FastHashSet<i32>,
111     valid_expression_list: Vec<Handle<crate::Expression>>,
112     valid_expression_set: BitSet,
113 }
114 
115 #[derive(Clone, Debug, thiserror::Error)]
116 pub enum ConstantError {
117     #[error("The type doesn't match the constant")]
118     InvalidType,
119     #[error("The component handle {0:?} can not be resolved")]
120     UnresolvedComponent(Handle<crate::Constant>),
121     #[error("The array size handle {0:?} can not be resolved")]
122     UnresolvedSize(Handle<crate::Constant>),
123     #[error(transparent)]
124     Compose(#[from] ComposeError),
125 }
126 
127 #[derive(Clone, Debug, thiserror::Error)]
128 pub enum ValidationError {
129     #[error(transparent)]
130     Layouter(#[from] InvalidBaseType),
131     #[error("Type {handle:?} '{name}' is invalid")]
132     Type {
133         handle: Handle<crate::Type>,
134         name: String,
135         #[source]
136         error: TypeError,
137     },
138     #[error("Constant {handle:?} '{name}' is invalid")]
139     Constant {
140         handle: Handle<crate::Constant>,
141         name: String,
142         #[source]
143         error: ConstantError,
144     },
145     #[error("Global variable {handle:?} '{name}' is invalid")]
146     GlobalVariable {
147         handle: Handle<crate::GlobalVariable>,
148         name: String,
149         #[source]
150         error: GlobalVariableError,
151     },
152     #[error("Function {handle:?} '{name}' is invalid")]
153     Function {
154         handle: Handle<crate::Function>,
155         name: String,
156         #[source]
157         error: FunctionError,
158     },
159     #[error("Entry point {name} at {stage:?} is invalid")]
160     EntryPoint {
161         stage: crate::ShaderStage,
162         name: String,
163         #[source]
164         error: EntryPointError,
165     },
166     #[error("Module is corrupted")]
167     Corrupted,
168 }
169 
170 impl crate::TypeInner {
171     #[cfg(feature = "validate")]
is_sized(&self) -> bool172     fn is_sized(&self) -> bool {
173         match *self {
174             Self::Scalar { .. }
175             | Self::Vector { .. }
176             | Self::Matrix { .. }
177             | Self::Array {
178                 size: crate::ArraySize::Constant(_),
179                 ..
180             }
181             | Self::Atomic { .. }
182             | Self::Pointer { .. }
183             | Self::ValuePointer { .. }
184             | Self::Struct { .. } => true,
185             Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false,
186         }
187     }
188 
189     /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
190     #[cfg(feature = "validate")]
image_storage_coordinates(&self) -> Option<crate::ImageDimension>191     fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
192         match *self {
193             Self::Scalar {
194                 kind: crate::ScalarKind::Sint,
195                 ..
196             } => Some(crate::ImageDimension::D1),
197             Self::Vector {
198                 size: crate::VectorSize::Bi,
199                 kind: crate::ScalarKind::Sint,
200                 ..
201             } => Some(crate::ImageDimension::D2),
202             Self::Vector {
203                 size: crate::VectorSize::Tri,
204                 kind: crate::ScalarKind::Sint,
205                 ..
206             } => Some(crate::ImageDimension::D3),
207             _ => None,
208         }
209     }
210 }
211 
212 impl Validator {
213     /// Construct a new validator instance.
new(flags: ValidationFlags, capabilities: Capabilities) -> Self214     pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
215         Validator {
216             flags,
217             capabilities,
218             types: Vec::new(),
219             layouter: Layouter::default(),
220             location_mask: BitSet::new(),
221             bind_group_masks: Vec::new(),
222             select_cases: FastHashSet::default(),
223             valid_expression_list: Vec::new(),
224             valid_expression_set: BitSet::new(),
225         }
226     }
227 
228     #[cfg(feature = "validate")]
validate_constant( &self, handle: Handle<crate::Constant>, constants: &Arena<crate::Constant>, types: &UniqueArena<crate::Type>, ) -> Result<(), ConstantError>229     fn validate_constant(
230         &self,
231         handle: Handle<crate::Constant>,
232         constants: &Arena<crate::Constant>,
233         types: &UniqueArena<crate::Type>,
234     ) -> Result<(), ConstantError> {
235         let con = &constants[handle];
236         match con.inner {
237             crate::ConstantInner::Scalar { width, ref value } => {
238                 if !self.check_width(value.scalar_kind(), width) {
239                     return Err(ConstantError::InvalidType);
240                 }
241             }
242             crate::ConstantInner::Composite { ty, ref components } => {
243                 match types[ty].inner {
244                     crate::TypeInner::Array {
245                         size: crate::ArraySize::Constant(size_handle),
246                         ..
247                     } if handle <= size_handle => {
248                         return Err(ConstantError::UnresolvedSize(size_handle));
249                     }
250                     _ => {}
251                 }
252                 if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) {
253                     return Err(ConstantError::UnresolvedComponent(comp));
254                 }
255                 compose::validate_compose(
256                     ty,
257                     constants,
258                     types,
259                     components
260                         .iter()
261                         .map(|&component| constants[component].inner.resolve_type()),
262                 )?;
263             }
264         }
265         Ok(())
266     }
267 
268     /// Check the given module to be valid.
validate( &mut self, module: &crate::Module, ) -> Result<ModuleInfo, WithSpan<ValidationError>>269     pub fn validate(
270         &mut self,
271         module: &crate::Module,
272     ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
273         self.reset_types(module.types.len());
274         self.layouter
275             .update(&module.types, &module.constants)
276             .map_err(|e| {
277                 let InvalidBaseType(handle) = e;
278                 ValidationError::from(e).with_span_handle(handle, &module.types)
279             })?;
280 
281         #[cfg(feature = "validate")]
282         if self.flags.contains(ValidationFlags::CONSTANTS) {
283             for (handle, constant) in module.constants.iter() {
284                 self.validate_constant(handle, &module.constants, &module.types)
285                     .map_err(|error| {
286                         ValidationError::Constant {
287                             handle,
288                             name: constant.name.clone().unwrap_or_default(),
289                             error,
290                         }
291                         .with_span_handle(handle, &module.constants)
292                     })?
293             }
294         }
295 
296         for (handle, ty) in module.types.iter() {
297             let ty_info = self
298                 .validate_type(handle, &module.types, &module.constants)
299                 .map_err(|error| {
300                     ValidationError::Type {
301                         handle,
302                         name: ty.name.clone().unwrap_or_default(),
303                         error,
304                     }
305                     .with_span_handle(handle, &module.types)
306                 })?;
307             self.types[handle.index()] = ty_info;
308         }
309 
310         #[cfg(feature = "validate")]
311         for (var_handle, var) in module.global_variables.iter() {
312             self.validate_global_var(var, &module.types)
313                 .map_err(|error| {
314                     ValidationError::GlobalVariable {
315                         handle: var_handle,
316                         name: var.name.clone().unwrap_or_default(),
317                         error,
318                     }
319                     .with_span_handle(var_handle, &module.global_variables)
320                 })?;
321         }
322 
323         let mut mod_info = ModuleInfo {
324             functions: Vec::with_capacity(module.functions.len()),
325             entry_points: Vec::with_capacity(module.entry_points.len()),
326         };
327 
328         for (handle, fun) in module.functions.iter() {
329             match self.validate_function(fun, module, &mod_info) {
330                 Ok(info) => mod_info.functions.push(info),
331                 Err(error) => {
332                     return Err(error.and_then(|error| {
333                         ValidationError::Function {
334                             handle,
335                             name: fun.name.clone().unwrap_or_default(),
336                             error,
337                         }
338                         .with_span_handle(handle, &module.functions)
339                     }))
340                 }
341             }
342         }
343 
344         let mut ep_map = FastHashSet::default();
345         for ep in module.entry_points.iter() {
346             if !ep_map.insert((ep.stage, &ep.name)) {
347                 return Err(ValidationError::EntryPoint {
348                     stage: ep.stage,
349                     name: ep.name.clone(),
350                     error: EntryPointError::Conflict,
351                 }
352                 .with_span()); // TODO: keep some EP span information?
353             }
354 
355             match self.validate_entry_point(ep, module, &mod_info) {
356                 Ok(info) => mod_info.entry_points.push(info),
357                 Err(error) => {
358                     return Err(error.and_then(|inner| {
359                         ValidationError::EntryPoint {
360                             stage: ep.stage,
361                             name: ep.name.clone(),
362                             error: inner,
363                         }
364                         .with_span()
365                     }))
366                 }
367             }
368         }
369 
370         Ok(mod_info)
371     }
372 }
373