1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "mlir/IR/StandardTypes.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 template <typename EltTy>
testSplat(Type eltType,const EltTy & splatElt)17 static void testSplat(Type eltType, const EltTy &splatElt) {
18   VectorType shape = VectorType::get({2, 1}, eltType);
19 
20   // Check that the generated splat is the same for 1 element and N elements.
21   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
22   EXPECT_TRUE(splat.isSplat());
23 
24   auto detectedSplat =
25       DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
26   EXPECT_EQ(detectedSplat, splat);
27 }
28 
29 namespace {
TEST(DenseSplatTest,BoolSplat)30 TEST(DenseSplatTest, BoolSplat) {
31   MLIRContext context;
32   IntegerType boolTy = IntegerType::get(1, &context);
33   VectorType shape = VectorType::get({2, 2}, boolTy);
34 
35   // Check that splat is automatically detected for boolean values.
36   /// True.
37   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
38   EXPECT_TRUE(trueSplat.isSplat());
39   /// False.
40   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
41   EXPECT_TRUE(falseSplat.isSplat());
42   EXPECT_NE(falseSplat, trueSplat);
43 
44   /// Detect and handle splat within 8 elements (bool values are bit-packed).
45   /// True.
46   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
47   EXPECT_EQ(detectedSplat, trueSplat);
48   /// False.
49   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
50   EXPECT_EQ(detectedSplat, falseSplat);
51 }
52 
TEST(DenseSplatTest,LargeBoolSplat)53 TEST(DenseSplatTest, LargeBoolSplat) {
54   constexpr int64_t boolCount = 56;
55 
56   MLIRContext context;
57   IntegerType boolTy = IntegerType::get(1, &context);
58   VectorType shape = VectorType::get({boolCount}, boolTy);
59 
60   // Check that splat is automatically detected for boolean values.
61   /// True.
62   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
63   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
64   EXPECT_TRUE(trueSplat.isSplat());
65   EXPECT_TRUE(falseSplat.isSplat());
66 
67   /// Detect that the large boolean arrays are properly splatted.
68   /// True.
69   SmallVector<bool, 64> trueValues(boolCount, true);
70   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
71   EXPECT_EQ(detectedSplat, trueSplat);
72   /// False.
73   SmallVector<bool, 64> falseValues(boolCount, false);
74   detectedSplat = DenseElementsAttr::get(shape, falseValues);
75   EXPECT_EQ(detectedSplat, falseSplat);
76 }
77 
TEST(DenseSplatTest,BoolNonSplat)78 TEST(DenseSplatTest, BoolNonSplat) {
79   MLIRContext context;
80   IntegerType boolTy = IntegerType::get(1, &context);
81   VectorType shape = VectorType::get({6}, boolTy);
82 
83   // Check that we properly handle non-splat values.
84   DenseElementsAttr nonSplat =
85       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
86   EXPECT_FALSE(nonSplat.isSplat());
87 }
88 
TEST(DenseSplatTest,OddIntSplat)89 TEST(DenseSplatTest, OddIntSplat) {
90   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
91   MLIRContext context;
92   constexpr size_t intWidth = 19;
93   IntegerType intTy = IntegerType::get(intWidth, &context);
94   APInt value(intWidth, 10);
95 
96   testSplat(intTy, value);
97 }
98 
TEST(DenseSplatTest,Int32Splat)99 TEST(DenseSplatTest, Int32Splat) {
100   MLIRContext context;
101   IntegerType intTy = IntegerType::get(32, &context);
102   int value = 64;
103 
104   testSplat(intTy, value);
105 }
106 
TEST(DenseSplatTest,IntAttrSplat)107 TEST(DenseSplatTest, IntAttrSplat) {
108   MLIRContext context;
109   IntegerType intTy = IntegerType::get(85, &context);
110   Attribute value = IntegerAttr::get(intTy, 109);
111 
112   testSplat(intTy, value);
113 }
114 
TEST(DenseSplatTest,F32Splat)115 TEST(DenseSplatTest, F32Splat) {
116   MLIRContext context;
117   FloatType floatTy = FloatType::getF32(&context);
118   float value = 10.0;
119 
120   testSplat(floatTy, value);
121 }
122 
TEST(DenseSplatTest,F64Splat)123 TEST(DenseSplatTest, F64Splat) {
124   MLIRContext context;
125   FloatType floatTy = FloatType::getF64(&context);
126   double value = 10.0;
127 
128   testSplat(floatTy, APFloat(value));
129 }
130 
TEST(DenseSplatTest,FloatAttrSplat)131 TEST(DenseSplatTest, FloatAttrSplat) {
132   MLIRContext context;
133   FloatType floatTy = FloatType::getBF16(&context);
134   Attribute value = FloatAttr::get(floatTy, 10.0);
135 
136   testSplat(floatTy, value);
137 }
138 
TEST(DenseSplatTest,BF16Splat)139 TEST(DenseSplatTest, BF16Splat) {
140   MLIRContext context;
141   FloatType floatTy = FloatType::getBF16(&context);
142   // Note: We currently use double to represent bfloat16.
143   double value = 10.0;
144 
145   testSplat(floatTy, value);
146 }
147 
148 } // end namespace
149