1 use std::cell::RefCell;
2 use std::collections::{BTreeSet, HashSet};
3 use std::fmt;
4 use std::hash;
5 use std::iter::FromIterator;
6 use std::ops;
7 use std::rc::Rc;
8 
9 use crate::cdsl::types::{LaneType, ReferenceType, SpecialType, ValueType};
10 
11 const MAX_LANES: u16 = 256;
12 const MAX_BITS: u16 = 128;
13 const MAX_FLOAT_BITS: u16 = 64;
14 
15 /// Type variables can be used in place of concrete types when defining
16 /// instructions. This makes the instructions *polymorphic*.
17 ///
18 /// A type variable is restricted to vary over a subset of the value types.
19 /// This subset is specified by a set of flags that control the permitted base
20 /// types and whether the type variable can assume scalar or vector types, or
21 /// both.
22 #[derive(Debug)]
23 pub(crate) struct TypeVarContent {
24     /// Short name of type variable used in instruction descriptions.
25     pub name: String,
26 
27     /// Documentation string.
28     pub doc: String,
29 
30     /// Type set associated to the type variable.
31     /// This field must remain private; use `get_typeset()` or `get_raw_typeset()` to get the
32     /// information you want.
33     type_set: TypeSet,
34 
35     pub base: Option<TypeVarParent>,
36 }
37 
38 #[derive(Clone, Debug)]
39 pub(crate) struct TypeVar {
40     content: Rc<RefCell<TypeVarContent>>,
41 }
42 
43 impl TypeVar {
new(name: impl Into<String>, doc: impl Into<String>, type_set: TypeSet) -> Self44     pub fn new(name: impl Into<String>, doc: impl Into<String>, type_set: TypeSet) -> Self {
45         Self {
46             content: Rc::new(RefCell::new(TypeVarContent {
47                 name: name.into(),
48                 doc: doc.into(),
49                 type_set,
50                 base: None,
51             })),
52         }
53     }
54 
new_singleton(value_type: ValueType) -> Self55     pub fn new_singleton(value_type: ValueType) -> Self {
56         let (name, doc) = (value_type.to_string(), value_type.doc());
57         let mut builder = TypeSetBuilder::new();
58 
59         let (scalar_type, num_lanes) = match value_type {
60             ValueType::Special(special_type) => {
61                 return TypeVar::new(name, doc, builder.specials(vec![special_type]).build());
62             }
63             ValueType::Reference(ReferenceType(reference_type)) => {
64                 let bits = reference_type as RangeBound;
65                 return TypeVar::new(name, doc, builder.refs(bits..bits).build());
66             }
67             ValueType::Lane(lane_type) => (lane_type, 1),
68             ValueType::Vector(vec_type) => {
69                 (vec_type.lane_type(), vec_type.lane_count() as RangeBound)
70             }
71         };
72 
73         builder = builder.simd_lanes(num_lanes..num_lanes);
74 
75         let builder = match scalar_type {
76             LaneType::Int(int_type) => {
77                 let bits = int_type as RangeBound;
78                 builder.ints(bits..bits)
79             }
80             LaneType::Float(float_type) => {
81                 let bits = float_type as RangeBound;
82                 builder.floats(bits..bits)
83             }
84             LaneType::Bool(bool_type) => {
85                 let bits = bool_type as RangeBound;
86                 builder.bools(bits..bits)
87             }
88         };
89         TypeVar::new(name, doc, builder.build())
90     }
91 
92     /// Get a fresh copy of self, named after `name`. Can only be called on non-derived typevars.
copy_from(other: &TypeVar, name: String) -> TypeVar93     pub fn copy_from(other: &TypeVar, name: String) -> TypeVar {
94         assert!(
95             other.base.is_none(),
96             "copy_from() can only be called on non-derived type variables"
97         );
98         TypeVar {
99             content: Rc::new(RefCell::new(TypeVarContent {
100                 name,
101                 doc: "".into(),
102                 type_set: other.type_set.clone(),
103                 base: None,
104             })),
105         }
106     }
107 
108     /// Returns the typeset for this TV. If the TV is derived, computes it recursively from the
109     /// derived function and the base's typeset.
110     /// Note this can't be done non-lazily in the constructor, because the TypeSet of the base may
111     /// change over time.
get_typeset(&self) -> TypeSet112     pub fn get_typeset(&self) -> TypeSet {
113         match &self.base {
114             Some(base) => base.type_var.get_typeset().image(base.derived_func),
115             None => self.type_set.clone(),
116         }
117     }
118 
119     /// Returns this typevar's type set, assuming this type var has no parent.
get_raw_typeset(&self) -> &TypeSet120     pub fn get_raw_typeset(&self) -> &TypeSet {
121         assert_eq!(self.type_set, self.get_typeset());
122         &self.type_set
123     }
124 
125     /// If the associated typeset has a single type return it. Otherwise return None.
singleton_type(&self) -> Option<ValueType>126     pub fn singleton_type(&self) -> Option<ValueType> {
127         let type_set = self.get_typeset();
128         if type_set.size() == 1 {
129             Some(type_set.get_singleton())
130         } else {
131             None
132         }
133     }
134 
135     /// Get the free type variable controlling this one.
free_typevar(&self) -> Option<TypeVar>136     pub fn free_typevar(&self) -> Option<TypeVar> {
137         match &self.base {
138             Some(base) => base.type_var.free_typevar(),
139             None => {
140                 match self.singleton_type() {
141                     // A singleton type isn't a proper free variable.
142                     Some(_) => None,
143                     None => Some(self.clone()),
144                 }
145             }
146         }
147     }
148 
149     /// Create a type variable that is a function of another.
derived(&self, derived_func: DerivedFunc) -> TypeVar150     pub fn derived(&self, derived_func: DerivedFunc) -> TypeVar {
151         let ts = self.get_typeset();
152 
153         // Safety checks to avoid over/underflows.
154         assert!(ts.specials.is_empty(), "can't derive from special types");
155         match derived_func {
156             DerivedFunc::HalfWidth => {
157                 assert!(
158                     ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
159                     "can't halve all integer types"
160                 );
161                 assert!(
162                     ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 32,
163                     "can't halve all float types"
164                 );
165                 assert!(
166                     ts.bools.is_empty() || *ts.bools.iter().min().unwrap() > 8,
167                     "can't halve all boolean types"
168                 );
169             }
170             DerivedFunc::DoubleWidth => {
171                 assert!(
172                     ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
173                     "can't double all integer types"
174                 );
175                 assert!(
176                     ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
177                     "can't double all float types"
178                 );
179                 assert!(
180                     ts.bools.is_empty() || *ts.bools.iter().max().unwrap() < MAX_BITS,
181                     "can't double all boolean types"
182                 );
183             }
184             DerivedFunc::HalfVector => {
185                 assert!(
186                     *ts.lanes.iter().min().unwrap() > 1,
187                     "can't halve a scalar type"
188                 );
189             }
190             DerivedFunc::DoubleVector => {
191                 assert!(
192                     *ts.lanes.iter().max().unwrap() < MAX_LANES,
193                     "can't double 256 lanes"
194                 );
195             }
196             DerivedFunc::SplitLanes => {
197                 assert!(
198                     ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
199                     "can't halve all integer types"
200                 );
201                 assert!(
202                     ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 32,
203                     "can't halve all float types"
204                 );
205                 assert!(
206                     ts.bools.is_empty() || *ts.bools.iter().min().unwrap() > 8,
207                     "can't halve all boolean types"
208                 );
209                 assert!(
210                     *ts.lanes.iter().max().unwrap() < MAX_LANES,
211                     "can't double 256 lanes"
212                 );
213             }
214             DerivedFunc::MergeLanes => {
215                 assert!(
216                     ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
217                     "can't double all integer types"
218                 );
219                 assert!(
220                     ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
221                     "can't double all float types"
222                 );
223                 assert!(
224                     ts.bools.is_empty() || *ts.bools.iter().max().unwrap() < MAX_BITS,
225                     "can't double all boolean types"
226                 );
227                 assert!(
228                     *ts.lanes.iter().min().unwrap() > 1,
229                     "can't halve a scalar type"
230                 );
231             }
232             DerivedFunc::LaneOf | DerivedFunc::AsBool => { /* no particular assertions */ }
233         }
234 
235         TypeVar {
236             content: Rc::new(RefCell::new(TypeVarContent {
237                 name: format!("{}({})", derived_func.name(), self.name),
238                 doc: "".into(),
239                 type_set: ts,
240                 base: Some(TypeVarParent {
241                     type_var: self.clone(),
242                     derived_func,
243                 }),
244             })),
245         }
246     }
247 
lane_of(&self) -> TypeVar248     pub fn lane_of(&self) -> TypeVar {
249         self.derived(DerivedFunc::LaneOf)
250     }
as_bool(&self) -> TypeVar251     pub fn as_bool(&self) -> TypeVar {
252         self.derived(DerivedFunc::AsBool)
253     }
half_width(&self) -> TypeVar254     pub fn half_width(&self) -> TypeVar {
255         self.derived(DerivedFunc::HalfWidth)
256     }
double_width(&self) -> TypeVar257     pub fn double_width(&self) -> TypeVar {
258         self.derived(DerivedFunc::DoubleWidth)
259     }
half_vector(&self) -> TypeVar260     pub fn half_vector(&self) -> TypeVar {
261         self.derived(DerivedFunc::HalfVector)
262     }
double_vector(&self) -> TypeVar263     pub fn double_vector(&self) -> TypeVar {
264         self.derived(DerivedFunc::DoubleVector)
265     }
split_lanes(&self) -> TypeVar266     pub fn split_lanes(&self) -> TypeVar {
267         self.derived(DerivedFunc::SplitLanes)
268     }
merge_lanes(&self) -> TypeVar269     pub fn merge_lanes(&self) -> TypeVar {
270         self.derived(DerivedFunc::MergeLanes)
271     }
272 
273     /// Constrain the range of types this variable can assume to a subset of those in the typeset
274     /// ts.
275     /// May mutate itself if it's not derived, or its parent if it is.
constrain_types_by_ts(&self, type_set: TypeSet)276     pub fn constrain_types_by_ts(&self, type_set: TypeSet) {
277         match &self.base {
278             Some(base) => {
279                 base.type_var
280                     .constrain_types_by_ts(type_set.preimage(base.derived_func));
281             }
282             None => {
283                 self.content
284                     .borrow_mut()
285                     .type_set
286                     .inplace_intersect_with(&type_set);
287             }
288         }
289     }
290 
291     /// Constrain the range of types this variable can assume to a subset of those `other` can
292     /// assume.
293     /// May mutate itself if it's not derived, or its parent if it is.
constrain_types(&self, other: TypeVar)294     pub fn constrain_types(&self, other: TypeVar) {
295         if self == &other {
296             return;
297         }
298         self.constrain_types_by_ts(other.get_typeset());
299     }
300 
301     /// Get a Rust expression that computes the type of this type variable.
to_rust_code(&self) -> String302     pub fn to_rust_code(&self) -> String {
303         match &self.base {
304             Some(base) => format!(
305                 "{}.{}().unwrap()",
306                 base.type_var.to_rust_code(),
307                 base.derived_func.name()
308             ),
309             None => {
310                 if let Some(singleton) = self.singleton_type() {
311                     singleton.rust_name()
312                 } else {
313                     self.name.clone()
314                 }
315             }
316         }
317     }
318 }
319 
320 impl Into<TypeVar> for &TypeVar {
into(self) -> TypeVar321     fn into(self) -> TypeVar {
322         self.clone()
323     }
324 }
325 impl Into<TypeVar> for ValueType {
into(self) -> TypeVar326     fn into(self) -> TypeVar {
327         TypeVar::new_singleton(self)
328     }
329 }
330 
331 // Hash TypeVars by pointers.
332 // There might be a better way to do this, but since TypeVar's content (namely TypeSet) can be
333 // mutated, it makes sense to use pointer equality/hashing here.
334 impl hash::Hash for TypeVar {
hash<H: hash::Hasher>(&self, h: &mut H)335     fn hash<H: hash::Hasher>(&self, h: &mut H) {
336         match &self.base {
337             Some(base) => {
338                 base.type_var.hash(h);
339                 base.derived_func.hash(h);
340             }
341             None => {
342                 (&**self as *const TypeVarContent).hash(h);
343             }
344         }
345     }
346 }
347 
348 impl PartialEq for TypeVar {
eq(&self, other: &TypeVar) -> bool349     fn eq(&self, other: &TypeVar) -> bool {
350         match (&self.base, &other.base) {
351             (Some(base1), Some(base2)) => {
352                 base1.type_var.eq(&base2.type_var) && base1.derived_func == base2.derived_func
353             }
354             (None, None) => Rc::ptr_eq(&self.content, &other.content),
355             _ => false,
356         }
357     }
358 }
359 
360 // Allow TypeVar as map keys, based on pointer equality (see also above PartialEq impl).
361 impl Eq for TypeVar {}
362 
363 impl ops::Deref for TypeVar {
364     type Target = TypeVarContent;
deref(&self) -> &Self::Target365     fn deref(&self) -> &Self::Target {
366         unsafe { self.content.as_ptr().as_ref().unwrap() }
367     }
368 }
369 
370 #[derive(Clone, Copy, Debug, Hash, PartialEq)]
371 pub(crate) enum DerivedFunc {
372     LaneOf,
373     AsBool,
374     HalfWidth,
375     DoubleWidth,
376     HalfVector,
377     DoubleVector,
378     SplitLanes,
379     MergeLanes,
380 }
381 
382 impl DerivedFunc {
name(self) -> &'static str383     pub fn name(self) -> &'static str {
384         match self {
385             DerivedFunc::LaneOf => "lane_of",
386             DerivedFunc::AsBool => "as_bool",
387             DerivedFunc::HalfWidth => "half_width",
388             DerivedFunc::DoubleWidth => "double_width",
389             DerivedFunc::HalfVector => "half_vector",
390             DerivedFunc::DoubleVector => "double_vector",
391             DerivedFunc::SplitLanes => "split_lanes",
392             DerivedFunc::MergeLanes => "merge_lanes",
393         }
394     }
395 
396     /// Returns the inverse function of this one, if it is a bijection.
inverse(self) -> Option<DerivedFunc>397     pub fn inverse(self) -> Option<DerivedFunc> {
398         match self {
399             DerivedFunc::HalfWidth => Some(DerivedFunc::DoubleWidth),
400             DerivedFunc::DoubleWidth => Some(DerivedFunc::HalfWidth),
401             DerivedFunc::HalfVector => Some(DerivedFunc::DoubleVector),
402             DerivedFunc::DoubleVector => Some(DerivedFunc::HalfVector),
403             DerivedFunc::MergeLanes => Some(DerivedFunc::SplitLanes),
404             DerivedFunc::SplitLanes => Some(DerivedFunc::MergeLanes),
405             _ => None,
406         }
407     }
408 }
409 
410 #[derive(Debug, Hash)]
411 pub(crate) struct TypeVarParent {
412     pub type_var: TypeVar,
413     pub derived_func: DerivedFunc,
414 }
415 
416 /// A set of types.
417 ///
418 /// We don't allow arbitrary subsets of types, but use a parametrized approach
419 /// instead.
420 ///
421 /// Objects of this class can be used as dictionary keys.
422 ///
423 /// Parametrized type sets are specified in terms of ranges:
424 /// - The permitted range of vector lanes, where 1 indicates a scalar type.
425 /// - The permitted range of integer types.
426 /// - The permitted range of floating point types, and
427 /// - The permitted range of boolean types.
428 ///
429 /// The ranges are inclusive from smallest bit-width to largest bit-width.
430 ///
431 /// Finally, a type set can contain special types (derived from `SpecialType`)
432 /// which can't appear as lane types.
433 
434 type RangeBound = u16;
435 type Range = ops::Range<RangeBound>;
436 type NumSet = BTreeSet<RangeBound>;
437 
438 macro_rules! num_set {
439     ($($expr:expr),*) => {
440         NumSet::from_iter(vec![$($expr),*])
441     };
442 }
443 
444 #[derive(Clone, PartialEq, Eq, Hash)]
445 pub(crate) struct TypeSet {
446     pub lanes: NumSet,
447     pub ints: NumSet,
448     pub floats: NumSet,
449     pub bools: NumSet,
450     pub refs: NumSet,
451     pub specials: Vec<SpecialType>,
452 }
453 
454 impl TypeSet {
new( lanes: NumSet, ints: NumSet, floats: NumSet, bools: NumSet, refs: NumSet, specials: Vec<SpecialType>, ) -> Self455     fn new(
456         lanes: NumSet,
457         ints: NumSet,
458         floats: NumSet,
459         bools: NumSet,
460         refs: NumSet,
461         specials: Vec<SpecialType>,
462     ) -> Self {
463         Self {
464             lanes,
465             ints,
466             floats,
467             bools,
468             refs,
469             specials,
470         }
471     }
472 
473     /// Return the number of concrete types represented by this typeset.
size(&self) -> usize474     pub fn size(&self) -> usize {
475         self.lanes.len()
476             * (self.ints.len() + self.floats.len() + self.bools.len() + self.refs.len())
477             + self.specials.len()
478     }
479 
480     /// Return the image of self across the derived function func.
image(&self, derived_func: DerivedFunc) -> TypeSet481     fn image(&self, derived_func: DerivedFunc) -> TypeSet {
482         match derived_func {
483             DerivedFunc::LaneOf => self.lane_of(),
484             DerivedFunc::AsBool => self.as_bool(),
485             DerivedFunc::HalfWidth => self.half_width(),
486             DerivedFunc::DoubleWidth => self.double_width(),
487             DerivedFunc::HalfVector => self.half_vector(),
488             DerivedFunc::DoubleVector => self.double_vector(),
489             DerivedFunc::SplitLanes => self.half_width().double_vector(),
490             DerivedFunc::MergeLanes => self.double_width().half_vector(),
491         }
492     }
493 
494     /// Return a TypeSet describing the image of self across lane_of.
lane_of(&self) -> TypeSet495     fn lane_of(&self) -> TypeSet {
496         let mut copy = self.clone();
497         copy.lanes = num_set![1];
498         copy
499     }
500 
501     /// Return a TypeSet describing the image of self across as_bool.
as_bool(&self) -> TypeSet502     fn as_bool(&self) -> TypeSet {
503         let mut copy = self.clone();
504         copy.ints = NumSet::new();
505         copy.floats = NumSet::new();
506         copy.refs = NumSet::new();
507         if !(&self.lanes - &num_set![1]).is_empty() {
508             copy.bools = &self.ints | &self.floats;
509             copy.bools = &copy.bools | &self.bools;
510         }
511         if self.lanes.contains(&1) {
512             copy.bools.insert(1);
513         }
514         copy
515     }
516 
517     /// Return a TypeSet describing the image of self across halfwidth.
half_width(&self) -> TypeSet518     fn half_width(&self) -> TypeSet {
519         let mut copy = self.clone();
520         copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x > 8).map(|&x| x / 2));
521         copy.floats = NumSet::from_iter(self.floats.iter().filter(|&&x| x > 32).map(|&x| x / 2));
522         copy.bools = NumSet::from_iter(self.bools.iter().filter(|&&x| x > 8).map(|&x| x / 2));
523         copy.specials = Vec::new();
524         copy
525     }
526 
527     /// Return a TypeSet describing the image of self across doublewidth.
double_width(&self) -> TypeSet528     fn double_width(&self) -> TypeSet {
529         let mut copy = self.clone();
530         copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x < MAX_BITS).map(|&x| x * 2));
531         copy.floats = NumSet::from_iter(
532             self.floats
533                 .iter()
534                 .filter(|&&x| x < MAX_FLOAT_BITS)
535                 .map(|&x| x * 2),
536         );
537         copy.bools = NumSet::from_iter(
538             self.bools
539                 .iter()
540                 .filter(|&&x| x < MAX_BITS)
541                 .map(|&x| x * 2)
542                 .filter(|x| legal_bool(*x)),
543         );
544         copy.specials = Vec::new();
545         copy
546     }
547 
548     /// Return a TypeSet describing the image of self across halfvector.
half_vector(&self) -> TypeSet549     fn half_vector(&self) -> TypeSet {
550         let mut copy = self.clone();
551         copy.lanes = NumSet::from_iter(self.lanes.iter().filter(|&&x| x > 1).map(|&x| x / 2));
552         copy.specials = Vec::new();
553         copy
554     }
555 
556     /// Return a TypeSet describing the image of self across doublevector.
double_vector(&self) -> TypeSet557     fn double_vector(&self) -> TypeSet {
558         let mut copy = self.clone();
559         copy.lanes = NumSet::from_iter(
560             self.lanes
561                 .iter()
562                 .filter(|&&x| x < MAX_LANES)
563                 .map(|&x| x * 2),
564         );
565         copy.specials = Vec::new();
566         copy
567     }
568 
concrete_types(&self) -> Vec<ValueType>569     fn concrete_types(&self) -> Vec<ValueType> {
570         let mut ret = Vec::new();
571         for &num_lanes in &self.lanes {
572             for &bits in &self.ints {
573                 ret.push(LaneType::int_from_bits(bits).by(num_lanes));
574             }
575             for &bits in &self.floats {
576                 ret.push(LaneType::float_from_bits(bits).by(num_lanes));
577             }
578             for &bits in &self.bools {
579                 ret.push(LaneType::bool_from_bits(bits).by(num_lanes));
580             }
581             for &bits in &self.refs {
582                 ret.push(ReferenceType::ref_from_bits(bits).into());
583             }
584         }
585         for &special in &self.specials {
586             ret.push(special.into());
587         }
588         ret
589     }
590 
591     /// Return the singleton type represented by self. Can only call on typesets containing 1 type.
get_singleton(&self) -> ValueType592     fn get_singleton(&self) -> ValueType {
593         let mut types = self.concrete_types();
594         assert_eq!(types.len(), 1);
595         types.remove(0)
596     }
597 
598     /// Return the inverse image of self across the derived function func.
preimage(&self, func: DerivedFunc) -> TypeSet599     fn preimage(&self, func: DerivedFunc) -> TypeSet {
600         if self.size() == 0 {
601             // The inverse of the empty set is itself.
602             return self.clone();
603         }
604 
605         match func {
606             DerivedFunc::LaneOf => {
607                 let mut copy = self.clone();
608                 copy.lanes =
609                     NumSet::from_iter((0..=MAX_LANES.trailing_zeros()).map(|i| u16::pow(2, i)));
610                 copy
611             }
612             DerivedFunc::AsBool => {
613                 let mut copy = self.clone();
614                 if self.bools.contains(&1) {
615                     copy.ints = NumSet::from_iter(vec![8, 16, 32, 64, 128]);
616                     copy.floats = NumSet::from_iter(vec![32, 64]);
617                 } else {
618                     copy.ints = &self.bools - &NumSet::from_iter(vec![1]);
619                     copy.floats = &self.bools & &NumSet::from_iter(vec![32, 64]);
620                     // If b1 is not in our typeset, than lanes=1 cannot be in the pre-image, as
621                     // as_bool() of scalars is always b1.
622                     copy.lanes = &self.lanes - &NumSet::from_iter(vec![1]);
623                 }
624                 copy
625             }
626             DerivedFunc::HalfWidth => self.double_width(),
627             DerivedFunc::DoubleWidth => self.half_width(),
628             DerivedFunc::HalfVector => self.double_vector(),
629             DerivedFunc::DoubleVector => self.half_vector(),
630             DerivedFunc::SplitLanes => self.double_width().half_vector(),
631             DerivedFunc::MergeLanes => self.half_width().double_vector(),
632         }
633     }
634 
inplace_intersect_with(&mut self, other: &TypeSet)635     pub fn inplace_intersect_with(&mut self, other: &TypeSet) {
636         self.lanes = &self.lanes & &other.lanes;
637         self.ints = &self.ints & &other.ints;
638         self.floats = &self.floats & &other.floats;
639         self.bools = &self.bools & &other.bools;
640         self.refs = &self.refs & &other.refs;
641 
642         let mut new_specials = Vec::new();
643         for spec in &self.specials {
644             if let Some(spec) = other.specials.iter().find(|&other_spec| other_spec == spec) {
645                 new_specials.push(*spec);
646             }
647         }
648         self.specials = new_specials;
649     }
650 
is_subset(&self, other: &TypeSet) -> bool651     pub fn is_subset(&self, other: &TypeSet) -> bool {
652         self.lanes.is_subset(&other.lanes)
653             && self.ints.is_subset(&other.ints)
654             && self.floats.is_subset(&other.floats)
655             && self.bools.is_subset(&other.bools)
656             && self.refs.is_subset(&other.refs)
657             && {
658                 let specials: HashSet<SpecialType> = HashSet::from_iter(self.specials.clone());
659                 let other_specials = HashSet::from_iter(other.specials.clone());
660                 specials.is_subset(&other_specials)
661             }
662     }
663 
is_wider_or_equal(&self, other: &TypeSet) -> bool664     pub fn is_wider_or_equal(&self, other: &TypeSet) -> bool {
665         set_wider_or_equal(&self.ints, &other.ints)
666             && set_wider_or_equal(&self.floats, &other.floats)
667             && set_wider_or_equal(&self.bools, &other.bools)
668             && set_wider_or_equal(&self.refs, &other.refs)
669     }
670 
is_narrower(&self, other: &TypeSet) -> bool671     pub fn is_narrower(&self, other: &TypeSet) -> bool {
672         set_narrower(&self.ints, &other.ints)
673             && set_narrower(&self.floats, &other.floats)
674             && set_narrower(&self.bools, &other.bools)
675             && set_narrower(&self.refs, &other.refs)
676     }
677 }
678 
set_wider_or_equal(s1: &NumSet, s2: &NumSet) -> bool679 fn set_wider_or_equal(s1: &NumSet, s2: &NumSet) -> bool {
680     !s1.is_empty() && !s2.is_empty() && s1.iter().min() >= s2.iter().max()
681 }
682 
set_narrower(s1: &NumSet, s2: &NumSet) -> bool683 fn set_narrower(s1: &NumSet, s2: &NumSet) -> bool {
684     !s1.is_empty() && !s2.is_empty() && s1.iter().min() < s2.iter().max()
685 }
686 
687 impl fmt::Debug for TypeSet {
fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error>688     fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
689         write!(fmt, "TypeSet(")?;
690 
691         let mut subsets = Vec::new();
692         if !self.lanes.is_empty() {
693             subsets.push(format!(
694                 "lanes={{{}}}",
695                 Vec::from_iter(self.lanes.iter().map(|x| x.to_string())).join(", ")
696             ));
697         }
698         if !self.ints.is_empty() {
699             subsets.push(format!(
700                 "ints={{{}}}",
701                 Vec::from_iter(self.ints.iter().map(|x| x.to_string())).join(", ")
702             ));
703         }
704         if !self.floats.is_empty() {
705             subsets.push(format!(
706                 "floats={{{}}}",
707                 Vec::from_iter(self.floats.iter().map(|x| x.to_string())).join(", ")
708             ));
709         }
710         if !self.bools.is_empty() {
711             subsets.push(format!(
712                 "bools={{{}}}",
713                 Vec::from_iter(self.bools.iter().map(|x| x.to_string())).join(", ")
714             ));
715         }
716         if !self.refs.is_empty() {
717             subsets.push(format!(
718                 "refs={{{}}}",
719                 Vec::from_iter(self.refs.iter().map(|x| x.to_string())).join(", ")
720             ));
721         }
722         if !self.specials.is_empty() {
723             subsets.push(format!(
724                 "specials={{{}}}",
725                 Vec::from_iter(self.specials.iter().map(|x| x.to_string())).join(", ")
726             ));
727         }
728 
729         write!(fmt, "{})", subsets.join(", "))?;
730         Ok(())
731     }
732 }
733 
734 pub(crate) struct TypeSetBuilder {
735     ints: Interval,
736     floats: Interval,
737     bools: Interval,
738     refs: Interval,
739     includes_scalars: bool,
740     simd_lanes: Interval,
741     specials: Vec<SpecialType>,
742 }
743 
744 impl TypeSetBuilder {
new() -> Self745     pub fn new() -> Self {
746         Self {
747             ints: Interval::None,
748             floats: Interval::None,
749             bools: Interval::None,
750             refs: Interval::None,
751             includes_scalars: true,
752             simd_lanes: Interval::None,
753             specials: Vec::new(),
754         }
755     }
756 
ints(mut self, interval: impl Into<Interval>) -> Self757     pub fn ints(mut self, interval: impl Into<Interval>) -> Self {
758         assert!(self.ints == Interval::None);
759         self.ints = interval.into();
760         self
761     }
floats(mut self, interval: impl Into<Interval>) -> Self762     pub fn floats(mut self, interval: impl Into<Interval>) -> Self {
763         assert!(self.floats == Interval::None);
764         self.floats = interval.into();
765         self
766     }
bools(mut self, interval: impl Into<Interval>) -> Self767     pub fn bools(mut self, interval: impl Into<Interval>) -> Self {
768         assert!(self.bools == Interval::None);
769         self.bools = interval.into();
770         self
771     }
refs(mut self, interval: impl Into<Interval>) -> Self772     pub fn refs(mut self, interval: impl Into<Interval>) -> Self {
773         assert!(self.refs == Interval::None);
774         self.refs = interval.into();
775         self
776     }
includes_scalars(mut self, includes_scalars: bool) -> Self777     pub fn includes_scalars(mut self, includes_scalars: bool) -> Self {
778         self.includes_scalars = includes_scalars;
779         self
780     }
simd_lanes(mut self, interval: impl Into<Interval>) -> Self781     pub fn simd_lanes(mut self, interval: impl Into<Interval>) -> Self {
782         assert!(self.simd_lanes == Interval::None);
783         self.simd_lanes = interval.into();
784         self
785     }
specials(mut self, specials: Vec<SpecialType>) -> Self786     pub fn specials(mut self, specials: Vec<SpecialType>) -> Self {
787         assert!(self.specials.is_empty());
788         self.specials = specials;
789         self
790     }
791 
build(self) -> TypeSet792     pub fn build(self) -> TypeSet {
793         let min_lanes = if self.includes_scalars { 1 } else { 2 };
794 
795         let bools = range_to_set(self.bools.to_range(1..MAX_BITS, None))
796             .into_iter()
797             .filter(|x| legal_bool(*x))
798             .collect();
799 
800         TypeSet::new(
801             range_to_set(self.simd_lanes.to_range(min_lanes..MAX_LANES, Some(1))),
802             range_to_set(self.ints.to_range(8..MAX_BITS, None)),
803             range_to_set(self.floats.to_range(32..64, None)),
804             bools,
805             range_to_set(self.refs.to_range(32..64, None)),
806             self.specials,
807         )
808     }
809 
all() -> TypeSet810     pub fn all() -> TypeSet {
811         TypeSetBuilder::new()
812             .ints(Interval::All)
813             .floats(Interval::All)
814             .bools(Interval::All)
815             .refs(Interval::All)
816             .simd_lanes(Interval::All)
817             .specials(ValueType::all_special_types().collect())
818             .includes_scalars(true)
819             .build()
820     }
821 }
822 
823 #[derive(PartialEq)]
824 pub(crate) enum Interval {
825     None,
826     All,
827     Range(Range),
828 }
829 
830 impl Interval {
to_range(&self, full_range: Range, default: Option<RangeBound>) -> Option<Range>831     fn to_range(&self, full_range: Range, default: Option<RangeBound>) -> Option<Range> {
832         match self {
833             Interval::None => {
834                 if let Some(default_val) = default {
835                     Some(default_val..default_val)
836                 } else {
837                     None
838                 }
839             }
840 
841             Interval::All => Some(full_range),
842 
843             Interval::Range(range) => {
844                 let (low, high) = (range.start, range.end);
845                 assert!(low.is_power_of_two());
846                 assert!(high.is_power_of_two());
847                 assert!(low <= high);
848                 assert!(low >= full_range.start);
849                 assert!(high <= full_range.end);
850                 Some(low..high)
851             }
852         }
853     }
854 }
855 
856 impl Into<Interval> for Range {
into(self) -> Interval857     fn into(self) -> Interval {
858         Interval::Range(self)
859     }
860 }
861 
legal_bool(bits: RangeBound) -> bool862 fn legal_bool(bits: RangeBound) -> bool {
863     // Only allow legal bit widths for bool types.
864     bits == 1 || (bits >= 8 && bits <= MAX_BITS && bits.is_power_of_two())
865 }
866 
867 /// Generates a set with all the powers of two included in the range.
range_to_set(range: Option<Range>) -> NumSet868 fn range_to_set(range: Option<Range>) -> NumSet {
869     let mut set = NumSet::new();
870 
871     let (low, high) = match range {
872         Some(range) => (range.start, range.end),
873         None => return set,
874     };
875 
876     assert!(low.is_power_of_two());
877     assert!(high.is_power_of_two());
878     assert!(low <= high);
879 
880     for i in low.trailing_zeros()..=high.trailing_zeros() {
881         assert!(1 << i <= RangeBound::max_value());
882         set.insert(1 << i);
883     }
884     set
885 }
886 
887 #[test]
test_typevar_builder()888 fn test_typevar_builder() {
889     let type_set = TypeSetBuilder::new().ints(Interval::All).build();
890     assert_eq!(type_set.lanes, num_set![1]);
891     assert!(type_set.floats.is_empty());
892     assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]);
893     assert!(type_set.bools.is_empty());
894     assert!(type_set.specials.is_empty());
895 
896     let type_set = TypeSetBuilder::new().bools(Interval::All).build();
897     assert_eq!(type_set.lanes, num_set![1]);
898     assert!(type_set.floats.is_empty());
899     assert!(type_set.ints.is_empty());
900     assert_eq!(type_set.bools, num_set![1, 8, 16, 32, 64, 128]);
901     assert!(type_set.specials.is_empty());
902 
903     let type_set = TypeSetBuilder::new().floats(Interval::All).build();
904     assert_eq!(type_set.lanes, num_set![1]);
905     assert_eq!(type_set.floats, num_set![32, 64]);
906     assert!(type_set.ints.is_empty());
907     assert!(type_set.bools.is_empty());
908     assert!(type_set.specials.is_empty());
909 
910     let type_set = TypeSetBuilder::new()
911         .floats(Interval::All)
912         .simd_lanes(Interval::All)
913         .includes_scalars(false)
914         .build();
915     assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]);
916     assert_eq!(type_set.floats, num_set![32, 64]);
917     assert!(type_set.ints.is_empty());
918     assert!(type_set.bools.is_empty());
919     assert!(type_set.specials.is_empty());
920 
921     let type_set = TypeSetBuilder::new()
922         .floats(Interval::All)
923         .simd_lanes(Interval::All)
924         .includes_scalars(true)
925         .build();
926     assert_eq!(type_set.lanes, num_set![1, 2, 4, 8, 16, 32, 64, 128, 256]);
927     assert_eq!(type_set.floats, num_set![32, 64]);
928     assert!(type_set.ints.is_empty());
929     assert!(type_set.bools.is_empty());
930     assert!(type_set.specials.is_empty());
931 
932     let type_set = TypeSetBuilder::new().ints(16..64).build();
933     assert_eq!(type_set.lanes, num_set![1]);
934     assert_eq!(type_set.ints, num_set![16, 32, 64]);
935     assert!(type_set.floats.is_empty());
936     assert!(type_set.bools.is_empty());
937     assert!(type_set.specials.is_empty());
938 }
939 
940 #[test]
941 #[should_panic]
test_typevar_builder_too_high_bound_panic()942 fn test_typevar_builder_too_high_bound_panic() {
943     TypeSetBuilder::new().ints(16..2 * MAX_BITS).build();
944 }
945 
946 #[test]
947 #[should_panic]
test_typevar_builder_inverted_bounds_panic()948 fn test_typevar_builder_inverted_bounds_panic() {
949     TypeSetBuilder::new().ints(32..16).build();
950 }
951 
952 #[test]
test_as_bool()953 fn test_as_bool() {
954     let a = TypeSetBuilder::new()
955         .simd_lanes(2..8)
956         .ints(8..8)
957         .floats(32..32)
958         .build();
959     assert_eq!(
960         a.lane_of(),
961         TypeSetBuilder::new().ints(8..8).floats(32..32).build()
962     );
963 
964     // Test as_bool with disjoint intervals.
965     let mut a_as_bool = TypeSetBuilder::new().simd_lanes(2..8).build();
966     a_as_bool.bools = num_set![8, 32];
967     assert_eq!(a.as_bool(), a_as_bool);
968 
969     let b = TypeSetBuilder::new()
970         .simd_lanes(1..8)
971         .ints(8..8)
972         .floats(32..32)
973         .build();
974     let mut b_as_bool = TypeSetBuilder::new().simd_lanes(1..8).build();
975     b_as_bool.bools = num_set![1, 8, 32];
976     assert_eq!(b.as_bool(), b_as_bool);
977 }
978 
979 #[test]
test_forward_images()980 fn test_forward_images() {
981     let empty_set = TypeSetBuilder::new().build();
982 
983     // Half vector.
984     assert_eq!(
985         TypeSetBuilder::new()
986             .simd_lanes(1..32)
987             .build()
988             .half_vector(),
989         TypeSetBuilder::new().simd_lanes(1..16).build()
990     );
991 
992     // Double vector.
993     assert_eq!(
994         TypeSetBuilder::new()
995             .simd_lanes(1..32)
996             .build()
997             .double_vector(),
998         TypeSetBuilder::new().simd_lanes(2..64).build()
999     );
1000     assert_eq!(
1001         TypeSetBuilder::new()
1002             .simd_lanes(128..256)
1003             .build()
1004             .double_vector(),
1005         TypeSetBuilder::new().simd_lanes(256..256).build()
1006     );
1007 
1008     // Half width.
1009     assert_eq!(
1010         TypeSetBuilder::new().ints(8..32).build().half_width(),
1011         TypeSetBuilder::new().ints(8..16).build()
1012     );
1013     assert_eq!(
1014         TypeSetBuilder::new().floats(32..32).build().half_width(),
1015         empty_set
1016     );
1017     assert_eq!(
1018         TypeSetBuilder::new().floats(32..64).build().half_width(),
1019         TypeSetBuilder::new().floats(32..32).build()
1020     );
1021     assert_eq!(
1022         TypeSetBuilder::new().bools(1..8).build().half_width(),
1023         empty_set
1024     );
1025     assert_eq!(
1026         TypeSetBuilder::new().bools(1..32).build().half_width(),
1027         TypeSetBuilder::new().bools(8..16).build()
1028     );
1029 
1030     // Double width.
1031     assert_eq!(
1032         TypeSetBuilder::new().ints(8..32).build().double_width(),
1033         TypeSetBuilder::new().ints(16..64).build()
1034     );
1035     assert_eq!(
1036         TypeSetBuilder::new().ints(32..64).build().double_width(),
1037         TypeSetBuilder::new().ints(64..128).build()
1038     );
1039     assert_eq!(
1040         TypeSetBuilder::new().floats(32..32).build().double_width(),
1041         TypeSetBuilder::new().floats(64..64).build()
1042     );
1043     assert_eq!(
1044         TypeSetBuilder::new().floats(32..64).build().double_width(),
1045         TypeSetBuilder::new().floats(64..64).build()
1046     );
1047     assert_eq!(
1048         TypeSetBuilder::new().bools(1..16).build().double_width(),
1049         TypeSetBuilder::new().bools(16..32).build()
1050     );
1051     assert_eq!(
1052         TypeSetBuilder::new().bools(32..64).build().double_width(),
1053         TypeSetBuilder::new().bools(64..128).build()
1054     );
1055 }
1056 
1057 #[test]
test_backward_images()1058 fn test_backward_images() {
1059     let empty_set = TypeSetBuilder::new().build();
1060 
1061     // LaneOf.
1062     assert_eq!(
1063         TypeSetBuilder::new()
1064             .simd_lanes(1..1)
1065             .ints(8..8)
1066             .floats(32..32)
1067             .build()
1068             .preimage(DerivedFunc::LaneOf),
1069         TypeSetBuilder::new()
1070             .simd_lanes(Interval::All)
1071             .ints(8..8)
1072             .floats(32..32)
1073             .build()
1074     );
1075     assert_eq!(empty_set.preimage(DerivedFunc::LaneOf), empty_set);
1076 
1077     // AsBool.
1078     assert_eq!(
1079         TypeSetBuilder::new()
1080             .simd_lanes(1..4)
1081             .bools(1..128)
1082             .build()
1083             .preimage(DerivedFunc::AsBool),
1084         TypeSetBuilder::new()
1085             .simd_lanes(1..4)
1086             .ints(Interval::All)
1087             .bools(Interval::All)
1088             .floats(Interval::All)
1089             .build()
1090     );
1091 
1092     // Double vector.
1093     assert_eq!(
1094         TypeSetBuilder::new()
1095             .simd_lanes(1..1)
1096             .ints(8..8)
1097             .build()
1098             .preimage(DerivedFunc::DoubleVector)
1099             .size(),
1100         0
1101     );
1102     assert_eq!(
1103         TypeSetBuilder::new()
1104             .simd_lanes(1..16)
1105             .ints(8..16)
1106             .floats(32..32)
1107             .build()
1108             .preimage(DerivedFunc::DoubleVector),
1109         TypeSetBuilder::new()
1110             .simd_lanes(1..8)
1111             .ints(8..16)
1112             .floats(32..32)
1113             .build(),
1114     );
1115 
1116     // Half vector.
1117     assert_eq!(
1118         TypeSetBuilder::new()
1119             .simd_lanes(256..256)
1120             .ints(8..8)
1121             .build()
1122             .preimage(DerivedFunc::HalfVector)
1123             .size(),
1124         0
1125     );
1126     assert_eq!(
1127         TypeSetBuilder::new()
1128             .simd_lanes(64..128)
1129             .bools(1..32)
1130             .build()
1131             .preimage(DerivedFunc::HalfVector),
1132         TypeSetBuilder::new()
1133             .simd_lanes(128..256)
1134             .bools(1..32)
1135             .build(),
1136     );
1137 
1138     // Half width.
1139     assert_eq!(
1140         TypeSetBuilder::new()
1141             .ints(128..128)
1142             .floats(64..64)
1143             .bools(128..128)
1144             .build()
1145             .preimage(DerivedFunc::HalfWidth)
1146             .size(),
1147         0
1148     );
1149     assert_eq!(
1150         TypeSetBuilder::new()
1151             .simd_lanes(64..256)
1152             .bools(1..64)
1153             .build()
1154             .preimage(DerivedFunc::HalfWidth),
1155         TypeSetBuilder::new()
1156             .simd_lanes(64..256)
1157             .bools(16..128)
1158             .build(),
1159     );
1160 
1161     // Double width.
1162     assert_eq!(
1163         TypeSetBuilder::new()
1164             .ints(8..8)
1165             .floats(32..32)
1166             .bools(1..8)
1167             .build()
1168             .preimage(DerivedFunc::DoubleWidth)
1169             .size(),
1170         0
1171     );
1172     assert_eq!(
1173         TypeSetBuilder::new()
1174             .simd_lanes(1..16)
1175             .ints(8..16)
1176             .floats(32..64)
1177             .build()
1178             .preimage(DerivedFunc::DoubleWidth),
1179         TypeSetBuilder::new()
1180             .simd_lanes(1..16)
1181             .ints(8..8)
1182             .floats(32..32)
1183             .build()
1184     );
1185 }
1186 
1187 #[test]
1188 #[should_panic]
test_typeset_singleton_panic_nonsingleton_types()1189 fn test_typeset_singleton_panic_nonsingleton_types() {
1190     TypeSetBuilder::new()
1191         .ints(8..8)
1192         .floats(32..32)
1193         .build()
1194         .get_singleton();
1195 }
1196 
1197 #[test]
1198 #[should_panic]
test_typeset_singleton_panic_nonsingleton_lanes()1199 fn test_typeset_singleton_panic_nonsingleton_lanes() {
1200     TypeSetBuilder::new()
1201         .simd_lanes(1..2)
1202         .floats(32..32)
1203         .build()
1204         .get_singleton();
1205 }
1206 
1207 #[test]
test_typeset_singleton()1208 fn test_typeset_singleton() {
1209     use crate::shared::types as shared_types;
1210     assert_eq!(
1211         TypeSetBuilder::new().ints(16..16).build().get_singleton(),
1212         ValueType::Lane(shared_types::Int::I16.into())
1213     );
1214     assert_eq!(
1215         TypeSetBuilder::new().floats(64..64).build().get_singleton(),
1216         ValueType::Lane(shared_types::Float::F64.into())
1217     );
1218     assert_eq!(
1219         TypeSetBuilder::new().bools(1..1).build().get_singleton(),
1220         ValueType::Lane(shared_types::Bool::B1.into())
1221     );
1222     assert_eq!(
1223         TypeSetBuilder::new()
1224             .simd_lanes(4..4)
1225             .ints(32..32)
1226             .build()
1227             .get_singleton(),
1228         LaneType::from(shared_types::Int::I32).by(4)
1229     );
1230 }
1231 
1232 #[test]
test_typevar_functions()1233 fn test_typevar_functions() {
1234     let x = TypeVar::new(
1235         "x",
1236         "i16 and up",
1237         TypeSetBuilder::new().ints(16..64).build(),
1238     );
1239     assert_eq!(x.half_width().name, "half_width(x)");
1240     assert_eq!(
1241         x.half_width().double_width().name,
1242         "double_width(half_width(x))"
1243     );
1244 
1245     let x = TypeVar::new("x", "up to i32", TypeSetBuilder::new().ints(8..32).build());
1246     assert_eq!(x.double_width().name, "double_width(x)");
1247 }
1248 
1249 #[test]
test_typevar_singleton()1250 fn test_typevar_singleton() {
1251     use crate::cdsl::types::VectorType;
1252     use crate::shared::types as shared_types;
1253 
1254     // Test i32.
1255     let typevar = TypeVar::new_singleton(ValueType::Lane(LaneType::Int(shared_types::Int::I32)));
1256     assert_eq!(typevar.name, "i32");
1257     assert_eq!(typevar.type_set.ints, num_set![32]);
1258     assert!(typevar.type_set.floats.is_empty());
1259     assert!(typevar.type_set.bools.is_empty());
1260     assert!(typevar.type_set.specials.is_empty());
1261     assert_eq!(typevar.type_set.lanes, num_set![1]);
1262 
1263     // Test f32x4.
1264     let typevar = TypeVar::new_singleton(ValueType::Vector(VectorType::new(
1265         LaneType::Float(shared_types::Float::F32),
1266         4,
1267     )));
1268     assert_eq!(typevar.name, "f32x4");
1269     assert!(typevar.type_set.ints.is_empty());
1270     assert_eq!(typevar.type_set.floats, num_set![32]);
1271     assert_eq!(typevar.type_set.lanes, num_set![4]);
1272     assert!(typevar.type_set.bools.is_empty());
1273     assert!(typevar.type_set.specials.is_empty());
1274 }
1275