1 mod analyzer; 2 mod compose; 3 mod expression; 4 mod function; 5 mod interface; 6 mod r#type; 7 8 use crate::{ 9 arena::{Arena, Handle}, 10 FastHashSet, 11 }; 12 use bit_set::BitSet; 13 use std::ops; 14 15 //TODO: analyze the model at the same time as we validate it, 16 // merge the corresponding matches over expressions and statements. 17 18 pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; 19 pub use compose::ComposeError; 20 pub use expression::ExpressionError; 21 pub use function::{CallError, FunctionError, LocalVariableError}; 22 pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; 23 pub use r#type::{Disalignment, TypeError, TypeFlags}; 24 25 bitflags::bitflags! { 26 /// Validation flags. 27 #[cfg_attr(feature = "serialize", derive(serde::Serialize))] 28 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] 29 pub struct ValidationFlags: u8 { 30 /// Expressions. 31 const EXPRESSIONS = 0x1; 32 /// Statements and blocks of them. 33 const BLOCKS = 0x2; 34 /// Uniformity of control flow for operations that require it. 35 const CONTROL_FLOW_UNIFORMITY = 0x4; 36 /// Host-shareable structure layouts. 37 const STRUCT_LAYOUTS = 0x8; 38 /// Constants. 39 const CONSTANTS = 0x10; 40 } 41 } 42 43 impl Default for ValidationFlags { default() -> Self44 fn default() -> Self { 45 Self::all() 46 } 47 } 48 49 #[must_use] 50 bitflags::bitflags! { 51 /// Allowed IR capabilities. 52 #[derive(Default)] 53 #[cfg_attr(feature = "serialize", derive(serde::Serialize))] 54 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] 55 pub struct Capabilities: u8 { 56 /// Support for `StorageClass:PushConstant`. 57 const PUSH_CONSTANT = 0x1; 58 /// Float values with width = 8. 59 const FLOAT64 = 0x2; 60 } 61 } 62 63 bitflags::bitflags! { 64 /// Validation flags. 65 #[cfg_attr(feature = "serialize", derive(serde::Serialize))] 66 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] 67 pub struct ShaderStages: u8 { 68 const VERTEX = 0x1; 69 const FRAGMENT = 0x2; 70 const COMPUTE = 0x4; 71 } 72 } 73 74 #[derive(Debug)] 75 #[cfg_attr(feature = "serialize", derive(serde::Serialize))] 76 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] 77 pub struct ModuleInfo { 78 functions: Vec<FunctionInfo>, 79 entry_points: Vec<FunctionInfo>, 80 } 81 82 impl ops::Index<Handle<crate::Function>> for ModuleInfo { 83 type Output = FunctionInfo; index(&self, handle: Handle<crate::Function>) -> &Self::Output84 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output { 85 &self.functions[handle.index()] 86 } 87 } 88 89 #[derive(Debug)] 90 pub struct Validator { 91 flags: ValidationFlags, 92 capabilities: Capabilities, 93 types: Vec<r#type::TypeInfo>, 94 location_mask: BitSet, 95 bind_group_masks: Vec<BitSet>, 96 select_cases: FastHashSet<i32>, 97 valid_expression_list: Vec<Handle<crate::Expression>>, 98 valid_expression_set: BitSet, 99 } 100 101 #[derive(Clone, Debug, thiserror::Error)] 102 pub enum ConstantError { 103 #[error("The type doesn't match the constant")] 104 InvalidType, 105 #[error("The component handle {0:?} can not be resolved")] 106 UnresolvedComponent(Handle<crate::Constant>), 107 #[error("The array size handle {0:?} can not be resolved")] 108 UnresolvedSize(Handle<crate::Constant>), 109 #[error(transparent)] 110 Compose(#[from] ComposeError), 111 } 112 113 #[derive(Clone, Debug, thiserror::Error)] 114 pub enum ValidationError { 115 #[error("Type {handle:?} '{name}' is invalid")] 116 Type { 117 handle: Handle<crate::Type>, 118 name: String, 119 #[source] 120 error: TypeError, 121 }, 122 #[error("Constant {handle:?} '{name}' is invalid")] 123 Constant { 124 handle: Handle<crate::Constant>, 125 name: String, 126 #[source] 127 error: ConstantError, 128 }, 129 #[error("Global variable {handle:?} '{name}' is invalid")] 130 GlobalVariable { 131 handle: Handle<crate::GlobalVariable>, 132 name: String, 133 #[source] 134 error: GlobalVariableError, 135 }, 136 #[error("Function {handle:?} '{name}' is invalid")] 137 Function { 138 handle: Handle<crate::Function>, 139 name: String, 140 #[source] 141 error: FunctionError, 142 }, 143 #[error("Entry point {name} at {stage:?} is invalid")] 144 EntryPoint { 145 stage: crate::ShaderStage, 146 name: String, 147 #[source] 148 error: EntryPointError, 149 }, 150 #[error("Module is corrupted")] 151 Corrupted, 152 } 153 154 impl crate::TypeInner { is_sized(&self) -> bool155 fn is_sized(&self) -> bool { 156 match *self { 157 Self::Scalar { .. } 158 | Self::Vector { .. } 159 | Self::Matrix { .. } 160 | Self::Array { 161 size: crate::ArraySize::Constant(_), 162 .. 163 } 164 | Self::Pointer { .. } 165 | Self::ValuePointer { .. } 166 | Self::Struct { .. } => true, 167 Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false, 168 } 169 } 170 image_storage_coordinates(&self) -> Option<crate::ImageDimension>171 fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> { 172 match *self { 173 Self::Scalar { 174 kind: crate::ScalarKind::Sint, 175 .. 176 } => Some(crate::ImageDimension::D1), 177 Self::Vector { 178 size: crate::VectorSize::Bi, 179 kind: crate::ScalarKind::Sint, 180 .. 181 } => Some(crate::ImageDimension::D2), 182 Self::Vector { 183 size: crate::VectorSize::Tri, 184 kind: crate::ScalarKind::Sint, 185 .. 186 } => Some(crate::ImageDimension::D3), 187 _ => None, 188 } 189 } 190 } 191 192 impl Validator { 193 /// Construct a new validator instance. new(flags: ValidationFlags, capabilities: Capabilities) -> Self194 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { 195 Validator { 196 flags, 197 capabilities, 198 types: Vec::new(), 199 location_mask: BitSet::new(), 200 bind_group_masks: Vec::new(), 201 select_cases: FastHashSet::default(), 202 valid_expression_list: Vec::new(), 203 valid_expression_set: BitSet::new(), 204 } 205 } 206 validate_constant( &self, handle: Handle<crate::Constant>, constants: &Arena<crate::Constant>, types: &Arena<crate::Type>, ) -> Result<(), ConstantError>207 fn validate_constant( 208 &self, 209 handle: Handle<crate::Constant>, 210 constants: &Arena<crate::Constant>, 211 types: &Arena<crate::Type>, 212 ) -> Result<(), ConstantError> { 213 let con = &constants[handle]; 214 match con.inner { 215 crate::ConstantInner::Scalar { width, ref value } => { 216 if !self.check_width(value.scalar_kind(), width) { 217 return Err(ConstantError::InvalidType); 218 } 219 } 220 crate::ConstantInner::Composite { ty, ref components } => { 221 match types[ty].inner { 222 crate::TypeInner::Array { 223 size: crate::ArraySize::Constant(size_handle), 224 .. 225 } if handle <= size_handle => { 226 return Err(ConstantError::UnresolvedSize(size_handle)); 227 } 228 _ => {} 229 } 230 if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { 231 return Err(ConstantError::UnresolvedComponent(comp)); 232 } 233 compose::validate_compose( 234 ty, 235 constants, 236 types, 237 components 238 .iter() 239 .map(|&component| constants[component].inner.resolve_type()), 240 )?; 241 } 242 } 243 Ok(()) 244 } 245 246 /// Check the given module to be valid. validate(&mut self, module: &crate::Module) -> Result<ModuleInfo, ValidationError>247 pub fn validate(&mut self, module: &crate::Module) -> Result<ModuleInfo, ValidationError> { 248 self.reset_types(module.types.len()); 249 250 if self.flags.contains(ValidationFlags::CONSTANTS) { 251 for (handle, constant) in module.constants.iter() { 252 self.validate_constant(handle, &module.constants, &module.types) 253 .map_err(|error| ValidationError::Constant { 254 handle, 255 name: constant.name.clone().unwrap_or_default(), 256 error, 257 })?; 258 } 259 } 260 261 // doing after the globals, so that `type_flags` is ready 262 for (handle, ty) in module.types.iter() { 263 let ty_info = self 264 .validate_type(handle, &module.types, &module.constants) 265 .map_err(|error| ValidationError::Type { 266 handle, 267 name: ty.name.clone().unwrap_or_default(), 268 error, 269 })?; 270 self.types[handle.index()] = ty_info; 271 } 272 273 for (var_handle, var) in module.global_variables.iter() { 274 self.validate_global_var(var, &module.types) 275 .map_err(|error| ValidationError::GlobalVariable { 276 handle: var_handle, 277 name: var.name.clone().unwrap_or_default(), 278 error, 279 })?; 280 } 281 282 let mut mod_info = ModuleInfo { 283 functions: Vec::with_capacity(module.functions.len()), 284 entry_points: Vec::with_capacity(module.entry_points.len()), 285 }; 286 287 for (handle, fun) in module.functions.iter() { 288 match self.validate_function(fun, module, &mod_info) { 289 Ok(info) => mod_info.functions.push(info), 290 Err(error) => { 291 return Err(ValidationError::Function { 292 handle, 293 name: fun.name.clone().unwrap_or_default(), 294 error, 295 }) 296 } 297 } 298 } 299 300 let mut ep_map = FastHashSet::default(); 301 for ep in module.entry_points.iter() { 302 if !ep_map.insert((ep.stage, &ep.name)) { 303 return Err(ValidationError::EntryPoint { 304 stage: ep.stage, 305 name: ep.name.clone(), 306 error: EntryPointError::Conflict, 307 }); 308 } 309 310 match self.validate_entry_point(ep, module, &mod_info) { 311 Ok(info) => mod_info.entry_points.push(info), 312 Err(error) => { 313 return Err(ValidationError::EntryPoint { 314 stage: ep.stage, 315 name: ep.name.clone(), 316 error, 317 }) 318 } 319 } 320 } 321 322 Ok(mod_info) 323 } 324 } 325