[mlir][linalg][bufferize] Bufferize Operation* instead of FuncOp

This change mainly changes the API. There is no mentioning of FuncOps in ComprehensiveBufferize anymore.

Also, bufferize methods of the op interface are called for ops without tensor operands/results if they have a region.

Differential Revision: https://reviews.llvm.org/D115212
This commit is contained in:
Matthias Springer 2021-12-07 17:53:45 +09:00
parent 718a1c989a
commit 958ae8b2d4
12 changed files with 56 additions and 52 deletions

View File

@ -70,7 +70,7 @@ struct PostAnalysisStep {
/// Run the post analysis step. This function may modify the IR, but must keep
/// `aliasInfo` (inside `state`) consistent. Newly created operations and
/// operations that should be re-analyzed must be stored in `newOps`.
virtual LogicalResult run(FuncOp funcOp, BufferizationState &state,
virtual LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) = 0;
};
@ -299,9 +299,8 @@ struct DialectBufferizationState {
/// directly return a mapped buffer or allocate a new brand new buffer.
class BufferizationState {
public:
BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options)
: aliasInfo(moduleOp), options(options),
builder(moduleOp->getContext()) {}
BufferizationState(Operation *op, const BufferizationOptions &options)
: aliasInfo(op), options(options), builder(op->getContext()) {}
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
@ -365,7 +364,7 @@ public:
private:
friend LogicalResult
runComprehensiveBufferize(FuncOp funcOp, const BufferizationOptions &options,
runComprehensiveBufferize(Operation *op, const BufferizationOptions &options,
BufferizationState &state,
const PostAnalysisStepList &extraSteps);

View File

@ -196,7 +196,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
before returning. Otherwise, nested ops will not be bufferized.
This method will never be called on ops that do not have at least one
tensor operand or result.
tensor operand/result or a region.
}],
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",

View File

@ -20,15 +20,16 @@ struct BufferizationOptions;
struct BufferizationState;
struct PostAnalysisStep;
/// Bufferize the given function. Does not bufferize the function boundary.
/// Reuses an existing BufferizationState object.
// TODO: This function is meant to be called from ModuleBufferize and not can
// not yet be called standalone.
/// Bufferize the given operation. Reuses an existing BufferizationState object.
LogicalResult runComprehensiveBufferize(
FuncOp funcOp, const BufferizationOptions &options,
Operation *op, const BufferizationOptions &options,
BufferizationState &state,
const std::vector<std::unique_ptr<PostAnalysisStep>> &extraSteps);
/// Bufferize the given operation.
LogicalResult runComprehensiveBufferize(Operation *op,
const BufferizationOptions &options);
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir

View File

