Adds ConstantFoldHook registry in MLIRContext

This reverts the previous method which needs to create a new dialect with the
constant fold hook from TensorFlow. This new method uses a function object in
dialect to store the constant fold hook. Once a hook is registered to the
dialect, this function object will be assigned when the dialect is added to the
MLIRContext.

For the operations which are not registered, a new method getRegisteredDialects
is added to the MLIRContext to query the dialects which matches their op name
prefixes.

PiperOrigin-RevId: 222310149
This commit is contained in:
Feng Liu 2018-11-20 14:47:10 -08:00 committed by jpienaar
parent 5041e13c96
commit a9d3e5ee38
5 changed files with 75 additions and 20 deletions

View File

@ -26,6 +26,9 @@
namespace mlir {
using DialectConstantFoldHook = std::function<bool(
const Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
/// default named types for asm printing, etc.
@ -39,19 +42,16 @@ public:
StringRef getOperationPrefix() const { return opPrefix; }
/// Dialect implementations can implement this hook. It should attempt to
/// constant fold this operation with the specified constant operand values -
/// the elements in "operands" will correspond directly to the operands of the
/// operation, but may be null if non-constant. If constant folding is
/// successful, this returns false and fills in the `results` vector. If not,
/// this returns true and `results` is unspecified.
///
/// If not overridden, this fallback implementation always fails to fold.
///
virtual bool constantFold(const Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) const {
return true;
}
/// Registered fallback constant fold hook for the dialect. Like the constant
/// fold hook of each operation, it attempts to constant fold the operation
/// with the specified constant operand values - the elements in "operands"
/// will correspond directly to the operands of the operation, but may be null
/// if non-constant. If constant folding is successful, this returns false
/// and fills in the `results` vector. If not, this returns true and
/// `results` is unspecified.
DialectConstantFoldHook constantFoldHook =
[](const Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) { return true; };
// TODO: Hook to return the list of named types that are known.
@ -108,11 +108,26 @@ private:
};
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
using ConstantFoldHookAllocator = std::function<void(MLIRContext *)>;
/// Register a specific dialect creation function with the system, typically
/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void registerDialectAllocator(const DialectAllocatorFunction &function);
/// Registers a constant fold hook for one or multiple dialects. The
/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted
/// dialect(s) in the context.
/// Exmaple:
/// registerConstantFoldHook([&](MLIRContext *ctx) {
/// auto dialects = ctx->getRegisteredDialects();
/// // then iterate and select the target dialect from dialects, or
/// // get one dialect directly by the prefix:
/// auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX")
///
/// dialect->constantFoldHook = MyConstantFoldHook;
/// });
void registerConstantFoldHook(const ConstantFoldHookAllocator &function);
/// Registers all dialects with the specified MLIRContext.
void registerAllDialects(MLIRContext *context);

View File

@ -45,6 +45,10 @@ public:
/// Return information about all registered IR dialects.
std::vector<Dialect *> getRegisteredDialects() const;
/// Get registered IR dialect which has the longest matching with the given
/// prefix. If none is found, returns nullptr.
Dialect *getRegisteredDialect(StringRef prefix) const;
/// Return information about all registered operations. This isn't very
/// efficient: typically you should ask the operations about their properties
/// directly.
@ -97,4 +101,4 @@ private:
};
} // end namespace mlir
#endif // MLIR_IR_MLIRCONTEXT_H
#endif // MLIR_IR_MLIRCONTEXT_H

View File

@ -24,17 +24,33 @@ using namespace mlir;
static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
dialectRegistry;
/// Register a specific dialect creation function with the system, typically
// Registry for dialect's constant fold hooks.
static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
constantFoldHookRegistry;
/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
assert(function && "Attempting to register an empty op initialize function");
assert(function &&
"Attempting to register an empty dialect initialize function");
dialectRegistry->push_back(function);
}
/// Registers all dialects with the specified MLIRContext.
/// Registers a constant fold hook for a specific dialect with the system.
void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
assert(
function &&
"Attempting to register an empty constant fold hook initialize function");
constantFoldHookRegistry->push_back(function);
}
/// Registers all dialects and their const folding hooks with the specified
/// MLIRContext.
void mlir::registerAllDialects(MLIRContext *context) {
for (const auto &fn : *dialectRegistry)
fn(context);
for (const auto &fn : *constantFoldHookRegistry)
fn(context);
}
Dialect::Dialect(StringRef opPrefix, MLIRContext *context)

View File

@ -516,6 +516,19 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
return result;
}
/// Get registered IR dialect which has the longest matching with the given
/// prefix. If none is found, returns nullptr.
Dialect *MLIRContext::getRegisteredDialect(StringRef prefix) const {
Dialect *result = nullptr;
for (auto &dialect : getImpl().dialects) {
if (prefix.startswith(dialect->getOperationPrefix()))
if (!result || result->getOperationPrefix().size() <
dialect->getOperationPrefix().size())
result = dialect.get();
}
return result;
}
/// Register this dialect object with the specified context. The context
/// takes ownership of the heap allocated dialect.
void Dialect::registerDialect(MLIRContext *context) {

View File

@ -304,9 +304,16 @@ bool Operation::constantFold(ArrayRef<Attribute> operands,
return false;
// Otherwise, fall back on the dialect hook to handle it.
Dialect &dialect = abstractOp->dialect;
return dialect.constantFold(this, operands, results);
return abstractOp->dialect.constantFoldHook(this, operands, results);
}
// If this operation hasn't been registered or doesn't have abstract
// operation, fall back to a dialect which matches the prefix.
auto opName = getName().getStringRef();
if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
return dialect->constantFoldHook(this, operands, results);
}
return true;
}