1 //===- DataLayoutInterfacesTest.cpp - Unit Tests for Data Layouts ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/Dialect/DLTI/DLTI.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/Parser.h"
18 
19 #include <gtest/gtest.h>
20 
21 using namespace mlir;
22 
23 namespace {
24 constexpr static llvm::StringLiteral kAttrName = "dltest.layout";
25 
26 /// Trivial array storage for the custom data layout spec attribute, just a list
27 /// of entries.
28 class DataLayoutSpecStorage : public AttributeStorage {
29 public:
30   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
31 
DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)32   DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
33       : entries(entries) {}
34 
operator ==(const KeyTy & key) const35   bool operator==(const KeyTy &key) const { return key == entries; }
36 
construct(AttributeStorageAllocator & allocator,const KeyTy & key)37   static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
38                                           const KeyTy &key) {
39     return new (allocator.allocate<DataLayoutSpecStorage>())
40         DataLayoutSpecStorage(allocator.copyInto(key));
41   }
42 
43   ArrayRef<DataLayoutEntryInterface> entries;
44 };
45 
46 /// Simple data layout spec containing a list of entries that always verifies
47 /// as valid.
48 struct CustomDataLayoutSpec
49     : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute,
50                                  DataLayoutSpecStorage,
51                                  DataLayoutSpecInterface::Trait> {
52   using Base::Base;
get__anonf2cbba820111::CustomDataLayoutSpec53   static CustomDataLayoutSpec get(MLIRContext *ctx,
54                                   ArrayRef<DataLayoutEntryInterface> entries) {
55     return Base::get(ctx, entries);
56   }
57   CustomDataLayoutSpec
combineWith__anonf2cbba820111::CustomDataLayoutSpec58   combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
59     return *this;
60   }
getEntries__anonf2cbba820111::CustomDataLayoutSpec61   DataLayoutEntryListRef getEntries() const { return getImpl()->entries; }
verifySpec__anonf2cbba820111::CustomDataLayoutSpec62   LogicalResult verifySpec(Location loc) { return success(); }
63 };
64 
65 /// A type subject to data layout that exits the program if it is queried more
66 /// than once. Handy to check if the cache works.
67 struct SingleQueryType
68     : public Type::TypeBase<SingleQueryType, Type, TypeStorage,
69                             DataLayoutTypeInterface::Trait> {
70   using Base::Base;
71 
get__anonf2cbba820111::SingleQueryType72   static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); }
73 
getTypeSizeInBits__anonf2cbba820111::SingleQueryType74   unsigned getTypeSizeInBits(const DataLayout &layout,
75                              DataLayoutEntryListRef params) const {
76     static bool executed = false;
77     if (executed)
78       llvm::report_fatal_error("repeated call");
79 
80     executed = true;
81     return 1;
82   }
83 
getABIAlignment__anonf2cbba820111::SingleQueryType84   unsigned getABIAlignment(const DataLayout &layout,
85                            DataLayoutEntryListRef params) {
86     static bool executed = false;
87     if (executed)
88       llvm::report_fatal_error("repeated call");
89 
90     executed = true;
91     return 2;
92   }
93 
getPreferredAlignment__anonf2cbba820111::SingleQueryType94   unsigned getPreferredAlignment(const DataLayout &layout,
95                                  DataLayoutEntryListRef params) {
96     static bool executed = false;
97     if (executed)
98       llvm::report_fatal_error("repeated call");
99 
100     executed = true;
101     return 4;
102   }
103 };
104 
105 /// A types that is not subject to data layout.
106 struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> {
107   using Base::Base;
108 
get__anonf2cbba820111::TypeNoLayout109   static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); }
110 };
111 
112 /// An op that serves as scope for data layout queries with the relevant
113 /// attribute attached. This can handle data layout requests for the built-in
114 /// types itself.
115 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
116   using Op::Op;
getAttributeNames__anonf2cbba820111::OpWithLayout117   static ArrayRef<StringRef> getAttributeNames() { return {}; }
118 
getOperationName__anonf2cbba820111::OpWithLayout119   static StringRef getOperationName() { return "dltest.op_with_layout"; }
120 
getDataLayoutSpec__anonf2cbba820111::OpWithLayout121   DataLayoutSpecInterface getDataLayoutSpec() {
122     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
123   }
124 
getTypeSizeInBits__anonf2cbba820111::OpWithLayout125   static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout,
126                                     DataLayoutEntryListRef params) {
127     // Make a recursive query.
128     if (type.isa<FloatType>())
129       return dataLayout.getTypeSizeInBits(
130           IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth()));
131 
132     // Handle built-in types that are not handled by the default process.
133     if (auto iType = type.dyn_cast<IntegerType>()) {
134       for (DataLayoutEntryInterface entry : params)
135         if (entry.getKey().dyn_cast<Type>() == type)
136           return 8 *
137                  entry.getValue().cast<IntegerAttr>().getValue().getZExtValue();
138       return 8 * iType.getIntOrFloatBitWidth();
139     }
140 
141     // Use the default process for everything else.
142     return detail::getDefaultTypeSize(type, dataLayout, params);
143   }
144 
getTypeABIAlignment__anonf2cbba820111::OpWithLayout145   static unsigned getTypeABIAlignment(Type type, const DataLayout &dataLayout,
146                                       DataLayoutEntryListRef params) {
147     return llvm::PowerOf2Ceil(getTypeSize(type, dataLayout, params));
148   }
149 
getTypePreferredAlignment__anonf2cbba820111::OpWithLayout150   static unsigned getTypePreferredAlignment(Type type,
151                                             const DataLayout &dataLayout,
152                                             DataLayoutEntryListRef params) {
153     return 2 * getTypeABIAlignment(type, dataLayout, params);
154   }
155 };
156 
157 struct OpWith7BitByte
158     : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> {
159   using Op::Op;
getAttributeNames__anonf2cbba820111::OpWith7BitByte160   static ArrayRef<StringRef> getAttributeNames() { return {}; }
161 
getOperationName__anonf2cbba820111::OpWith7BitByte162   static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }
163 
getDataLayoutSpec__anonf2cbba820111::OpWith7BitByte164   DataLayoutSpecInterface getDataLayoutSpec() {
165     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
166   }
167 
168   // Bytes are assumed to be 7-bit here.
getTypeSize__anonf2cbba820111::OpWith7BitByte169   static unsigned getTypeSize(Type type, const DataLayout &dataLayout,
170                               DataLayoutEntryListRef params) {
171     return llvm::divideCeil(dataLayout.getTypeSizeInBits(type), 7);
172   }
173 };
174 
175 /// A dialect putting all the above together.
176 struct DLTestDialect : Dialect {
DLTestDialect__anonf2cbba820111::DLTestDialect177   explicit DLTestDialect(MLIRContext *ctx)
178       : Dialect(getDialectNamespace(), ctx, TypeID::get<DLTestDialect>()) {
179     ctx->getOrLoadDialect<DLTIDialect>();
180     addAttributes<CustomDataLayoutSpec>();
181     addOperations<OpWithLayout, OpWith7BitByte>();
182     addTypes<SingleQueryType, TypeNoLayout>();
183   }
getDialectNamespace__anonf2cbba820111::DLTestDialect184   static StringRef getDialectNamespace() { return "dltest"; }
185 
printAttribute__anonf2cbba820111::DLTestDialect186   void printAttribute(Attribute attr,
187                       DialectAsmPrinter &printer) const override {
188     printer << "spec<";
189     llvm::interleaveComma(attr.cast<CustomDataLayoutSpec>().getEntries(),
190                           printer);
191     printer << ">";
192   }
193 
parseAttribute__anonf2cbba820111::DLTestDialect194   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
195     bool ok =
196         succeeded(parser.parseKeyword("spec")) && succeeded(parser.parseLess());
197     (void)ok;
198     assert(ok);
199     if (succeeded(parser.parseOptionalGreater()))
200       return CustomDataLayoutSpec::get(parser.getBuilder().getContext(), {});
201 
202     SmallVector<DataLayoutEntryInterface> entries;
203     do {
204       entries.emplace_back();
205       ok = succeeded(parser.parseAttribute(entries.back()));
206       assert(ok);
207     } while (succeeded(parser.parseOptionalComma()));
208     ok = succeeded(parser.parseGreater());
209     assert(ok);
210     return CustomDataLayoutSpec::get(parser.getBuilder().getContext(), entries);
211   }
212 
printType__anonf2cbba820111::DLTestDialect213   void printType(Type type, DialectAsmPrinter &printer) const override {
214     if (type.isa<SingleQueryType>())
215       printer << "single_query";
216     else
217       printer << "no_layout";
218   }
219 
parseType__anonf2cbba820111::DLTestDialect220   Type parseType(DialectAsmParser &parser) const override {
221     bool ok = succeeded(parser.parseKeyword("single_query"));
222     (void)ok;
223     assert(ok);
224     return SingleQueryType::get(parser.getBuilder().getContext());
225   }
226 };
227 
228 } // end namespace
229 
TEST(DataLayout,FallbackDefault)230 TEST(DataLayout, FallbackDefault) {
231   const char *ir = R"MLIR(
232 module {}
233   )MLIR";
234 
235   DialectRegistry registry;
236   registry.insert<DLTIDialect, DLTestDialect>();
237   MLIRContext ctx(registry);
238 
239   OwningModuleRef module = parseSourceString(ir, &ctx);
240   DataLayout layout(module.get());
241   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
242   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u);
243   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
244   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u);
245   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
246   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u);
247   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u);
248   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u);
249 }
250 
TEST(DataLayout,NullSpec)251 TEST(DataLayout, NullSpec) {
252   const char *ir = R"MLIR(
253 "dltest.op_with_layout"() : () -> ()
254   )MLIR";
255 
256   DialectRegistry registry;
257   registry.insert<DLTIDialect, DLTestDialect>();
258   MLIRContext ctx(registry);
259 
260   OwningModuleRef module = parseSourceString(ir, &ctx);
261   auto op =
262       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
263   DataLayout layout(op);
264   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
265   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
266   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
267   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
268   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
269   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
270   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
271   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
272 }
273 
TEST(DataLayout,EmptySpec)274 TEST(DataLayout, EmptySpec) {
275   const char *ir = R"MLIR(
276 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()
277   )MLIR";
278 
279   DialectRegistry registry;
280   registry.insert<DLTIDialect, DLTestDialect>();
281   MLIRContext ctx(registry);
282 
283   OwningModuleRef module = parseSourceString(ir, &ctx);
284   auto op =
285       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
286   DataLayout layout(op);
287   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
288   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
289   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
290   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
291   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
292   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
293   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
294   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
295 }
296 
TEST(DataLayout,SpecWithEntries)297 TEST(DataLayout, SpecWithEntries) {
298   const char *ir = R"MLIR(
299 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
300   #dlti.dl_entry<i42, 5>,
301   #dlti.dl_entry<i16, 6>
302 > } : () -> ()
303   )MLIR";
304 
305   DialectRegistry registry;
306   registry.insert<DLTIDialect, DLTestDialect>();
307   MLIRContext ctx(registry);
308 
309   OwningModuleRef module = parseSourceString(ir, &ctx);
310   auto op =
311       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
312   DataLayout layout(op);
313   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 5u);
314   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
315   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 40u);
316   EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 48u);
317   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u);
318   EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 8u);
319   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 16u);
320   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 16u);
321 
322   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 32u);
323   EXPECT_EQ(layout.getTypeSize(Float32Type::get(&ctx)), 32u);
324   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 256u);
325   EXPECT_EQ(layout.getTypeSizeInBits(Float32Type::get(&ctx)), 256u);
326   EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 32)), 32u);
327   EXPECT_EQ(layout.getTypeABIAlignment(Float32Type::get(&ctx)), 32u);
328   EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 32)), 64u);
329   EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u);
330 }
331 
TEST(DataLayout,Caching)332 TEST(DataLayout, Caching) {
333   const char *ir = R"MLIR(
334 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
335   )MLIR";
336 
337   DialectRegistry registry;
338   registry.insert<DLTIDialect, DLTestDialect>();
339   MLIRContext ctx(registry);
340 
341   OwningModuleRef module = parseSourceString(ir, &ctx);
342   auto op =
343       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
344   DataLayout layout(op);
345 
346   unsigned sum = 0;
347   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
348   // The second call should hit the cache. If it does not, the function in
349   // SingleQueryType will be called and will abort the process.
350   sum += layout.getTypeSize(SingleQueryType::get(&ctx));
351   // Make sure the complier doesn't optimize away the query code.
352   EXPECT_EQ(sum, 2u);
353 
354   // A fresh data layout has a new cache, so the call to it should be dispatched
355   // down to the type and abort the proces.
356   DataLayout second(op);
357   ASSERT_DEATH(second.getTypeSize(SingleQueryType::get(&ctx)), "repeated call");
358 }
359 
TEST(DataLayout,CacheInvalidation)360 TEST(DataLayout, CacheInvalidation) {
361   const char *ir = R"MLIR(
362 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<
363   #dlti.dl_entry<i42, 5>,
364   #dlti.dl_entry<i16, 6>
365 > } : () -> ()
366   )MLIR";
367 
368   DialectRegistry registry;
369   registry.insert<DLTIDialect, DLTestDialect>();
370   MLIRContext ctx(registry);
371 
372   OwningModuleRef module = parseSourceString(ir, &ctx);
373   auto op =
374       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
375   DataLayout layout(op);
376 
377   // Normal query is fine.
378   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 6u);
379 
380   // Replace the data layout spec with a new, empty spec.
381   op->setAttr(kAttrName, CustomDataLayoutSpec::get(&ctx, {}));
382 
383   // Data layout is no longer valid and should trigger assertion when queried.
384 #ifndef NDEBUG
385   ASSERT_DEATH(layout.getTypeSize(Float16Type::get(&ctx)), "no longer valid");
386 #endif
387 }
388 
TEST(DataLayout,UnimplementedTypeInterface)389 TEST(DataLayout, UnimplementedTypeInterface) {
390   const char *ir = R"MLIR(
391 "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
392   )MLIR";
393 
394   DialectRegistry registry;
395   registry.insert<DLTIDialect, DLTestDialect>();
396   MLIRContext ctx(registry);
397 
398   OwningModuleRef module = parseSourceString(ir, &ctx);
399   auto op =
400       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
401   DataLayout layout(op);
402 
403   ASSERT_DEATH(layout.getTypeSize(TypeNoLayout::get(&ctx)),
404                "neither the scoping op nor the type class provide data layout "
405                "information");
406 }
407 
TEST(DataLayout,SevenBitByte)408 TEST(DataLayout, SevenBitByte) {
409   const char *ir = R"MLIR(
410 "dltest.op_with_7bit_byte"() { dltest.layout = #dltest.spec<> } : () -> ()
411   )MLIR";
412 
413   DialectRegistry registry;
414   registry.insert<DLTIDialect, DLTestDialect>();
415   MLIRContext ctx(registry);
416 
417   OwningModuleRef module = parseSourceString(ir, &ctx);
418   auto op =
419       cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
420   DataLayout layout(op);
421 
422   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
423   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 32)), 32u);
424   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
425   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 32)), 5u);
426 }
427