forked from OSchip/llvm-project
[mlir] Add isa/dyn_cast support for dialect interfaces
This matches the same API usage as attributes/ops/types. For example: ```c++ Dialect *dialect = ...; // Instead of this: if (auto *interface = dialect->getRegisteredInterface<DialectInlinerInterface>()) // You can do this: if (auto *interface = dyn_cast<DialectInlinerInterface>(dialect)) ``` Differential Revision: https://reviews.llvm.org/D117859
This commit is contained in:
parent
51ed14d224
commit
58e7bf78a3
|
@ -77,8 +77,7 @@ or transformation without the need to determine the specific dialect subclass:
|
|||
|
||||
```c++
|
||||
Dialect *dialect = ...;
|
||||
if (DialectInlinerInterface *interface
|
||||
= dialect->getRegisteredInterface<DialectInlinerInterface>()) {
|
||||
if (DialectInlinerInterface *interface = dyn_cast<DialectInlinerInterface>(dialect)) {
|
||||
// The dialect has provided an implementation of this interface.
|
||||
...
|
||||
}
|
||||
|
|
|
@ -440,11 +440,58 @@ private:
|
|||
|
||||
namespace llvm {
|
||||
/// Provide isa functionality for Dialects.
|
||||
template <typename T> struct isa_impl<T, ::mlir::Dialect> {
|
||||
template <typename T>
|
||||
struct isa_impl<T, ::mlir::Dialect,
|
||||
std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
|
||||
static inline bool doit(const ::mlir::Dialect &dialect) {
|
||||
return mlir::TypeID::get<T>() == dialect.getTypeID();
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
struct isa_impl<
|
||||
T, ::mlir::Dialect,
|
||||
std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
|
||||
static inline bool doit(const ::mlir::Dialect &dialect) {
|
||||
return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
struct cast_retty_impl<T, ::mlir::Dialect *> {
|
||||
using ret_type =
|
||||
std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
|
||||
const T *>;
|
||||
};
|
||||
template <typename T>
|
||||
struct cast_retty_impl<T, ::mlir::Dialect> {
|
||||
using ret_type =
|
||||
std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
|
||||
const T &>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
|
||||
template <typename To>
|
||||
static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
|
||||
doitImpl(::mlir::Dialect &dialect) {
|
||||
return static_cast<To &>(dialect);
|
||||
}
|
||||
template <typename To>
|
||||
static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
|
||||
const To &>
|
||||
doitImpl(::mlir::Dialect &dialect) {
|
||||
return *dialect.getRegisteredInterface<To>();
|
||||
}
|
||||
|
||||
static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
|
||||
};
|
||||
template <class T>
|
||||
struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
|
||||
static auto doit(::mlir::Dialect *dialect) {
|
||||
return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
|
||||
*dialect);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
#endif
|
||||
|
|
|
@ -231,8 +231,8 @@ combineOneSpec(DataLayoutSpecInterface spec,
|
|||
// dialect is not loaded for some reason, use the default combinator
|
||||
// that conservatively accepts identical entries only.
|
||||
entriesForID[id] =
|
||||
dialect ? dialect->getRegisteredInterface<DataLayoutDialectInterface>()
|
||||
->combine(entriesForID[id], kvp.second)
|
||||
dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
|
||||
entriesForID[id], kvp.second)
|
||||
: DataLayoutDialectInterface::defaultCombine(entriesForID[id],
|
||||
kvp.second);
|
||||
if (!entriesForID[id])
|
||||
|
|
|
@ -1236,8 +1236,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
|
|||
Dialect *dialect = getContext()->getLoadedDialect(getDialect());
|
||||
if (!dialect)
|
||||
return true;
|
||||
auto *interface =
|
||||
dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
|
||||
auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
|
||||
if (!interface)
|
||||
return true;
|
||||
return failed(interface->decode(*this, result));
|
||||
|
|
|
@ -506,7 +506,7 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
|
|||
if (!dialect)
|
||||
return failure();
|
||||
|
||||
auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
|
||||
auto *interface = dyn_cast<DialectFoldInterface>(dialect);
|
||||
if (!interface)
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -438,8 +438,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
|
|||
if (!dialect)
|
||||
continue;
|
||||
|
||||
const auto *iface =
|
||||
dialect->getRegisteredInterface<DataLayoutDialectInterface>();
|
||||
const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
|
||||
if (!iface) {
|
||||
return emitError(loc)
|
||||
<< "the '" << dialect->getNamespace()
|
||||
|
|
|
@ -68,18 +68,17 @@ TEST(Dialect, DelayedInterfaceRegistration) {
|
|||
MLIRContext context(registry);
|
||||
|
||||
// Load the TestDialect and check that the interface got registered for it.
|
||||
auto *testDialect = context.getOrLoadDialect<TestDialect>();
|
||||
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
|
||||
ASSERT_TRUE(testDialect != nullptr);
|
||||
auto *testDialectInterface =
|
||||
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
|
||||
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
|
||||
EXPECT_TRUE(testDialectInterface != nullptr);
|
||||
|
||||
// Load the SecondTestDialect and check that the interface is not registered
|
||||
// for it.
|
||||
auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
|
||||
Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
|
||||
ASSERT_TRUE(secondTestDialect != nullptr);
|
||||
auto *secondTestDialectInterface =
|
||||
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
|
||||
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
|
||||
EXPECT_TRUE(secondTestDialectInterface == nullptr);
|
||||
|
||||
// Use the same mechanism as for delayed registration but for an already
|
||||
|
@ -90,7 +89,7 @@ TEST(Dialect, DelayedInterfaceRegistration) {
|
|||
.addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
|
||||
context.appendDialectRegistry(secondRegistry);
|
||||
secondTestDialectInterface =
|
||||
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
|
||||
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
|
||||
EXPECT_TRUE(secondTestDialectInterface != nullptr);
|
||||
}
|
||||
|
||||
|
@ -102,10 +101,9 @@ TEST(Dialect, RepeatedDelayedRegistration) {
|
|||
MLIRContext context(registry);
|
||||
|
||||
// Load the TestDialect and check that the interface got registered for it.
|
||||
auto *testDialect = context.getOrLoadDialect<TestDialect>();
|
||||
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
|
||||
ASSERT_TRUE(testDialect != nullptr);
|
||||
auto *testDialectInterface =
|
||||
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
|
||||
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
|
||||
EXPECT_TRUE(testDialectInterface != nullptr);
|
||||
|
||||
// Try adding the same dialect interface again and check that we don't crash
|
||||
|
@ -114,8 +112,7 @@ TEST(Dialect, RepeatedDelayedRegistration) {
|
|||
secondRegistry.insert<TestDialect>();
|
||||
secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
|
||||
context.appendDialectRegistry(secondRegistry);
|
||||
testDialectInterface =
|
||||
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
|
||||
testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
|
||||
EXPECT_TRUE(testDialectInterface != nullptr);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue