forked from OSchip/llvm-project
[mlir] Enable delayed registration of attribute/operation/type interfaces
This functionality is similar to delayed registration of dialect interfaces. It allows external interface models to be registered before the dialect containing the attribute/operation/type interface is loaded, or even before the context is created. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D104397
This commit is contained in:
parent
ccc0f777f6
commit
d7e8912134
|
@ -50,6 +50,12 @@ public:
|
|||
return interfaceMap.lookup<T>();
|
||||
}
|
||||
|
||||
/// Returns true if the attribute has the interface with the given ID
|
||||
/// registered.
|
||||
bool hasInterface(TypeID interfaceID) const {
|
||||
return interfaceMap.contains(interfaceID);
|
||||
}
|
||||
|
||||
/// Return the unique identifier representing the concrete attribute class.
|
||||
TypeID getTypeID() const { return typeID; }
|
||||
|
||||
|
|
|
@ -27,8 +27,9 @@ class Type;
|
|||
|
||||
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
|
||||
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
|
||||
using InterfaceAllocatorFunction =
|
||||
using DialectInterfaceAllocatorFunction =
|
||||
std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
|
||||
using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
|
||||
|
||||
/// Dialects are groups of MLIR operations, types and attributes, as well as
|
||||
/// behavior associated with the entire group. For example, hooks into other
|
||||
|
@ -278,11 +279,19 @@ private:
|
|||
/// dialects loaded in the Context. The parser in particular will lazily load
|
||||
/// dialects in the Context as operations are encountered.
|
||||
class DialectRegistry {
|
||||
/// Lists of interfaces that need to be registered when the dialect is loaded.
|
||||
struct DelayedInterfaces {
|
||||
/// Dialect interfaces.
|
||||
SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
|
||||
dialectInterfaces;
|
||||
/// Attribute/Operation/Type interfaces.
|
||||
SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
|
||||
objectInterfaces;
|
||||
};
|
||||
|
||||
using MapTy =
|
||||
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
|
||||
using InterfaceMapTy =
|
||||
DenseMap<TypeID,
|
||||
SmallVector<std::pair<TypeID, InterfaceAllocatorFunction>, 2>>;
|
||||
using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
|
||||
|
||||
public:
|
||||
explicit DialectRegistry() {}
|
||||
|
@ -336,7 +345,7 @@ public:
|
|||
/// the registry.
|
||||
template <typename DialectTy>
|
||||
void addDialectInterface(TypeID interfaceTypeID,
|
||||
InterfaceAllocatorFunction allocator) {
|
||||
DialectInterfaceAllocatorFunction allocator) {
|
||||
addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
|
||||
allocator);
|
||||
}
|
||||
|
@ -351,6 +360,36 @@ public:
|
|||
});
|
||||
}
|
||||
|
||||
/// Add an external op interface model for an op that belongs to a dialect,
|
||||
/// both provided as template parameters. The dialect must be present in the
|
||||
/// registry.
|
||||
template <typename OpTy, typename ModelTy>
|
||||
void addOpInterface() {
|
||||
StringRef opName = OpTy::getOperationName();
|
||||
StringRef dialectName = opName.split('.').first;
|
||||
addObjectInterface(dialectName == opName ? "" : dialectName,
|
||||
ModelTy::Interface::getInterfaceID(),
|
||||
[](MLIRContext *context) {
|
||||
OpTy::template attachInterface<ModelTy>(*context);
|
||||
});
|
||||
}
|
||||
|
||||
/// Add an external attribute interface model for an attribute type `AttrTy`
|
||||
/// that is going to belong to `DialectTy`. The dialect must be present in the
|
||||
/// registry.
|
||||
template <typename DialectTy, typename AttrTy, typename ModelTy>
|
||||
void addAttrInterface() {
|
||||
addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Add an external type interface model for an type class `TypeTy` that is
|
||||
/// going to belong to `DialectTy`. The dialect must be present in the
|
||||
/// registry.
|
||||
template <typename DialectTy, typename TypeTy, typename ModelTy>
|
||||
void addTypeInterface() {
|
||||
addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Register any interfaces required for the given dialect (based on its
|
||||
/// TypeID). Users are not expected to call this directly.
|
||||
void registerDelayedInterfaces(Dialect *dialect) const;
|
||||
|
@ -359,7 +398,22 @@ private:
|
|||
/// Add an interface constructed with the given allocation function to the
|
||||
/// dialect identified by its namespace.
|
||||
void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
|
||||
InterfaceAllocatorFunction allocator);
|
||||
DialectInterfaceAllocatorFunction allocator);
|
||||
|
||||
/// Add an attribute/operation/type interface constructible with the given
|
||||
/// allocation function to the dialect identified by its namespace.
|
||||
void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID,
|
||||
ObjectInterfaceAllocatorFunction allocator);
|
||||
|
||||
/// Add an external model for an attribute/type interface to the dialect
|
||||
/// identified by its namespace.
|
||||
template <typename ObjectTy, typename ModelTy>
|
||||
void addStorageUserInterface(StringRef dialectName) {
|
||||
addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
|
||||
[](MLIRContext *context) {
|
||||
ObjectTy::template attachInterface<ModelTy>(*context);
|
||||
});
|
||||
}
|
||||
|
||||
MapTy registry;
|
||||
InterfaceMapTy interfaces;
|
||||
|
|
|
@ -58,6 +58,11 @@ public:
|
|||
return interfaceMap.lookup<T>();
|
||||
}
|
||||
|
||||
/// Returns true if the type has the interface with the given ID.
|
||||
bool hasInterface(TypeID interfaceID) const {
|
||||
return interfaceMap.contains(interfaceID);
|
||||
}
|
||||
|
||||
/// Return the unique identifier representing the concrete type class.
|
||||
TypeID getTypeID() const { return typeID; }
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Support/TypeID.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/TypeName.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -236,8 +237,10 @@ public:
|
|||
llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
|
||||
return compare(it.first, id);
|
||||
});
|
||||
if (it != interfaces.end() && it->first == id)
|
||||
llvm::report_fatal_error("Interface already registered");
|
||||
if (it != interfaces.end() && it->first == id) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
|
||||
continue;
|
||||
}
|
||||
interfaces.insert(it, element);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
|
@ -31,7 +32,7 @@ DialectAsmParser::~DialectAsmParser() {}
|
|||
|
||||
void DialectRegistry::addDialectInterface(
|
||||
StringRef dialectName, TypeID interfaceTypeID,
|
||||
InterfaceAllocatorFunction allocator) {
|
||||
DialectInterfaceAllocatorFunction allocator) {
|
||||
assert(allocator && "unexpected null interface allocation function");
|
||||
auto it = registry.find(dialectName.str());
|
||||
assert(it != registry.end() &&
|
||||
|
@ -40,8 +41,8 @@ void DialectRegistry::addDialectInterface(
|
|||
// Bail out if the interface with the given ID is already in the registry for
|
||||
// the given dialect. We expect a small number (dozens) of interfaces so a
|
||||
// linear search is fine here.
|
||||
auto &dialectInterfaces = interfaces[it->second.first];
|
||||
for (const auto &kvp : dialectInterfaces) {
|
||||
auto &ifaces = interfaces[it->second.first];
|
||||
for (const auto &kvp : ifaces.dialectInterfaces) {
|
||||
if (kvp.first == interfaceTypeID) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "[" DEBUG_TYPE
|
||||
|
@ -51,7 +52,36 @@ void DialectRegistry::addDialectInterface(
|
|||
}
|
||||
}
|
||||
|
||||
dialectInterfaces.emplace_back(interfaceTypeID, allocator);
|
||||
ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
|
||||
}
|
||||
|
||||
void DialectRegistry::addObjectInterface(
|
||||
StringRef dialectName, TypeID interfaceTypeID,
|
||||
ObjectInterfaceAllocatorFunction allocator) {
|
||||
assert(allocator && "unexpected null interface allocation function");
|
||||
|
||||
// Builtin dialect has an empty prefix and is always registered.
|
||||
TypeID dialectTypeID;
|
||||
if (!dialectName.empty()) {
|
||||
auto it = registry.find(dialectName.str());
|
||||
assert(it != registry.end() &&
|
||||
"adding an interface for an op from an unregistered dialect");
|
||||
dialectTypeID = it->second.first;
|
||||
} else {
|
||||
dialectTypeID = TypeID::get<BuiltinDialect>();
|
||||
}
|
||||
|
||||
auto &ifaces = interfaces[dialectTypeID];
|
||||
for (const auto &kvp : ifaces.objectInterfaces) {
|
||||
if (kvp.first == interfaceTypeID) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "[" DEBUG_TYPE
|
||||
"] repeated interface object interface registration");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
|
||||
}
|
||||
|
||||
DialectAllocatorFunctionRef
|
||||
|
@ -79,11 +109,15 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
|
|||
return;
|
||||
|
||||
// Add an interface if it is not already present.
|
||||
for (const auto &kvp : it->second) {
|
||||
for (const auto &kvp : it->getSecond().dialectInterfaces) {
|
||||
if (dialect->getRegisteredInterface(kvp.first))
|
||||
continue;
|
||||
dialect->addInterface(kvp.second(dialect));
|
||||
}
|
||||
|
||||
// Add attribute, operation and type interfaces.
|
||||
for (const auto &kvp : it->getSecond().objectInterfaces)
|
||||
kvp.second(dialect->getContext());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -356,12 +356,12 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry)
|
|||
printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
|
||||
}
|
||||
|
||||
// Ensure the builtin dialect is always pre-loaded.
|
||||
getOrLoadDialect<BuiltinDialect>();
|
||||
|
||||
// Pre-populate the registry.
|
||||
registry.appendTo(impl->dialectsRegistry);
|
||||
|
||||
// Ensure the builtin dialect is always pre-loaded.
|
||||
getOrLoadDialect<BuiltinDialect>();
|
||||
|
||||
// Initialize several common attributes and types to avoid the need to lock
|
||||
// the context when accessing them.
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
@ -87,6 +88,74 @@ TEST(InterfaceAttachment, Type) {
|
|||
EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
|
||||
}
|
||||
|
||||
/// External interface model for the test type from the test dialect.
|
||||
struct TestTypeModel
|
||||
: public TestExternalTypeInterface::ExternalModel<TestTypeModel,
|
||||
test::TestType> {
|
||||
unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
|
||||
|
||||
static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
|
||||
};
|
||||
|
||||
TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
|
||||
// Put the interface in the registry.
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
|
||||
|
||||
// Check that when a context is constructed with the given registry, the type
|
||||
// interface gets registered.
|
||||
MLIRContext context(registry);
|
||||
context.loadDialect<test::TestDialect>();
|
||||
test::TestType testType = test::TestType::get(&context);
|
||||
auto iface = testType.dyn_cast<TestExternalTypeInterface>();
|
||||
ASSERT_TRUE(iface != nullptr);
|
||||
EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
|
||||
EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
|
||||
}
|
||||
|
||||
TEST(InterfaceAttachment, TypeDelayedContextAppend) {
|
||||
// Put the interface in the registry.
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
|
||||
|
||||
// Check that when the registry gets appended to the context, the interface
|
||||
// becomes available for objects in loaded dialects.
|
||||
MLIRContext context;
|
||||
context.loadDialect<test::TestDialect>();
|
||||
test::TestType testType = test::TestType::get(&context);
|
||||
EXPECT_FALSE(testType.isa<TestExternalTypeInterface>());
|
||||
context.appendDialectRegistry(registry);
|
||||
EXPECT_TRUE(testType.isa<TestExternalTypeInterface>());
|
||||
}
|
||||
|
||||
TEST(InterfaceAttachment, RepeatedRegistration) {
|
||||
DialectRegistry registry;
|
||||
registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
|
||||
MLIRContext context(registry);
|
||||
|
||||
// Should't fail on repeated registration through the dialect registry.
|
||||
context.appendDialectRegistry(registry);
|
||||
}
|
||||
|
||||
TEST(InterfaceAttachment, TypeBuiltinDelayed) {
|
||||
// Builtin dialect needs to registration or loading, but delayed interface
|
||||
// registration must still work.
|
||||
DialectRegistry registry;
|
||||
registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
|
||||
|
||||
MLIRContext context(registry);
|
||||
IntegerType i16 = IntegerType::get(&context, 16);
|
||||
EXPECT_TRUE(i16.isa<TestExternalTypeInterface>());
|
||||
|
||||
MLIRContext initiallyEmpty;
|
||||
IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
|
||||
EXPECT_FALSE(i32.isa<TestExternalTypeInterface>());
|
||||
initiallyEmpty.appendDialectRegistry(registry);
|
||||
EXPECT_TRUE(i32.isa<TestExternalTypeInterface>());
|
||||
}
|
||||
|
||||
/// The interface provides a default implementation that expects
|
||||
/// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
|
||||
/// just derives from the ExternalModel.
|
||||
|
@ -128,9 +197,9 @@ TEST(InterfaceAttachment, Fallback) {
|
|||
}
|
||||
|
||||
/// External model for attribute interfaces.
|
||||
struct TextExternalIntegerAttrModel
|
||||
struct TestExternalIntegerAttrModel
|
||||
: public TestExternalAttrInterface::ExternalModel<
|
||||
TextExternalIntegerAttrModel, IntegerAttr> {
|
||||
TestExternalIntegerAttrModel, IntegerAttr> {
|
||||
const Dialect *getDialectPtr(Attribute attr) const {
|
||||
return &attr.cast<IntegerAttr>().getDialect();
|
||||
}
|
||||
|
@ -145,13 +214,45 @@ TEST(InterfaceAttachment, Attribute) {
|
|||
// that the basics work for attributes.
|
||||
IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
|
||||
ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
|
||||
IntegerAttr::attachInterface<TextExternalIntegerAttrModel>(context);
|
||||
IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
|
||||
auto iface = attr.dyn_cast<TestExternalAttrInterface>();
|
||||
ASSERT_TRUE(iface != nullptr);
|
||||
EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
|
||||
EXPECT_EQ(iface.getSomeNumber(), 42);
|
||||
}
|
||||
|
||||
/// External model for an interface attachable to a non-builtin attribute.
|
||||
struct TestExternalSimpleAAttrModel
|
||||
: public TestExternalAttrInterface::ExternalModel<
|
||||
TestExternalSimpleAAttrModel, test::SimpleAAttr> {
|
||||
const Dialect *getDialectPtr(Attribute attr) const {
|
||||
return &attr.getDialect();
|
||||
}
|
||||
|
||||
static int getSomeNumber() { return 21; }
|
||||
};
|
||||
|
||||
TEST(InterfaceAttachmentTest, AttributeDelayed) {
|
||||
// Attribute interfaces use the exact same mechanism as types, so just check
|
||||
// that the delayed registration work for attributes.
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addAttrInterface<test::TestDialect, test::SimpleAAttr,
|
||||
TestExternalSimpleAAttrModel>();
|
||||
|
||||
MLIRContext context(registry);
|
||||
context.loadDialect<test::TestDialect>();
|
||||
auto attr = test::SimpleAAttr::get(&context);
|
||||
EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
|
||||
|
||||
MLIRContext initiallyEmpty;
|
||||
initiallyEmpty.loadDialect<test::TestDialect>();
|
||||
attr = test::SimpleAAttr::get(&initiallyEmpty);
|
||||
EXPECT_FALSE(attr.isa<TestExternalAttrInterface>());
|
||||
initiallyEmpty.appendDialectRegistry(registry);
|
||||
EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
|
||||
}
|
||||
|
||||
/// External interface model for the module operation. Only provides non-default
|
||||
/// methods.
|
||||
struct TestExternalOpModel
|
||||
|
@ -220,4 +321,55 @@ TEST(InterfaceAttachment, Operation) {
|
|||
ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
|
||||
}
|
||||
|
||||
struct TestExternalTestOpModel
|
||||
: public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel,
|
||||
test::OpJ> {
|
||||
unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
|
||||
return op->getName().getStringRef().size() + arg;
|
||||
}
|
||||
|
||||
static unsigned getNameLengthPlusArgTwice(unsigned arg) {
|
||||
return test::OpJ::getOperationName().size() + 2 * arg;
|
||||
}
|
||||
};
|
||||
|
||||
TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
|
||||
registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
|
||||
|
||||
// Construct the context directly from a registry. The interfaces are expected
|
||||
// to be readily available on operations.
|
||||
MLIRContext context(registry);
|
||||
context.loadDialect<test::TestDialect>();
|
||||
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
OpBuilder builder(module);
|
||||
auto op =
|
||||
builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
|
||||
EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
|
||||
EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
|
||||
}
|
||||
|
||||
TEST(InterfaceAttachment, OperationDelayedContextAppend) {
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
|
||||
registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
|
||||
|
||||
// Construct the context, create ops, and only then append the registry. The
|
||||
// interfaces are expected to be available after appending the registry.
|
||||
MLIRContext context;
|
||||
context.loadDialect<test::TestDialect>();
|
||||
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
OpBuilder builder(module);
|
||||
auto op =
|
||||
builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
|
||||
EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation()));
|
||||
EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation()));
|
||||
context.appendDialectRegistry(registry);
|
||||
EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
|
||||
EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
|
Loading…
Reference in New Issue