1 mod analyzer;
2 mod compose;
3 mod expression;
4 mod function;
5 mod interface;
6 mod r#type;
7 
8 use crate::{
9     arena::{Arena, Handle},
10     FastHashSet,
11 };
12 use bit_set::BitSet;
13 use std::ops;
14 
15 //TODO: analyze the model at the same time as we validate it,
16 // merge the corresponding matches over expressions and statements.
17 
18 pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
19 pub use compose::ComposeError;
20 pub use expression::ExpressionError;
21 pub use function::{CallError, FunctionError, LocalVariableError};
22 pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
23 pub use r#type::{Disalignment, TypeError, TypeFlags};
24 
25 bitflags::bitflags! {
26     /// Validation flags.
27     #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
28     #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
29     pub struct ValidationFlags: u8 {
30         /// Expressions.
31         const EXPRESSIONS = 0x1;
32         /// Statements and blocks of them.
33         const BLOCKS = 0x2;
34         /// Uniformity of control flow for operations that require it.
35         const CONTROL_FLOW_UNIFORMITY = 0x4;
36         /// Host-shareable structure layouts.
37         const STRUCT_LAYOUTS = 0x8;
38         /// Constants.
39         const CONSTANTS = 0x10;
40     }
41 }
42 
43 impl Default for ValidationFlags {
default() -> Self44     fn default() -> Self {
45         Self::all()
46     }
47 }
48 
49 #[must_use]
50 bitflags::bitflags! {
51     /// Allowed IR capabilities.
52     #[derive(Default)]
53     #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
54     #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
55     pub struct Capabilities: u8 {
56         /// Support for `StorageClass:PushConstant`.
57         const PUSH_CONSTANT = 0x1;
58         /// Float values with width = 8.
59         const FLOAT64 = 0x2;
60     }
61 }
62 
63 bitflags::bitflags! {
64     /// Validation flags.
65     #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
66     #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
67     pub struct ShaderStages: u8 {
68         const VERTEX = 0x1;
69         const FRAGMENT = 0x2;
70         const COMPUTE = 0x4;
71     }
72 }
73 
74 #[derive(Debug)]
75 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
76 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
77 pub struct ModuleInfo {
78     functions: Vec<FunctionInfo>,
79     entry_points: Vec<FunctionInfo>,
80 }
81 
82 impl ops::Index<Handle<crate::Function>> for ModuleInfo {
83     type Output = FunctionInfo;
index(&self, handle: Handle<crate::Function>) -> &Self::Output84     fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
85         &self.functions[handle.index()]
86     }
87 }
88 
89 #[derive(Debug)]
90 pub struct Validator {
91     flags: ValidationFlags,
92     capabilities: Capabilities,
93     types: Vec<r#type::TypeInfo>,
94     location_mask: BitSet,
95     bind_group_masks: Vec<BitSet>,
96     select_cases: FastHashSet<i32>,
97     valid_expression_list: Vec<Handle<crate::Expression>>,
98     valid_expression_set: BitSet,
99 }
100 
101 #[derive(Clone, Debug, thiserror::Error)]
102 pub enum ConstantError {
103     #[error("The type doesn't match the constant")]
104     InvalidType,
105     #[error("The component handle {0:?} can not be resolved")]
106     UnresolvedComponent(Handle<crate::Constant>),
107     #[error("The array size handle {0:?} can not be resolved")]
108     UnresolvedSize(Handle<crate::Constant>),
109     #[error(transparent)]
110     Compose(#[from] ComposeError),
111 }
112 
113 #[derive(Clone, Debug, thiserror::Error)]
114 pub enum ValidationError {
115     #[error("Type {handle:?} '{name}' is invalid")]
116     Type {
117         handle: Handle<crate::Type>,
118         name: String,
119         #[source]
120         error: TypeError,
121     },
122     #[error("Constant {handle:?} '{name}' is invalid")]
123     Constant {
124         handle: Handle<crate::Constant>,
125         name: String,
126         #[source]
127         error: ConstantError,
128     },
129     #[error("Global variable {handle:?} '{name}' is invalid")]
130     GlobalVariable {
131         handle: Handle<crate::GlobalVariable>,
132         name: String,
133         #[source]
134         error: GlobalVariableError,
135     },
136     #[error("Function {handle:?} '{name}' is invalid")]
137     Function {
138         handle: Handle<crate::Function>,
139         name: String,
140         #[source]
141         error: FunctionError,
142     },
143     #[error("Entry point {name} at {stage:?} is invalid")]
144     EntryPoint {
145         stage: crate::ShaderStage,
146         name: String,
147         #[source]
148         error: EntryPointError,
149     },
150     #[error("Module is corrupted")]
151     Corrupted,
152 }
153 
154 impl crate::TypeInner {
is_sized(&self) -> bool155     fn is_sized(&self) -> bool {
156         match *self {
157             Self::Scalar { .. }
158             | Self::Vector { .. }
159             | Self::Matrix { .. }
160             | Self::Array {
161                 size: crate::ArraySize::Constant(_),
162                 ..
163             }
164             | Self::Pointer { .. }
165             | Self::ValuePointer { .. }
166             | Self::Struct { .. } => true,
167             Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false,
168         }
169     }
170 
image_storage_coordinates(&self) -> Option<crate::ImageDimension>171     fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
172         match *self {
173             Self::Scalar {
174                 kind: crate::ScalarKind::Sint,
175                 ..
176             } => Some(crate::ImageDimension::D1),
177             Self::Vector {
178                 size: crate::VectorSize::Bi,
179                 kind: crate::ScalarKind::Sint,
180                 ..
181             } => Some(crate::ImageDimension::D2),
182             Self::Vector {
183                 size: crate::VectorSize::Tri,
184                 kind: crate::ScalarKind::Sint,
185                 ..
186             } => Some(crate::ImageDimension::D3),
187             _ => None,
188         }
189     }
190 }
191 
192 impl Validator {
193     /// Construct a new validator instance.
new(flags: ValidationFlags, capabilities: Capabilities) -> Self194     pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
195         Validator {
196             flags,
197             capabilities,
198             types: Vec::new(),
199             location_mask: BitSet::new(),
200             bind_group_masks: Vec::new(),
201             select_cases: FastHashSet::default(),
202             valid_expression_list: Vec::new(),
203             valid_expression_set: BitSet::new(),
204         }
205     }
206 
validate_constant( &self, handle: Handle<crate::Constant>, constants: &Arena<crate::Constant>, types: &Arena<crate::Type>, ) -> Result<(), ConstantError>207     fn validate_constant(
208         &self,
209         handle: Handle<crate::Constant>,
210         constants: &Arena<crate::Constant>,
211         types: &Arena<crate::Type>,
212     ) -> Result<(), ConstantError> {
213         let con = &constants[handle];
214         match con.inner {
215             crate::ConstantInner::Scalar { width, ref value } => {
216                 if !self.check_width(value.scalar_kind(), width) {
217                     return Err(ConstantError::InvalidType);
218                 }
219             }
220             crate::ConstantInner::Composite { ty, ref components } => {
221                 match types[ty].inner {
222                     crate::TypeInner::Array {
223                         size: crate::ArraySize::Constant(size_handle),
224                         ..
225                     } if handle <= size_handle => {
226                         return Err(ConstantError::UnresolvedSize(size_handle));
227                     }
228                     _ => {}
229                 }
230                 if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) {
231                     return Err(ConstantError::UnresolvedComponent(comp));
232                 }
233                 compose::validate_compose(
234                     ty,
235                     constants,
236                     types,
237                     components
238                         .iter()
239                         .map(|&component| constants[component].inner.resolve_type()),
240                 )?;
241             }
242         }
243         Ok(())
244     }
245 
246     /// Check the given module to be valid.
validate(&mut self, module: &crate::Module) -> Result<ModuleInfo, ValidationError>247     pub fn validate(&mut self, module: &crate::Module) -> Result<ModuleInfo, ValidationError> {
248         self.reset_types(module.types.len());
249 
250         if self.flags.contains(ValidationFlags::CONSTANTS) {
251             for (handle, constant) in module.constants.iter() {
252                 self.validate_constant(handle, &module.constants, &module.types)
253                     .map_err(|error| ValidationError::Constant {
254                         handle,
255                         name: constant.name.clone().unwrap_or_default(),
256                         error,
257                     })?;
258             }
259         }
260 
261         // doing after the globals, so that `type_flags` is ready
262         for (handle, ty) in module.types.iter() {
263             let ty_info = self
264                 .validate_type(handle, &module.types, &module.constants)
265                 .map_err(|error| ValidationError::Type {
266                     handle,
267                     name: ty.name.clone().unwrap_or_default(),
268                     error,
269                 })?;
270             self.types[handle.index()] = ty_info;
271         }
272 
273         for (var_handle, var) in module.global_variables.iter() {
274             self.validate_global_var(var, &module.types)
275                 .map_err(|error| ValidationError::GlobalVariable {
276                     handle: var_handle,
277                     name: var.name.clone().unwrap_or_default(),
278                     error,
279                 })?;
280         }
281 
282         let mut mod_info = ModuleInfo {
283             functions: Vec::with_capacity(module.functions.len()),
284             entry_points: Vec::with_capacity(module.entry_points.len()),
285         };
286 
287         for (handle, fun) in module.functions.iter() {
288             match self.validate_function(fun, module, &mod_info) {
289                 Ok(info) => mod_info.functions.push(info),
290                 Err(error) => {
291                     return Err(ValidationError::Function {
292                         handle,
293                         name: fun.name.clone().unwrap_or_default(),
294                         error,
295                     })
296                 }
297             }
298         }
299 
300         let mut ep_map = FastHashSet::default();
301         for ep in module.entry_points.iter() {
302             if !ep_map.insert((ep.stage, &ep.name)) {
303                 return Err(ValidationError::EntryPoint {
304                     stage: ep.stage,
305                     name: ep.name.clone(),
306                     error: EntryPointError::Conflict,
307                 });
308             }
309 
310             match self.validate_entry_point(ep, module, &mod_info) {
311                 Ok(info) => mod_info.entry_points.push(info),
312                 Err(error) => {
313                     return Err(ValidationError::EntryPoint {
314                         stage: ep.stage,
315                         name: ep.name.clone(),
316                         error,
317                     })
318                 }
319             }
320         }
321 
322         Ok(mod_info)
323     }
324 }
325