[mlir][bufferization] Add extra filter mechanism to bufferizeOp

Differential Revision: https://reviews.llvm.org/D126569
This commit is contained in:
Matthias Springer 2022-05-28 04:48:36 +02:00
parent f470f8cbce
commit 2f0a634c5e
2 changed files with 26 additions and 10 deletions

View File

@ -27,6 +27,7 @@ namespace bufferization {
class AnalysisState;
struct BufferizationState;
struct BufferizationOptions;
class OpFilter;
/// A helper type converter class that automatically populates the relevant
/// materializations and type conversions for bufferization.
@ -84,8 +85,8 @@ BufferizationOptions getPartialBufferizationOptions();
/// Reuse an existing `BufferizationState`.
///
/// Note: This function overload is useful for extending the bufferization.
LogicalResult bufferizeOp(Operation *op,
BufferizationState &bufferizationState);
LogicalResult bufferizeOp(Operation *op, BufferizationState &bufferizationState,
const OpFilter *opFilter = nullptr);
/// Finalize all buffer allocations: Create alloc/dealloc ops as specified by
/// the bufferization options.

View File

@ -345,9 +345,10 @@ class BufferizationRewriter : public IRRewriter {
public:
BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
DenseSet<Operation *> &toMemrefOps,
const BufferizationOptions &options)
const BufferizationOptions &options,
const OpFilter *opFilter)
: IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
options(options) {}
options(options), opFilter(opFilter) {}
protected:
void notifyOperationRemoved(Operation *op) override {
@ -370,10 +371,18 @@ protected:
if (isa<ToTensorOp>(op))
return;
// Skip non-tensor ops.
if (!hasTensorSemantics(op))
return;
// Skip ops that are not allowed.
if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
return;
// Adding new bufferizable ops is not allowed during bufferization. Such ops
// would not be analyzed and can lead to surprising behavior.
assert((!hasTensorSemantics(op) || !options.isOpAllowed(op)) &&
"creating new tensor ops is not allowed during bufferization");
llvm_unreachable(
"creating new tensor ops is not allowed during bufferization");
}
private:
@ -387,12 +396,14 @@ private:
/// Used for debug modes.
LLVM_ATTRIBUTE_UNUSED
const BufferizationOptions &options;
const OpFilter *opFilter;
};
} // namespace
LogicalResult
bufferization::bufferizeOp(Operation *op,
BufferizationState &bufferizationState) {
LogicalResult bufferization::bufferizeOp(Operation *op,
BufferizationState &bufferizationState,
const OpFilter *opFilter) {
const auto &options = bufferizationState.getOptions();
assert(options.unknownTypeConversion !=
BufferizationOptions::LayoutMapOption::InferLayoutMap &&
@ -420,7 +431,7 @@ bufferization::bufferizeOp(Operation *op,
// Bufferize all ops.
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
bufferizationState.getOptions());
bufferizationState.getOptions(), opFilter);
for (unsigned i = 0; i < worklist.size(); ++i) {
Operation *op = worklist[i];
// Skip ops that were erased.
@ -430,6 +441,8 @@ bufferization::bufferizeOp(Operation *op,
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (!bufferizableOp)
continue;
if (opFilter && !opFilter->isOpAllowed(op))
continue;
// Skip ops that no longer have tensor semantics.
if (!hasTensorSemantics(op))
continue;
@ -462,6 +475,8 @@ bufferization::bufferizeOp(Operation *op,
// Continue ops that are not allowed.
if (!options.isOpAllowed(op))
continue;
if (opFilter && !opFilter->isOpAllowed(op))
continue;
// Ops without any uses and no side effects will fold away.
if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
continue;