[mlir] Add isa/dyn_cast support for dialect interfaces

This matches the same API usage as attributes/ops/types. For example:

```c++
Dialect *dialect = ...;

// Instead of this:
if (auto *interface = dialect->getRegisteredInterface<DialectInlinerInterface>())

// You can do this:
if (auto *interface = dyn_cast<DialectInlinerInterface>(dialect))
```

Differential Revision: https://reviews.llvm.org/D117859
This commit is contained in:
River Riddle 2022-01-21 00:38:30 -08:00
parent 51ed14d224
commit 58e7bf78a3
7 changed files with 62 additions and 21 deletions

View File

@ -77,8 +77,7 @@ or transformation without the need to determine the specific dialect subclass:
```c++
Dialect *dialect = ...;
if (DialectInlinerInterface *interface
= dialect->getRegisteredInterface<DialectInlinerInterface>()) {
if (DialectInlinerInterface *interface = dyn_cast<DialectInlinerInterface>(dialect)) {
// The dialect has provided an implementation of this interface.
...
}

View File

@ -440,11 +440,58 @@ private:
namespace llvm {
/// Provide isa functionality for Dialects.
template <typename T> struct isa_impl<T, ::mlir::Dialect> {
template <typename T>
struct isa_impl<T, ::mlir::Dialect,
std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
static inline bool doit(const ::mlir::Dialect &dialect) {
return mlir::TypeID::get<T>() == dialect.getTypeID();
}
};
template <typename T>
struct isa_impl<
T, ::mlir::Dialect,
std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
static inline bool doit(const ::mlir::Dialect &dialect) {
return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
}
};
template <typename T>
struct cast_retty_impl<T, ::mlir::Dialect *> {
using ret_type =
std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
const T *>;
};
template <typename T>
struct cast_retty_impl<T, ::mlir::Dialect> {
using ret_type =
std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
const T &>;
};
template <typename T>
struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
template <typename To>
static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
doitImpl(::mlir::Dialect &dialect) {
return static_cast<To &>(dialect);
}
template <typename To>
static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
const To &>
doitImpl(::mlir::Dialect &dialect) {
return *dialect.getRegisteredInterface<To>();
}
static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
};
template <class T>
struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
static auto doit(::mlir::Dialect *dialect) {
return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
*dialect);
}
};
} // namespace llvm
#endif

View File

@ -231,8 +231,8 @@ combineOneSpec(DataLayoutSpecInterface spec,
// dialect is not loaded for some reason, use the default combinator
// that conservatively accepts identical entries only.
entriesForID[id] =
dialect ? dialect->getRegisteredInterface<DataLayoutDialectInterface>()
->combine(entriesForID[id], kvp.second)
dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
entriesForID[id], kvp.second)
: DataLayoutDialectInterface::defaultCombine(entriesForID[id],
kvp.second);
if (!entriesForID[id])

View File

@ -1236,8 +1236,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
Dialect *dialect = getContext()->getLoadedDialect(getDialect());
if (!dialect)
return true;
auto *interface =
dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
if (!interface)
return true;
return failed(interface->decode(*this, result));

View File

@ -506,7 +506,7 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
if (!dialect)
return failure();
auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
auto *interface = dyn_cast<DialectFoldInterface>(dialect);
if (!interface)
return failure();

View File

@ -438,8 +438,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
if (!dialect)
continue;
const auto *iface =
dialect->getRegisteredInterface<DataLayoutDialectInterface>();
const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
if (!iface) {
return emitError(loc)
<< "the '" << dialect->getNamespace()

View File

@ -68,18 +68,17 @@ TEST(Dialect, DelayedInterfaceRegistration) {
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
auto *testDialect = context.getOrLoadDialect<TestDialect>();
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
auto *testDialectInterface =
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
// Load the SecondTestDialect and check that the interface is not registered
// for it.
auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
ASSERT_TRUE(secondTestDialect != nullptr);
auto *secondTestDialectInterface =
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface == nullptr);
// Use the same mechanism as for delayed registration but for an already
@ -90,7 +89,7 @@ TEST(Dialect, DelayedInterfaceRegistration) {
.addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
context.appendDialectRegistry(secondRegistry);
secondTestDialectInterface =
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface != nullptr);
}
@ -102,10 +101,9 @@ TEST(Dialect, RepeatedDelayedRegistration) {
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
auto *testDialect = context.getOrLoadDialect<TestDialect>();
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
auto *testDialectInterface =
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
// Try adding the same dialect interface again and check that we don't crash
@ -114,8 +112,7 @@ TEST(Dialect, RepeatedDelayedRegistration) {
secondRegistry.insert<TestDialect>();
secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
context.appendDialectRegistry(secondRegistry);
testDialectInterface =
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
}