1 use crate::{
2     arena::{Arena, Handle, UniqueArena},
3     BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
4     UnaryOperator,
5 };
6 
7 #[derive(Debug)]
8 pub struct ConstantSolver<'a> {
9     pub types: &'a mut UniqueArena<Type>,
10     pub expressions: &'a Arena<Expression>,
11     pub constants: &'a mut Arena<Constant>,
12 }
13 
14 #[derive(Clone, Debug, PartialEq, thiserror::Error)]
15 pub enum ConstantSolvingError {
16     #[error("Constants cannot access function arguments")]
17     FunctionArg,
18     #[error("Constants cannot access global variables")]
19     GlobalVariable,
20     #[error("Constants cannot access local variables")]
21     LocalVariable,
22     #[error("Cannot get the array length of a non array type")]
23     InvalidArrayLengthArg,
24     #[error("Constants cannot get the array length of a dynamically sized array")]
25     ArrayLengthDynamic,
26     #[error("Constants cannot call functions")]
27     Call,
28     #[error("Constants don't support atomic functions")]
29     Atomic,
30     #[error("Constants don't support relational functions")]
31     Relational,
32     #[error("Constants don't support derivative functions")]
33     Derivative,
34     #[error("Constants don't support select expressions")]
35     Select,
36     #[error("Constants don't support load expressions")]
37     Load,
38     #[error("Constants don't support image expressions")]
39     ImageExpression,
40     #[error("Cannot access the type")]
41     InvalidAccessBase,
42     #[error("Cannot access at the index")]
43     InvalidAccessIndex,
44     #[error("Cannot access with index of type")]
45     InvalidAccessIndexTy,
46     #[error("Constants don't support bitcasts")]
47     Bitcast,
48     #[error("Cannot cast type")]
49     InvalidCastArg,
50     #[error("Cannot apply the unary op to the argument")]
51     InvalidUnaryOpArg,
52     #[error("Cannot apply the binary op to the arguments")]
53     InvalidBinaryOpArgs,
54     #[error("Cannot apply math function to type")]
55     InvalidMathArg,
56     #[error("Splat is defined only on scalar values")]
57     SplatScalarOnly,
58     #[error("Can only swizzle vector constants")]
59     SwizzleVectorOnly,
60     #[error("Not implemented: {0}")]
61     NotImplemented(String),
62 }
63 
64 impl<'a> ConstantSolver<'a> {
solve( &mut self, expr: Handle<Expression>, ) -> Result<Handle<Constant>, ConstantSolvingError>65     pub fn solve(
66         &mut self,
67         expr: Handle<Expression>,
68     ) -> Result<Handle<Constant>, ConstantSolvingError> {
69         let span = self.expressions.get_span(expr);
70         match self.expressions[expr] {
71             Expression::Constant(constant) => Ok(constant),
72             Expression::AccessIndex { base, index } => self.access(base, index as usize),
73             Expression::Access { base, index } => {
74                 let index = self.solve(index)?;
75 
76                 self.access(base, self.constant_index(index)?)
77             }
78             Expression::Splat {
79                 size,
80                 value: splat_value,
81             } => {
82                 let value_constant = self.solve(splat_value)?;
83                 let ty = match self.constants[value_constant].inner {
84                     ConstantInner::Scalar { ref value, width } => {
85                         let kind = value.scalar_kind();
86                         self.types.insert(
87                             Type {
88                                 name: None,
89                                 inner: TypeInner::Vector { size, kind, width },
90                             },
91                             span,
92                         )
93                     }
94                     ConstantInner::Composite { .. } => {
95                         return Err(ConstantSolvingError::SplatScalarOnly);
96                     }
97                 };
98 
99                 let inner = ConstantInner::Composite {
100                     ty,
101                     components: vec![value_constant; size as usize],
102                 };
103                 Ok(self.register_constant(inner, span))
104             }
105             Expression::Swizzle {
106                 size,
107                 vector: src_vector,
108                 pattern,
109             } => {
110                 let src_constant = self.solve(src_vector)?;
111                 let (ty, src_components) = match self.constants[src_constant].inner {
112                     ConstantInner::Scalar { .. } => {
113                         return Err(ConstantSolvingError::SwizzleVectorOnly);
114                     }
115                     ConstantInner::Composite {
116                         ty,
117                         components: ref src_components,
118                     } => match self.types[ty].inner {
119                         crate::TypeInner::Vector {
120                             size: _,
121                             kind,
122                             width,
123                         } => {
124                             let dst_ty = self.types.insert(
125                                 Type {
126                                     name: None,
127                                     inner: crate::TypeInner::Vector { size, kind, width },
128                                 },
129                                 span,
130                             );
131                             (dst_ty, &src_components[..])
132                         }
133                         _ => {
134                             return Err(ConstantSolvingError::SwizzleVectorOnly);
135                         }
136                     },
137                 };
138 
139                 let components = pattern
140                     .iter()
141                     .map(|&sc| src_components[sc as usize])
142                     .collect();
143                 let inner = ConstantInner::Composite { ty, components };
144 
145                 Ok(self.register_constant(inner, span))
146             }
147             Expression::Compose { ty, ref components } => {
148                 let components = components
149                     .iter()
150                     .map(|c| self.solve(*c))
151                     .collect::<Result<_, _>>()?;
152                 let inner = ConstantInner::Composite { ty, components };
153 
154                 Ok(self.register_constant(inner, span))
155             }
156             Expression::Unary { expr, op } => {
157                 let expr_constant = self.solve(expr)?;
158 
159                 self.unary_op(op, expr_constant, span)
160             }
161             Expression::Binary { left, right, op } => {
162                 let left_constant = self.solve(left)?;
163                 let right_constant = self.solve(right)?;
164 
165                 self.binary_op(op, left_constant, right_constant, span)
166             }
167             Expression::Math { fun, arg, arg1, .. } => {
168                 let arg = self.solve(arg)?;
169                 let arg1 = arg1.map(|arg| self.solve(arg)).transpose()?;
170 
171                 let const0 = &self.constants[arg].inner;
172                 let const1 = arg1.map(|arg| &self.constants[arg].inner);
173 
174                 match fun {
175                     crate::MathFunction::Pow => {
176                         let (value, width) = match (const0, const1.unwrap()) {
177                             (
178                                 &ConstantInner::Scalar {
179                                     width,
180                                     value: value0,
181                                 },
182                                 &ConstantInner::Scalar { value: value1, .. },
183                             ) => (
184                                 match (value0, value1) {
185                                     (ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
186                                         ScalarValue::Sint(a.pow(b as u32))
187                                     }
188                                     (ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
189                                         ScalarValue::Uint(a.pow(b as u32))
190                                     }
191                                     (ScalarValue::Float(a), ScalarValue::Float(b)) => {
192                                         ScalarValue::Float(a.powf(b))
193                                     }
194                                     _ => return Err(ConstantSolvingError::InvalidMathArg),
195                                 },
196                                 width,
197                             ),
198                             _ => return Err(ConstantSolvingError::InvalidMathArg),
199                         };
200 
201                         let inner = ConstantInner::Scalar { width, value };
202                         Ok(self.register_constant(inner, span))
203                     }
204                     _ => Err(ConstantSolvingError::NotImplemented(format!("{:?}", fun))),
205                 }
206             }
207             Expression::As {
208                 convert,
209                 expr,
210                 kind,
211             } => {
212                 let expr_constant = self.solve(expr)?;
213 
214                 match convert {
215                     Some(width) => self.cast(expr_constant, kind, width, span),
216                     None => Err(ConstantSolvingError::Bitcast),
217                 }
218             }
219             Expression::ArrayLength(expr) => {
220                 let array = self.solve(expr)?;
221 
222                 match self.constants[array].inner {
223                     ConstantInner::Scalar { .. } => {
224                         Err(ConstantSolvingError::InvalidArrayLengthArg)
225                     }
226                     ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
227                         TypeInner::Array { size, .. } => match size {
228                             crate::ArraySize::Constant(constant) => Ok(constant),
229                             crate::ArraySize::Dynamic => {
230                                 Err(ConstantSolvingError::ArrayLengthDynamic)
231                             }
232                         },
233                         _ => Err(ConstantSolvingError::InvalidArrayLengthArg),
234                     },
235                 }
236             }
237 
238             Expression::Load { .. } => Err(ConstantSolvingError::Load),
239             Expression::Select { .. } => Err(ConstantSolvingError::Select),
240             Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable),
241             Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative),
242             Expression::Relational { .. } => Err(ConstantSolvingError::Relational),
243             Expression::CallResult { .. } => Err(ConstantSolvingError::Call),
244             Expression::AtomicResult { .. } => Err(ConstantSolvingError::Atomic),
245             Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg),
246             Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable),
247             Expression::ImageSample { .. }
248             | Expression::ImageLoad { .. }
249             | Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression),
250         }
251     }
252 
access( &mut self, base: Handle<Expression>, index: usize, ) -> Result<Handle<Constant>, ConstantSolvingError>253     fn access(
254         &mut self,
255         base: Handle<Expression>,
256         index: usize,
257     ) -> Result<Handle<Constant>, ConstantSolvingError> {
258         let base = self.solve(base)?;
259 
260         match self.constants[base].inner {
261             ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
262             ConstantInner::Composite { ty, ref components } => {
263                 match self.types[ty].inner {
264                     TypeInner::Vector { .. }
265                     | TypeInner::Matrix { .. }
266                     | TypeInner::Array { .. }
267                     | TypeInner::Struct { .. } => (),
268                     _ => return Err(ConstantSolvingError::InvalidAccessBase),
269                 }
270 
271                 components
272                     .get(index)
273                     .copied()
274                     .ok_or(ConstantSolvingError::InvalidAccessIndex)
275             }
276         }
277     }
278 
constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError>279     fn constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError> {
280         match self.constants[constant].inner {
281             ConstantInner::Scalar {
282                 value: ScalarValue::Uint(index),
283                 ..
284             } => Ok(index as usize),
285             _ => Err(ConstantSolvingError::InvalidAccessIndexTy),
286         }
287     }
288 
cast( &mut self, constant: Handle<Constant>, kind: ScalarKind, target_width: crate::Bytes, span: crate::Span, ) -> Result<Handle<Constant>, ConstantSolvingError>289     fn cast(
290         &mut self,
291         constant: Handle<Constant>,
292         kind: ScalarKind,
293         target_width: crate::Bytes,
294         span: crate::Span,
295     ) -> Result<Handle<Constant>, ConstantSolvingError> {
296         let mut inner = self.constants[constant].inner.clone();
297 
298         match inner {
299             ConstantInner::Scalar {
300                 ref mut value,
301                 ref mut width,
302             } => {
303                 *width = target_width;
304                 *value = match kind {
305                     ScalarKind::Sint => ScalarValue::Sint(match *value {
306                         ScalarValue::Sint(v) => v,
307                         ScalarValue::Uint(v) => v as i64,
308                         ScalarValue::Float(v) => v as i64,
309                         ScalarValue::Bool(v) => v as i64,
310                     }),
311                     ScalarKind::Uint => ScalarValue::Uint(match *value {
312                         ScalarValue::Sint(v) => v as u64,
313                         ScalarValue::Uint(v) => v,
314                         ScalarValue::Float(v) => v as u64,
315                         ScalarValue::Bool(v) => v as u64,
316                     }),
317                     ScalarKind::Float => ScalarValue::Float(match *value {
318                         ScalarValue::Sint(v) => v as f64,
319                         ScalarValue::Uint(v) => v as f64,
320                         ScalarValue::Float(v) => v,
321                         ScalarValue::Bool(v) => v as u64 as f64,
322                     }),
323                     ScalarKind::Bool => ScalarValue::Bool(match *value {
324                         ScalarValue::Sint(v) => v != 0,
325                         ScalarValue::Uint(v) => v != 0,
326                         ScalarValue::Float(v) => v != 0.0,
327                         ScalarValue::Bool(v) => v,
328                     }),
329                 }
330             }
331             ConstantInner::Composite {
332                 ty,
333                 ref mut components,
334             } => {
335                 match self.types[ty].inner {
336                     TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
337                     _ => return Err(ConstantSolvingError::InvalidCastArg),
338                 }
339 
340                 for component in components {
341                     *component = self.cast(*component, kind, target_width, span)?;
342                 }
343             }
344         }
345 
346         Ok(self.register_constant(inner, span))
347     }
348 
unary_op( &mut self, op: UnaryOperator, constant: Handle<Constant>, span: crate::Span, ) -> Result<Handle<Constant>, ConstantSolvingError>349     fn unary_op(
350         &mut self,
351         op: UnaryOperator,
352         constant: Handle<Constant>,
353         span: crate::Span,
354     ) -> Result<Handle<Constant>, ConstantSolvingError> {
355         let mut inner = self.constants[constant].inner.clone();
356 
357         match inner {
358             ConstantInner::Scalar { ref mut value, .. } => match op {
359                 UnaryOperator::Negate => match *value {
360                     ScalarValue::Sint(ref mut v) => *v = -*v,
361                     ScalarValue::Float(ref mut v) => *v = -*v,
362                     _ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
363                 },
364                 UnaryOperator::Not => match *value {
365                     ScalarValue::Sint(ref mut v) => *v = !*v,
366                     ScalarValue::Uint(ref mut v) => *v = !*v,
367                     ScalarValue::Bool(ref mut v) => *v = !*v,
368                     _ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
369                 },
370             },
371             ConstantInner::Composite {
372                 ty,
373                 ref mut components,
374             } => {
375                 match self.types[ty].inner {
376                     TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
377                     _ => return Err(ConstantSolvingError::InvalidCastArg),
378                 }
379 
380                 for component in components {
381                     *component = self.unary_op(op, *component, span)?
382                 }
383             }
384         }
385 
386         Ok(self.register_constant(inner, span))
387     }
388 
binary_op( &mut self, op: BinaryOperator, left: Handle<Constant>, right: Handle<Constant>, span: crate::Span, ) -> Result<Handle<Constant>, ConstantSolvingError>389     fn binary_op(
390         &mut self,
391         op: BinaryOperator,
392         left: Handle<Constant>,
393         right: Handle<Constant>,
394         span: crate::Span,
395     ) -> Result<Handle<Constant>, ConstantSolvingError> {
396         let left_inner = &self.constants[left].inner;
397         let right_inner = &self.constants[right].inner;
398 
399         let inner = match (left_inner, right_inner) {
400             (
401                 &ConstantInner::Scalar {
402                     value: left_value,
403                     width,
404                 },
405                 &ConstantInner::Scalar {
406                     value: right_value,
407                     width: _,
408                 },
409             ) => {
410                 let value = match op {
411                     BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value),
412                     BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value),
413                     BinaryOperator::Less => ScalarValue::Bool(left_value < right_value),
414                     BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value),
415                     BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value),
416                     BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value),
417 
418                     _ => match (left_value, right_value) {
419                         (ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
420                             ScalarValue::Sint(match op {
421                                 BinaryOperator::Add => a.wrapping_add(b),
422                                 BinaryOperator::Subtract => a.wrapping_sub(b),
423                                 BinaryOperator::Multiply => a.wrapping_mul(b),
424                                 BinaryOperator::Divide => a.checked_div(b).unwrap_or(0),
425                                 BinaryOperator::Modulo => a.checked_rem(b).unwrap_or(0),
426                                 BinaryOperator::And => a & b,
427                                 BinaryOperator::ExclusiveOr => a ^ b,
428                                 BinaryOperator::InclusiveOr => a | b,
429                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
430                             })
431                         }
432                         (ScalarValue::Sint(a), ScalarValue::Uint(b)) => {
433                             ScalarValue::Sint(match op {
434                                 BinaryOperator::ShiftLeft => a.wrapping_shl(b as u32),
435                                 BinaryOperator::ShiftRight => a.wrapping_shr(b as u32),
436                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
437                             })
438                         }
439                         (ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
440                             ScalarValue::Uint(match op {
441                                 BinaryOperator::Add => a.wrapping_add(b),
442                                 BinaryOperator::Subtract => a.wrapping_sub(b),
443                                 BinaryOperator::Multiply => a.wrapping_mul(b),
444                                 BinaryOperator::Divide => a.checked_div(b).unwrap_or(0),
445                                 BinaryOperator::Modulo => a.checked_rem(b).unwrap_or(0),
446                                 BinaryOperator::And => a & b,
447                                 BinaryOperator::ExclusiveOr => a ^ b,
448                                 BinaryOperator::InclusiveOr => a | b,
449                                 BinaryOperator::ShiftLeft => a.wrapping_shl(b as u32),
450                                 BinaryOperator::ShiftRight => a.wrapping_shr(b as u32),
451                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
452                             })
453                         }
454                         (ScalarValue::Float(a), ScalarValue::Float(b)) => {
455                             ScalarValue::Float(match op {
456                                 BinaryOperator::Add => a + b,
457                                 BinaryOperator::Subtract => a - b,
458                                 BinaryOperator::Multiply => a * b,
459                                 BinaryOperator::Divide => a / b,
460                                 BinaryOperator::Modulo => a % b,
461                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
462                             })
463                         }
464                         (ScalarValue::Bool(a), ScalarValue::Bool(b)) => {
465                             ScalarValue::Bool(match op {
466                                 BinaryOperator::LogicalAnd => a && b,
467                                 BinaryOperator::LogicalOr => a || b,
468                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
469                             })
470                         }
471                         _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
472                     },
473                 };
474 
475                 ConstantInner::Scalar { value, width }
476             }
477             (&ConstantInner::Composite { ref components, ty }, &ConstantInner::Scalar { .. }) => {
478                 let mut components = components.clone();
479                 for comp in components.iter_mut() {
480                     *comp = self.binary_op(op, *comp, right, span)?;
481                 }
482                 ConstantInner::Composite { ty, components }
483             }
484             (&ConstantInner::Scalar { .. }, &ConstantInner::Composite { ref components, ty }) => {
485                 let mut components = components.clone();
486                 for comp in components.iter_mut() {
487                     *comp = self.binary_op(op, left, *comp, span)?;
488                 }
489                 ConstantInner::Composite { ty, components }
490             }
491             _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
492         };
493 
494         Ok(self.register_constant(inner, span))
495     }
496 
register_constant(&mut self, inner: ConstantInner, span: crate::Span) -> Handle<Constant>497     fn register_constant(&mut self, inner: ConstantInner, span: crate::Span) -> Handle<Constant> {
498         self.constants.fetch_or_append(
499             Constant {
500                 name: None,
501                 specialization: None,
502                 inner,
503             },
504             span,
505         )
506     }
507 }
508 
509 #[cfg(test)]
510 mod tests {
511     use std::vec;
512 
513     use crate::{
514         Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
515         UnaryOperator, UniqueArena, VectorSize,
516     };
517 
518     use super::ConstantSolver;
519 
520     #[test]
unary_op()521     fn unary_op() {
522         let mut types = UniqueArena::new();
523         let mut expressions = Arena::new();
524         let mut constants = Arena::new();
525 
526         let vec_ty = types.insert(
527             Type {
528                 name: None,
529                 inner: TypeInner::Vector {
530                     size: VectorSize::Bi,
531                     kind: ScalarKind::Sint,
532                     width: 4,
533                 },
534             },
535             Default::default(),
536         );
537 
538         let h = constants.append(
539             Constant {
540                 name: None,
541                 specialization: None,
542                 inner: ConstantInner::Scalar {
543                     width: 4,
544                     value: ScalarValue::Sint(4),
545                 },
546             },
547             Default::default(),
548         );
549 
550         let h1 = constants.append(
551             Constant {
552                 name: None,
553                 specialization: None,
554                 inner: ConstantInner::Scalar {
555                     width: 4,
556                     value: ScalarValue::Sint(8),
557                 },
558             },
559             Default::default(),
560         );
561 
562         let vec_h = constants.append(
563             Constant {
564                 name: None,
565                 specialization: None,
566                 inner: ConstantInner::Composite {
567                     ty: vec_ty,
568                     components: vec![h, h1],
569                 },
570             },
571             Default::default(),
572         );
573 
574         let expr = expressions.append(Expression::Constant(h), Default::default());
575         let expr1 = expressions.append(Expression::Constant(vec_h), Default::default());
576 
577         let root1 = expressions.append(
578             Expression::Unary {
579                 op: UnaryOperator::Negate,
580                 expr,
581             },
582             Default::default(),
583         );
584 
585         let root2 = expressions.append(
586             Expression::Unary {
587                 op: UnaryOperator::Not,
588                 expr,
589             },
590             Default::default(),
591         );
592 
593         let root3 = expressions.append(
594             Expression::Unary {
595                 op: UnaryOperator::Not,
596                 expr: expr1,
597             },
598             Default::default(),
599         );
600 
601         let mut solver = ConstantSolver {
602             types: &mut types,
603             expressions: &expressions,
604             constants: &mut constants,
605         };
606 
607         let res1 = solver.solve(root1).unwrap();
608         let res2 = solver.solve(root2).unwrap();
609         let res3 = solver.solve(root3).unwrap();
610 
611         assert_eq!(
612             constants[res1].inner,
613             ConstantInner::Scalar {
614                 width: 4,
615                 value: ScalarValue::Sint(-4),
616             },
617         );
618 
619         assert_eq!(
620             constants[res2].inner,
621             ConstantInner::Scalar {
622                 width: 4,
623                 value: ScalarValue::Sint(!4),
624             },
625         );
626 
627         let res3_inner = &constants[res3].inner;
628 
629         match res3_inner {
630             ConstantInner::Composite { ty, components } => {
631                 assert_eq!(*ty, vec_ty);
632                 let mut components_iter = components.iter().copied();
633                 assert_eq!(
634                     constants[components_iter.next().unwrap()].inner,
635                     ConstantInner::Scalar {
636                         width: 4,
637                         value: ScalarValue::Sint(!4),
638                     },
639                 );
640                 assert_eq!(
641                     constants[components_iter.next().unwrap()].inner,
642                     ConstantInner::Scalar {
643                         width: 4,
644                         value: ScalarValue::Sint(!8),
645                     },
646                 );
647                 assert!(components_iter.next().is_none());
648             }
649             _ => panic!("Expected vector"),
650         }
651     }
652 
653     #[test]
cast()654     fn cast() {
655         let mut expressions = Arena::new();
656         let mut constants = Arena::new();
657 
658         let h = constants.append(
659             Constant {
660                 name: None,
661                 specialization: None,
662                 inner: ConstantInner::Scalar {
663                     width: 4,
664                     value: ScalarValue::Sint(4),
665                 },
666             },
667             Default::default(),
668         );
669 
670         let expr = expressions.append(Expression::Constant(h), Default::default());
671 
672         let root = expressions.append(
673             Expression::As {
674                 expr,
675                 kind: ScalarKind::Bool,
676                 convert: Some(crate::BOOL_WIDTH),
677             },
678             Default::default(),
679         );
680 
681         let mut solver = ConstantSolver {
682             types: &mut UniqueArena::new(),
683             expressions: &expressions,
684             constants: &mut constants,
685         };
686 
687         let res = solver.solve(root).unwrap();
688 
689         assert_eq!(
690             constants[res].inner,
691             ConstantInner::Scalar {
692                 width: crate::BOOL_WIDTH,
693                 value: ScalarValue::Bool(true),
694             },
695         );
696     }
697 
698     #[test]
access()699     fn access() {
700         let mut types = UniqueArena::new();
701         let mut expressions = Arena::new();
702         let mut constants = Arena::new();
703 
704         let matrix_ty = types.insert(
705             Type {
706                 name: None,
707                 inner: TypeInner::Matrix {
708                     columns: VectorSize::Bi,
709                     rows: VectorSize::Tri,
710                     width: 4,
711                 },
712             },
713             Default::default(),
714         );
715 
716         let vec_ty = types.insert(
717             Type {
718                 name: None,
719                 inner: TypeInner::Vector {
720                     size: VectorSize::Tri,
721                     kind: ScalarKind::Float,
722                     width: 4,
723                 },
724             },
725             Default::default(),
726         );
727 
728         let mut vec1_components = Vec::with_capacity(3);
729         let mut vec2_components = Vec::with_capacity(3);
730 
731         for i in 0..3 {
732             let h = constants.append(
733                 Constant {
734                     name: None,
735                     specialization: None,
736                     inner: ConstantInner::Scalar {
737                         width: 4,
738                         value: ScalarValue::Float(i as f64),
739                     },
740                 },
741                 Default::default(),
742             );
743 
744             vec1_components.push(h)
745         }
746 
747         for i in 3..6 {
748             let h = constants.append(
749                 Constant {
750                     name: None,
751                     specialization: None,
752                     inner: ConstantInner::Scalar {
753                         width: 4,
754                         value: ScalarValue::Float(i as f64),
755                     },
756                 },
757                 Default::default(),
758             );
759 
760             vec2_components.push(h)
761         }
762 
763         let vec1 = constants.append(
764             Constant {
765                 name: None,
766                 specialization: None,
767                 inner: ConstantInner::Composite {
768                     ty: vec_ty,
769                     components: vec1_components,
770                 },
771             },
772             Default::default(),
773         );
774 
775         let vec2 = constants.append(
776             Constant {
777                 name: None,
778                 specialization: None,
779                 inner: ConstantInner::Composite {
780                     ty: vec_ty,
781                     components: vec2_components,
782                 },
783             },
784             Default::default(),
785         );
786 
787         let h = constants.append(
788             Constant {
789                 name: None,
790                 specialization: None,
791                 inner: ConstantInner::Composite {
792                     ty: matrix_ty,
793                     components: vec![vec1, vec2],
794                 },
795             },
796             Default::default(),
797         );
798 
799         let base = expressions.append(Expression::Constant(h), Default::default());
800         let root1 = expressions.append(
801             Expression::AccessIndex { base, index: 1 },
802             Default::default(),
803         );
804         let root2 = expressions.append(
805             Expression::AccessIndex {
806                 base: root1,
807                 index: 2,
808             },
809             Default::default(),
810         );
811 
812         let mut solver = ConstantSolver {
813             types: &mut types,
814             expressions: &expressions,
815             constants: &mut constants,
816         };
817 
818         let res1 = solver.solve(root1).unwrap();
819         let res2 = solver.solve(root2).unwrap();
820 
821         let res1_inner = &constants[res1].inner;
822 
823         match res1_inner {
824             ConstantInner::Composite { ty, components } => {
825                 assert_eq!(*ty, vec_ty);
826                 let mut components_iter = components.iter().copied();
827                 assert_eq!(
828                     constants[components_iter.next().unwrap()].inner,
829                     ConstantInner::Scalar {
830                         width: 4,
831                         value: ScalarValue::Float(3.),
832                     },
833                 );
834                 assert_eq!(
835                     constants[components_iter.next().unwrap()].inner,
836                     ConstantInner::Scalar {
837                         width: 4,
838                         value: ScalarValue::Float(4.),
839                     },
840                 );
841                 assert_eq!(
842                     constants[components_iter.next().unwrap()].inner,
843                     ConstantInner::Scalar {
844                         width: 4,
845                         value: ScalarValue::Float(5.),
846                     },
847                 );
848                 assert!(components_iter.next().is_none());
849             }
850             _ => panic!("Expected vector"),
851         }
852 
853         assert_eq!(
854             constants[res2].inner,
855             ConstantInner::Scalar {
856                 width: 4,
857                 value: ScalarValue::Float(5.),
858             },
859         );
860     }
861 }
862