1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/Attributes.h"
10 #include "mlir/IR/StandardTypes.h"
11 #include "gtest/gtest.h"
12
13 using namespace mlir;
14 using namespace mlir::detail;
15
16 template <typename EltTy>
testSplat(Type eltType,const EltTy & splatElt)17 static void testSplat(Type eltType, const EltTy &splatElt) {
18 VectorType shape = VectorType::get({2, 1}, eltType);
19
20 // Check that the generated splat is the same for 1 element and N elements.
21 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
22 EXPECT_TRUE(splat.isSplat());
23
24 auto detectedSplat =
25 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
26 EXPECT_EQ(detectedSplat, splat);
27 }
28
29 namespace {
TEST(DenseSplatTest,BoolSplat)30 TEST(DenseSplatTest, BoolSplat) {
31 MLIRContext context;
32 IntegerType boolTy = IntegerType::get(1, &context);
33 VectorType shape = VectorType::get({2, 2}, boolTy);
34
35 // Check that splat is automatically detected for boolean values.
36 /// True.
37 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
38 EXPECT_TRUE(trueSplat.isSplat());
39 /// False.
40 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
41 EXPECT_TRUE(falseSplat.isSplat());
42 EXPECT_NE(falseSplat, trueSplat);
43
44 /// Detect and handle splat within 8 elements (bool values are bit-packed).
45 /// True.
46 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
47 EXPECT_EQ(detectedSplat, trueSplat);
48 /// False.
49 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
50 EXPECT_EQ(detectedSplat, falseSplat);
51 }
52
TEST(DenseSplatTest,LargeBoolSplat)53 TEST(DenseSplatTest, LargeBoolSplat) {
54 constexpr int64_t boolCount = 56;
55
56 MLIRContext context;
57 IntegerType boolTy = IntegerType::get(1, &context);
58 VectorType shape = VectorType::get({boolCount}, boolTy);
59
60 // Check that splat is automatically detected for boolean values.
61 /// True.
62 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
63 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
64 EXPECT_TRUE(trueSplat.isSplat());
65 EXPECT_TRUE(falseSplat.isSplat());
66
67 /// Detect that the large boolean arrays are properly splatted.
68 /// True.
69 SmallVector<bool, 64> trueValues(boolCount, true);
70 auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
71 EXPECT_EQ(detectedSplat, trueSplat);
72 /// False.
73 SmallVector<bool, 64> falseValues(boolCount, false);
74 detectedSplat = DenseElementsAttr::get(shape, falseValues);
75 EXPECT_EQ(detectedSplat, falseSplat);
76 }
77
TEST(DenseSplatTest,BoolNonSplat)78 TEST(DenseSplatTest, BoolNonSplat) {
79 MLIRContext context;
80 IntegerType boolTy = IntegerType::get(1, &context);
81 VectorType shape = VectorType::get({6}, boolTy);
82
83 // Check that we properly handle non-splat values.
84 DenseElementsAttr nonSplat =
85 DenseElementsAttr::get(shape, {false, false, true, false, false, true});
86 EXPECT_FALSE(nonSplat.isSplat());
87 }
88
TEST(DenseSplatTest,OddIntSplat)89 TEST(DenseSplatTest, OddIntSplat) {
90 // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
91 MLIRContext context;
92 constexpr size_t intWidth = 19;
93 IntegerType intTy = IntegerType::get(intWidth, &context);
94 APInt value(intWidth, 10);
95
96 testSplat(intTy, value);
97 }
98
TEST(DenseSplatTest,Int32Splat)99 TEST(DenseSplatTest, Int32Splat) {
100 MLIRContext context;
101 IntegerType intTy = IntegerType::get(32, &context);
102 int value = 64;
103
104 testSplat(intTy, value);
105 }
106
TEST(DenseSplatTest,IntAttrSplat)107 TEST(DenseSplatTest, IntAttrSplat) {
108 MLIRContext context;
109 IntegerType intTy = IntegerType::get(85, &context);
110 Attribute value = IntegerAttr::get(intTy, 109);
111
112 testSplat(intTy, value);
113 }
114
TEST(DenseSplatTest,F32Splat)115 TEST(DenseSplatTest, F32Splat) {
116 MLIRContext context;
117 FloatType floatTy = FloatType::getF32(&context);
118 float value = 10.0;
119
120 testSplat(floatTy, value);
121 }
122
TEST(DenseSplatTest,F64Splat)123 TEST(DenseSplatTest, F64Splat) {
124 MLIRContext context;
125 FloatType floatTy = FloatType::getF64(&context);
126 double value = 10.0;
127
128 testSplat(floatTy, APFloat(value));
129 }
130
TEST(DenseSplatTest,FloatAttrSplat)131 TEST(DenseSplatTest, FloatAttrSplat) {
132 MLIRContext context;
133 FloatType floatTy = FloatType::getBF16(&context);
134 Attribute value = FloatAttr::get(floatTy, 10.0);
135
136 testSplat(floatTy, value);
137 }
138
TEST(DenseSplatTest,BF16Splat)139 TEST(DenseSplatTest, BF16Splat) {
140 MLIRContext context;
141 FloatType floatTy = FloatType::getBF16(&context);
142 // Note: We currently use double to represent bfloat16.
143 double value = 10.0;
144
145 testSplat(floatTy, value);
146 }
147
148 } // end namespace
149