Use dialect hook registration for constant folding hook.

Deletes specialized mechanism for registering constant folding hook and uses dialect hooks registration mechanism instead.

PiperOrigin-RevId: 235535410
This commit is contained in:
Tatiana Shpeisman 2019-02-25 08:37:28 -08:00 committed by jpienaar
parent a51d21538c
commit 8b99d1bdbf
3 changed files with 4 additions and 29 deletions

View File

@ -178,26 +178,11 @@ private:
};
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
using ConstantFoldHookAllocator = std::function<void(MLIRContext *)>;
/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void registerDialectAllocator(const DialectAllocatorFunction &function);
/// Registers a constant fold hook for one or multiple dialects. The
/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted
/// dialect(s) in the context.
/// Exmaple:
/// registerConstantFoldHook([&](MLIRContext *ctx) {
/// auto dialects = ctx->getRegisteredDialects();
/// // then iterate and select the target dialect from dialects, or
/// // get one dialect directly by the prefix:
/// auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX")
///
/// dialect->constantFoldHook = MyConstantFoldHook;
/// });
void registerConstantFoldHook(const ConstantFoldHookAllocator &function);
/// Registers all dialects with the specified MLIRContext.
void registerAllDialects(MLIRContext *context);

View File

@ -38,6 +38,8 @@ using DialectHooksSetter = std::function<void(MLIRContext *)>;
/// The subclass should override DialectHook methods for supported hooks.
class DialectHooks {
public:
// Returns hook to constant fold an operation.
DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
// Returns hook to decode opaque constant tensor.
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
// Returns hook to extract an element of an opaque constant tensor.
@ -65,6 +67,8 @@ template <typename ConcreteHooks> struct DialectHooksRegistration {
}
// Set hooks.
ConcreteHooks hooks;
if (auto h = hooks.getConstantFoldHook())
dialect->constantFoldHook = h;
if (auto h = hooks.getDecodeHook())
dialect->decodeHook = h;
if (auto h = hooks.getExtractElementHook())

View File

@ -25,10 +25,6 @@ using namespace mlir;
static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
dialectRegistry;
// Registry for dialect's constant fold hooks.
static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
constantFoldHookRegistry;
// Registry for functions that set dialect hooks.
static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>>
dialectHooksRegistry;
@ -41,14 +37,6 @@ void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
dialectRegistry->push_back(function);
}
/// Registers a constant fold hook for a specific dialect with the system.
void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
assert(
function &&
"Attempting to register an empty constant fold hook initialize function");
constantFoldHookRegistry->push_back(function);
}
/// Registers a function to set specific hooks for a specific dialect, typically
/// used through the DialectHooksRegistreation template.
void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
@ -64,8 +52,6 @@ void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
void mlir::registerAllDialects(MLIRContext *context) {
for (const auto &fn : *dialectRegistry)
fn(context);
for (const auto &fn : *constantFoldHookRegistry)
fn(context);
for (const auto &fn : *dialectHooksRegistry) {
fn(context);
}