1 use std::cell::RefCell;
2 use std::collections::{BTreeSet, HashSet};
3 use std::fmt;
4 use std::hash;
5 use std::iter::FromIterator;
6 use std::ops;
7 use std::rc::Rc;
8
9 use crate::cdsl::types::{LaneType, ReferenceType, SpecialType, ValueType};
10
11 const MAX_LANES: u16 = 256;
12 const MAX_BITS: u16 = 128;
13 const MAX_FLOAT_BITS: u16 = 64;
14
15 /// Type variables can be used in place of concrete types when defining
16 /// instructions. This makes the instructions *polymorphic*.
17 ///
18 /// A type variable is restricted to vary over a subset of the value types.
19 /// This subset is specified by a set of flags that control the permitted base
20 /// types and whether the type variable can assume scalar or vector types, or
21 /// both.
22 #[derive(Debug)]
23 pub(crate) struct TypeVarContent {
24 /// Short name of type variable used in instruction descriptions.
25 pub name: String,
26
27 /// Documentation string.
28 pub doc: String,
29
30 /// Type set associated to the type variable.
31 /// This field must remain private; use `get_typeset()` or `get_raw_typeset()` to get the
32 /// information you want.
33 type_set: TypeSet,
34
35 pub base: Option<TypeVarParent>,
36 }
37
38 #[derive(Clone, Debug)]
39 pub(crate) struct TypeVar {
40 content: Rc<RefCell<TypeVarContent>>,
41 }
42
43 impl TypeVar {
new(name: impl Into<String>, doc: impl Into<String>, type_set: TypeSet) -> Self44 pub fn new(name: impl Into<String>, doc: impl Into<String>, type_set: TypeSet) -> Self {
45 Self {
46 content: Rc::new(RefCell::new(TypeVarContent {
47 name: name.into(),
48 doc: doc.into(),
49 type_set,
50 base: None,
51 })),
52 }
53 }
54
new_singleton(value_type: ValueType) -> Self55 pub fn new_singleton(value_type: ValueType) -> Self {
56 let (name, doc) = (value_type.to_string(), value_type.doc());
57 let mut builder = TypeSetBuilder::new();
58
59 let (scalar_type, num_lanes) = match value_type {
60 ValueType::Special(special_type) => {
61 return TypeVar::new(name, doc, builder.specials(vec![special_type]).build());
62 }
63 ValueType::Reference(ReferenceType(reference_type)) => {
64 let bits = reference_type as RangeBound;
65 return TypeVar::new(name, doc, builder.refs(bits..bits).build());
66 }
67 ValueType::Lane(lane_type) => (lane_type, 1),
68 ValueType::Vector(vec_type) => {
69 (vec_type.lane_type(), vec_type.lane_count() as RangeBound)
70 }
71 };
72
73 builder = builder.simd_lanes(num_lanes..num_lanes);
74
75 let builder = match scalar_type {
76 LaneType::Int(int_type) => {
77 let bits = int_type as RangeBound;
78 builder.ints(bits..bits)
79 }
80 LaneType::Float(float_type) => {
81 let bits = float_type as RangeBound;
82 builder.floats(bits..bits)
83 }
84 LaneType::Bool(bool_type) => {
85 let bits = bool_type as RangeBound;
86 builder.bools(bits..bits)
87 }
88 };
89 TypeVar::new(name, doc, builder.build())
90 }
91
92 /// Get a fresh copy of self, named after `name`. Can only be called on non-derived typevars.
copy_from(other: &TypeVar, name: String) -> TypeVar93 pub fn copy_from(other: &TypeVar, name: String) -> TypeVar {
94 assert!(
95 other.base.is_none(),
96 "copy_from() can only be called on non-derived type variables"
97 );
98 TypeVar {
99 content: Rc::new(RefCell::new(TypeVarContent {
100 name,
101 doc: "".into(),
102 type_set: other.type_set.clone(),
103 base: None,
104 })),
105 }
106 }
107
108 /// Returns the typeset for this TV. If the TV is derived, computes it recursively from the
109 /// derived function and the base's typeset.
110 /// Note this can't be done non-lazily in the constructor, because the TypeSet of the base may
111 /// change over time.
get_typeset(&self) -> TypeSet112 pub fn get_typeset(&self) -> TypeSet {
113 match &self.base {
114 Some(base) => base.type_var.get_typeset().image(base.derived_func),
115 None => self.type_set.clone(),
116 }
117 }
118
119 /// Returns this typevar's type set, assuming this type var has no parent.
get_raw_typeset(&self) -> &TypeSet120 pub fn get_raw_typeset(&self) -> &TypeSet {
121 assert_eq!(self.type_set, self.get_typeset());
122 &self.type_set
123 }
124
125 /// If the associated typeset has a single type return it. Otherwise return None.
singleton_type(&self) -> Option<ValueType>126 pub fn singleton_type(&self) -> Option<ValueType> {
127 let type_set = self.get_typeset();
128 if type_set.size() == 1 {
129 Some(type_set.get_singleton())
130 } else {
131 None
132 }
133 }
134
135 /// Get the free type variable controlling this one.
free_typevar(&self) -> Option<TypeVar>136 pub fn free_typevar(&self) -> Option<TypeVar> {
137 match &self.base {
138 Some(base) => base.type_var.free_typevar(),
139 None => {
140 match self.singleton_type() {
141 // A singleton type isn't a proper free variable.
142 Some(_) => None,
143 None => Some(self.clone()),
144 }
145 }
146 }
147 }
148
149 /// Create a type variable that is a function of another.
derived(&self, derived_func: DerivedFunc) -> TypeVar150 pub fn derived(&self, derived_func: DerivedFunc) -> TypeVar {
151 let ts = self.get_typeset();
152
153 // Safety checks to avoid over/underflows.
154 assert!(ts.specials.is_empty(), "can't derive from special types");
155 match derived_func {
156 DerivedFunc::HalfWidth => {
157 assert!(
158 ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
159 "can't halve all integer types"
160 );
161 assert!(
162 ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 32,
163 "can't halve all float types"
164 );
165 assert!(
166 ts.bools.is_empty() || *ts.bools.iter().min().unwrap() > 8,
167 "can't halve all boolean types"
168 );
169 }
170 DerivedFunc::DoubleWidth => {
171 assert!(
172 ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
173 "can't double all integer types"
174 );
175 assert!(
176 ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
177 "can't double all float types"
178 );
179 assert!(
180 ts.bools.is_empty() || *ts.bools.iter().max().unwrap() < MAX_BITS,
181 "can't double all boolean types"
182 );
183 }
184 DerivedFunc::HalfVector => {
185 assert!(
186 *ts.lanes.iter().min().unwrap() > 1,
187 "can't halve a scalar type"
188 );
189 }
190 DerivedFunc::DoubleVector => {
191 assert!(
192 *ts.lanes.iter().max().unwrap() < MAX_LANES,
193 "can't double 256 lanes"
194 );
195 }
196 DerivedFunc::SplitLanes => {
197 assert!(
198 ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
199 "can't halve all integer types"
200 );
201 assert!(
202 ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 32,
203 "can't halve all float types"
204 );
205 assert!(
206 ts.bools.is_empty() || *ts.bools.iter().min().unwrap() > 8,
207 "can't halve all boolean types"
208 );
209 assert!(
210 *ts.lanes.iter().max().unwrap() < MAX_LANES,
211 "can't double 256 lanes"
212 );
213 }
214 DerivedFunc::MergeLanes => {
215 assert!(
216 ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
217 "can't double all integer types"
218 );
219 assert!(
220 ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
221 "can't double all float types"
222 );
223 assert!(
224 ts.bools.is_empty() || *ts.bools.iter().max().unwrap() < MAX_BITS,
225 "can't double all boolean types"
226 );
227 assert!(
228 *ts.lanes.iter().min().unwrap() > 1,
229 "can't halve a scalar type"
230 );
231 }
232 DerivedFunc::LaneOf | DerivedFunc::AsBool => { /* no particular assertions */ }
233 }
234
235 TypeVar {
236 content: Rc::new(RefCell::new(TypeVarContent {
237 name: format!("{}({})", derived_func.name(), self.name),
238 doc: "".into(),
239 type_set: ts,
240 base: Some(TypeVarParent {
241 type_var: self.clone(),
242 derived_func,
243 }),
244 })),
245 }
246 }
247
lane_of(&self) -> TypeVar248 pub fn lane_of(&self) -> TypeVar {
249 self.derived(DerivedFunc::LaneOf)
250 }
as_bool(&self) -> TypeVar251 pub fn as_bool(&self) -> TypeVar {
252 self.derived(DerivedFunc::AsBool)
253 }
half_width(&self) -> TypeVar254 pub fn half_width(&self) -> TypeVar {
255 self.derived(DerivedFunc::HalfWidth)
256 }
double_width(&self) -> TypeVar257 pub fn double_width(&self) -> TypeVar {
258 self.derived(DerivedFunc::DoubleWidth)
259 }
half_vector(&self) -> TypeVar260 pub fn half_vector(&self) -> TypeVar {
261 self.derived(DerivedFunc::HalfVector)
262 }
double_vector(&self) -> TypeVar263 pub fn double_vector(&self) -> TypeVar {
264 self.derived(DerivedFunc::DoubleVector)
265 }
split_lanes(&self) -> TypeVar266 pub fn split_lanes(&self) -> TypeVar {
267 self.derived(DerivedFunc::SplitLanes)
268 }
merge_lanes(&self) -> TypeVar269 pub fn merge_lanes(&self) -> TypeVar {
270 self.derived(DerivedFunc::MergeLanes)
271 }
272
273 /// Constrain the range of types this variable can assume to a subset of those in the typeset
274 /// ts.
275 /// May mutate itself if it's not derived, or its parent if it is.
constrain_types_by_ts(&self, type_set: TypeSet)276 pub fn constrain_types_by_ts(&self, type_set: TypeSet) {
277 match &self.base {
278 Some(base) => {
279 base.type_var
280 .constrain_types_by_ts(type_set.preimage(base.derived_func));
281 }
282 None => {
283 self.content
284 .borrow_mut()
285 .type_set
286 .inplace_intersect_with(&type_set);
287 }
288 }
289 }
290
291 /// Constrain the range of types this variable can assume to a subset of those `other` can
292 /// assume.
293 /// May mutate itself if it's not derived, or its parent if it is.
constrain_types(&self, other: TypeVar)294 pub fn constrain_types(&self, other: TypeVar) {
295 if self == &other {
296 return;
297 }
298 self.constrain_types_by_ts(other.get_typeset());
299 }
300
301 /// Get a Rust expression that computes the type of this type variable.
to_rust_code(&self) -> String302 pub fn to_rust_code(&self) -> String {
303 match &self.base {
304 Some(base) => format!(
305 "{}.{}().unwrap()",
306 base.type_var.to_rust_code(),
307 base.derived_func.name()
308 ),
309 None => {
310 if let Some(singleton) = self.singleton_type() {
311 singleton.rust_name()
312 } else {
313 self.name.clone()
314 }
315 }
316 }
317 }
318 }
319
320 impl Into<TypeVar> for &TypeVar {
into(self) -> TypeVar321 fn into(self) -> TypeVar {
322 self.clone()
323 }
324 }
325 impl Into<TypeVar> for ValueType {
into(self) -> TypeVar326 fn into(self) -> TypeVar {
327 TypeVar::new_singleton(self)
328 }
329 }
330
331 // Hash TypeVars by pointers.
332 // There might be a better way to do this, but since TypeVar's content (namely TypeSet) can be
333 // mutated, it makes sense to use pointer equality/hashing here.
334 impl hash::Hash for TypeVar {
hash<H: hash::Hasher>(&self, h: &mut H)335 fn hash<H: hash::Hasher>(&self, h: &mut H) {
336 match &self.base {
337 Some(base) => {
338 base.type_var.hash(h);
339 base.derived_func.hash(h);
340 }
341 None => {
342 (&**self as *const TypeVarContent).hash(h);
343 }
344 }
345 }
346 }
347
348 impl PartialEq for TypeVar {
eq(&self, other: &TypeVar) -> bool349 fn eq(&self, other: &TypeVar) -> bool {
350 match (&self.base, &other.base) {
351 (Some(base1), Some(base2)) => {
352 base1.type_var.eq(&base2.type_var) && base1.derived_func == base2.derived_func
353 }
354 (None, None) => Rc::ptr_eq(&self.content, &other.content),
355 _ => false,
356 }
357 }
358 }
359
360 // Allow TypeVar as map keys, based on pointer equality (see also above PartialEq impl).
361 impl Eq for TypeVar {}
362
363 impl ops::Deref for TypeVar {
364 type Target = TypeVarContent;
deref(&self) -> &Self::Target365 fn deref(&self) -> &Self::Target {
366 unsafe { self.content.as_ptr().as_ref().unwrap() }
367 }
368 }
369
370 #[derive(Clone, Copy, Debug, Hash, PartialEq)]
371 pub(crate) enum DerivedFunc {
372 LaneOf,
373 AsBool,
374 HalfWidth,
375 DoubleWidth,
376 HalfVector,
377 DoubleVector,
378 SplitLanes,
379 MergeLanes,
380 }
381
382 impl DerivedFunc {
name(self) -> &'static str383 pub fn name(self) -> &'static str {
384 match self {
385 DerivedFunc::LaneOf => "lane_of",
386 DerivedFunc::AsBool => "as_bool",
387 DerivedFunc::HalfWidth => "half_width",
388 DerivedFunc::DoubleWidth => "double_width",
389 DerivedFunc::HalfVector => "half_vector",
390 DerivedFunc::DoubleVector => "double_vector",
391 DerivedFunc::SplitLanes => "split_lanes",
392 DerivedFunc::MergeLanes => "merge_lanes",
393 }
394 }
395
396 /// Returns the inverse function of this one, if it is a bijection.
inverse(self) -> Option<DerivedFunc>397 pub fn inverse(self) -> Option<DerivedFunc> {
398 match self {
399 DerivedFunc::HalfWidth => Some(DerivedFunc::DoubleWidth),
400 DerivedFunc::DoubleWidth => Some(DerivedFunc::HalfWidth),
401 DerivedFunc::HalfVector => Some(DerivedFunc::DoubleVector),
402 DerivedFunc::DoubleVector => Some(DerivedFunc::HalfVector),
403 DerivedFunc::MergeLanes => Some(DerivedFunc::SplitLanes),
404 DerivedFunc::SplitLanes => Some(DerivedFunc::MergeLanes),
405 _ => None,
406 }
407 }
408 }
409
410 #[derive(Debug, Hash)]
411 pub(crate) struct TypeVarParent {
412 pub type_var: TypeVar,
413 pub derived_func: DerivedFunc,
414 }
415
416 /// A set of types.
417 ///
418 /// We don't allow arbitrary subsets of types, but use a parametrized approach
419 /// instead.
420 ///
421 /// Objects of this class can be used as dictionary keys.
422 ///
423 /// Parametrized type sets are specified in terms of ranges:
424 /// - The permitted range of vector lanes, where 1 indicates a scalar type.
425 /// - The permitted range of integer types.
426 /// - The permitted range of floating point types, and
427 /// - The permitted range of boolean types.
428 ///
429 /// The ranges are inclusive from smallest bit-width to largest bit-width.
430 ///
431 /// Finally, a type set can contain special types (derived from `SpecialType`)
432 /// which can't appear as lane types.
433
434 type RangeBound = u16;
435 type Range = ops::Range<RangeBound>;
436 type NumSet = BTreeSet<RangeBound>;
437
438 macro_rules! num_set {
439 ($($expr:expr),*) => {
440 NumSet::from_iter(vec![$($expr),*])
441 };
442 }
443
444 #[derive(Clone, PartialEq, Eq, Hash)]
445 pub(crate) struct TypeSet {
446 pub lanes: NumSet,
447 pub ints: NumSet,
448 pub floats: NumSet,
449 pub bools: NumSet,
450 pub refs: NumSet,
451 pub specials: Vec<SpecialType>,
452 }
453
454 impl TypeSet {
new( lanes: NumSet, ints: NumSet, floats: NumSet, bools: NumSet, refs: NumSet, specials: Vec<SpecialType>, ) -> Self455 fn new(
456 lanes: NumSet,
457 ints: NumSet,
458 floats: NumSet,
459 bools: NumSet,
460 refs: NumSet,
461 specials: Vec<SpecialType>,
462 ) -> Self {
463 Self {
464 lanes,
465 ints,
466 floats,
467 bools,
468 refs,
469 specials,
470 }
471 }
472
473 /// Return the number of concrete types represented by this typeset.
size(&self) -> usize474 pub fn size(&self) -> usize {
475 self.lanes.len()
476 * (self.ints.len() + self.floats.len() + self.bools.len() + self.refs.len())
477 + self.specials.len()
478 }
479
480 /// Return the image of self across the derived function func.
image(&self, derived_func: DerivedFunc) -> TypeSet481 fn image(&self, derived_func: DerivedFunc) -> TypeSet {
482 match derived_func {
483 DerivedFunc::LaneOf => self.lane_of(),
484 DerivedFunc::AsBool => self.as_bool(),
485 DerivedFunc::HalfWidth => self.half_width(),
486 DerivedFunc::DoubleWidth => self.double_width(),
487 DerivedFunc::HalfVector => self.half_vector(),
488 DerivedFunc::DoubleVector => self.double_vector(),
489 DerivedFunc::SplitLanes => self.half_width().double_vector(),
490 DerivedFunc::MergeLanes => self.double_width().half_vector(),
491 }
492 }
493
494 /// Return a TypeSet describing the image of self across lane_of.
lane_of(&self) -> TypeSet495 fn lane_of(&self) -> TypeSet {
496 let mut copy = self.clone();
497 copy.lanes = num_set![1];
498 copy
499 }
500
501 /// Return a TypeSet describing the image of self across as_bool.
as_bool(&self) -> TypeSet502 fn as_bool(&self) -> TypeSet {
503 let mut copy = self.clone();
504 copy.ints = NumSet::new();
505 copy.floats = NumSet::new();
506 copy.refs = NumSet::new();
507 if !(&self.lanes - &num_set![1]).is_empty() {
508 copy.bools = &self.ints | &self.floats;
509 copy.bools = ©.bools | &self.bools;
510 }
511 if self.lanes.contains(&1) {
512 copy.bools.insert(1);
513 }
514 copy
515 }
516
517 /// Return a TypeSet describing the image of self across halfwidth.
half_width(&self) -> TypeSet518 fn half_width(&self) -> TypeSet {
519 let mut copy = self.clone();
520 copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x > 8).map(|&x| x / 2));
521 copy.floats = NumSet::from_iter(self.floats.iter().filter(|&&x| x > 32).map(|&x| x / 2));
522 copy.bools = NumSet::from_iter(self.bools.iter().filter(|&&x| x > 8).map(|&x| x / 2));
523 copy.specials = Vec::new();
524 copy
525 }
526
527 /// Return a TypeSet describing the image of self across doublewidth.
double_width(&self) -> TypeSet528 fn double_width(&self) -> TypeSet {
529 let mut copy = self.clone();
530 copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x < MAX_BITS).map(|&x| x * 2));
531 copy.floats = NumSet::from_iter(
532 self.floats
533 .iter()
534 .filter(|&&x| x < MAX_FLOAT_BITS)
535 .map(|&x| x * 2),
536 );
537 copy.bools = NumSet::from_iter(
538 self.bools
539 .iter()
540 .filter(|&&x| x < MAX_BITS)
541 .map(|&x| x * 2)
542 .filter(|x| legal_bool(*x)),
543 );
544 copy.specials = Vec::new();
545 copy
546 }
547
548 /// Return a TypeSet describing the image of self across halfvector.
half_vector(&self) -> TypeSet549 fn half_vector(&self) -> TypeSet {
550 let mut copy = self.clone();
551 copy.lanes = NumSet::from_iter(self.lanes.iter().filter(|&&x| x > 1).map(|&x| x / 2));
552 copy.specials = Vec::new();
553 copy
554 }
555
556 /// Return a TypeSet describing the image of self across doublevector.
double_vector(&self) -> TypeSet557 fn double_vector(&self) -> TypeSet {
558 let mut copy = self.clone();
559 copy.lanes = NumSet::from_iter(
560 self.lanes
561 .iter()
562 .filter(|&&x| x < MAX_LANES)
563 .map(|&x| x * 2),
564 );
565 copy.specials = Vec::new();
566 copy
567 }
568
concrete_types(&self) -> Vec<ValueType>569 fn concrete_types(&self) -> Vec<ValueType> {
570 let mut ret = Vec::new();
571 for &num_lanes in &self.lanes {
572 for &bits in &self.ints {
573 ret.push(LaneType::int_from_bits(bits).by(num_lanes));
574 }
575 for &bits in &self.floats {
576 ret.push(LaneType::float_from_bits(bits).by(num_lanes));
577 }
578 for &bits in &self.bools {
579 ret.push(LaneType::bool_from_bits(bits).by(num_lanes));
580 }
581 for &bits in &self.refs {
582 ret.push(ReferenceType::ref_from_bits(bits).into());
583 }
584 }
585 for &special in &self.specials {
586 ret.push(special.into());
587 }
588 ret
589 }
590
591 /// Return the singleton type represented by self. Can only call on typesets containing 1 type.
get_singleton(&self) -> ValueType592 fn get_singleton(&self) -> ValueType {
593 let mut types = self.concrete_types();
594 assert_eq!(types.len(), 1);
595 types.remove(0)
596 }
597
598 /// Return the inverse image of self across the derived function func.
preimage(&self, func: DerivedFunc) -> TypeSet599 fn preimage(&self, func: DerivedFunc) -> TypeSet {
600 if self.size() == 0 {
601 // The inverse of the empty set is itself.
602 return self.clone();
603 }
604
605 match func {
606 DerivedFunc::LaneOf => {
607 let mut copy = self.clone();
608 copy.lanes =
609 NumSet::from_iter((0..=MAX_LANES.trailing_zeros()).map(|i| u16::pow(2, i)));
610 copy
611 }
612 DerivedFunc::AsBool => {
613 let mut copy = self.clone();
614 if self.bools.contains(&1) {
615 copy.ints = NumSet::from_iter(vec![8, 16, 32, 64, 128]);
616 copy.floats = NumSet::from_iter(vec![32, 64]);
617 } else {
618 copy.ints = &self.bools - &NumSet::from_iter(vec![1]);
619 copy.floats = &self.bools & &NumSet::from_iter(vec![32, 64]);
620 // If b1 is not in our typeset, than lanes=1 cannot be in the pre-image, as
621 // as_bool() of scalars is always b1.
622 copy.lanes = &self.lanes - &NumSet::from_iter(vec![1]);
623 }
624 copy
625 }
626 DerivedFunc::HalfWidth => self.double_width(),
627 DerivedFunc::DoubleWidth => self.half_width(),
628 DerivedFunc::HalfVector => self.double_vector(),
629 DerivedFunc::DoubleVector => self.half_vector(),
630 DerivedFunc::SplitLanes => self.double_width().half_vector(),
631 DerivedFunc::MergeLanes => self.half_width().double_vector(),
632 }
633 }
634
inplace_intersect_with(&mut self, other: &TypeSet)635 pub fn inplace_intersect_with(&mut self, other: &TypeSet) {
636 self.lanes = &self.lanes & &other.lanes;
637 self.ints = &self.ints & &other.ints;
638 self.floats = &self.floats & &other.floats;
639 self.bools = &self.bools & &other.bools;
640 self.refs = &self.refs & &other.refs;
641
642 let mut new_specials = Vec::new();
643 for spec in &self.specials {
644 if let Some(spec) = other.specials.iter().find(|&other_spec| other_spec == spec) {
645 new_specials.push(*spec);
646 }
647 }
648 self.specials = new_specials;
649 }
650
is_subset(&self, other: &TypeSet) -> bool651 pub fn is_subset(&self, other: &TypeSet) -> bool {
652 self.lanes.is_subset(&other.lanes)
653 && self.ints.is_subset(&other.ints)
654 && self.floats.is_subset(&other.floats)
655 && self.bools.is_subset(&other.bools)
656 && self.refs.is_subset(&other.refs)
657 && {
658 let specials: HashSet<SpecialType> = HashSet::from_iter(self.specials.clone());
659 let other_specials = HashSet::from_iter(other.specials.clone());
660 specials.is_subset(&other_specials)
661 }
662 }
663
is_wider_or_equal(&self, other: &TypeSet) -> bool664 pub fn is_wider_or_equal(&self, other: &TypeSet) -> bool {
665 set_wider_or_equal(&self.ints, &other.ints)
666 && set_wider_or_equal(&self.floats, &other.floats)
667 && set_wider_or_equal(&self.bools, &other.bools)
668 && set_wider_or_equal(&self.refs, &other.refs)
669 }
670
is_narrower(&self, other: &TypeSet) -> bool671 pub fn is_narrower(&self, other: &TypeSet) -> bool {
672 set_narrower(&self.ints, &other.ints)
673 && set_narrower(&self.floats, &other.floats)
674 && set_narrower(&self.bools, &other.bools)
675 && set_narrower(&self.refs, &other.refs)
676 }
677 }
678
set_wider_or_equal(s1: &NumSet, s2: &NumSet) -> bool679 fn set_wider_or_equal(s1: &NumSet, s2: &NumSet) -> bool {
680 !s1.is_empty() && !s2.is_empty() && s1.iter().min() >= s2.iter().max()
681 }
682
set_narrower(s1: &NumSet, s2: &NumSet) -> bool683 fn set_narrower(s1: &NumSet, s2: &NumSet) -> bool {
684 !s1.is_empty() && !s2.is_empty() && s1.iter().min() < s2.iter().max()
685 }
686
687 impl fmt::Debug for TypeSet {
fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error>688 fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
689 write!(fmt, "TypeSet(")?;
690
691 let mut subsets = Vec::new();
692 if !self.lanes.is_empty() {
693 subsets.push(format!(
694 "lanes={{{}}}",
695 Vec::from_iter(self.lanes.iter().map(|x| x.to_string())).join(", ")
696 ));
697 }
698 if !self.ints.is_empty() {
699 subsets.push(format!(
700 "ints={{{}}}",
701 Vec::from_iter(self.ints.iter().map(|x| x.to_string())).join(", ")
702 ));
703 }
704 if !self.floats.is_empty() {
705 subsets.push(format!(
706 "floats={{{}}}",
707 Vec::from_iter(self.floats.iter().map(|x| x.to_string())).join(", ")
708 ));
709 }
710 if !self.bools.is_empty() {
711 subsets.push(format!(
712 "bools={{{}}}",
713 Vec::from_iter(self.bools.iter().map(|x| x.to_string())).join(", ")
714 ));
715 }
716 if !self.refs.is_empty() {
717 subsets.push(format!(
718 "refs={{{}}}",
719 Vec::from_iter(self.refs.iter().map(|x| x.to_string())).join(", ")
720 ));
721 }
722 if !self.specials.is_empty() {
723 subsets.push(format!(
724 "specials={{{}}}",
725 Vec::from_iter(self.specials.iter().map(|x| x.to_string())).join(", ")
726 ));
727 }
728
729 write!(fmt, "{})", subsets.join(", "))?;
730 Ok(())
731 }
732 }
733
734 pub(crate) struct TypeSetBuilder {
735 ints: Interval,
736 floats: Interval,
737 bools: Interval,
738 refs: Interval,
739 includes_scalars: bool,
740 simd_lanes: Interval,
741 specials: Vec<SpecialType>,
742 }
743
744 impl TypeSetBuilder {
new() -> Self745 pub fn new() -> Self {
746 Self {
747 ints: Interval::None,
748 floats: Interval::None,
749 bools: Interval::None,
750 refs: Interval::None,
751 includes_scalars: true,
752 simd_lanes: Interval::None,
753 specials: Vec::new(),
754 }
755 }
756
ints(mut self, interval: impl Into<Interval>) -> Self757 pub fn ints(mut self, interval: impl Into<Interval>) -> Self {
758 assert!(self.ints == Interval::None);
759 self.ints = interval.into();
760 self
761 }
floats(mut self, interval: impl Into<Interval>) -> Self762 pub fn floats(mut self, interval: impl Into<Interval>) -> Self {
763 assert!(self.floats == Interval::None);
764 self.floats = interval.into();
765 self
766 }
bools(mut self, interval: impl Into<Interval>) -> Self767 pub fn bools(mut self, interval: impl Into<Interval>) -> Self {
768 assert!(self.bools == Interval::None);
769 self.bools = interval.into();
770 self
771 }
refs(mut self, interval: impl Into<Interval>) -> Self772 pub fn refs(mut self, interval: impl Into<Interval>) -> Self {
773 assert!(self.refs == Interval::None);
774 self.refs = interval.into();
775 self
776 }
includes_scalars(mut self, includes_scalars: bool) -> Self777 pub fn includes_scalars(mut self, includes_scalars: bool) -> Self {
778 self.includes_scalars = includes_scalars;
779 self
780 }
simd_lanes(mut self, interval: impl Into<Interval>) -> Self781 pub fn simd_lanes(mut self, interval: impl Into<Interval>) -> Self {
782 assert!(self.simd_lanes == Interval::None);
783 self.simd_lanes = interval.into();
784 self
785 }
specials(mut self, specials: Vec<SpecialType>) -> Self786 pub fn specials(mut self, specials: Vec<SpecialType>) -> Self {
787 assert!(self.specials.is_empty());
788 self.specials = specials;
789 self
790 }
791
build(self) -> TypeSet792 pub fn build(self) -> TypeSet {
793 let min_lanes = if self.includes_scalars { 1 } else { 2 };
794
795 let bools = range_to_set(self.bools.to_range(1..MAX_BITS, None))
796 .into_iter()
797 .filter(|x| legal_bool(*x))
798 .collect();
799
800 TypeSet::new(
801 range_to_set(self.simd_lanes.to_range(min_lanes..MAX_LANES, Some(1))),
802 range_to_set(self.ints.to_range(8..MAX_BITS, None)),
803 range_to_set(self.floats.to_range(32..64, None)),
804 bools,
805 range_to_set(self.refs.to_range(32..64, None)),
806 self.specials,
807 )
808 }
809
all() -> TypeSet810 pub fn all() -> TypeSet {
811 TypeSetBuilder::new()
812 .ints(Interval::All)
813 .floats(Interval::All)
814 .bools(Interval::All)
815 .refs(Interval::All)
816 .simd_lanes(Interval::All)
817 .specials(ValueType::all_special_types().collect())
818 .includes_scalars(true)
819 .build()
820 }
821 }
822
823 #[derive(PartialEq)]
824 pub(crate) enum Interval {
825 None,
826 All,
827 Range(Range),
828 }
829
830 impl Interval {
to_range(&self, full_range: Range, default: Option<RangeBound>) -> Option<Range>831 fn to_range(&self, full_range: Range, default: Option<RangeBound>) -> Option<Range> {
832 match self {
833 Interval::None => {
834 if let Some(default_val) = default {
835 Some(default_val..default_val)
836 } else {
837 None
838 }
839 }
840
841 Interval::All => Some(full_range),
842
843 Interval::Range(range) => {
844 let (low, high) = (range.start, range.end);
845 assert!(low.is_power_of_two());
846 assert!(high.is_power_of_two());
847 assert!(low <= high);
848 assert!(low >= full_range.start);
849 assert!(high <= full_range.end);
850 Some(low..high)
851 }
852 }
853 }
854 }
855
856 impl Into<Interval> for Range {
into(self) -> Interval857 fn into(self) -> Interval {
858 Interval::Range(self)
859 }
860 }
861
legal_bool(bits: RangeBound) -> bool862 fn legal_bool(bits: RangeBound) -> bool {
863 // Only allow legal bit widths for bool types.
864 bits == 1 || (bits >= 8 && bits <= MAX_BITS && bits.is_power_of_two())
865 }
866
867 /// Generates a set with all the powers of two included in the range.
range_to_set(range: Option<Range>) -> NumSet868 fn range_to_set(range: Option<Range>) -> NumSet {
869 let mut set = NumSet::new();
870
871 let (low, high) = match range {
872 Some(range) => (range.start, range.end),
873 None => return set,
874 };
875
876 assert!(low.is_power_of_two());
877 assert!(high.is_power_of_two());
878 assert!(low <= high);
879
880 for i in low.trailing_zeros()..=high.trailing_zeros() {
881 assert!(1 << i <= RangeBound::max_value());
882 set.insert(1 << i);
883 }
884 set
885 }
886
887 #[test]
test_typevar_builder()888 fn test_typevar_builder() {
889 let type_set = TypeSetBuilder::new().ints(Interval::All).build();
890 assert_eq!(type_set.lanes, num_set![1]);
891 assert!(type_set.floats.is_empty());
892 assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]);
893 assert!(type_set.bools.is_empty());
894 assert!(type_set.specials.is_empty());
895
896 let type_set = TypeSetBuilder::new().bools(Interval::All).build();
897 assert_eq!(type_set.lanes, num_set![1]);
898 assert!(type_set.floats.is_empty());
899 assert!(type_set.ints.is_empty());
900 assert_eq!(type_set.bools, num_set![1, 8, 16, 32, 64, 128]);
901 assert!(type_set.specials.is_empty());
902
903 let type_set = TypeSetBuilder::new().floats(Interval::All).build();
904 assert_eq!(type_set.lanes, num_set![1]);
905 assert_eq!(type_set.floats, num_set![32, 64]);
906 assert!(type_set.ints.is_empty());
907 assert!(type_set.bools.is_empty());
908 assert!(type_set.specials.is_empty());
909
910 let type_set = TypeSetBuilder::new()
911 .floats(Interval::All)
912 .simd_lanes(Interval::All)
913 .includes_scalars(false)
914 .build();
915 assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]);
916 assert_eq!(type_set.floats, num_set![32, 64]);
917 assert!(type_set.ints.is_empty());
918 assert!(type_set.bools.is_empty());
919 assert!(type_set.specials.is_empty());
920
921 let type_set = TypeSetBuilder::new()
922 .floats(Interval::All)
923 .simd_lanes(Interval::All)
924 .includes_scalars(true)
925 .build();
926 assert_eq!(type_set.lanes, num_set![1, 2, 4, 8, 16, 32, 64, 128, 256]);
927 assert_eq!(type_set.floats, num_set![32, 64]);
928 assert!(type_set.ints.is_empty());
929 assert!(type_set.bools.is_empty());
930 assert!(type_set.specials.is_empty());
931
932 let type_set = TypeSetBuilder::new().ints(16..64).build();
933 assert_eq!(type_set.lanes, num_set![1]);
934 assert_eq!(type_set.ints, num_set![16, 32, 64]);
935 assert!(type_set.floats.is_empty());
936 assert!(type_set.bools.is_empty());
937 assert!(type_set.specials.is_empty());
938 }
939
940 #[test]
941 #[should_panic]
test_typevar_builder_too_high_bound_panic()942 fn test_typevar_builder_too_high_bound_panic() {
943 TypeSetBuilder::new().ints(16..2 * MAX_BITS).build();
944 }
945
946 #[test]
947 #[should_panic]
test_typevar_builder_inverted_bounds_panic()948 fn test_typevar_builder_inverted_bounds_panic() {
949 TypeSetBuilder::new().ints(32..16).build();
950 }
951
952 #[test]
test_as_bool()953 fn test_as_bool() {
954 let a = TypeSetBuilder::new()
955 .simd_lanes(2..8)
956 .ints(8..8)
957 .floats(32..32)
958 .build();
959 assert_eq!(
960 a.lane_of(),
961 TypeSetBuilder::new().ints(8..8).floats(32..32).build()
962 );
963
964 // Test as_bool with disjoint intervals.
965 let mut a_as_bool = TypeSetBuilder::new().simd_lanes(2..8).build();
966 a_as_bool.bools = num_set![8, 32];
967 assert_eq!(a.as_bool(), a_as_bool);
968
969 let b = TypeSetBuilder::new()
970 .simd_lanes(1..8)
971 .ints(8..8)
972 .floats(32..32)
973 .build();
974 let mut b_as_bool = TypeSetBuilder::new().simd_lanes(1..8).build();
975 b_as_bool.bools = num_set![1, 8, 32];
976 assert_eq!(b.as_bool(), b_as_bool);
977 }
978
979 #[test]
test_forward_images()980 fn test_forward_images() {
981 let empty_set = TypeSetBuilder::new().build();
982
983 // Half vector.
984 assert_eq!(
985 TypeSetBuilder::new()
986 .simd_lanes(1..32)
987 .build()
988 .half_vector(),
989 TypeSetBuilder::new().simd_lanes(1..16).build()
990 );
991
992 // Double vector.
993 assert_eq!(
994 TypeSetBuilder::new()
995 .simd_lanes(1..32)
996 .build()
997 .double_vector(),
998 TypeSetBuilder::new().simd_lanes(2..64).build()
999 );
1000 assert_eq!(
1001 TypeSetBuilder::new()
1002 .simd_lanes(128..256)
1003 .build()
1004 .double_vector(),
1005 TypeSetBuilder::new().simd_lanes(256..256).build()
1006 );
1007
1008 // Half width.
1009 assert_eq!(
1010 TypeSetBuilder::new().ints(8..32).build().half_width(),
1011 TypeSetBuilder::new().ints(8..16).build()
1012 );
1013 assert_eq!(
1014 TypeSetBuilder::new().floats(32..32).build().half_width(),
1015 empty_set
1016 );
1017 assert_eq!(
1018 TypeSetBuilder::new().floats(32..64).build().half_width(),
1019 TypeSetBuilder::new().floats(32..32).build()
1020 );
1021 assert_eq!(
1022 TypeSetBuilder::new().bools(1..8).build().half_width(),
1023 empty_set
1024 );
1025 assert_eq!(
1026 TypeSetBuilder::new().bools(1..32).build().half_width(),
1027 TypeSetBuilder::new().bools(8..16).build()
1028 );
1029
1030 // Double width.
1031 assert_eq!(
1032 TypeSetBuilder::new().ints(8..32).build().double_width(),
1033 TypeSetBuilder::new().ints(16..64).build()
1034 );
1035 assert_eq!(
1036 TypeSetBuilder::new().ints(32..64).build().double_width(),
1037 TypeSetBuilder::new().ints(64..128).build()
1038 );
1039 assert_eq!(
1040 TypeSetBuilder::new().floats(32..32).build().double_width(),
1041 TypeSetBuilder::new().floats(64..64).build()
1042 );
1043 assert_eq!(
1044 TypeSetBuilder::new().floats(32..64).build().double_width(),
1045 TypeSetBuilder::new().floats(64..64).build()
1046 );
1047 assert_eq!(
1048 TypeSetBuilder::new().bools(1..16).build().double_width(),
1049 TypeSetBuilder::new().bools(16..32).build()
1050 );
1051 assert_eq!(
1052 TypeSetBuilder::new().bools(32..64).build().double_width(),
1053 TypeSetBuilder::new().bools(64..128).build()
1054 );
1055 }
1056
1057 #[test]
test_backward_images()1058 fn test_backward_images() {
1059 let empty_set = TypeSetBuilder::new().build();
1060
1061 // LaneOf.
1062 assert_eq!(
1063 TypeSetBuilder::new()
1064 .simd_lanes(1..1)
1065 .ints(8..8)
1066 .floats(32..32)
1067 .build()
1068 .preimage(DerivedFunc::LaneOf),
1069 TypeSetBuilder::new()
1070 .simd_lanes(Interval::All)
1071 .ints(8..8)
1072 .floats(32..32)
1073 .build()
1074 );
1075 assert_eq!(empty_set.preimage(DerivedFunc::LaneOf), empty_set);
1076
1077 // AsBool.
1078 assert_eq!(
1079 TypeSetBuilder::new()
1080 .simd_lanes(1..4)
1081 .bools(1..128)
1082 .build()
1083 .preimage(DerivedFunc::AsBool),
1084 TypeSetBuilder::new()
1085 .simd_lanes(1..4)
1086 .ints(Interval::All)
1087 .bools(Interval::All)
1088 .floats(Interval::All)
1089 .build()
1090 );
1091
1092 // Double vector.
1093 assert_eq!(
1094 TypeSetBuilder::new()
1095 .simd_lanes(1..1)
1096 .ints(8..8)
1097 .build()
1098 .preimage(DerivedFunc::DoubleVector)
1099 .size(),
1100 0
1101 );
1102 assert_eq!(
1103 TypeSetBuilder::new()
1104 .simd_lanes(1..16)
1105 .ints(8..16)
1106 .floats(32..32)
1107 .build()
1108 .preimage(DerivedFunc::DoubleVector),
1109 TypeSetBuilder::new()
1110 .simd_lanes(1..8)
1111 .ints(8..16)
1112 .floats(32..32)
1113 .build(),
1114 );
1115
1116 // Half vector.
1117 assert_eq!(
1118 TypeSetBuilder::new()
1119 .simd_lanes(256..256)
1120 .ints(8..8)
1121 .build()
1122 .preimage(DerivedFunc::HalfVector)
1123 .size(),
1124 0
1125 );
1126 assert_eq!(
1127 TypeSetBuilder::new()
1128 .simd_lanes(64..128)
1129 .bools(1..32)
1130 .build()
1131 .preimage(DerivedFunc::HalfVector),
1132 TypeSetBuilder::new()
1133 .simd_lanes(128..256)
1134 .bools(1..32)
1135 .build(),
1136 );
1137
1138 // Half width.
1139 assert_eq!(
1140 TypeSetBuilder::new()
1141 .ints(128..128)
1142 .floats(64..64)
1143 .bools(128..128)
1144 .build()
1145 .preimage(DerivedFunc::HalfWidth)
1146 .size(),
1147 0
1148 );
1149 assert_eq!(
1150 TypeSetBuilder::new()
1151 .simd_lanes(64..256)
1152 .bools(1..64)
1153 .build()
1154 .preimage(DerivedFunc::HalfWidth),
1155 TypeSetBuilder::new()
1156 .simd_lanes(64..256)
1157 .bools(16..128)
1158 .build(),
1159 );
1160
1161 // Double width.
1162 assert_eq!(
1163 TypeSetBuilder::new()
1164 .ints(8..8)
1165 .floats(32..32)
1166 .bools(1..8)
1167 .build()
1168 .preimage(DerivedFunc::DoubleWidth)
1169 .size(),
1170 0
1171 );
1172 assert_eq!(
1173 TypeSetBuilder::new()
1174 .simd_lanes(1..16)
1175 .ints(8..16)
1176 .floats(32..64)
1177 .build()
1178 .preimage(DerivedFunc::DoubleWidth),
1179 TypeSetBuilder::new()
1180 .simd_lanes(1..16)
1181 .ints(8..8)
1182 .floats(32..32)
1183 .build()
1184 );
1185 }
1186
1187 #[test]
1188 #[should_panic]
test_typeset_singleton_panic_nonsingleton_types()1189 fn test_typeset_singleton_panic_nonsingleton_types() {
1190 TypeSetBuilder::new()
1191 .ints(8..8)
1192 .floats(32..32)
1193 .build()
1194 .get_singleton();
1195 }
1196
1197 #[test]
1198 #[should_panic]
test_typeset_singleton_panic_nonsingleton_lanes()1199 fn test_typeset_singleton_panic_nonsingleton_lanes() {
1200 TypeSetBuilder::new()
1201 .simd_lanes(1..2)
1202 .floats(32..32)
1203 .build()
1204 .get_singleton();
1205 }
1206
1207 #[test]
test_typeset_singleton()1208 fn test_typeset_singleton() {
1209 use crate::shared::types as shared_types;
1210 assert_eq!(
1211 TypeSetBuilder::new().ints(16..16).build().get_singleton(),
1212 ValueType::Lane(shared_types::Int::I16.into())
1213 );
1214 assert_eq!(
1215 TypeSetBuilder::new().floats(64..64).build().get_singleton(),
1216 ValueType::Lane(shared_types::Float::F64.into())
1217 );
1218 assert_eq!(
1219 TypeSetBuilder::new().bools(1..1).build().get_singleton(),
1220 ValueType::Lane(shared_types::Bool::B1.into())
1221 );
1222 assert_eq!(
1223 TypeSetBuilder::new()
1224 .simd_lanes(4..4)
1225 .ints(32..32)
1226 .build()
1227 .get_singleton(),
1228 LaneType::from(shared_types::Int::I32).by(4)
1229 );
1230 }
1231
1232 #[test]
test_typevar_functions()1233 fn test_typevar_functions() {
1234 let x = TypeVar::new(
1235 "x",
1236 "i16 and up",
1237 TypeSetBuilder::new().ints(16..64).build(),
1238 );
1239 assert_eq!(x.half_width().name, "half_width(x)");
1240 assert_eq!(
1241 x.half_width().double_width().name,
1242 "double_width(half_width(x))"
1243 );
1244
1245 let x = TypeVar::new("x", "up to i32", TypeSetBuilder::new().ints(8..32).build());
1246 assert_eq!(x.double_width().name, "double_width(x)");
1247 }
1248
1249 #[test]
test_typevar_singleton()1250 fn test_typevar_singleton() {
1251 use crate::cdsl::types::VectorType;
1252 use crate::shared::types as shared_types;
1253
1254 // Test i32.
1255 let typevar = TypeVar::new_singleton(ValueType::Lane(LaneType::Int(shared_types::Int::I32)));
1256 assert_eq!(typevar.name, "i32");
1257 assert_eq!(typevar.type_set.ints, num_set![32]);
1258 assert!(typevar.type_set.floats.is_empty());
1259 assert!(typevar.type_set.bools.is_empty());
1260 assert!(typevar.type_set.specials.is_empty());
1261 assert_eq!(typevar.type_set.lanes, num_set![1]);
1262
1263 // Test f32x4.
1264 let typevar = TypeVar::new_singleton(ValueType::Vector(VectorType::new(
1265 LaneType::Float(shared_types::Float::F32),
1266 4,
1267 )));
1268 assert_eq!(typevar.name, "f32x4");
1269 assert!(typevar.type_set.ints.is_empty());
1270 assert_eq!(typevar.type_set.floats, num_set![32]);
1271 assert_eq!(typevar.type_set.lanes, num_set![4]);
1272 assert!(typevar.type_set.bools.is_empty());
1273 assert!(typevar.type_set.specials.is_empty());
1274 }
1275