1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/DialectInterface.h"
11 #include "gtest/gtest.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 namespace {
17 struct TestDialect : public Dialect {
getDialectNamespace__anon81f12dd30111::TestDialect18   static StringRef getDialectNamespace() { return "test"; };
TestDialect__anon81f12dd30111::TestDialect19   TestDialect(MLIRContext *context)
20       : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
21 };
22 struct AnotherTestDialect : public Dialect {
getDialectNamespace__anon81f12dd30111::AnotherTestDialect23   static StringRef getDialectNamespace() { return "test"; };
AnotherTestDialect__anon81f12dd30111::AnotherTestDialect24   AnotherTestDialect(MLIRContext *context)
25       : Dialect(getDialectNamespace(), context,
26                 TypeID::get<AnotherTestDialect>()) {}
27 };
28 
TEST(DialectDeathTest,MultipleDialectsWithSameNamespace)29 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
30   MLIRContext context;
31 
32   // Registering a dialect with the same namespace twice should result in a
33   // failure.
34   context.loadDialect<TestDialect>();
35   ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
36 }
37 
38 struct SecondTestDialect : public Dialect {
getDialectNamespace__anon81f12dd30111::SecondTestDialect39   static StringRef getDialectNamespace() { return "test2"; }
SecondTestDialect__anon81f12dd30111::SecondTestDialect40   SecondTestDialect(MLIRContext *context)
41       : Dialect(getDialectNamespace(), context,
42                 TypeID::get<SecondTestDialect>()) {}
43 };
44 
45 struct TestDialectInterfaceBase
46     : public DialectInterface::Base<TestDialectInterfaceBase> {
TestDialectInterfaceBase__anon81f12dd30111::TestDialectInterfaceBase47   TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
function__anon81f12dd30111::TestDialectInterfaceBase48   virtual int function() const { return 42; }
49 };
50 
51 struct TestDialectInterface : public TestDialectInterfaceBase {
52   using TestDialectInterfaceBase::TestDialectInterfaceBase;
function__anon81f12dd30111::TestDialectInterface53   int function() const final { return 56; }
54 };
55 
56 struct SecondTestDialectInterface : public TestDialectInterfaceBase {
57   using TestDialectInterfaceBase::TestDialectInterfaceBase;
function__anon81f12dd30111::SecondTestDialectInterface58   int function() const final { return 78; }
59 };
60 
TEST(Dialect,DelayedInterfaceRegistration)61 TEST(Dialect, DelayedInterfaceRegistration) {
62   DialectRegistry registry;
63   registry.insert<TestDialect, SecondTestDialect>();
64 
65   // Delayed registration of an interface for TestDialect.
66   registry.addDialectInterface<TestDialect, TestDialectInterface>();
67 
68   MLIRContext context(registry);
69 
70   // Load the TestDialect and check that the interface got registered for it.
71   auto *testDialect = context.getOrLoadDialect<TestDialect>();
72   ASSERT_TRUE(testDialect != nullptr);
73   auto *testDialectInterface =
74       testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
75   EXPECT_TRUE(testDialectInterface != nullptr);
76 
77   // Load the SecondTestDialect and check that the interface is not registered
78   // for it.
79   auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
80   ASSERT_TRUE(secondTestDialect != nullptr);
81   auto *secondTestDialectInterface =
82       secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
83   EXPECT_TRUE(secondTestDialectInterface == nullptr);
84 
85   // Use the same mechanism as for delayed registration but for an already
86   // loaded dialect and check that the interface is now registered.
87   DialectRegistry secondRegistry;
88   secondRegistry.insert<SecondTestDialect>();
89   secondRegistry
90       .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
91   context.appendDialectRegistry(secondRegistry);
92   secondTestDialectInterface =
93       secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
94   EXPECT_TRUE(secondTestDialectInterface != nullptr);
95 }
96 
TEST(Dialect,RepeatedDelayedRegistration)97 TEST(Dialect, RepeatedDelayedRegistration) {
98   // Set up the delayed registration.
99   DialectRegistry registry;
100   registry.insert<TestDialect>();
101   registry.addDialectInterface<TestDialect, TestDialectInterface>();
102   MLIRContext context(registry);
103 
104   // Load the TestDialect and check that the interface got registered for it.
105   auto *testDialect = context.getOrLoadDialect<TestDialect>();
106   ASSERT_TRUE(testDialect != nullptr);
107   auto *testDialectInterface =
108       testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
109   EXPECT_TRUE(testDialectInterface != nullptr);
110 
111   // Try adding the same dialect interface again and check that we don't crash
112   // on repeated interface registration.
113   DialectRegistry secondRegistry;
114   secondRegistry.insert<TestDialect>();
115   secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
116   context.appendDialectRegistry(secondRegistry);
117   testDialectInterface =
118       testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
119   EXPECT_TRUE(testDialectInterface != nullptr);
120 }
121 
122 // A dialect that registers two interfaces with the same InterfaceID, triggering
123 // an assertion failure.
124 struct RepeatedRegistrationDialect : public Dialect {
getDialectNamespace__anon81f12dd30111::RepeatedRegistrationDialect125   static StringRef getDialectNamespace() { return "repeatedreg"; }
RepeatedRegistrationDialect__anon81f12dd30111::RepeatedRegistrationDialect126   RepeatedRegistrationDialect(MLIRContext *context)
127       : Dialect(getDialectNamespace(), context,
128                 TypeID::get<RepeatedRegistrationDialect>()) {
129     addInterfaces<TestDialectInterface>();
130     addInterfaces<SecondTestDialectInterface>();
131   }
132 };
133 
TEST(Dialect,RepeatedInterfaceRegistrationDeath)134 TEST(Dialect, RepeatedInterfaceRegistrationDeath) {
135   MLIRContext context;
136   (void)context;
137 
138   // This triggers an assertion in debug mode.
139 #ifndef NDEBUG
140   ASSERT_DEATH(context.loadDialect<RepeatedRegistrationDialect>(),
141                "interface kind has already been registered");
142 #endif
143 }
144 
145 } // end namespace
146