From 8b99d1bdbfebb48ff5b052216100b52c4cdf842d Mon Sep 17 00:00:00 2001 From: Tatiana Shpeisman Date: Mon, 25 Feb 2019 08:37:28 -0800 Subject: [PATCH] Use dialect hook registration for constant folding hook. Deletes specialized mechanism for registering constant folding hook and uses dialect hooks registration mechanism instead. PiperOrigin-RevId: 235535410 --- mlir/include/mlir/IR/Dialect.h | 15 --------------- mlir/include/mlir/IR/DialectHooks.h | 4 ++++ mlir/lib/IR/Dialect.cpp | 14 -------------- 3 files changed, 4 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 9f7732e37766..55b6f7efd365 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -178,26 +178,11 @@ private: }; using DialectAllocatorFunction = std::function; -using ConstantFoldHookAllocator = std::function; /// Registers a specific dialect creation function with the system, typically /// used through the DialectRegistration template. void registerDialectAllocator(const DialectAllocatorFunction &function); -/// Registers a constant fold hook for one or multiple dialects. The -/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted -/// dialect(s) in the context. -/// Exmaple: -/// registerConstantFoldHook([&](MLIRContext *ctx) { -/// auto dialects = ctx->getRegisteredDialects(); -/// // then iterate and select the target dialect from dialects, or -/// // get one dialect directly by the prefix: -/// auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX") -/// -/// dialect->constantFoldHook = MyConstantFoldHook; -/// }); -void registerConstantFoldHook(const ConstantFoldHookAllocator &function); - /// Registers all dialects with the specified MLIRContext. void registerAllDialects(MLIRContext *context); diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h index dbfb1ab33c70..f368988b5b40 100644 --- a/mlir/include/mlir/IR/DialectHooks.h +++ b/mlir/include/mlir/IR/DialectHooks.h @@ -38,6 +38,8 @@ using DialectHooksSetter = std::function; /// The subclass should override DialectHook methods for supported hooks. class DialectHooks { public: + // Returns hook to constant fold an operation. + DialectConstantFoldHook getConstantFoldHook() { return nullptr; } // Returns hook to decode opaque constant tensor. DialectConstantDecodeHook getDecodeHook() { return nullptr; } // Returns hook to extract an element of an opaque constant tensor. @@ -65,6 +67,8 @@ template struct DialectHooksRegistration { } // Set hooks. ConcreteHooks hooks; + if (auto h = hooks.getConstantFoldHook()) + dialect->constantFoldHook = h; if (auto h = hooks.getDecodeHook()) dialect->decodeHook = h; if (auto h = hooks.getExtractElementHook()) diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 249c9d84c1f7..338c918c3396 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -25,10 +25,6 @@ using namespace mlir; static llvm::ManagedStatic> dialectRegistry; -// Registry for dialect's constant fold hooks. -static llvm::ManagedStatic> - constantFoldHookRegistry; - // Registry for functions that set dialect hooks. static llvm::ManagedStatic> dialectHooksRegistry; @@ -41,14 +37,6 @@ void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) { dialectRegistry->push_back(function); } -/// Registers a constant fold hook for a specific dialect with the system. -void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) { - assert( - function && - "Attempting to register an empty constant fold hook initialize function"); - constantFoldHookRegistry->push_back(function); -} - /// Registers a function to set specific hooks for a specific dialect, typically /// used through the DialectHooksRegistreation template. void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) { @@ -64,8 +52,6 @@ void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) { void mlir::registerAllDialects(MLIRContext *context) { for (const auto &fn : *dialectRegistry) fn(context); - for (const auto &fn : *constantFoldHookRegistry) - fn(context); for (const auto &fn : *dialectHooksRegistry) { fn(context); }