1 mod analyzer; 2 mod compose; 3 mod expression; 4 mod function; 5 mod interface; 6 mod r#type; 7 8 #[cfg(feature = "validate")] 9 use crate::arena::{Arena, UniqueArena}; 10 11 use crate::{ 12 arena::Handle, 13 proc::{InvalidBaseType, Layouter}, 14 FastHashSet, 15 }; 16 use bit_set::BitSet; 17 use std::ops; 18 19 //TODO: analyze the model at the same time as we validate it, 20 // merge the corresponding matches over expressions and statements. 21 22 use crate::span::{AddSpan as _, WithSpan}; 23 pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; 24 pub use compose::ComposeError; 25 pub use expression::ExpressionError; 26 pub use function::{CallError, FunctionError, LocalVariableError}; 27 pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; 28 pub use r#type::{Disalignment, TypeError, TypeFlags}; 29 30 bitflags::bitflags! { 31 /// Validation flags. 32 #[cfg_attr(feature = "serialize", derive(serde::Serialize))] 33 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] 34 pub struct ValidationFlags: u8 { 35 /// Expressions. 36 #[cfg(feature = "validate")] 37 const EXPRESSIONS = 0x1; 38 /// Statements and blocks of them. 39 #[cfg(feature = "validate")] 40 const BLOCKS = 0x2; 41 /// Uniformity of control flow for operations that require it. 42 #[cfg(feature = "validate")] 43 const CONTROL_FLOW_UNIFORMITY = 0x4; 44 /// Host-shareable structure layouts. 45 #[cfg(feature = "validate")] 46 const STRUCT_LAYOUTS = 0x8; 47 /// Constants. 48 #[cfg(feature = "validate")] 49 const CONSTANTS = 0x10; 50 } 51 } 52 53 impl Default for ValidationFlags { default() -> Self54 fn default() -> Self { 55 Self::all() 56 } 57 } 58 59 bitflags::bitflags! { 60 /// Allowed IR capabilities. 61 #[must_use] 62 #[derive(Default)] 63 #[cfg_attr(feature = "serialize", derive(serde::Serialize))] 64 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] 65 pub struct Capabilities: u8 { 66 /// Support for `StorageClass:PushConstant`. 67 const PUSH_CONSTANT = 0x1; 68 /// Float values with width = 8. 69 const FLOAT64 = 0x2; 70 /// Support for `Builtin:PrimitiveIndex`. 71 const PRIMITIVE_INDEX = 0x4; 72 } 73 } 74 75 bitflags::bitflags! { 76 /// Validation flags. 77 #[cfg_attr(feature = "serialize", derive(serde::Serialize))] 78 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] 79 pub struct ShaderStages: u8 { 80 const VERTEX = 0x1; 81 const FRAGMENT = 0x2; 82 const COMPUTE = 0x4; 83 } 84 } 85 86 #[derive(Debug)] 87 #[cfg_attr(feature = "serialize", derive(serde::Serialize))] 88 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] 89 pub struct ModuleInfo { 90 functions: Vec<FunctionInfo>, 91 entry_points: Vec<FunctionInfo>, 92 } 93 94 impl ops::Index<Handle<crate::Function>> for ModuleInfo { 95 type Output = FunctionInfo; index(&self, handle: Handle<crate::Function>) -> &Self::Output96 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output { 97 &self.functions[handle.index()] 98 } 99 } 100 101 #[derive(Debug)] 102 pub struct Validator { 103 flags: ValidationFlags, 104 capabilities: Capabilities, 105 types: Vec<r#type::TypeInfo>, 106 layouter: Layouter, 107 location_mask: BitSet, 108 bind_group_masks: Vec<BitSet>, 109 #[allow(dead_code)] 110 select_cases: FastHashSet<i32>, 111 valid_expression_list: Vec<Handle<crate::Expression>>, 112 valid_expression_set: BitSet, 113 } 114 115 #[derive(Clone, Debug, thiserror::Error)] 116 pub enum ConstantError { 117 #[error("The type doesn't match the constant")] 118 InvalidType, 119 #[error("The component handle {0:?} can not be resolved")] 120 UnresolvedComponent(Handle<crate::Constant>), 121 #[error("The array size handle {0:?} can not be resolved")] 122 UnresolvedSize(Handle<crate::Constant>), 123 #[error(transparent)] 124 Compose(#[from] ComposeError), 125 } 126 127 #[derive(Clone, Debug, thiserror::Error)] 128 pub enum ValidationError { 129 #[error(transparent)] 130 Layouter(#[from] InvalidBaseType), 131 #[error("Type {handle:?} '{name}' is invalid")] 132 Type { 133 handle: Handle<crate::Type>, 134 name: String, 135 #[source] 136 error: TypeError, 137 }, 138 #[error("Constant {handle:?} '{name}' is invalid")] 139 Constant { 140 handle: Handle<crate::Constant>, 141 name: String, 142 #[source] 143 error: ConstantError, 144 }, 145 #[error("Global variable {handle:?} '{name}' is invalid")] 146 GlobalVariable { 147 handle: Handle<crate::GlobalVariable>, 148 name: String, 149 #[source] 150 error: GlobalVariableError, 151 }, 152 #[error("Function {handle:?} '{name}' is invalid")] 153 Function { 154 handle: Handle<crate::Function>, 155 name: String, 156 #[source] 157 error: FunctionError, 158 }, 159 #[error("Entry point {name} at {stage:?} is invalid")] 160 EntryPoint { 161 stage: crate::ShaderStage, 162 name: String, 163 #[source] 164 error: EntryPointError, 165 }, 166 #[error("Module is corrupted")] 167 Corrupted, 168 } 169 170 impl crate::TypeInner { 171 #[cfg(feature = "validate")] is_sized(&self) -> bool172 fn is_sized(&self) -> bool { 173 match *self { 174 Self::Scalar { .. } 175 | Self::Vector { .. } 176 | Self::Matrix { .. } 177 | Self::Array { 178 size: crate::ArraySize::Constant(_), 179 .. 180 } 181 | Self::Atomic { .. } 182 | Self::Pointer { .. } 183 | Self::ValuePointer { .. } 184 | Self::Struct { .. } => true, 185 Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } => false, 186 } 187 } 188 189 /// Return the `ImageDimension` for which `self` is an appropriate coordinate. 190 #[cfg(feature = "validate")] image_storage_coordinates(&self) -> Option<crate::ImageDimension>191 fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> { 192 match *self { 193 Self::Scalar { 194 kind: crate::ScalarKind::Sint, 195 .. 196 } => Some(crate::ImageDimension::D1), 197 Self::Vector { 198 size: crate::VectorSize::Bi, 199 kind: crate::ScalarKind::Sint, 200 .. 201 } => Some(crate::ImageDimension::D2), 202 Self::Vector { 203 size: crate::VectorSize::Tri, 204 kind: crate::ScalarKind::Sint, 205 .. 206 } => Some(crate::ImageDimension::D3), 207 _ => None, 208 } 209 } 210 } 211 212 impl Validator { 213 /// Construct a new validator instance. new(flags: ValidationFlags, capabilities: Capabilities) -> Self214 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { 215 Validator { 216 flags, 217 capabilities, 218 types: Vec::new(), 219 layouter: Layouter::default(), 220 location_mask: BitSet::new(), 221 bind_group_masks: Vec::new(), 222 select_cases: FastHashSet::default(), 223 valid_expression_list: Vec::new(), 224 valid_expression_set: BitSet::new(), 225 } 226 } 227 228 #[cfg(feature = "validate")] validate_constant( &self, handle: Handle<crate::Constant>, constants: &Arena<crate::Constant>, types: &UniqueArena<crate::Type>, ) -> Result<(), ConstantError>229 fn validate_constant( 230 &self, 231 handle: Handle<crate::Constant>, 232 constants: &Arena<crate::Constant>, 233 types: &UniqueArena<crate::Type>, 234 ) -> Result<(), ConstantError> { 235 let con = &constants[handle]; 236 match con.inner { 237 crate::ConstantInner::Scalar { width, ref value } => { 238 if !self.check_width(value.scalar_kind(), width) { 239 return Err(ConstantError::InvalidType); 240 } 241 } 242 crate::ConstantInner::Composite { ty, ref components } => { 243 match types[ty].inner { 244 crate::TypeInner::Array { 245 size: crate::ArraySize::Constant(size_handle), 246 .. 247 } if handle <= size_handle => { 248 return Err(ConstantError::UnresolvedSize(size_handle)); 249 } 250 _ => {} 251 } 252 if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { 253 return Err(ConstantError::UnresolvedComponent(comp)); 254 } 255 compose::validate_compose( 256 ty, 257 constants, 258 types, 259 components 260 .iter() 261 .map(|&component| constants[component].inner.resolve_type()), 262 )?; 263 } 264 } 265 Ok(()) 266 } 267 268 /// Check the given module to be valid. validate( &mut self, module: &crate::Module, ) -> Result<ModuleInfo, WithSpan<ValidationError>>269 pub fn validate( 270 &mut self, 271 module: &crate::Module, 272 ) -> Result<ModuleInfo, WithSpan<ValidationError>> { 273 self.reset_types(module.types.len()); 274 self.layouter 275 .update(&module.types, &module.constants) 276 .map_err(|e| { 277 let InvalidBaseType(handle) = e; 278 ValidationError::from(e).with_span_handle(handle, &module.types) 279 })?; 280 281 #[cfg(feature = "validate")] 282 if self.flags.contains(ValidationFlags::CONSTANTS) { 283 for (handle, constant) in module.constants.iter() { 284 self.validate_constant(handle, &module.constants, &module.types) 285 .map_err(|error| { 286 ValidationError::Constant { 287 handle, 288 name: constant.name.clone().unwrap_or_default(), 289 error, 290 } 291 .with_span_handle(handle, &module.constants) 292 })? 293 } 294 } 295 296 for (handle, ty) in module.types.iter() { 297 let ty_info = self 298 .validate_type(handle, &module.types, &module.constants) 299 .map_err(|error| { 300 ValidationError::Type { 301 handle, 302 name: ty.name.clone().unwrap_or_default(), 303 error, 304 } 305 .with_span_handle(handle, &module.types) 306 })?; 307 self.types[handle.index()] = ty_info; 308 } 309 310 #[cfg(feature = "validate")] 311 for (var_handle, var) in module.global_variables.iter() { 312 self.validate_global_var(var, &module.types) 313 .map_err(|error| { 314 ValidationError::GlobalVariable { 315 handle: var_handle, 316 name: var.name.clone().unwrap_or_default(), 317 error, 318 } 319 .with_span_handle(var_handle, &module.global_variables) 320 })?; 321 } 322 323 let mut mod_info = ModuleInfo { 324 functions: Vec::with_capacity(module.functions.len()), 325 entry_points: Vec::with_capacity(module.entry_points.len()), 326 }; 327 328 for (handle, fun) in module.functions.iter() { 329 match self.validate_function(fun, module, &mod_info) { 330 Ok(info) => mod_info.functions.push(info), 331 Err(error) => { 332 return Err(error.and_then(|error| { 333 ValidationError::Function { 334 handle, 335 name: fun.name.clone().unwrap_or_default(), 336 error, 337 } 338 .with_span_handle(handle, &module.functions) 339 })) 340 } 341 } 342 } 343 344 let mut ep_map = FastHashSet::default(); 345 for ep in module.entry_points.iter() { 346 if !ep_map.insert((ep.stage, &ep.name)) { 347 return Err(ValidationError::EntryPoint { 348 stage: ep.stage, 349 name: ep.name.clone(), 350 error: EntryPointError::Conflict, 351 } 352 .with_span()); // TODO: keep some EP span information? 353 } 354 355 match self.validate_entry_point(ep, module, &mod_info) { 356 Ok(info) => mod_info.entry_points.push(info), 357 Err(error) => { 358 return Err(error.and_then(|inner| { 359 ValidationError::EntryPoint { 360 stage: ep.stage, 361 name: ep.name.clone(), 362 error: inner, 363 } 364 .with_span() 365 })) 366 } 367 } 368 } 369 370 Ok(mod_info) 371 } 372 } 373