forked from OSchip/llvm-project
Adds ConstantFoldHook registry in MLIRContext
This reverts the previous method which needs to create a new dialect with the constant fold hook from TensorFlow. This new method uses a function object in dialect to store the constant fold hook. Once a hook is registered to the dialect, this function object will be assigned when the dialect is added to the MLIRContext. For the operations which are not registered, a new method getRegisteredDialects is added to the MLIRContext to query the dialects which matches their op name prefixes. PiperOrigin-RevId: 222310149
This commit is contained in:
parent
5041e13c96
commit
a9d3e5ee38
|
@ -26,6 +26,9 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
using DialectConstantFoldHook = std::function<bool(
|
||||
const Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
|
||||
|
||||
/// Dialects are groups of MLIR operations and behavior associated with the
|
||||
/// entire group. For example, hooks into other systems for constant folding,
|
||||
/// default named types for asm printing, etc.
|
||||
|
@ -39,19 +42,16 @@ public:
|
|||
|
||||
StringRef getOperationPrefix() const { return opPrefix; }
|
||||
|
||||
/// Dialect implementations can implement this hook. It should attempt to
|
||||
/// constant fold this operation with the specified constant operand values -
|
||||
/// the elements in "operands" will correspond directly to the operands of the
|
||||
/// operation, but may be null if non-constant. If constant folding is
|
||||
/// successful, this returns false and fills in the `results` vector. If not,
|
||||
/// this returns true and `results` is unspecified.
|
||||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
virtual bool constantFold(const Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) const {
|
||||
return true;
|
||||
}
|
||||
/// Registered fallback constant fold hook for the dialect. Like the constant
|
||||
/// fold hook of each operation, it attempts to constant fold the operation
|
||||
/// with the specified constant operand values - the elements in "operands"
|
||||
/// will correspond directly to the operands of the operation, but may be null
|
||||
/// if non-constant. If constant folding is successful, this returns false
|
||||
/// and fills in the `results` vector. If not, this returns true and
|
||||
/// `results` is unspecified.
|
||||
DialectConstantFoldHook constantFoldHook =
|
||||
[](const Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) { return true; };
|
||||
|
||||
// TODO: Hook to return the list of named types that are known.
|
||||
|
||||
|
@ -108,11 +108,26 @@ private:
|
|||
};
|
||||
|
||||
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
|
||||
using ConstantFoldHookAllocator = std::function<void(MLIRContext *)>;
|
||||
|
||||
/// Register a specific dialect creation function with the system, typically
|
||||
/// Registers a specific dialect creation function with the system, typically
|
||||
/// used through the DialectRegistration template.
|
||||
void registerDialectAllocator(const DialectAllocatorFunction &function);
|
||||
|
||||
/// Registers a constant fold hook for one or multiple dialects. The
|
||||
/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted
|
||||
/// dialect(s) in the context.
|
||||
/// Exmaple:
|
||||
/// registerConstantFoldHook([&](MLIRContext *ctx) {
|
||||
/// auto dialects = ctx->getRegisteredDialects();
|
||||
/// // then iterate and select the target dialect from dialects, or
|
||||
/// // get one dialect directly by the prefix:
|
||||
/// auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX")
|
||||
///
|
||||
/// dialect->constantFoldHook = MyConstantFoldHook;
|
||||
/// });
|
||||
void registerConstantFoldHook(const ConstantFoldHookAllocator &function);
|
||||
|
||||
/// Registers all dialects with the specified MLIRContext.
|
||||
void registerAllDialects(MLIRContext *context);
|
||||
|
||||
|
|
|
@ -45,6 +45,10 @@ public:
|
|||
/// Return information about all registered IR dialects.
|
||||
std::vector<Dialect *> getRegisteredDialects() const;
|
||||
|
||||
/// Get registered IR dialect which has the longest matching with the given
|
||||
/// prefix. If none is found, returns nullptr.
|
||||
Dialect *getRegisteredDialect(StringRef prefix) const;
|
||||
|
||||
/// Return information about all registered operations. This isn't very
|
||||
/// efficient: typically you should ask the operations about their properties
|
||||
/// directly.
|
||||
|
@ -97,4 +101,4 @@ private:
|
|||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_MLIRCONTEXT_H
|
||||
#endif // MLIR_IR_MLIRCONTEXT_H
|
||||
|
|
|
@ -24,17 +24,33 @@ using namespace mlir;
|
|||
static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
|
||||
dialectRegistry;
|
||||
|
||||
/// Register a specific dialect creation function with the system, typically
|
||||
// Registry for dialect's constant fold hooks.
|
||||
static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
|
||||
constantFoldHookRegistry;
|
||||
|
||||
/// Registers a specific dialect creation function with the system, typically
|
||||
/// used through the DialectRegistration template.
|
||||
void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
|
||||
assert(function && "Attempting to register an empty op initialize function");
|
||||
assert(function &&
|
||||
"Attempting to register an empty dialect initialize function");
|
||||
dialectRegistry->push_back(function);
|
||||
}
|
||||
|
||||
/// Registers all dialects with the specified MLIRContext.
|
||||
/// Registers a constant fold hook for a specific dialect with the system.
|
||||
void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
|
||||
assert(
|
||||
function &&
|
||||
"Attempting to register an empty constant fold hook initialize function");
|
||||
constantFoldHookRegistry->push_back(function);
|
||||
}
|
||||
|
||||
/// Registers all dialects and their const folding hooks with the specified
|
||||
/// MLIRContext.
|
||||
void mlir::registerAllDialects(MLIRContext *context) {
|
||||
for (const auto &fn : *dialectRegistry)
|
||||
fn(context);
|
||||
for (const auto &fn : *constantFoldHookRegistry)
|
||||
fn(context);
|
||||
}
|
||||
|
||||
Dialect::Dialect(StringRef opPrefix, MLIRContext *context)
|
||||
|
|
|
@ -516,6 +516,19 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Get registered IR dialect which has the longest matching with the given
|
||||
/// prefix. If none is found, returns nullptr.
|
||||
Dialect *MLIRContext::getRegisteredDialect(StringRef prefix) const {
|
||||
Dialect *result = nullptr;
|
||||
for (auto &dialect : getImpl().dialects) {
|
||||
if (prefix.startswith(dialect->getOperationPrefix()))
|
||||
if (!result || result->getOperationPrefix().size() <
|
||||
dialect->getOperationPrefix().size())
|
||||
result = dialect.get();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Register this dialect object with the specified context. The context
|
||||
/// takes ownership of the heap allocated dialect.
|
||||
void Dialect::registerDialect(MLIRContext *context) {
|
||||
|
|
|
@ -304,9 +304,16 @@ bool Operation::constantFold(ArrayRef<Attribute> operands,
|
|||
return false;
|
||||
|
||||
// Otherwise, fall back on the dialect hook to handle it.
|
||||
Dialect &dialect = abstractOp->dialect;
|
||||
return dialect.constantFold(this, operands, results);
|
||||
return abstractOp->dialect.constantFoldHook(this, operands, results);
|
||||
}
|
||||
|
||||
// If this operation hasn't been registered or doesn't have abstract
|
||||
// operation, fall back to a dialect which matches the prefix.
|
||||
auto opName = getName().getStringRef();
|
||||
if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
|
||||
return dialect->constantFoldHook(this, operands, results);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue