[mlir] Enable delayed registration of attribute/operation/type interfaces

This functionality is similar to delayed registration of dialect interfaces. It
allows external interface models to be registered before the dialect containing
the attribute/operation/type interface is loaded, or even before the context is
created.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D104397
This commit is contained in:
Alex Zinenko 2021-06-16 18:53:21 +02:00
parent ccc0f777f6
commit d7e8912134
7 changed files with 273 additions and 19 deletions

View File

@ -50,6 +50,12 @@ public:
return interfaceMap.lookup<T>();
}
/// Returns true if the attribute has the interface with the given ID
/// registered.
bool hasInterface(TypeID interfaceID) const {
return interfaceMap.contains(interfaceID);
}
/// Return the unique identifier representing the concrete attribute class.
TypeID getTypeID() const { return typeID; }

View File

@ -27,8 +27,9 @@ class Type;
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
using InterfaceAllocatorFunction =
using DialectInterfaceAllocatorFunction =
std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
/// Dialects are groups of MLIR operations, types and attributes, as well as
/// behavior associated with the entire group. For example, hooks into other
@ -278,11 +279,19 @@ private:
/// dialects loaded in the Context. The parser in particular will lazily load
/// dialects in the Context as operations are encountered.
class DialectRegistry {
/// Lists of interfaces that need to be registered when the dialect is loaded.
struct DelayedInterfaces {
/// Dialect interfaces.
SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
dialectInterfaces;
/// Attribute/Operation/Type interfaces.
SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
objectInterfaces;
};
using MapTy =
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
using InterfaceMapTy =
DenseMap<TypeID,
SmallVector<std::pair<TypeID, InterfaceAllocatorFunction>, 2>>;
using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
public:
explicit DialectRegistry() {}
@ -336,7 +345,7 @@ public:
/// the registry.
template <typename DialectTy>
void addDialectInterface(TypeID interfaceTypeID,
InterfaceAllocatorFunction allocator) {
DialectInterfaceAllocatorFunction allocator) {
addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
allocator);
}
@ -351,6 +360,36 @@ public:
});
}
/// Add an external op interface model for an op that belongs to a dialect,
/// both provided as template parameters. The dialect must be present in the
/// registry.
template <typename OpTy, typename ModelTy>
void addOpInterface() {
StringRef opName = OpTy::getOperationName();
StringRef dialectName = opName.split('.').first;
addObjectInterface(dialectName == opName ? "" : dialectName,
ModelTy::Interface::getInterfaceID(),
[](MLIRContext *context) {
OpTy::template attachInterface<ModelTy>(*context);
});
}
/// Add an external attribute interface model for an attribute type `AttrTy`
/// that is going to belong to `DialectTy`. The dialect must be present in the
/// registry.
template <typename DialectTy, typename AttrTy, typename ModelTy>
void addAttrInterface() {
addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
}
/// Add an external type interface model for an type class `TypeTy` that is
/// going to belong to `DialectTy`. The dialect must be present in the
/// registry.
template <typename DialectTy, typename TypeTy, typename ModelTy>
void addTypeInterface() {
addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
}
/// Register any interfaces required for the given dialect (based on its
/// TypeID). Users are not expected to call this directly.
void registerDelayedInterfaces(Dialect *dialect) const;
@ -359,7 +398,22 @@ private:
/// Add an interface constructed with the given allocation function to the
/// dialect identified by its namespace.
void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
InterfaceAllocatorFunction allocator);
DialectInterfaceAllocatorFunction allocator);
/// Add an attribute/operation/type interface constructible with the given
/// allocation function to the dialect identified by its namespace.
void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID,
ObjectInterfaceAllocatorFunction allocator);
/// Add an external model for an attribute/type interface to the dialect
/// identified by its namespace.
template <typename ObjectTy, typename ModelTy>
void addStorageUserInterface(StringRef dialectName) {
addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
[](MLIRContext *context) {
ObjectTy::template attachInterface<ModelTy>(*context);
});
}
MapTy registry;
InterfaceMapTy interfaces;

View File

@ -58,6 +58,11 @@ public:
return interfaceMap.lookup<T>();
}
/// Returns true if the type has the interface with the given ID.
bool hasInterface(TypeID interfaceID) const {
return interfaceMap.contains(interfaceID);
}
/// Return the unique identifier representing the concrete type class.
TypeID getTypeID() const { return typeID; }

View File

@ -16,6 +16,7 @@
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/TypeName.h"
namespace mlir {
@ -236,8 +237,10 @@ public:
llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
return compare(it.first, id);
});
if (it != interfaces.end() && it->first == id)
llvm::report_fatal_error("Interface already registered");
if (it != interfaces.end() && it->first == id) {
LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
continue;
}
interfaces.insert(it, element);
}
}

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Dialect.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
@ -31,7 +32,7 @@ DialectAsmParser::~DialectAsmParser() {}
void DialectRegistry::addDialectInterface(
StringRef dialectName, TypeID interfaceTypeID,
InterfaceAllocatorFunction allocator) {
DialectInterfaceAllocatorFunction allocator) {
assert(allocator && "unexpected null interface allocation function");
auto it = registry.find(dialectName.str());
assert(it != registry.end() &&
@ -40,8 +41,8 @@ void DialectRegistry::addDialectInterface(
// Bail out if the interface with the given ID is already in the registry for
// the given dialect. We expect a small number (dozens) of interfaces so a
// linear search is fine here.
auto &dialectInterfaces = interfaces[it->second.first];
for (const auto &kvp : dialectInterfaces) {
auto &ifaces = interfaces[it->second.first];
for (const auto &kvp : ifaces.dialectInterfaces) {
if (kvp.first == interfaceTypeID) {
LLVM_DEBUG(llvm::dbgs()
<< "[" DEBUG_TYPE
@ -51,7 +52,36 @@ void DialectRegistry::addDialectInterface(
}
}
dialectInterfaces.emplace_back(interfaceTypeID, allocator);
ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
}
void DialectRegistry::addObjectInterface(
StringRef dialectName, TypeID interfaceTypeID,
ObjectInterfaceAllocatorFunction allocator) {
assert(allocator && "unexpected null interface allocation function");
// Builtin dialect has an empty prefix and is always registered.
TypeID dialectTypeID;
if (!dialectName.empty()) {
auto it = registry.find(dialectName.str());
assert(it != registry.end() &&
"adding an interface for an op from an unregistered dialect");
dialectTypeID = it->second.first;
} else {
dialectTypeID = TypeID::get<BuiltinDialect>();
}
auto &ifaces = interfaces[dialectTypeID];
for (const auto &kvp : ifaces.objectInterfaces) {
if (kvp.first == interfaceTypeID) {
LLVM_DEBUG(llvm::dbgs()
<< "[" DEBUG_TYPE
"] repeated interface object interface registration");
return;
}
}
ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
}
DialectAllocatorFunctionRef
@ -79,11 +109,15 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
return;
// Add an interface if it is not already present.
for (const auto &kvp : it->second) {
for (const auto &kvp : it->getSecond().dialectInterfaces) {
if (dialect->getRegisteredInterface(kvp.first))
continue;
dialect->addInterface(kvp.second(dialect));
}
// Add attribute, operation and type interfaces.
for (const auto &kvp : it->getSecond().objectInterfaces)
kvp.second(dialect->getContext());
}
//===----------------------------------------------------------------------===//

View File

@ -356,12 +356,12 @@ MLIRContext::MLIRContext(const DialectRegistry &registry)
printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
}
// Ensure the builtin dialect is always pre-loaded.
getOrLoadDialect<BuiltinDialect>();
// Pre-populate the registry.
registry.appendTo(impl->dialectsRegistry);
// Ensure the builtin dialect is always pre-loaded.
getOrLoadDialect<BuiltinDialect>();
// Initialize several common attributes and types to avoid the need to lock
// the context when accessing them.

View File

@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "gtest/gtest.h"
@ -87,6 +88,74 @@ TEST(InterfaceAttachment, Type) {
EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
}
/// External interface model for the test type from the test dialect.
struct TestTypeModel
: public TestExternalTypeInterface::ExternalModel<TestTypeModel,
test::TestType> {
unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
};
TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
// Put the interface in the registry.
DialectRegistry registry;
registry.insert<test::TestDialect>();
registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
// Check that when a context is constructed with the given registry, the type
// interface gets registered.
MLIRContext context(registry);
context.loadDialect<test::TestDialect>();
test::TestType testType = test::TestType::get(&context);
auto iface = testType.dyn_cast<TestExternalTypeInterface>();
ASSERT_TRUE(iface != nullptr);
EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
}
TEST(InterfaceAttachment, TypeDelayedContextAppend) {
// Put the interface in the registry.
DialectRegistry registry;
registry.insert<test::TestDialect>();
registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
// Check that when the registry gets appended to the context, the interface
// becomes available for objects in loaded dialects.
MLIRContext context;
context.loadDialect<test::TestDialect>();
test::TestType testType = test::TestType::get(&context);
EXPECT_FALSE(testType.isa<TestExternalTypeInterface>());
context.appendDialectRegistry(registry);
EXPECT_TRUE(testType.isa<TestExternalTypeInterface>());
}
TEST(InterfaceAttachment, RepeatedRegistration) {
DialectRegistry registry;
registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
MLIRContext context(registry);
// Should't fail on repeated registration through the dialect registry.
context.appendDialectRegistry(registry);
}
TEST(InterfaceAttachment, TypeBuiltinDelayed) {
// Builtin dialect needs to registration or loading, but delayed interface
// registration must still work.
DialectRegistry registry;
registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
MLIRContext context(registry);
IntegerType i16 = IntegerType::get(&context, 16);
EXPECT_TRUE(i16.isa<TestExternalTypeInterface>());
MLIRContext initiallyEmpty;
IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
EXPECT_FALSE(i32.isa<TestExternalTypeInterface>());
initiallyEmpty.appendDialectRegistry(registry);
EXPECT_TRUE(i32.isa<TestExternalTypeInterface>());
}
/// The interface provides a default implementation that expects
/// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
/// just derives from the ExternalModel.
@ -128,9 +197,9 @@ TEST(InterfaceAttachment, Fallback) {
}
/// External model for attribute interfaces.
struct TextExternalIntegerAttrModel
struct TestExternalIntegerAttrModel
: public TestExternalAttrInterface::ExternalModel<
TextExternalIntegerAttrModel, IntegerAttr> {
TestExternalIntegerAttrModel, IntegerAttr> {
const Dialect *getDialectPtr(Attribute attr) const {
return &attr.cast<IntegerAttr>().getDialect();
}
@ -145,13 +214,45 @@ TEST(InterfaceAttachment, Attribute) {
// that the basics work for attributes.
IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
IntegerAttr::attachInterface<TextExternalIntegerAttrModel>(context);
IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
auto iface = attr.dyn_cast<TestExternalAttrInterface>();
ASSERT_TRUE(iface != nullptr);
EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
EXPECT_EQ(iface.getSomeNumber(), 42);
}
/// External model for an interface attachable to a non-builtin attribute.
struct TestExternalSimpleAAttrModel
: public TestExternalAttrInterface::ExternalModel<
TestExternalSimpleAAttrModel, test::SimpleAAttr> {
const Dialect *getDialectPtr(Attribute attr) const {
return &attr.getDialect();
}
static int getSomeNumber() { return 21; }
};
TEST(InterfaceAttachmentTest, AttributeDelayed) {
// Attribute interfaces use the exact same mechanism as types, so just check
// that the delayed registration work for attributes.
DialectRegistry registry;
registry.insert<test::TestDialect>();
registry.addAttrInterface<test::TestDialect, test::SimpleAAttr,
TestExternalSimpleAAttrModel>();
MLIRContext context(registry);
context.loadDialect<test::TestDialect>();
auto attr = test::SimpleAAttr::get(&context);
EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
MLIRContext initiallyEmpty;
initiallyEmpty.loadDialect<test::TestDialect>();
attr = test::SimpleAAttr::get(&initiallyEmpty);
EXPECT_FALSE(attr.isa<TestExternalAttrInterface>());
initiallyEmpty.appendDialectRegistry(registry);
EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
}
/// External interface model for the module operation. Only provides non-default
/// methods.
struct TestExternalOpModel
@ -220,4 +321,55 @@ TEST(InterfaceAttachment, Operation) {
ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
}
struct TestExternalTestOpModel
: public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel,
test::OpJ> {
unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
return op->getName().getStringRef().size() + arg;
}
static unsigned getNameLengthPlusArgTwice(unsigned arg) {
return test::OpJ::getOperationName().size() + 2 * arg;
}
};
TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
DialectRegistry registry;
registry.insert<test::TestDialect>();
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
// Construct the context directly from a registry. The interfaces are expected
// to be readily available on operations.
MLIRContext context(registry);
context.loadDialect<test::TestDialect>();
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
OpBuilder builder(module);
auto op =
builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
}
TEST(InterfaceAttachment, OperationDelayedContextAppend) {
DialectRegistry registry;
registry.insert<test::TestDialect>();
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
// Construct the context, create ops, and only then append the registry. The
// interfaces are expected to be available after appending the registry.
MLIRContext context;
context.loadDialect<test::TestDialect>();
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
OpBuilder builder(module);
auto op =
builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation()));
EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation()));
context.appendDialectRegistry(registry);
EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
}
} // end namespace