1 use crate::{
2     arena::{Arena, Handle},
3     BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
4     UnaryOperator,
5 };
6 
7 #[derive(Debug)]
8 pub struct ConstantSolver<'a> {
9     pub types: &'a Arena<Type>,
10     pub expressions: &'a Arena<Expression>,
11     pub constants: &'a mut Arena<Constant>,
12 }
13 
14 #[derive(Clone, Debug, PartialEq, thiserror::Error)]
15 pub enum ConstantSolvingError {
16     #[error("Constants cannot access function arguments")]
17     FunctionArg,
18     #[error("Constants cannot access global variables")]
19     GlobalVariable,
20     #[error("Constants cannot access local variables")]
21     LocalVariable,
22     #[error("Cannot get the array length of a non array type")]
23     InvalidArrayLengthArg,
24     #[error("Constants cannot get the array length of a dynamically sized array")]
25     ArrayLengthDynamic,
26     #[error("Constants cannot call functions")]
27     Call,
28     #[error("Constants don't support relational functions")]
29     Relational,
30     #[error("Constants don't support derivative functions")]
31     Derivative,
32     #[error("Constants don't support select expressions")]
33     Select,
34     #[error("Constants don't support load expressions")]
35     Load,
36     #[error("Constants don't support image expressions")]
37     ImageExpression,
38     #[error("Cannot access the type")]
39     InvalidAccessBase,
40     #[error("Cannot access at the index")]
41     InvalidAccessIndex,
42     #[error("Cannot access with index of type")]
43     InvalidAccessIndexTy,
44     #[error("Constants don't support bitcasts")]
45     Bitcast,
46     #[error("Cannot cast type")]
47     InvalidCastArg,
48     #[error("Cannot apply the unary op to the argument")]
49     InvalidUnaryOpArg,
50     #[error("Cannot apply the binary op to the arguments")]
51     InvalidBinaryOpArgs,
52     #[error("Splat/swizzle type is not registered")]
53     DestinationTypeNotFound,
54 }
55 
56 impl<'a> ConstantSolver<'a> {
solve( &mut self, expr: Handle<Expression>, ) -> Result<Handle<Constant>, ConstantSolvingError>57     pub fn solve(
58         &mut self,
59         expr: Handle<Expression>,
60     ) -> Result<Handle<Constant>, ConstantSolvingError> {
61         match self.expressions[expr] {
62             Expression::Constant(constant) => Ok(constant),
63             Expression::AccessIndex { base, index } => self.access(base, index as usize),
64             Expression::Access { base, index } => {
65                 let index = self.solve(index)?;
66 
67                 self.access(base, self.constant_index(index)?)
68             }
69             Expression::Splat {
70                 size,
71                 value: splat_value,
72             } => {
73                 let value_constant = self.solve(splat_value)?;
74                 let ty = match self.constants[value_constant].inner {
75                     ConstantInner::Scalar { ref value, width } => {
76                         let kind = value.scalar_kind();
77                         self.types
78                             .fetch_if(|t| t.inner == crate::TypeInner::Vector { size, kind, width })
79                     }
80                     ConstantInner::Composite { .. } => None,
81                 };
82 
83                 //TODO: register the new type if needed
84                 let ty = ty.ok_or(ConstantSolvingError::DestinationTypeNotFound)?;
85                 Ok(self.constants.fetch_or_append(Constant {
86                     name: None,
87                     specialization: None,
88                     inner: ConstantInner::Composite {
89                         ty,
90                         components: vec![value_constant; size as usize],
91                     },
92                 }))
93             }
94             Expression::Swizzle {
95                 size,
96                 vector: src_vector,
97                 pattern,
98             } => {
99                 let src_constant = self.solve(src_vector)?;
100                 let (ty, src_components) = match self.constants[src_constant].inner {
101                     ConstantInner::Scalar { .. } => (None, &[][..]),
102                     ConstantInner::Composite {
103                         ty,
104                         components: ref src_components,
105                     } => match self.types[ty].inner {
106                         crate::TypeInner::Vector {
107                             size: _,
108                             kind,
109                             width,
110                         } => {
111                             let dst_ty = self.types.fetch_if(|t| {
112                                 t.inner == crate::TypeInner::Vector { size, kind, width }
113                             });
114                             (dst_ty, &src_components[..])
115                         }
116                         _ => (None, &[][..]),
117                     },
118                 };
119                 //TODO: register the new type if needed
120                 let ty = ty.ok_or(ConstantSolvingError::DestinationTypeNotFound)?;
121                 let components = pattern
122                     .iter()
123                     .map(|&sc| src_components[sc as usize])
124                     .collect();
125 
126                 Ok(self.constants.fetch_or_append(Constant {
127                     name: None,
128                     specialization: None,
129                     inner: ConstantInner::Composite { ty, components },
130                 }))
131             }
132             Expression::Compose { ty, ref components } => {
133                 let components = components
134                     .iter()
135                     .map(|c| self.solve(*c))
136                     .collect::<Result<_, _>>()?;
137 
138                 Ok(self.constants.fetch_or_append(Constant {
139                     name: None,
140                     specialization: None,
141                     inner: ConstantInner::Composite { ty, components },
142                 }))
143             }
144             Expression::Unary { expr, op } => {
145                 let expr_constant = self.solve(expr)?;
146 
147                 self.unary_op(op, expr_constant)
148             }
149             Expression::Binary { left, right, op } => {
150                 let left_constant = self.solve(left)?;
151                 let right_constant = self.solve(right)?;
152 
153                 self.binary_op(op, left_constant, right_constant)
154             }
155             Expression::Math { .. } => todo!(),
156             Expression::As {
157                 convert,
158                 expr,
159                 kind,
160             } => {
161                 let expr_constant = self.solve(expr)?;
162 
163                 match convert {
164                     Some(width) => self.cast(expr_constant, kind, width),
165                     None => Err(ConstantSolvingError::Bitcast),
166                 }
167             }
168             Expression::ArrayLength(expr) => {
169                 let array = self.solve(expr)?;
170 
171                 match self.constants[array].inner {
172                     ConstantInner::Scalar { .. } => {
173                         Err(ConstantSolvingError::InvalidArrayLengthArg)
174                     }
175                     ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
176                         TypeInner::Array { size, .. } => match size {
177                             crate::ArraySize::Constant(constant) => Ok(constant),
178                             crate::ArraySize::Dynamic => {
179                                 Err(ConstantSolvingError::ArrayLengthDynamic)
180                             }
181                         },
182                         _ => Err(ConstantSolvingError::InvalidArrayLengthArg),
183                     },
184                 }
185             }
186 
187             Expression::Load { .. } => Err(ConstantSolvingError::Load),
188             Expression::Select { .. } => Err(ConstantSolvingError::Select),
189             Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable),
190             Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative),
191             Expression::Relational { .. } => Err(ConstantSolvingError::Relational),
192             Expression::Call { .. } => Err(ConstantSolvingError::Call),
193             Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg),
194             Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable),
195             Expression::ImageSample { .. }
196             | Expression::ImageLoad { .. }
197             | Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression),
198         }
199     }
200 
access( &mut self, base: Handle<Expression>, index: usize, ) -> Result<Handle<Constant>, ConstantSolvingError>201     fn access(
202         &mut self,
203         base: Handle<Expression>,
204         index: usize,
205     ) -> Result<Handle<Constant>, ConstantSolvingError> {
206         let base = self.solve(base)?;
207 
208         match self.constants[base].inner {
209             ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
210             ConstantInner::Composite { ty, ref components } => {
211                 match self.types[ty].inner {
212                     TypeInner::Vector { .. }
213                     | TypeInner::Matrix { .. }
214                     | TypeInner::Array { .. }
215                     | TypeInner::Struct { .. } => (),
216                     _ => return Err(ConstantSolvingError::InvalidAccessBase),
217                 }
218 
219                 components
220                     .get(index)
221                     .copied()
222                     .ok_or(ConstantSolvingError::InvalidAccessIndex)
223             }
224         }
225     }
226 
constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError>227     fn constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError> {
228         match self.constants[constant].inner {
229             ConstantInner::Scalar {
230                 value: ScalarValue::Uint(index),
231                 ..
232             } => Ok(index as usize),
233             _ => Err(ConstantSolvingError::InvalidAccessIndexTy),
234         }
235     }
236 
cast( &mut self, constant: Handle<Constant>, kind: ScalarKind, target_width: crate::Bytes, ) -> Result<Handle<Constant>, ConstantSolvingError>237     fn cast(
238         &mut self,
239         constant: Handle<Constant>,
240         kind: ScalarKind,
241         target_width: crate::Bytes,
242     ) -> Result<Handle<Constant>, ConstantSolvingError> {
243         fn inner_cast<A: num_traits::FromPrimitive>(value: ScalarValue) -> A {
244             match value {
245                 ScalarValue::Sint(v) => A::from_i64(v),
246                 ScalarValue::Uint(v) => A::from_u64(v),
247                 ScalarValue::Float(v) => A::from_f64(v),
248                 ScalarValue::Bool(v) => A::from_u64(v as u64),
249             }
250             .unwrap()
251         }
252 
253         let mut inner = self.constants[constant].inner.clone();
254 
255         match inner {
256             ConstantInner::Scalar {
257                 ref mut value,
258                 ref mut width,
259             } => {
260                 *width = target_width;
261                 *value = match kind {
262                     ScalarKind::Sint => ScalarValue::Sint(inner_cast(*value)),
263                     ScalarKind::Uint => ScalarValue::Uint(inner_cast(*value)),
264                     ScalarKind::Float => ScalarValue::Float(inner_cast(*value)),
265                     ScalarKind::Bool => ScalarValue::Bool(inner_cast::<u64>(*value) != 0),
266                 }
267             }
268             ConstantInner::Composite {
269                 ty,
270                 ref mut components,
271             } => {
272                 match self.types[ty].inner {
273                     TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
274                     _ => return Err(ConstantSolvingError::InvalidCastArg),
275                 }
276 
277                 for component in components {
278                     *component = self.cast(*component, kind, target_width)?;
279                 }
280             }
281         }
282 
283         Ok(self.constants.fetch_or_append(Constant {
284             name: None,
285             specialization: None,
286             inner,
287         }))
288     }
289 
unary_op( &mut self, op: UnaryOperator, constant: Handle<Constant>, ) -> Result<Handle<Constant>, ConstantSolvingError>290     fn unary_op(
291         &mut self,
292         op: UnaryOperator,
293         constant: Handle<Constant>,
294     ) -> Result<Handle<Constant>, ConstantSolvingError> {
295         let mut inner = self.constants[constant].inner.clone();
296 
297         match inner {
298             ConstantInner::Scalar { ref mut value, .. } => match op {
299                 UnaryOperator::Negate => match *value {
300                     ScalarValue::Sint(ref mut v) => *v = -*v,
301                     ScalarValue::Float(ref mut v) => *v = -*v,
302                     _ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
303                 },
304                 UnaryOperator::Not => match *value {
305                     ScalarValue::Sint(ref mut v) => *v = !*v,
306                     ScalarValue::Uint(ref mut v) => *v = !*v,
307                     ScalarValue::Bool(ref mut v) => *v = !*v,
308                     _ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
309                 },
310             },
311             ConstantInner::Composite {
312                 ty,
313                 ref mut components,
314             } => {
315                 match self.types[ty].inner {
316                     TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
317                     _ => return Err(ConstantSolvingError::InvalidCastArg),
318                 }
319 
320                 for component in components {
321                     *component = self.unary_op(op, *component)?
322                 }
323             }
324         }
325 
326         Ok(self.constants.fetch_or_append(Constant {
327             name: None,
328             specialization: None,
329             inner,
330         }))
331     }
332 
binary_op( &mut self, op: BinaryOperator, left: Handle<Constant>, right: Handle<Constant>, ) -> Result<Handle<Constant>, ConstantSolvingError>333     fn binary_op(
334         &mut self,
335         op: BinaryOperator,
336         left: Handle<Constant>,
337         right: Handle<Constant>,
338     ) -> Result<Handle<Constant>, ConstantSolvingError> {
339         let left = &self.constants[left].inner;
340         let right = &self.constants[right].inner;
341 
342         let inner = match (left, right) {
343             (
344                 &ConstantInner::Scalar {
345                     value: left_value,
346                     width,
347                 },
348                 &ConstantInner::Scalar {
349                     value: right_value,
350                     width: _,
351                 },
352             ) => {
353                 let value = match op {
354                     BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value),
355                     BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value),
356                     BinaryOperator::Less => ScalarValue::Bool(left_value < right_value),
357                     BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value),
358                     BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value),
359                     BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value),
360 
361                     _ => match (left_value, right_value) {
362                         (ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
363                             ScalarValue::Sint(match op {
364                                 BinaryOperator::Add => a + b,
365                                 BinaryOperator::Subtract => a - b,
366                                 BinaryOperator::Multiply => a * b,
367                                 BinaryOperator::Divide => a / b,
368                                 BinaryOperator::Modulo => a % b,
369                                 BinaryOperator::And => a & b,
370                                 BinaryOperator::ExclusiveOr => a ^ b,
371                                 BinaryOperator::InclusiveOr => a | b,
372                                 BinaryOperator::ShiftLeft => a << b,
373                                 BinaryOperator::ShiftRight => a >> b,
374                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
375                             })
376                         }
377                         (ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
378                             ScalarValue::Uint(match op {
379                                 BinaryOperator::Add => a + b,
380                                 BinaryOperator::Subtract => a - b,
381                                 BinaryOperator::Multiply => a * b,
382                                 BinaryOperator::Divide => a / b,
383                                 BinaryOperator::Modulo => a % b,
384                                 BinaryOperator::And => a & b,
385                                 BinaryOperator::ExclusiveOr => a ^ b,
386                                 BinaryOperator::InclusiveOr => a | b,
387                                 BinaryOperator::ShiftLeft => a << b,
388                                 BinaryOperator::ShiftRight => a >> b,
389                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
390                             })
391                         }
392                         (ScalarValue::Float(a), ScalarValue::Float(b)) => {
393                             ScalarValue::Float(match op {
394                                 BinaryOperator::Add => a + b,
395                                 BinaryOperator::Subtract => a - b,
396                                 BinaryOperator::Multiply => a * b,
397                                 BinaryOperator::Divide => a / b,
398                                 BinaryOperator::Modulo => a % b,
399                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
400                             })
401                         }
402                         (ScalarValue::Bool(a), ScalarValue::Bool(b)) => {
403                             ScalarValue::Bool(match op {
404                                 BinaryOperator::LogicalAnd => a && b,
405                                 BinaryOperator::LogicalOr => a || b,
406                                 _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
407                             })
408                         }
409                         _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
410                     },
411                 };
412 
413                 ConstantInner::Scalar { value, width }
414             }
415             _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
416         };
417 
418         Ok(self.constants.fetch_or_append(Constant {
419             name: None,
420             specialization: None,
421             inner,
422         }))
423     }
424 }
425 
426 #[cfg(test)]
427 mod tests {
428     use std::vec;
429 
430     use crate::{
431         Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
432         UnaryOperator, VectorSize,
433     };
434 
435     use super::ConstantSolver;
436 
437     #[test]
unary_op()438     fn unary_op() {
439         let mut types = Arena::new();
440         let mut expressions = Arena::new();
441         let mut constants = Arena::new();
442 
443         let vec_ty = types.append(Type {
444             name: None,
445             inner: TypeInner::Vector {
446                 size: VectorSize::Bi,
447                 kind: ScalarKind::Sint,
448                 width: 4,
449             },
450         });
451 
452         let h = constants.append(Constant {
453             name: None,
454             specialization: None,
455             inner: ConstantInner::Scalar {
456                 width: 4,
457                 value: ScalarValue::Sint(4),
458             },
459         });
460 
461         let h1 = constants.append(Constant {
462             name: None,
463             specialization: None,
464             inner: ConstantInner::Scalar {
465                 width: 4,
466                 value: ScalarValue::Sint(8),
467             },
468         });
469 
470         let vec_h = constants.append(Constant {
471             name: None,
472             specialization: None,
473             inner: ConstantInner::Composite {
474                 ty: vec_ty,
475                 components: vec![h, h1],
476             },
477         });
478 
479         let expr = expressions.append(Expression::Constant(h));
480         let expr1 = expressions.append(Expression::Constant(vec_h));
481 
482         let root1 = expressions.append(Expression::Unary {
483             op: UnaryOperator::Negate,
484             expr,
485         });
486 
487         let root2 = expressions.append(Expression::Unary {
488             op: UnaryOperator::Not,
489             expr,
490         });
491 
492         let root3 = expressions.append(Expression::Unary {
493             op: UnaryOperator::Not,
494             expr: expr1,
495         });
496 
497         let mut solver = ConstantSolver {
498             types: &types,
499             expressions: &expressions,
500             constants: &mut constants,
501         };
502 
503         let res1 = solver.solve(root1).unwrap();
504         let res2 = solver.solve(root2).unwrap();
505         let res3 = solver.solve(root3).unwrap();
506 
507         assert_eq!(
508             constants[res1].inner,
509             ConstantInner::Scalar {
510                 width: 4,
511                 value: ScalarValue::Sint(-4),
512             },
513         );
514 
515         assert_eq!(
516             constants[res2].inner,
517             ConstantInner::Scalar {
518                 width: 4,
519                 value: ScalarValue::Sint(!4),
520             },
521         );
522 
523         let res3_inner = &constants[res3].inner;
524 
525         match res3_inner {
526             ConstantInner::Composite { ty, components } => {
527                 assert_eq!(*ty, vec_ty);
528                 let mut components_iter = components.iter().copied();
529                 assert_eq!(
530                     constants[components_iter.next().unwrap()].inner,
531                     ConstantInner::Scalar {
532                         width: 4,
533                         value: ScalarValue::Sint(!4),
534                     },
535                 );
536                 assert_eq!(
537                     constants[components_iter.next().unwrap()].inner,
538                     ConstantInner::Scalar {
539                         width: 4,
540                         value: ScalarValue::Sint(!8),
541                     },
542                 );
543                 assert!(components_iter.next().is_none());
544             }
545             _ => panic!("Expected vector"),
546         }
547     }
548 
549     #[test]
cast()550     fn cast() {
551         let mut expressions = Arena::new();
552         let mut constants = Arena::new();
553 
554         let h = constants.append(Constant {
555             name: None,
556             specialization: None,
557             inner: ConstantInner::Scalar {
558                 width: 4,
559                 value: ScalarValue::Sint(4),
560             },
561         });
562 
563         let expr = expressions.append(Expression::Constant(h));
564 
565         let root = expressions.append(Expression::As {
566             expr,
567             kind: ScalarKind::Bool,
568             convert: Some(crate::BOOL_WIDTH),
569         });
570 
571         let mut solver = ConstantSolver {
572             types: &Arena::new(),
573             expressions: &expressions,
574             constants: &mut constants,
575         };
576 
577         let res = solver.solve(root).unwrap();
578 
579         assert_eq!(
580             constants[res].inner,
581             ConstantInner::Scalar {
582                 width: crate::BOOL_WIDTH,
583                 value: ScalarValue::Bool(true),
584             },
585         );
586     }
587 
588     #[test]
access()589     fn access() {
590         let mut types = Arena::new();
591         let mut expressions = Arena::new();
592         let mut constants = Arena::new();
593 
594         let matrix_ty = types.append(Type {
595             name: None,
596             inner: TypeInner::Matrix {
597                 columns: VectorSize::Bi,
598                 rows: VectorSize::Tri,
599                 width: 4,
600             },
601         });
602 
603         let vec_ty = types.append(Type {
604             name: None,
605             inner: TypeInner::Vector {
606                 size: VectorSize::Tri,
607                 kind: ScalarKind::Float,
608                 width: 4,
609             },
610         });
611 
612         let mut vec1_components = Vec::with_capacity(3);
613         let mut vec2_components = Vec::with_capacity(3);
614 
615         for i in 0..3 {
616             let h = constants.append(Constant {
617                 name: None,
618                 specialization: None,
619                 inner: ConstantInner::Scalar {
620                     width: 4,
621                     value: ScalarValue::Float(i as f64),
622                 },
623             });
624 
625             vec1_components.push(h)
626         }
627 
628         for i in 3..6 {
629             let h = constants.append(Constant {
630                 name: None,
631                 specialization: None,
632                 inner: ConstantInner::Scalar {
633                     width: 4,
634                     value: ScalarValue::Float(i as f64),
635                 },
636             });
637 
638             vec2_components.push(h)
639         }
640 
641         let vec1 = constants.append(Constant {
642             name: None,
643             specialization: None,
644             inner: ConstantInner::Composite {
645                 ty: vec_ty,
646                 components: vec1_components,
647             },
648         });
649 
650         let vec2 = constants.append(Constant {
651             name: None,
652             specialization: None,
653             inner: ConstantInner::Composite {
654                 ty: vec_ty,
655                 components: vec2_components,
656             },
657         });
658 
659         let h = constants.append(Constant {
660             name: None,
661             specialization: None,
662             inner: ConstantInner::Composite {
663                 ty: matrix_ty,
664                 components: vec![vec1, vec2],
665             },
666         });
667 
668         let base = expressions.append(Expression::Constant(h));
669         let root1 = expressions.append(Expression::AccessIndex { base, index: 1 });
670         let root2 = expressions.append(Expression::AccessIndex {
671             base: root1,
672             index: 2,
673         });
674 
675         let mut solver = ConstantSolver {
676             types: &types,
677             expressions: &expressions,
678             constants: &mut constants,
679         };
680 
681         let res1 = solver.solve(root1).unwrap();
682         let res2 = solver.solve(root2).unwrap();
683 
684         let res1_inner = &constants[res1].inner;
685 
686         match res1_inner {
687             ConstantInner::Composite { ty, components } => {
688                 assert_eq!(*ty, vec_ty);
689                 let mut components_iter = components.iter().copied();
690                 assert_eq!(
691                     constants[components_iter.next().unwrap()].inner,
692                     ConstantInner::Scalar {
693                         width: 4,
694                         value: ScalarValue::Float(3.),
695                     },
696                 );
697                 assert_eq!(
698                     constants[components_iter.next().unwrap()].inner,
699                     ConstantInner::Scalar {
700                         width: 4,
701                         value: ScalarValue::Float(4.),
702                     },
703                 );
704                 assert_eq!(
705                     constants[components_iter.next().unwrap()].inner,
706                     ConstantInner::Scalar {
707                         width: 4,
708                         value: ScalarValue::Float(5.),
709                     },
710                 );
711                 assert!(components_iter.next().is_none());
712             }
713             _ => panic!("Expected vector"),
714         }
715 
716         assert_eq!(
717             constants[res2].inner,
718             ConstantInner::Scalar {
719                 width: 4,
720                 value: ScalarValue::Float(5.),
721             },
722         );
723     }
724 }
725