llvm-project/mlir/unittests/IR/DialectTest.cpp

146 lines
5.3 KiB
C++

//===- DialectTest.cpp - Dialect unit tests -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::detail;
namespace {
struct TestDialect : public Dialect {
static StringRef getDialectNamespace() { return "test"; };
TestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
};
struct AnotherTestDialect : public Dialect {
static StringRef getDialectNamespace() { return "test"; };
AnotherTestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<AnotherTestDialect>()) {}
};
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
MLIRContext context;
// Registering a dialect with the same namespace twice should result in a
// failure.
context.loadDialect<TestDialect>();
ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
}
struct SecondTestDialect : public Dialect {
static StringRef getDialectNamespace() { return "test2"; }
SecondTestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<SecondTestDialect>()) {}
};
struct TestDialectInterfaceBase
: public DialectInterface::Base<TestDialectInterfaceBase> {
TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
virtual int function() const { return 42; }
};
struct TestDialectInterface : public TestDialectInterfaceBase {
using TestDialectInterfaceBase::TestDialectInterfaceBase;
int function() const final { return 56; }
};
struct SecondTestDialectInterface : public TestDialectInterfaceBase {
using TestDialectInterfaceBase::TestDialectInterfaceBase;
int function() const final { return 78; }
};
TEST(Dialect, DelayedInterfaceRegistration) {
DialectRegistry registry;
registry.insert<TestDialect, SecondTestDialect>();
// Delayed registration of an interface for TestDialect.
registry.addDialectInterface<TestDialect, TestDialectInterface>();
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
auto *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
auto *testDialectInterface =
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
EXPECT_TRUE(testDialectInterface != nullptr);
// Load the SecondTestDialect and check that the interface is not registered
// for it.
auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
ASSERT_TRUE(secondTestDialect != nullptr);
auto *secondTestDialectInterface =
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
EXPECT_TRUE(secondTestDialectInterface == nullptr);
// Use the same mechanism as for delayed registration but for an already
// loaded dialect and check that the interface is now registered.
DialectRegistry secondRegistry;
secondRegistry.insert<SecondTestDialect>();
secondRegistry
.addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
context.appendDialectRegistry(secondRegistry);
secondTestDialectInterface =
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
EXPECT_TRUE(secondTestDialectInterface != nullptr);
}
TEST(Dialect, RepeatedDelayedRegistration) {
// Set up the delayed registration.
DialectRegistry registry;
registry.insert<TestDialect>();
registry.addDialectInterface<TestDialect, TestDialectInterface>();
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
auto *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
auto *testDialectInterface =
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
EXPECT_TRUE(testDialectInterface != nullptr);
// Try adding the same dialect interface again and check that we don't crash
// on repeated interface registration.
DialectRegistry secondRegistry;
secondRegistry.insert<TestDialect>();
secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
context.appendDialectRegistry(secondRegistry);
testDialectInterface =
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
EXPECT_TRUE(testDialectInterface != nullptr);
}
// A dialect that registers two interfaces with the same InterfaceID, triggering
// an assertion failure.
struct RepeatedRegistrationDialect : public Dialect {
static StringRef getDialectNamespace() { return "repeatedreg"; }
RepeatedRegistrationDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<RepeatedRegistrationDialect>()) {
addInterfaces<TestDialectInterface>();
addInterfaces<SecondTestDialectInterface>();
}
};
TEST(Dialect, RepeatedInterfaceRegistrationDeath) {
MLIRContext context;
(void)context;
// This triggers an assertion in debug mode.
#ifndef NDEBUG
ASSERT_DEATH(context.loadDialect<RepeatedRegistrationDialect>(),
"interface kind has already been registered");
#endif
}
} // namespace