forked from OSchip/llvm-project
Use dialect hook registration for constant folding hook.
Deletes specialized mechanism for registering constant folding hook and uses dialect hooks registration mechanism instead. PiperOrigin-RevId: 235535410
This commit is contained in:
parent
a51d21538c
commit
8b99d1bdbf
|
@ -178,26 +178,11 @@ private:
|
|||
};
|
||||
|
||||
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
|
||||
using ConstantFoldHookAllocator = std::function<void(MLIRContext *)>;
|
||||
|
||||
/// Registers a specific dialect creation function with the system, typically
|
||||
/// used through the DialectRegistration template.
|
||||
void registerDialectAllocator(const DialectAllocatorFunction &function);
|
||||
|
||||
/// Registers a constant fold hook for one or multiple dialects. The
|
||||
/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted
|
||||
/// dialect(s) in the context.
|
||||
/// Exmaple:
|
||||
/// registerConstantFoldHook([&](MLIRContext *ctx) {
|
||||
/// auto dialects = ctx->getRegisteredDialects();
|
||||
/// // then iterate and select the target dialect from dialects, or
|
||||
/// // get one dialect directly by the prefix:
|
||||
/// auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX")
|
||||
///
|
||||
/// dialect->constantFoldHook = MyConstantFoldHook;
|
||||
/// });
|
||||
void registerConstantFoldHook(const ConstantFoldHookAllocator &function);
|
||||
|
||||
/// Registers all dialects with the specified MLIRContext.
|
||||
void registerAllDialects(MLIRContext *context);
|
||||
|
||||
|
|
|
@ -38,6 +38,8 @@ using DialectHooksSetter = std::function<void(MLIRContext *)>;
|
|||
/// The subclass should override DialectHook methods for supported hooks.
|
||||
class DialectHooks {
|
||||
public:
|
||||
// Returns hook to constant fold an operation.
|
||||
DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
|
||||
// Returns hook to decode opaque constant tensor.
|
||||
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
|
||||
// Returns hook to extract an element of an opaque constant tensor.
|
||||
|
@ -65,6 +67,8 @@ template <typename ConcreteHooks> struct DialectHooksRegistration {
|
|||
}
|
||||
// Set hooks.
|
||||
ConcreteHooks hooks;
|
||||
if (auto h = hooks.getConstantFoldHook())
|
||||
dialect->constantFoldHook = h;
|
||||
if (auto h = hooks.getDecodeHook())
|
||||
dialect->decodeHook = h;
|
||||
if (auto h = hooks.getExtractElementHook())
|
||||
|
|
|
@ -25,10 +25,6 @@ using namespace mlir;
|
|||
static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
|
||||
dialectRegistry;
|
||||
|
||||
// Registry for dialect's constant fold hooks.
|
||||
static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
|
||||
constantFoldHookRegistry;
|
||||
|
||||
// Registry for functions that set dialect hooks.
|
||||
static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>>
|
||||
dialectHooksRegistry;
|
||||
|
@ -41,14 +37,6 @@ void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
|
|||
dialectRegistry->push_back(function);
|
||||
}
|
||||
|
||||
/// Registers a constant fold hook for a specific dialect with the system.
|
||||
void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
|
||||
assert(
|
||||
function &&
|
||||
"Attempting to register an empty constant fold hook initialize function");
|
||||
constantFoldHookRegistry->push_back(function);
|
||||
}
|
||||
|
||||
/// Registers a function to set specific hooks for a specific dialect, typically
|
||||
/// used through the DialectHooksRegistreation template.
|
||||
void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
|
||||
|
@ -64,8 +52,6 @@ void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
|
|||
void mlir::registerAllDialects(MLIRContext *context) {
|
||||
for (const auto &fn : *dialectRegistry)
|
||||
fn(context);
|
||||
for (const auto &fn : *constantFoldHookRegistry)
|
||||
fn(context);
|
||||
for (const auto &fn : *dialectHooksRegistry) {
|
||||
fn(context);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue