1 use crate::{ 2 arena::{Arena, Handle, UniqueArena}, 3 BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner, 4 UnaryOperator, 5 }; 6 7 #[derive(Debug)] 8 pub struct ConstantSolver<'a> { 9 pub types: &'a mut UniqueArena<Type>, 10 pub expressions: &'a Arena<Expression>, 11 pub constants: &'a mut Arena<Constant>, 12 } 13 14 #[derive(Clone, Debug, PartialEq, thiserror::Error)] 15 pub enum ConstantSolvingError { 16 #[error("Constants cannot access function arguments")] 17 FunctionArg, 18 #[error("Constants cannot access global variables")] 19 GlobalVariable, 20 #[error("Constants cannot access local variables")] 21 LocalVariable, 22 #[error("Cannot get the array length of a non array type")] 23 InvalidArrayLengthArg, 24 #[error("Constants cannot get the array length of a dynamically sized array")] 25 ArrayLengthDynamic, 26 #[error("Constants cannot call functions")] 27 Call, 28 #[error("Constants don't support atomic functions")] 29 Atomic, 30 #[error("Constants don't support relational functions")] 31 Relational, 32 #[error("Constants don't support derivative functions")] 33 Derivative, 34 #[error("Constants don't support select expressions")] 35 Select, 36 #[error("Constants don't support load expressions")] 37 Load, 38 #[error("Constants don't support image expressions")] 39 ImageExpression, 40 #[error("Cannot access the type")] 41 InvalidAccessBase, 42 #[error("Cannot access at the index")] 43 InvalidAccessIndex, 44 #[error("Cannot access with index of type")] 45 InvalidAccessIndexTy, 46 #[error("Constants don't support bitcasts")] 47 Bitcast, 48 #[error("Cannot cast type")] 49 InvalidCastArg, 50 #[error("Cannot apply the unary op to the argument")] 51 InvalidUnaryOpArg, 52 #[error("Cannot apply the binary op to the arguments")] 53 InvalidBinaryOpArgs, 54 #[error("Cannot apply math function to type")] 55 InvalidMathArg, 56 #[error("Splat is defined only on scalar values")] 57 SplatScalarOnly, 58 #[error("Can only swizzle vector constants")] 59 SwizzleVectorOnly, 60 #[error("Not implemented: {0}")] 61 NotImplemented(String), 62 } 63 64 impl<'a> ConstantSolver<'a> { solve( &mut self, expr: Handle<Expression>, ) -> Result<Handle<Constant>, ConstantSolvingError>65 pub fn solve( 66 &mut self, 67 expr: Handle<Expression>, 68 ) -> Result<Handle<Constant>, ConstantSolvingError> { 69 let span = self.expressions.get_span(expr); 70 match self.expressions[expr] { 71 Expression::Constant(constant) => Ok(constant), 72 Expression::AccessIndex { base, index } => self.access(base, index as usize), 73 Expression::Access { base, index } => { 74 let index = self.solve(index)?; 75 76 self.access(base, self.constant_index(index)?) 77 } 78 Expression::Splat { 79 size, 80 value: splat_value, 81 } => { 82 let value_constant = self.solve(splat_value)?; 83 let ty = match self.constants[value_constant].inner { 84 ConstantInner::Scalar { ref value, width } => { 85 let kind = value.scalar_kind(); 86 self.types.insert( 87 Type { 88 name: None, 89 inner: TypeInner::Vector { size, kind, width }, 90 }, 91 span, 92 ) 93 } 94 ConstantInner::Composite { .. } => { 95 return Err(ConstantSolvingError::SplatScalarOnly); 96 } 97 }; 98 99 let inner = ConstantInner::Composite { 100 ty, 101 components: vec![value_constant; size as usize], 102 }; 103 Ok(self.register_constant(inner, span)) 104 } 105 Expression::Swizzle { 106 size, 107 vector: src_vector, 108 pattern, 109 } => { 110 let src_constant = self.solve(src_vector)?; 111 let (ty, src_components) = match self.constants[src_constant].inner { 112 ConstantInner::Scalar { .. } => { 113 return Err(ConstantSolvingError::SwizzleVectorOnly); 114 } 115 ConstantInner::Composite { 116 ty, 117 components: ref src_components, 118 } => match self.types[ty].inner { 119 crate::TypeInner::Vector { 120 size: _, 121 kind, 122 width, 123 } => { 124 let dst_ty = self.types.insert( 125 Type { 126 name: None, 127 inner: crate::TypeInner::Vector { size, kind, width }, 128 }, 129 span, 130 ); 131 (dst_ty, &src_components[..]) 132 } 133 _ => { 134 return Err(ConstantSolvingError::SwizzleVectorOnly); 135 } 136 }, 137 }; 138 139 let components = pattern 140 .iter() 141 .map(|&sc| src_components[sc as usize]) 142 .collect(); 143 let inner = ConstantInner::Composite { ty, components }; 144 145 Ok(self.register_constant(inner, span)) 146 } 147 Expression::Compose { ty, ref components } => { 148 let components = components 149 .iter() 150 .map(|c| self.solve(*c)) 151 .collect::<Result<_, _>>()?; 152 let inner = ConstantInner::Composite { ty, components }; 153 154 Ok(self.register_constant(inner, span)) 155 } 156 Expression::Unary { expr, op } => { 157 let expr_constant = self.solve(expr)?; 158 159 self.unary_op(op, expr_constant, span) 160 } 161 Expression::Binary { left, right, op } => { 162 let left_constant = self.solve(left)?; 163 let right_constant = self.solve(right)?; 164 165 self.binary_op(op, left_constant, right_constant, span) 166 } 167 Expression::Math { fun, arg, arg1, .. } => { 168 let arg = self.solve(arg)?; 169 let arg1 = arg1.map(|arg| self.solve(arg)).transpose()?; 170 171 let const0 = &self.constants[arg].inner; 172 let const1 = arg1.map(|arg| &self.constants[arg].inner); 173 174 match fun { 175 crate::MathFunction::Pow => { 176 let (value, width) = match (const0, const1.unwrap()) { 177 ( 178 &ConstantInner::Scalar { 179 width, 180 value: value0, 181 }, 182 &ConstantInner::Scalar { value: value1, .. }, 183 ) => ( 184 match (value0, value1) { 185 (ScalarValue::Sint(a), ScalarValue::Sint(b)) => { 186 ScalarValue::Sint(a.pow(b as u32)) 187 } 188 (ScalarValue::Uint(a), ScalarValue::Uint(b)) => { 189 ScalarValue::Uint(a.pow(b as u32)) 190 } 191 (ScalarValue::Float(a), ScalarValue::Float(b)) => { 192 ScalarValue::Float(a.powf(b)) 193 } 194 _ => return Err(ConstantSolvingError::InvalidMathArg), 195 }, 196 width, 197 ), 198 _ => return Err(ConstantSolvingError::InvalidMathArg), 199 }; 200 201 let inner = ConstantInner::Scalar { width, value }; 202 Ok(self.register_constant(inner, span)) 203 } 204 _ => Err(ConstantSolvingError::NotImplemented(format!("{:?}", fun))), 205 } 206 } 207 Expression::As { 208 convert, 209 expr, 210 kind, 211 } => { 212 let expr_constant = self.solve(expr)?; 213 214 match convert { 215 Some(width) => self.cast(expr_constant, kind, width, span), 216 None => Err(ConstantSolvingError::Bitcast), 217 } 218 } 219 Expression::ArrayLength(expr) => { 220 let array = self.solve(expr)?; 221 222 match self.constants[array].inner { 223 ConstantInner::Scalar { .. } => { 224 Err(ConstantSolvingError::InvalidArrayLengthArg) 225 } 226 ConstantInner::Composite { ty, .. } => match self.types[ty].inner { 227 TypeInner::Array { size, .. } => match size { 228 crate::ArraySize::Constant(constant) => Ok(constant), 229 crate::ArraySize::Dynamic => { 230 Err(ConstantSolvingError::ArrayLengthDynamic) 231 } 232 }, 233 _ => Err(ConstantSolvingError::InvalidArrayLengthArg), 234 }, 235 } 236 } 237 238 Expression::Load { .. } => Err(ConstantSolvingError::Load), 239 Expression::Select { .. } => Err(ConstantSolvingError::Select), 240 Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable), 241 Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative), 242 Expression::Relational { .. } => Err(ConstantSolvingError::Relational), 243 Expression::CallResult { .. } => Err(ConstantSolvingError::Call), 244 Expression::AtomicResult { .. } => Err(ConstantSolvingError::Atomic), 245 Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg), 246 Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable), 247 Expression::ImageSample { .. } 248 | Expression::ImageLoad { .. } 249 | Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression), 250 } 251 } 252 access( &mut self, base: Handle<Expression>, index: usize, ) -> Result<Handle<Constant>, ConstantSolvingError>253 fn access( 254 &mut self, 255 base: Handle<Expression>, 256 index: usize, 257 ) -> Result<Handle<Constant>, ConstantSolvingError> { 258 let base = self.solve(base)?; 259 260 match self.constants[base].inner { 261 ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase), 262 ConstantInner::Composite { ty, ref components } => { 263 match self.types[ty].inner { 264 TypeInner::Vector { .. } 265 | TypeInner::Matrix { .. } 266 | TypeInner::Array { .. } 267 | TypeInner::Struct { .. } => (), 268 _ => return Err(ConstantSolvingError::InvalidAccessBase), 269 } 270 271 components 272 .get(index) 273 .copied() 274 .ok_or(ConstantSolvingError::InvalidAccessIndex) 275 } 276 } 277 } 278 constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError>279 fn constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError> { 280 match self.constants[constant].inner { 281 ConstantInner::Scalar { 282 value: ScalarValue::Uint(index), 283 .. 284 } => Ok(index as usize), 285 _ => Err(ConstantSolvingError::InvalidAccessIndexTy), 286 } 287 } 288 cast( &mut self, constant: Handle<Constant>, kind: ScalarKind, target_width: crate::Bytes, span: crate::Span, ) -> Result<Handle<Constant>, ConstantSolvingError>289 fn cast( 290 &mut self, 291 constant: Handle<Constant>, 292 kind: ScalarKind, 293 target_width: crate::Bytes, 294 span: crate::Span, 295 ) -> Result<Handle<Constant>, ConstantSolvingError> { 296 let mut inner = self.constants[constant].inner.clone(); 297 298 match inner { 299 ConstantInner::Scalar { 300 ref mut value, 301 ref mut width, 302 } => { 303 *width = target_width; 304 *value = match kind { 305 ScalarKind::Sint => ScalarValue::Sint(match *value { 306 ScalarValue::Sint(v) => v, 307 ScalarValue::Uint(v) => v as i64, 308 ScalarValue::Float(v) => v as i64, 309 ScalarValue::Bool(v) => v as i64, 310 }), 311 ScalarKind::Uint => ScalarValue::Uint(match *value { 312 ScalarValue::Sint(v) => v as u64, 313 ScalarValue::Uint(v) => v, 314 ScalarValue::Float(v) => v as u64, 315 ScalarValue::Bool(v) => v as u64, 316 }), 317 ScalarKind::Float => ScalarValue::Float(match *value { 318 ScalarValue::Sint(v) => v as f64, 319 ScalarValue::Uint(v) => v as f64, 320 ScalarValue::Float(v) => v, 321 ScalarValue::Bool(v) => v as u64 as f64, 322 }), 323 ScalarKind::Bool => ScalarValue::Bool(match *value { 324 ScalarValue::Sint(v) => v != 0, 325 ScalarValue::Uint(v) => v != 0, 326 ScalarValue::Float(v) => v != 0.0, 327 ScalarValue::Bool(v) => v, 328 }), 329 } 330 } 331 ConstantInner::Composite { 332 ty, 333 ref mut components, 334 } => { 335 match self.types[ty].inner { 336 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (), 337 _ => return Err(ConstantSolvingError::InvalidCastArg), 338 } 339 340 for component in components { 341 *component = self.cast(*component, kind, target_width, span)?; 342 } 343 } 344 } 345 346 Ok(self.register_constant(inner, span)) 347 } 348 unary_op( &mut self, op: UnaryOperator, constant: Handle<Constant>, span: crate::Span, ) -> Result<Handle<Constant>, ConstantSolvingError>349 fn unary_op( 350 &mut self, 351 op: UnaryOperator, 352 constant: Handle<Constant>, 353 span: crate::Span, 354 ) -> Result<Handle<Constant>, ConstantSolvingError> { 355 let mut inner = self.constants[constant].inner.clone(); 356 357 match inner { 358 ConstantInner::Scalar { ref mut value, .. } => match op { 359 UnaryOperator::Negate => match *value { 360 ScalarValue::Sint(ref mut v) => *v = -*v, 361 ScalarValue::Float(ref mut v) => *v = -*v, 362 _ => return Err(ConstantSolvingError::InvalidUnaryOpArg), 363 }, 364 UnaryOperator::Not => match *value { 365 ScalarValue::Sint(ref mut v) => *v = !*v, 366 ScalarValue::Uint(ref mut v) => *v = !*v, 367 ScalarValue::Bool(ref mut v) => *v = !*v, 368 _ => return Err(ConstantSolvingError::InvalidUnaryOpArg), 369 }, 370 }, 371 ConstantInner::Composite { 372 ty, 373 ref mut components, 374 } => { 375 match self.types[ty].inner { 376 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (), 377 _ => return Err(ConstantSolvingError::InvalidCastArg), 378 } 379 380 for component in components { 381 *component = self.unary_op(op, *component, span)? 382 } 383 } 384 } 385 386 Ok(self.register_constant(inner, span)) 387 } 388 binary_op( &mut self, op: BinaryOperator, left: Handle<Constant>, right: Handle<Constant>, span: crate::Span, ) -> Result<Handle<Constant>, ConstantSolvingError>389 fn binary_op( 390 &mut self, 391 op: BinaryOperator, 392 left: Handle<Constant>, 393 right: Handle<Constant>, 394 span: crate::Span, 395 ) -> Result<Handle<Constant>, ConstantSolvingError> { 396 let left_inner = &self.constants[left].inner; 397 let right_inner = &self.constants[right].inner; 398 399 let inner = match (left_inner, right_inner) { 400 ( 401 &ConstantInner::Scalar { 402 value: left_value, 403 width, 404 }, 405 &ConstantInner::Scalar { 406 value: right_value, 407 width: _, 408 }, 409 ) => { 410 let value = match op { 411 BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value), 412 BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value), 413 BinaryOperator::Less => ScalarValue::Bool(left_value < right_value), 414 BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value), 415 BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value), 416 BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value), 417 418 _ => match (left_value, right_value) { 419 (ScalarValue::Sint(a), ScalarValue::Sint(b)) => { 420 ScalarValue::Sint(match op { 421 BinaryOperator::Add => a.wrapping_add(b), 422 BinaryOperator::Subtract => a.wrapping_sub(b), 423 BinaryOperator::Multiply => a.wrapping_mul(b), 424 BinaryOperator::Divide => a.checked_div(b).unwrap_or(0), 425 BinaryOperator::Modulo => a.checked_rem(b).unwrap_or(0), 426 BinaryOperator::And => a & b, 427 BinaryOperator::ExclusiveOr => a ^ b, 428 BinaryOperator::InclusiveOr => a | b, 429 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 430 }) 431 } 432 (ScalarValue::Sint(a), ScalarValue::Uint(b)) => { 433 ScalarValue::Sint(match op { 434 BinaryOperator::ShiftLeft => a.wrapping_shl(b as u32), 435 BinaryOperator::ShiftRight => a.wrapping_shr(b as u32), 436 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 437 }) 438 } 439 (ScalarValue::Uint(a), ScalarValue::Uint(b)) => { 440 ScalarValue::Uint(match op { 441 BinaryOperator::Add => a.wrapping_add(b), 442 BinaryOperator::Subtract => a.wrapping_sub(b), 443 BinaryOperator::Multiply => a.wrapping_mul(b), 444 BinaryOperator::Divide => a.checked_div(b).unwrap_or(0), 445 BinaryOperator::Modulo => a.checked_rem(b).unwrap_or(0), 446 BinaryOperator::And => a & b, 447 BinaryOperator::ExclusiveOr => a ^ b, 448 BinaryOperator::InclusiveOr => a | b, 449 BinaryOperator::ShiftLeft => a.wrapping_shl(b as u32), 450 BinaryOperator::ShiftRight => a.wrapping_shr(b as u32), 451 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 452 }) 453 } 454 (ScalarValue::Float(a), ScalarValue::Float(b)) => { 455 ScalarValue::Float(match op { 456 BinaryOperator::Add => a + b, 457 BinaryOperator::Subtract => a - b, 458 BinaryOperator::Multiply => a * b, 459 BinaryOperator::Divide => a / b, 460 BinaryOperator::Modulo => a % b, 461 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 462 }) 463 } 464 (ScalarValue::Bool(a), ScalarValue::Bool(b)) => { 465 ScalarValue::Bool(match op { 466 BinaryOperator::LogicalAnd => a && b, 467 BinaryOperator::LogicalOr => a || b, 468 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 469 }) 470 } 471 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 472 }, 473 }; 474 475 ConstantInner::Scalar { value, width } 476 } 477 (&ConstantInner::Composite { ref components, ty }, &ConstantInner::Scalar { .. }) => { 478 let mut components = components.clone(); 479 for comp in components.iter_mut() { 480 *comp = self.binary_op(op, *comp, right, span)?; 481 } 482 ConstantInner::Composite { ty, components } 483 } 484 (&ConstantInner::Scalar { .. }, &ConstantInner::Composite { ref components, ty }) => { 485 let mut components = components.clone(); 486 for comp in components.iter_mut() { 487 *comp = self.binary_op(op, left, *comp, span)?; 488 } 489 ConstantInner::Composite { ty, components } 490 } 491 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), 492 }; 493 494 Ok(self.register_constant(inner, span)) 495 } 496 register_constant(&mut self, inner: ConstantInner, span: crate::Span) -> Handle<Constant>497 fn register_constant(&mut self, inner: ConstantInner, span: crate::Span) -> Handle<Constant> { 498 self.constants.fetch_or_append( 499 Constant { 500 name: None, 501 specialization: None, 502 inner, 503 }, 504 span, 505 ) 506 } 507 } 508 509 #[cfg(test)] 510 mod tests { 511 use std::vec; 512 513 use crate::{ 514 Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner, 515 UnaryOperator, UniqueArena, VectorSize, 516 }; 517 518 use super::ConstantSolver; 519 520 #[test] unary_op()521 fn unary_op() { 522 let mut types = UniqueArena::new(); 523 let mut expressions = Arena::new(); 524 let mut constants = Arena::new(); 525 526 let vec_ty = types.insert( 527 Type { 528 name: None, 529 inner: TypeInner::Vector { 530 size: VectorSize::Bi, 531 kind: ScalarKind::Sint, 532 width: 4, 533 }, 534 }, 535 Default::default(), 536 ); 537 538 let h = constants.append( 539 Constant { 540 name: None, 541 specialization: None, 542 inner: ConstantInner::Scalar { 543 width: 4, 544 value: ScalarValue::Sint(4), 545 }, 546 }, 547 Default::default(), 548 ); 549 550 let h1 = constants.append( 551 Constant { 552 name: None, 553 specialization: None, 554 inner: ConstantInner::Scalar { 555 width: 4, 556 value: ScalarValue::Sint(8), 557 }, 558 }, 559 Default::default(), 560 ); 561 562 let vec_h = constants.append( 563 Constant { 564 name: None, 565 specialization: None, 566 inner: ConstantInner::Composite { 567 ty: vec_ty, 568 components: vec![h, h1], 569 }, 570 }, 571 Default::default(), 572 ); 573 574 let expr = expressions.append(Expression::Constant(h), Default::default()); 575 let expr1 = expressions.append(Expression::Constant(vec_h), Default::default()); 576 577 let root1 = expressions.append( 578 Expression::Unary { 579 op: UnaryOperator::Negate, 580 expr, 581 }, 582 Default::default(), 583 ); 584 585 let root2 = expressions.append( 586 Expression::Unary { 587 op: UnaryOperator::Not, 588 expr, 589 }, 590 Default::default(), 591 ); 592 593 let root3 = expressions.append( 594 Expression::Unary { 595 op: UnaryOperator::Not, 596 expr: expr1, 597 }, 598 Default::default(), 599 ); 600 601 let mut solver = ConstantSolver { 602 types: &mut types, 603 expressions: &expressions, 604 constants: &mut constants, 605 }; 606 607 let res1 = solver.solve(root1).unwrap(); 608 let res2 = solver.solve(root2).unwrap(); 609 let res3 = solver.solve(root3).unwrap(); 610 611 assert_eq!( 612 constants[res1].inner, 613 ConstantInner::Scalar { 614 width: 4, 615 value: ScalarValue::Sint(-4), 616 }, 617 ); 618 619 assert_eq!( 620 constants[res2].inner, 621 ConstantInner::Scalar { 622 width: 4, 623 value: ScalarValue::Sint(!4), 624 }, 625 ); 626 627 let res3_inner = &constants[res3].inner; 628 629 match res3_inner { 630 ConstantInner::Composite { ty, components } => { 631 assert_eq!(*ty, vec_ty); 632 let mut components_iter = components.iter().copied(); 633 assert_eq!( 634 constants[components_iter.next().unwrap()].inner, 635 ConstantInner::Scalar { 636 width: 4, 637 value: ScalarValue::Sint(!4), 638 }, 639 ); 640 assert_eq!( 641 constants[components_iter.next().unwrap()].inner, 642 ConstantInner::Scalar { 643 width: 4, 644 value: ScalarValue::Sint(!8), 645 }, 646 ); 647 assert!(components_iter.next().is_none()); 648 } 649 _ => panic!("Expected vector"), 650 } 651 } 652 653 #[test] cast()654 fn cast() { 655 let mut expressions = Arena::new(); 656 let mut constants = Arena::new(); 657 658 let h = constants.append( 659 Constant { 660 name: None, 661 specialization: None, 662 inner: ConstantInner::Scalar { 663 width: 4, 664 value: ScalarValue::Sint(4), 665 }, 666 }, 667 Default::default(), 668 ); 669 670 let expr = expressions.append(Expression::Constant(h), Default::default()); 671 672 let root = expressions.append( 673 Expression::As { 674 expr, 675 kind: ScalarKind::Bool, 676 convert: Some(crate::BOOL_WIDTH), 677 }, 678 Default::default(), 679 ); 680 681 let mut solver = ConstantSolver { 682 types: &mut UniqueArena::new(), 683 expressions: &expressions, 684 constants: &mut constants, 685 }; 686 687 let res = solver.solve(root).unwrap(); 688 689 assert_eq!( 690 constants[res].inner, 691 ConstantInner::Scalar { 692 width: crate::BOOL_WIDTH, 693 value: ScalarValue::Bool(true), 694 }, 695 ); 696 } 697 698 #[test] access()699 fn access() { 700 let mut types = UniqueArena::new(); 701 let mut expressions = Arena::new(); 702 let mut constants = Arena::new(); 703 704 let matrix_ty = types.insert( 705 Type { 706 name: None, 707 inner: TypeInner::Matrix { 708 columns: VectorSize::Bi, 709 rows: VectorSize::Tri, 710 width: 4, 711 }, 712 }, 713 Default::default(), 714 ); 715 716 let vec_ty = types.insert( 717 Type { 718 name: None, 719 inner: TypeInner::Vector { 720 size: VectorSize::Tri, 721 kind: ScalarKind::Float, 722 width: 4, 723 }, 724 }, 725 Default::default(), 726 ); 727 728 let mut vec1_components = Vec::with_capacity(3); 729 let mut vec2_components = Vec::with_capacity(3); 730 731 for i in 0..3 { 732 let h = constants.append( 733 Constant { 734 name: None, 735 specialization: None, 736 inner: ConstantInner::Scalar { 737 width: 4, 738 value: ScalarValue::Float(i as f64), 739 }, 740 }, 741 Default::default(), 742 ); 743 744 vec1_components.push(h) 745 } 746 747 for i in 3..6 { 748 let h = constants.append( 749 Constant { 750 name: None, 751 specialization: None, 752 inner: ConstantInner::Scalar { 753 width: 4, 754 value: ScalarValue::Float(i as f64), 755 }, 756 }, 757 Default::default(), 758 ); 759 760 vec2_components.push(h) 761 } 762 763 let vec1 = constants.append( 764 Constant { 765 name: None, 766 specialization: None, 767 inner: ConstantInner::Composite { 768 ty: vec_ty, 769 components: vec1_components, 770 }, 771 }, 772 Default::default(), 773 ); 774 775 let vec2 = constants.append( 776 Constant { 777 name: None, 778 specialization: None, 779 inner: ConstantInner::Composite { 780 ty: vec_ty, 781 components: vec2_components, 782 }, 783 }, 784 Default::default(), 785 ); 786 787 let h = constants.append( 788 Constant { 789 name: None, 790 specialization: None, 791 inner: ConstantInner::Composite { 792 ty: matrix_ty, 793 components: vec![vec1, vec2], 794 }, 795 }, 796 Default::default(), 797 ); 798 799 let base = expressions.append(Expression::Constant(h), Default::default()); 800 let root1 = expressions.append( 801 Expression::AccessIndex { base, index: 1 }, 802 Default::default(), 803 ); 804 let root2 = expressions.append( 805 Expression::AccessIndex { 806 base: root1, 807 index: 2, 808 }, 809 Default::default(), 810 ); 811 812 let mut solver = ConstantSolver { 813 types: &mut types, 814 expressions: &expressions, 815 constants: &mut constants, 816 }; 817 818 let res1 = solver.solve(root1).unwrap(); 819 let res2 = solver.solve(root2).unwrap(); 820 821 let res1_inner = &constants[res1].inner; 822 823 match res1_inner { 824 ConstantInner::Composite { ty, components } => { 825 assert_eq!(*ty, vec_ty); 826 let mut components_iter = components.iter().copied(); 827 assert_eq!( 828 constants[components_iter.next().unwrap()].inner, 829 ConstantInner::Scalar { 830 width: 4, 831 value: ScalarValue::Float(3.), 832 }, 833 ); 834 assert_eq!( 835 constants[components_iter.next().unwrap()].inner, 836 ConstantInner::Scalar { 837 width: 4, 838 value: ScalarValue::Float(4.), 839 }, 840 ); 841 assert_eq!( 842 constants[components_iter.next().unwrap()].inner, 843 ConstantInner::Scalar { 844 width: 4, 845 value: ScalarValue::Float(5.), 846 }, 847 ); 848 assert!(components_iter.next().is_none()); 849 } 850 _ => panic!("Expected vector"), 851 } 852 853 assert_eq!( 854 constants[res2].inner, 855 ConstantInner::Scalar { 856 width: 4, 857 value: ScalarValue::Float(5.), 858 }, 859 ); 860 } 861 } 862