1 use crate::{ 2 arena::{Arena, Handle}, 3 BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner, 4 UnaryOperator, 5 }; 6 7 #[derive(Debug)] 8 pub struct ConstantSolver<'a> { 9 pub types: &'a Arena<Type>, 10 pub expressions: &'a Arena<Expression>, 11 pub constants: &'a mut Arena<Constant>, 12 } 13 14 #[derive(Clone, Debug, PartialEq, thiserror::Error)] 15 pub enum ConstantSolvingError { 16 #[error("Constants cannot access function arguments")] 17 FunctionArg, 18 #[error("Constants cannot access global variables")] 19 GlobalVariable, 20 #[error("Constants cannot access local variables")] 21 LocalVariable, 22 #[error("Cannot get the array length of a non array type")] 23 InvalidArrayLengthArg, 24 #[error("Constants cannot get the array length of a dynamically sized array")] 25 ArrayLengthDynamic, 26 #[error("Constants cannot call functions")] 27 Call, 28 #[error("Constants don't support relational functions")] 29 Relational, 30 #[error("Constants don't support derivative functions")] 31 Derivative, 32 #[error("Constants don't support select expressions")] 33 Select, 34 #[error("Constants don't support load expressions")] 35 Load, 36 #[error("Constants don't support image expressions")] 37 ImageExpression, 38 #[error("Cannot access the type")] 39 InvalidAccessBase, 40 #[error("Cannot access at the index")] 41 InvalidAccessIndex, 42 #[error("Cannot access with index of type")] 43 InvalidAccessIndexTy, 44 #[error("Constants don't support bitcasts")] 45 Bitcast, 46 #[error("Cannot cast type")] 47 InvalidCastArg, 48 #[error("Cannot apply the unary op to the argument")] 49 InvalidUnaryOpArg, 50 #[error("Cannot apply the binary op to the arguments")] 51 InvalidBinaryOpArgs, 52 #[error("Splat/swizzle type is not registered")] 53 DestinationTypeNotFound, 54 } 55 56 impl<'a> ConstantSolver<'a> { solve( &mut self, expr: Handle<Expression>, ) -> Result<Handle<Constant>, ConstantSolvingError>57 pub fn solve( 58 &mut self, 59 expr: Handle<Expression>, 60 ) -> Result<Handle<Constant>, ConstantSolvingError> { 61 match self.expressions[expr] { 62 Expression::Constant(constant) => Ok(constant), 63 Expression::AccessIndex { base, index } => self.access(base, index as usize), 64 Expression::Access { base, index } => { 65 let index = self.solve(index)?; 66 67 self.access(base, self.constant_index(index)?) 68 } 69 Expression::Splat { 70 size, 71 value: splat_value, 72 } => { 73 let value_constant = self.solve(splat_value)?; 74 let ty = match self.constants[value_constant].inner { 75 ConstantInner::Scalar { ref value, width } => { 76 let kind = value.scalar_kind(); 77 self.types 78 .fetch_if(|t| t.inner == crate::TypeInner::Vector { size, kind, width }) 79 } 80 ConstantInner::Composite { .. } => None, 81 }; 82 83 //TODO: register the new type if needed 84 let ty = ty.ok_or(ConstantSolvingError::DestinationTypeNotFound)?; 85 Ok(self.constants.fetch_or_append(Constant { 86 name: None, 87 specialization: None, 88 inner: ConstantInner::Composite { 89 ty, 90 components: vec![value_constant; size as usize], 91 }, 92 })) 93 } 94 Expression::Swizzle { 95 size, 96 vector: src_vector, 97 pattern, 98 } => { 99 let src_constant = self.solve(src_vector)?; 100 let (ty, src_components) = match self.constants[src_constant].inner { 101 ConstantInner::Scalar { .. } => (None, &[][..]), 102 ConstantInner::Composite { 103 ty, 104 components: ref src_components, 105 } => match self.types[ty].inner { 106 crate::TypeInner::Vector { 107 size: _, 108 kind, 109 width, 110 } => { 111 let dst_ty = self.types.fetch_if(|t| { 112 t.inner == crate::TypeInner::Vector { size, kind, width } 113 }); 114 (dst_ty, &src_components[..]) 115 } 116 _ => (None, &[][..]), 117 }, 118 }; 119 //TODO: register the new type if needed 120 let ty = ty.ok_or(ConstantSolvingError::DestinationTypeNotFound)?; 121 let components = pattern 122 .iter() 123 .map(|&sc| src_components[sc as usize]) 124 .collect(); 125 126 Ok(self.constants.fetch_or_append(Constant { 127 name: None, 128 specialization: None, 129 inner: ConstantInner::Composite { ty, components }, 130 })) 131 } 132 Expression::Compose { ty, ref components } => { 133 let components = components 134 .iter() 135 .map(|c| self.solve(*c)) 136 .collect::<Result<_, _>>()?; 137 138 Ok(self.constants.fetch_or_append(Constant { 139 name: None, 140 specialization: None, 141 inner: ConstantInner::Composite { ty, components }, 142 })) 143 } 144 Expression::Unary { expr, op } => { 145 let expr_constant = self.solve(expr)?; 146 147 self.unary_op(op, expr_constant) 148 } 149 Expression::Binary { left, right, op } => { 150 let left_constant = self.solve(left)?; 151 let right_constant = self.solve(right)?; 152 153 self.binary_op(op, left_constant, right_constant) 154 } 155 Expression::Math { .. } => todo!(), 156 Expression::As { 157 convert, 158 expr, 159 kind, 160 } => { 161 let expr_constant = self.solve(expr)?; 162 163 match convert { 164 Some(width) => self.cast(expr_constant, kind, width), 165 None => Err(ConstantSolvingError::Bitcast), 166 } 167 } 168 Expression::ArrayLength(expr) => { 169 let array = self.solve(expr)?; 170 171 match self.constants[array].inner { 172 ConstantInner::Scalar { .. } => { 173 Err(ConstantSolvingError::InvalidArrayLengthArg) 174 } 175 ConstantInner::Composite { ty, .. } => match self.types[ty].inner { 176 TypeInner::Array { size, .. } => match size { 177 crate::ArraySize::Constant(constant) => Ok(constant), 178 crate::ArraySize::Dynamic => { 179 Err(ConstantSolvingError::ArrayLengthDynamic) 180 } 181 }, 182 _ => Err(ConstantSolvingError::InvalidArrayLengthArg), 183 }, 184 } 185 } 186 187 Expression::Load { .. } => Err(ConstantSolvingError::Load), 188 Expression::Select { .. } => Err(ConstantSolvingError::Select), 189 Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable), 190 Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative), 191 Expression::Relational { .. } => Err(ConstantSolvingError::Relational), 192 Expression::Call { .. } => Err(ConstantSolvingError::Call), 193 Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg), 194 Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable), 195 Expression::ImageSample { .. } 196 | Expression::ImageLoad { .. } 197 | Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression), 198 } 199 } 200 access( &mut self, base: Handle<Expression>, index: usize, ) -> Result<Handle<Constant>, ConstantSolvingError>201 fn access( 202 &mut self, 203 base: Handle<Expression>, 204 index: usize, 205 ) -> Result<Handle<Constant>, ConstantSolvingError> { 206 let base = self.solve(base)?; 207 208 match self.constants[base].inner { 209 ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase), 210 ConstantInner::Composite { ty, ref components } => { 211 match self.types[ty].inner { 212 TypeInner::Vector { .. } 213 | TypeInner::Matrix { .. } 214 | TypeInner::Array { .. } 215 | TypeInner::Struct { .. } => (), 216 _ => return Err(ConstantSolvingError::InvalidAccessBase), 217 } 218 219 components 220 .get(index) 221 .copied() 222 .ok_or(ConstantSolvingError::InvalidAccessIndex) 223 } 224 } 225 } 226 constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError>227 fn constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError> { 228 match self.constants[constant].inner { 229 ConstantInner::Scalar { 230 value: ScalarValue::Uint(index), 231 .. 232 } => Ok(index as usize), 233 _ => Err(ConstantSolvingError::InvalidAccessIndexTy), 234 } 235 } 236 cast( &mut self, constant: Handle<Constant>, kind: ScalarKind, target_width: crate::Bytes, ) -> Result<Handle<Constant>, ConstantSolvingError>237 fn cast( 238 &mut self, 239 constant: Handle<Constant>, 240 kind: ScalarKind, 241 target_width: crate::Bytes, 242 ) -> Result<Handle<Constant>, ConstantSolvingError> { 243 fn inner_cast<A: num_traits::FromPrimitive>(value: ScalarValue) -> A { 244 match value { 245 ScalarValue::Sint(v) => A::from_i64(v), 246 ScalarValue::Uint(v) => A::from_u64(v), 247 ScalarValue::Float(v) => A::from_f64(v), 248 ScalarValue::Bool(v) => A::from_u64(v as u64), 249 } 250 .unwrap() 251 } 252 253 let mut inner = self.constants[constant].inner.clone(); 254 255 match inner { 256 ConstantInner::Scalar { 257 ref mut value, 258 ref mut width, 259 } => { 260 *width = target_width; 261 *value = match kind { 262 ScalarKind::Sint => ScalarValue::Sint(inner_cast(*value)), 263 ScalarKind::Uint => ScalarValue::Uint(inner_cast(*value)), 264 ScalarKind::Float => ScalarValue::Float(inner_cast(*value)), 265 ScalarKind::Bool => ScalarValue::Bool(inner_cast::<u64>(*value) != 0), 266 } 267 } 268 ConstantInner::Composite { 269 ty, 270 ref mut components, 271 } => { 272 match self.types[ty].inner { 273 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (), 274 _ => return Err(ConstantSolvingError::InvalidCastArg), 275 } 276 277 for component in components { 278 *component = self.cast(*component, kind, target_width)?; 279 } 280 } 281 } 282 283 Ok(self.constants.fetch_or_append(Constant { 284 name: None, 285 specialization: None, 286 inner, 287 })) 288 } 289 unary_op( &mut self, op: UnaryOperator, constant: Handle<Constant>, ) -> Result<Handle<Constant>, ConstantSolvingError>290 fn unary_op( 291 &mut self, 292 op: UnaryOperator, 293 constant: Handle<Constant>, 294 ) -> Result<Handle<Constant>, ConstantSolvingError> { 295 let mut inner = self.constants[constant].inner.clone(); 296 297 match inner { 298 ConstantInner::Scalar { ref mut value, .. } => match op { 299 UnaryOperator::Negate => match *value { 300 ScalarValue::Sint(ref mut v) => *v = -*v, 301 ScalarValue::Float(ref mut v) => *v = -*v, 302 _ => return Err(ConstantSolvingError::InvalidUnaryOpArg), 303 }, 304 UnaryOperator::Not => match *value { 305 ScalarValue::Sint(ref mut v) => *v = !*v, 306 ScalarValue::Uint(ref mut v) => *v = !*v, 307 ScalarValue::Bool(ref mut v) => *v = !*v, 308 _ => return Err(ConstantSolvingError::InvalidUnaryOpArg), 309 }, 310 }, 311 ConstantInner::Composite { 312 ty, 313 ref mut components, 314 } => { 315 match self.types[ty].inner { 316 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (), 317 _ => return Err(ConstantSolvingError::InvalidCastArg), 318 } 319 320 for component in components { 321 *component = self.unary_op(op, *component)? 322 } 323 } 324 } 325 326 Ok(self.constants.fetch_or_append(Constant { 327 name: None, 328 specialization: None, 329 inner, 330 })) 331 } 332 binary_op( &mut self, op: BinaryOperator, left: Handle<Constant>, right: Handle<Constant>, ) -> Result<Handle<Constant>, ConstantSolvingError>333 fn binary_op( 334 &mut self, 335 op: BinaryOperator, 336 left: Handle<Constant>, 337 right: Handle<Constant>, 338 ) -> Result<Handle<Constant>, ConstantSolvingError> { 339 let left = &self.constants[left].inner; 340 let right = &self.constants[right].inner; 341 342 let inner = match (left, right) { 343 ( 344 &ConstantInner::Scalar { 345 value: left_value, 346 width, 347 }, 348 &ConstantInner::Scalar { 349 value: right_value, 350 width: _, 351 }, 352 ) => { 353 let value = match op { 354 BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value), 355 BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value), 356 BinaryOperator::Less => ScalarValue::Bool(left_value < right_value), 357 BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value), 358 BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value), 359 BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value), 360 361 _ => match (left_value, right_value) { 362 (ScalarValue::Sint(a), ScalarValue::Sint(b)) => { 363 ScalarValue::Sint(match op { 364 BinaryOperator::Add => a + b, 365 BinaryOperator::Subtract => a - b, 366 BinaryOperator::Multiply => a * b, 367 BinaryOperator::Divide => a / b, 368 BinaryOperator::Modulo => a % b, 369 BinaryOperator::And => a & b, 370 BinaryOperator::ExclusiveOr => a ^ b, 371 BinaryOperator::InclusiveOr => a | b, 372 BinaryOperator::ShiftLeft => a << b, 373 BinaryOperator::ShiftRight => a >> b, 374 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 375 }) 376 } 377 (ScalarValue::Uint(a), ScalarValue::Uint(b)) => { 378 ScalarValue::Uint(match op { 379 BinaryOperator::Add => a + b, 380 BinaryOperator::Subtract => a - b, 381 BinaryOperator::Multiply => a * b, 382 BinaryOperator::Divide => a / b, 383 BinaryOperator::Modulo => a % b, 384 BinaryOperator::And => a & b, 385 BinaryOperator::ExclusiveOr => a ^ b, 386 BinaryOperator::InclusiveOr => a | b, 387 BinaryOperator::ShiftLeft => a << b, 388 BinaryOperator::ShiftRight => a >> b, 389 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 390 }) 391 } 392 (ScalarValue::Float(a), ScalarValue::Float(b)) => { 393 ScalarValue::Float(match op { 394 BinaryOperator::Add => a + b, 395 BinaryOperator::Subtract => a - b, 396 BinaryOperator::Multiply => a * b, 397 BinaryOperator::Divide => a / b, 398 BinaryOperator::Modulo => a % b, 399 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 400 }) 401 } 402 (ScalarValue::Bool(a), ScalarValue::Bool(b)) => { 403 ScalarValue::Bool(match op { 404 BinaryOperator::LogicalAnd => a && b, 405 BinaryOperator::LogicalOr => a || b, 406 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 407 }) 408 } 409 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 410 }, 411 }; 412 413 ConstantInner::Scalar { value, width } 414 } 415 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 416 }; 417 418 Ok(self.constants.fetch_or_append(Constant { 419 name: None, 420 specialization: None, 421 inner, 422 })) 423 } 424 } 425 426 #[cfg(test)] 427 mod tests { 428 use std::vec; 429 430 use crate::{ 431 Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner, 432 UnaryOperator, VectorSize, 433 }; 434 435 use super::ConstantSolver; 436 437 #[test] unary_op()438 fn unary_op() { 439 let mut types = Arena::new(); 440 let mut expressions = Arena::new(); 441 let mut constants = Arena::new(); 442 443 let vec_ty = types.append(Type { 444 name: None, 445 inner: TypeInner::Vector { 446 size: VectorSize::Bi, 447 kind: ScalarKind::Sint, 448 width: 4, 449 }, 450 }); 451 452 let h = constants.append(Constant { 453 name: None, 454 specialization: None, 455 inner: ConstantInner::Scalar { 456 width: 4, 457 value: ScalarValue::Sint(4), 458 }, 459 }); 460 461 let h1 = constants.append(Constant { 462 name: None, 463 specialization: None, 464 inner: ConstantInner::Scalar { 465 width: 4, 466 value: ScalarValue::Sint(8), 467 }, 468 }); 469 470 let vec_h = constants.append(Constant { 471 name: None, 472 specialization: None, 473 inner: ConstantInner::Composite { 474 ty: vec_ty, 475 components: vec![h, h1], 476 }, 477 }); 478 479 let expr = expressions.append(Expression::Constant(h)); 480 let expr1 = expressions.append(Expression::Constant(vec_h)); 481 482 let root1 = expressions.append(Expression::Unary { 483 op: UnaryOperator::Negate, 484 expr, 485 }); 486 487 let root2 = expressions.append(Expression::Unary { 488 op: UnaryOperator::Not, 489 expr, 490 }); 491 492 let root3 = expressions.append(Expression::Unary { 493 op: UnaryOperator::Not, 494 expr: expr1, 495 }); 496 497 let mut solver = ConstantSolver { 498 types: &types, 499 expressions: &expressions, 500 constants: &mut constants, 501 }; 502 503 let res1 = solver.solve(root1).unwrap(); 504 let res2 = solver.solve(root2).unwrap(); 505 let res3 = solver.solve(root3).unwrap(); 506 507 assert_eq!( 508 constants[res1].inner, 509 ConstantInner::Scalar { 510 width: 4, 511 value: ScalarValue::Sint(-4), 512 }, 513 ); 514 515 assert_eq!( 516 constants[res2].inner, 517 ConstantInner::Scalar { 518 width: 4, 519 value: ScalarValue::Sint(!4), 520 }, 521 ); 522 523 let res3_inner = &constants[res3].inner; 524 525 match res3_inner { 526 ConstantInner::Composite { ty, components } => { 527 assert_eq!(*ty, vec_ty); 528 let mut components_iter = components.iter().copied(); 529 assert_eq!( 530 constants[components_iter.next().unwrap()].inner, 531 ConstantInner::Scalar { 532 width: 4, 533 value: ScalarValue::Sint(!4), 534 }, 535 ); 536 assert_eq!( 537 constants[components_iter.next().unwrap()].inner, 538 ConstantInner::Scalar { 539 width: 4, 540 value: ScalarValue::Sint(!8), 541 }, 542 ); 543 assert!(components_iter.next().is_none()); 544 } 545 _ => panic!("Expected vector"), 546 } 547 } 548 549 #[test] cast()550 fn cast() { 551 let mut expressions = Arena::new(); 552 let mut constants = Arena::new(); 553 554 let h = constants.append(Constant { 555 name: None, 556 specialization: None, 557 inner: ConstantInner::Scalar { 558 width: 4, 559 value: ScalarValue::Sint(4), 560 }, 561 }); 562 563 let expr = expressions.append(Expression::Constant(h)); 564 565 let root = expressions.append(Expression::As { 566 expr, 567 kind: ScalarKind::Bool, 568 convert: Some(crate::BOOL_WIDTH), 569 }); 570 571 let mut solver = ConstantSolver { 572 types: &Arena::new(), 573 expressions: &expressions, 574 constants: &mut constants, 575 }; 576 577 let res = solver.solve(root).unwrap(); 578 579 assert_eq!( 580 constants[res].inner, 581 ConstantInner::Scalar { 582 width: crate::BOOL_WIDTH, 583 value: ScalarValue::Bool(true), 584 }, 585 ); 586 } 587 588 #[test] access()589 fn access() { 590 let mut types = Arena::new(); 591 let mut expressions = Arena::new(); 592 let mut constants = Arena::new(); 593 594 let matrix_ty = types.append(Type { 595 name: None, 596 inner: TypeInner::Matrix { 597 columns: VectorSize::Bi, 598 rows: VectorSize::Tri, 599 width: 4, 600 }, 601 }); 602 603 let vec_ty = types.append(Type { 604 name: None, 605 inner: TypeInner::Vector { 606 size: VectorSize::Tri, 607 kind: ScalarKind::Float, 608 width: 4, 609 }, 610 }); 611 612 let mut vec1_components = Vec::with_capacity(3); 613 let mut vec2_components = Vec::with_capacity(3); 614 615 for i in 0..3 { 616 let h = constants.append(Constant { 617 name: None, 618 specialization: None, 619 inner: ConstantInner::Scalar { 620 width: 4, 621 value: ScalarValue::Float(i as f64), 622 }, 623 }); 624 625 vec1_components.push(h) 626 } 627 628 for i in 3..6 { 629 let h = constants.append(Constant { 630 name: None, 631 specialization: None, 632 inner: ConstantInner::Scalar { 633 width: 4, 634 value: ScalarValue::Float(i as f64), 635 }, 636 }); 637 638 vec2_components.push(h) 639 } 640 641 let vec1 = constants.append(Constant { 642 name: None, 643 specialization: None, 644 inner: ConstantInner::Composite { 645 ty: vec_ty, 646 components: vec1_components, 647 }, 648 }); 649 650 let vec2 = constants.append(Constant { 651 name: None, 652 specialization: None, 653 inner: ConstantInner::Composite { 654 ty: vec_ty, 655 components: vec2_components, 656 }, 657 }); 658 659 let h = constants.append(Constant { 660 name: None, 661 specialization: None, 662 inner: ConstantInner::Composite { 663 ty: matrix_ty, 664 components: vec![vec1, vec2], 665 }, 666 }); 667 668 let base = expressions.append(Expression::Constant(h)); 669 let root1 = expressions.append(Expression::AccessIndex { base, index: 1 }); 670 let root2 = expressions.append(Expression::AccessIndex { 671 base: root1, 672 index: 2, 673 }); 674 675 let mut solver = ConstantSolver { 676 types: &types, 677 expressions: &expressions, 678 constants: &mut constants, 679 }; 680 681 let res1 = solver.solve(root1).unwrap(); 682 let res2 = solver.solve(root2).unwrap(); 683 684 let res1_inner = &constants[res1].inner; 685 686 match res1_inner { 687 ConstantInner::Composite { ty, components } => { 688 assert_eq!(*ty, vec_ty); 689 let mut components_iter = components.iter().copied(); 690 assert_eq!( 691 constants[components_iter.next().unwrap()].inner, 692 ConstantInner::Scalar { 693 width: 4, 694 value: ScalarValue::Float(3.), 695 }, 696 ); 697 assert_eq!( 698 constants[components_iter.next().unwrap()].inner, 699 ConstantInner::Scalar { 700 width: 4, 701 value: ScalarValue::Float(4.), 702 }, 703 ); 704 assert_eq!( 705 constants[components_iter.next().unwrap()].inner, 706 ConstantInner::Scalar { 707 width: 4, 708 value: ScalarValue::Float(5.), 709 }, 710 ); 711 assert!(components_iter.next().is_none()); 712 } 713 _ => panic!("Expected vector"), 714 } 715 716 assert_eq!( 717 constants[res2].inner, 718 ConstantInner::Scalar { 719 width: 4, 720 value: ScalarValue::Float(5.), 721 }, 722 ); 723 } 724 } 725