1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Regex.h"
21 
22 #define DEBUG_TYPE "dialect"
23 
24 using namespace mlir;
25 using namespace detail;
26 
~DialectAsmParser()27 DialectAsmParser::~DialectAsmParser() {}
28 
29 //===----------------------------------------------------------------------===//
30 // DialectRegistry
31 //===----------------------------------------------------------------------===//
32 
addDialectInterface(StringRef dialectName,TypeID interfaceTypeID,DialectInterfaceAllocatorFunction allocator)33 void DialectRegistry::addDialectInterface(
34     StringRef dialectName, TypeID interfaceTypeID,
35     DialectInterfaceAllocatorFunction allocator) {
36   assert(allocator && "unexpected null interface allocation function");
37   auto it = registry.find(dialectName.str());
38   assert(it != registry.end() &&
39          "adding an interface for an unregistered dialect");
40 
41   // Bail out if the interface with the given ID is already in the registry for
42   // the given dialect. We expect a small number (dozens) of interfaces so a
43   // linear search is fine here.
44   auto &ifaces = interfaces[it->second.first];
45   for (const auto &kvp : ifaces.dialectInterfaces) {
46     if (kvp.first == interfaceTypeID) {
47       LLVM_DEBUG(llvm::dbgs()
48                  << "[" DEBUG_TYPE
49                     "] repeated interface registration for dialect "
50                  << dialectName);
51       return;
52     }
53   }
54 
55   ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
56 }
57 
addObjectInterface(StringRef dialectName,TypeID interfaceTypeID,ObjectInterfaceAllocatorFunction allocator)58 void DialectRegistry::addObjectInterface(
59     StringRef dialectName, TypeID interfaceTypeID,
60     ObjectInterfaceAllocatorFunction allocator) {
61   assert(allocator && "unexpected null interface allocation function");
62 
63   // Builtin dialect has an empty prefix and is always registered.
64   TypeID dialectTypeID;
65   if (!dialectName.empty()) {
66     auto it = registry.find(dialectName.str());
67     assert(it != registry.end() &&
68            "adding an interface for an op from an unregistered dialect");
69     dialectTypeID = it->second.first;
70   } else {
71     dialectTypeID = TypeID::get<BuiltinDialect>();
72   }
73 
74   auto &ifaces = interfaces[dialectTypeID];
75   for (const auto &kvp : ifaces.objectInterfaces) {
76     if (kvp.first == interfaceTypeID) {
77       LLVM_DEBUG(llvm::dbgs()
78                  << "[" DEBUG_TYPE
79                     "] repeated interface object interface registration");
80       return;
81     }
82   }
83 
84   ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
85 }
86 
87 DialectAllocatorFunctionRef
getDialectAllocator(StringRef name) const88 DialectRegistry::getDialectAllocator(StringRef name) const {
89   auto it = registry.find(name.str());
90   if (it == registry.end())
91     return nullptr;
92   return it->second.second;
93 }
94 
insert(TypeID typeID,StringRef name,DialectAllocatorFunction ctor)95 void DialectRegistry::insert(TypeID typeID, StringRef name,
96                              DialectAllocatorFunction ctor) {
97   auto inserted = registry.insert(
98       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
99   if (!inserted.second && inserted.first->second.first != typeID) {
100     llvm::report_fatal_error(
101         "Trying to register different dialects for the same namespace: " +
102         name);
103   }
104 }
105 
registerDelayedInterfaces(Dialect * dialect) const106 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
107   auto it = interfaces.find(dialect->getTypeID());
108   if (it == interfaces.end())
109     return;
110 
111   // Add an interface if it is not already present.
112   for (const auto &kvp : it->getSecond().dialectInterfaces) {
113     if (dialect->getRegisteredInterface(kvp.first))
114       continue;
115     dialect->addInterface(kvp.second(dialect));
116   }
117 
118   // Add attribute, operation and type interfaces.
119   for (const auto &kvp : it->getSecond().objectInterfaces)
120     kvp.second(dialect->getContext());
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Dialect
125 //===----------------------------------------------------------------------===//
126 
Dialect(StringRef name,MLIRContext * context,TypeID id)127 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
128     : name(name), dialectID(id), context(context) {
129   assert(isValidNamespace(name) && "invalid dialect namespace");
130 }
131 
~Dialect()132 Dialect::~Dialect() {}
133 
134 /// Verify an attribute from this dialect on the argument at 'argIndex' for
135 /// the region at 'regionIndex' on the given operation. Returns failure if
136 /// the verification failed, success otherwise. This hook may optionally be
137 /// invoked from any operation containing a region.
verifyRegionArgAttribute(Operation *,unsigned,unsigned,NamedAttribute)138 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
139                                                 NamedAttribute) {
140   return success();
141 }
142 
143 /// Verify an attribute from this dialect on the result at 'resultIndex' for
144 /// the region at 'regionIndex' on the given operation. Returns failure if
145 /// the verification failed, success otherwise. This hook may optionally be
146 /// invoked from any operation containing a region.
verifyRegionResultAttribute(Operation *,unsigned,unsigned,NamedAttribute)147 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
148                                                    unsigned, NamedAttribute) {
149   return success();
150 }
151 
152 /// Parse an attribute registered to this dialect.
parseAttribute(DialectAsmParser & parser,Type type) const153 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
154   parser.emitError(parser.getNameLoc())
155       << "dialect '" << getNamespace()
156       << "' provides no attribute parsing hook";
157   return Attribute();
158 }
159 
160 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const161 Type Dialect::parseType(DialectAsmParser &parser) const {
162   // If this dialect allows unknown types, then represent this with OpaqueType.
163   if (allowsUnknownTypes()) {
164     Identifier ns = Identifier::get(getNamespace(), getContext());
165     return OpaqueType::get(ns, parser.getFullSymbolSpec());
166   }
167 
168   parser.emitError(parser.getNameLoc())
169       << "dialect '" << getNamespace() << "' provides no type parsing hook";
170   return Type();
171 }
172 
173 Optional<Dialect::ParseOpHook>
getParseOperationHook(StringRef opName) const174 Dialect::getParseOperationHook(StringRef opName) const {
175   return None;
176 }
177 
printOperation(Operation * op,OpAsmPrinter & printer) const178 LogicalResult Dialect::printOperation(Operation *op,
179                                       OpAsmPrinter &printer) const {
180   assert(op->getDialect() == this &&
181          "Dialect hook invoked on non-dialect owned operation");
182   return failure();
183 }
184 
185 /// Utility function that returns if the given string is a valid dialect
186 /// namespace.
isValidNamespace(StringRef str)187 bool Dialect::isValidNamespace(StringRef str) {
188   if (str.empty())
189     return true;
190   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
191   return dialectNameRegex.match(str);
192 }
193 
194 /// Register a set of dialect interfaces with this dialect instance.
addInterface(std::unique_ptr<DialectInterface> interface)195 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
196   auto it = registeredInterfaces.try_emplace(interface->getID(),
197                                              std::move(interface));
198   (void)it;
199   assert(it.second && "interface kind has already been registered");
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // Dialect Interface
204 //===----------------------------------------------------------------------===//
205 
~DialectInterface()206 DialectInterface::~DialectInterface() {}
207 
DialectInterfaceCollectionBase(MLIRContext * ctx,TypeID interfaceKind)208 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
209     MLIRContext *ctx, TypeID interfaceKind) {
210   for (auto *dialect : ctx->getLoadedDialects()) {
211     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
212       interfaces.insert(interface);
213       orderedInterfaces.push_back(interface);
214     }
215   }
216 }
217 
~DialectInterfaceCollectionBase()218 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
219 
220 /// Get the interface for the dialect of given operation, or null if one
221 /// is not registered.
222 const DialectInterface *
getInterfaceFor(Operation * op) const223 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
224   return getInterfaceFor(op->getDialect());
225 }
226