1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/DialectInterface.h"
11 #include "gtest/gtest.h"
12
13 using namespace mlir;
14 using namespace mlir::detail;
15
16 namespace {
17 struct TestDialect : public Dialect {
getDialectNamespace__anon81f12dd30111::TestDialect18 static StringRef getDialectNamespace() { return "test"; };
TestDialect__anon81f12dd30111::TestDialect19 TestDialect(MLIRContext *context)
20 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
21 };
22 struct AnotherTestDialect : public Dialect {
getDialectNamespace__anon81f12dd30111::AnotherTestDialect23 static StringRef getDialectNamespace() { return "test"; };
AnotherTestDialect__anon81f12dd30111::AnotherTestDialect24 AnotherTestDialect(MLIRContext *context)
25 : Dialect(getDialectNamespace(), context,
26 TypeID::get<AnotherTestDialect>()) {}
27 };
28
TEST(DialectDeathTest,MultipleDialectsWithSameNamespace)29 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
30 MLIRContext context;
31
32 // Registering a dialect with the same namespace twice should result in a
33 // failure.
34 context.loadDialect<TestDialect>();
35 ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
36 }
37
38 struct SecondTestDialect : public Dialect {
getDialectNamespace__anon81f12dd30111::SecondTestDialect39 static StringRef getDialectNamespace() { return "test2"; }
SecondTestDialect__anon81f12dd30111::SecondTestDialect40 SecondTestDialect(MLIRContext *context)
41 : Dialect(getDialectNamespace(), context,
42 TypeID::get<SecondTestDialect>()) {}
43 };
44
45 struct TestDialectInterfaceBase
46 : public DialectInterface::Base<TestDialectInterfaceBase> {
TestDialectInterfaceBase__anon81f12dd30111::TestDialectInterfaceBase47 TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
function__anon81f12dd30111::TestDialectInterfaceBase48 virtual int function() const { return 42; }
49 };
50
51 struct TestDialectInterface : public TestDialectInterfaceBase {
52 using TestDialectInterfaceBase::TestDialectInterfaceBase;
function__anon81f12dd30111::TestDialectInterface53 int function() const final { return 56; }
54 };
55
56 struct SecondTestDialectInterface : public TestDialectInterfaceBase {
57 using TestDialectInterfaceBase::TestDialectInterfaceBase;
function__anon81f12dd30111::SecondTestDialectInterface58 int function() const final { return 78; }
59 };
60
TEST(Dialect,DelayedInterfaceRegistration)61 TEST(Dialect, DelayedInterfaceRegistration) {
62 DialectRegistry registry;
63 registry.insert<TestDialect, SecondTestDialect>();
64
65 // Delayed registration of an interface for TestDialect.
66 registry.addDialectInterface<TestDialect, TestDialectInterface>();
67
68 MLIRContext context(registry);
69
70 // Load the TestDialect and check that the interface got registered for it.
71 auto *testDialect = context.getOrLoadDialect<TestDialect>();
72 ASSERT_TRUE(testDialect != nullptr);
73 auto *testDialectInterface =
74 testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
75 EXPECT_TRUE(testDialectInterface != nullptr);
76
77 // Load the SecondTestDialect and check that the interface is not registered
78 // for it.
79 auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
80 ASSERT_TRUE(secondTestDialect != nullptr);
81 auto *secondTestDialectInterface =
82 secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
83 EXPECT_TRUE(secondTestDialectInterface == nullptr);
84
85 // Use the same mechanism as for delayed registration but for an already
86 // loaded dialect and check that the interface is now registered.
87 DialectRegistry secondRegistry;
88 secondRegistry.insert<SecondTestDialect>();
89 secondRegistry
90 .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
91 context.appendDialectRegistry(secondRegistry);
92 secondTestDialectInterface =
93 secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
94 EXPECT_TRUE(secondTestDialectInterface != nullptr);
95 }
96
TEST(Dialect,RepeatedDelayedRegistration)97 TEST(Dialect, RepeatedDelayedRegistration) {
98 // Set up the delayed registration.
99 DialectRegistry registry;
100 registry.insert<TestDialect>();
101 registry.addDialectInterface<TestDialect, TestDialectInterface>();
102 MLIRContext context(registry);
103
104 // Load the TestDialect and check that the interface got registered for it.
105 auto *testDialect = context.getOrLoadDialect<TestDialect>();
106 ASSERT_TRUE(testDialect != nullptr);
107 auto *testDialectInterface =
108 testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
109 EXPECT_TRUE(testDialectInterface != nullptr);
110
111 // Try adding the same dialect interface again and check that we don't crash
112 // on repeated interface registration.
113 DialectRegistry secondRegistry;
114 secondRegistry.insert<TestDialect>();
115 secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
116 context.appendDialectRegistry(secondRegistry);
117 testDialectInterface =
118 testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
119 EXPECT_TRUE(testDialectInterface != nullptr);
120 }
121
122 // A dialect that registers two interfaces with the same InterfaceID, triggering
123 // an assertion failure.
124 struct RepeatedRegistrationDialect : public Dialect {
getDialectNamespace__anon81f12dd30111::RepeatedRegistrationDialect125 static StringRef getDialectNamespace() { return "repeatedreg"; }
RepeatedRegistrationDialect__anon81f12dd30111::RepeatedRegistrationDialect126 RepeatedRegistrationDialect(MLIRContext *context)
127 : Dialect(getDialectNamespace(), context,
128 TypeID::get<RepeatedRegistrationDialect>()) {
129 addInterfaces<TestDialectInterface>();
130 addInterfaces<SecondTestDialectInterface>();
131 }
132 };
133
TEST(Dialect,RepeatedInterfaceRegistrationDeath)134 TEST(Dialect, RepeatedInterfaceRegistrationDeath) {
135 MLIRContext context;
136 (void)context;
137
138 // This triggers an assertion in debug mode.
139 #ifndef NDEBUG
140 ASSERT_DEATH(context.loadDialect<RepeatedRegistrationDialect>(),
141 "interface kind has already been registered");
142 #endif
143 }
144
145 } // end namespace
146