@ -23,7 +23,7 @@ class BufferizationAliasInfo;
namespace linalg_ext {
struct InitTensorEliminationStep : public PostAnalysisStep {
/// Try to eliminate InitTensorOps inside `funcOp`.
/// Try to eliminate InitTensorOps inside `op`.
///
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
/// * Only InitTensorOps that are anchored on a matching OpOperand as per
@ -34,19 +34,19 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
/// This analysis can be skipped with `skipAnalysis`.
LogicalResult eliminateInitTensors(
FuncOp funcOp, BufferizationState &state,
Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
std::function<bool(OpOperand &)> anchorMatchFunc,
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
SmallVector<Operation *> &newOps);
};
/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
/// Try to eliminate InitTensorOps inside `op` that are anchored on an
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
struct InsertSliceAnchoredInitTensorEliminationStep
: public InitTensorEliminationStep {
LogicalResult run(FuncOp funcOp, BufferizationState &state,
LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override;
};

View File

@ -22,7 +22,7 @@ namespace scf_ext {
/// Equivalence analysis for scf.for. Raise an error if iter_args are not
/// equivalent to their corresponding loop yield values.
struct AssertDestinationPassingStyle : public PostAnalysisStep {
LogicalResult run(FuncOp funcOp, BufferizationState &state,
LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override;
};

View File

@ -20,7 +20,7 @@ namespace comprehensive_bufferize {
namespace tensor_ext {
struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep {
LogicalResult run(FuncOp funcOp, BufferizationState &state,
LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override;
};

View File

@ -425,14 +425,11 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
bool hasRegions = !op->getRegions().empty();
// No tensor results or operands: Simply bufferize all nested ops.
if (!hasTensorResult && !hasTensorOperand) {
for (Region &region : op->getRegions())
if (failed(bufferize(&region, state)))
return failure();
// No tensor results/operands or regions. We are done.
if (!hasTensorResult && !hasTensorOperand && !hasRegions)
return success();
}
// Bufferize using `BufferizableOpInterface`. Interface implementations are
// responsible for bufferizing nested ops.
@ -449,6 +446,8 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
for (OpOperand &operand : op->getOpOperands()) {
if (operand.get().getType().isa<TensorType>() &&
state.isMapped(operand.get())) {
assert(state.getOptions().allowUnknownOps &&
"unsupported op error should have been emitted earlier");
b.setInsertionPoint(op);
Value toTensorOp = b.create<bufferization::ToTensorOp>(
op->getLoc(), state.lookupBuffer(operand.get()));
@ -456,6 +455,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
}
}
// Bufferize all regions.
for (Region &region : op->getRegions())
if (failed(bufferize(&region, state)))
return failure();

View File

@ -667,10 +667,10 @@ static void equivalenceAnalysis(Operation *op,
/// Assert that the current bufferization decisions are consistent.
static LogicalResult
checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo,
checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo) {
Operation *inconsistentOp = nullptr;
WalkResult walkResult = funcOp.walk([&](Operation *op) {
WalkResult walkResult = op->walk([&](Operation *op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>()) {
@ -710,20 +710,23 @@ annotateOpsWithBufferizationMarkers(Operation *op,
}
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
FuncOp funcOp, const BufferizationOptions &options,
Operation *op, const BufferizationOptions &options) {
BufferizationState state(op, options);
PostAnalysisStepList extraSteps;
return runComprehensiveBufferize(op, options, state, extraSteps);
}
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
Operation *op, const BufferizationOptions &options,
BufferizationState &state, const PostAnalysisStepList &extraSteps) {
DominanceInfo domInfo(funcOp);
DominanceInfo domInfo(op);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
if (funcOp.body().empty())
return success();
if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo)))
if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo)))
return failure();
// If the analysis fails, just return.
Operation *op = funcOp.getOperation();
if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
options.analysisFuzzerSeed)))
return failure();
@ -732,7 +735,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
SmallVector<Operation *> newOps;
if (failed(step->run(funcOp, state, aliasInfo, newOps)))
if (failed(step->run(op, state, aliasInfo, newOps)))
return failure();
// Analyze ops that were created by the PostAnalysisStep.
if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
@ -749,16 +752,12 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Annotate operations if we only want to report the analysis.
if (options.testAnalysisOnly) {
annotateOpsWithBufferizationMarkers(funcOp, aliasInfo);
annotateOpsWithBufferizationMarkers(op, aliasInfo);
return success();
}
// Bufferize all ops in funcOp.
OpBuilder b(funcOp.getContext());
auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(funcOp.getOperation());
assert(bufferizableOp && "must use ModuleBufferization");
if (failed(bufferizableOp.bufferize(b, state)))
// Bufferize the op and its nested ops.
if (failed(bufferize(op, state)))
return failure();
// Erase all obsolete ops.

View File

@ -56,6 +56,10 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
// Nothing to do. This op is already bufferized.
if (op.hasBufferSemantics())
return success();
// Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
// basis.
if (!op.hasTensorSemantics())
@ -371,21 +375,21 @@ struct LinalgOpInterfaceHelper<> {
} // namespace
/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced
/// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced
/// with the the result of `rewriteFunc` if it is anchored on a matching
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single InitTensorOp.
LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
InitTensorEliminationStep::eliminateInitTensors(
FuncOp funcOp, BufferizationState &state,
Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
std::function<bool(OpOperand &)> anchorMatchFunc,
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
SmallVector<Operation *> &newOps) {
OpBuilder b(funcOp->getContext());
OpBuilder b(op->getContext());
WalkResult status = funcOp->walk([&](Operation *op) {
WalkResult status = op->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
// Is this a matching OpOperand?
if (!anchorMatchFunc(operand))
@ -443,7 +447,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
return failure(status.wasInterrupted());
}
/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be
/// Try to eliminate InitTensorOps inside `op`. An InitTensorOp can be
/// eliminated if it is eventually inserted into another tensor (and some other
/// conditions are met).
///
@ -473,10 +477,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
/// out-of-place due to RaW conflicts.
LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
InsertSliceAnchoredInitTensorEliminationStep::run(
FuncOp funcOp, BufferizationState &state,
Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
return eliminateInitTensors(
funcOp, state, aliasInfo,
op, state, aliasInfo,
[&](OpOperand &operand) {
auto insertSliceOp =
dyn_cast<tensor::InsertSliceOp>(operand.getOwner());

View File

@ -87,12 +87,13 @@ struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
}
LogicalResult run(FuncOp funcOp, BufferizationState &state,
LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override {
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
// Support only single return-terminated block in the function.
auto funcOp = cast<FuncOp>(op);
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");

View File

@ -264,11 +264,11 @@ struct ForOpInterface
};
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
LogicalResult status = success();
funcOp->walk([&](scf::YieldOp yieldOp) {
op->walk([&](scf::YieldOp yieldOp) {
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp)
return WalkResult::advance();

View File

@ -432,11 +432,11 @@ struct InsertSliceOpInterface
} // namespace mlir
LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext::
InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state,
InplaceInsertSliceOpAnalysis::run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
auto &tensorState = getTensorBufferizationState(state);
funcOp.walk([&](InsertSliceOp insertSliceOp) {
op->walk([&](InsertSliceOp insertSliceOp) {
// A copy of the source buffer is needed if either:
// - The producer of `source` is not inplace. This is the case where a
// slice is computed out of place into the inplace full tensor.