[mlir][linalg][bufferize] Generalize InitTensorOp elimination

This allows for external users of Comprehensive Bufferize to specify their own InitTensorOp elimination procedures.

Differential Revision: https://reviews.llvm.org/D112686
This commit is contained in:
Matthias Springer 2021-11-04 13:51:30 +09:00
parent 4ae8c83104
commit bb83520dce
2 changed files with 116 additions and 54 deletions

View File

@ -195,6 +195,29 @@ bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
/// Register external models implemented for the `BufferizableOpInterface`.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
/// Try to eliminate InitTensorOps inside `funcOp`.
///
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
/// * Only InitTensorOps that are anchored on a matching OpOperand as per
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path on
/// the reverse SSA use-def chain, starting from the OpOperand and always
/// following the aliasing OpOperand, that eventually ends at a single
/// InitTensorOp.
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
/// This analysis can be skipped with `skipAnalysis`.
LogicalResult initTensorElimination(
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
std::function<bool(OpOperand &)> anchorMatchFunc,
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
bool skipAnalysis = false);
/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
LogicalResult eliminateInsertSliceAnchoredInitTensorOps(
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo);
} // namespace linalg
} // namespace mlir

View File

@ -2150,6 +2150,78 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
}
}
/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced
/// with the the result of `rewriteFunc` if it is anchored on a matching
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single InitTensorOp.
LogicalResult mlir::linalg::initTensorElimination(
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
std::function<bool(OpOperand &)> anchorMatchFunc,
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
bool skipAnalysis) {
OpBuilder b(funcOp->getContext());
WalkResult status = funcOp->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
// Is this a matching OpOperand?
if (!anchorMatchFunc(operand))
continue;
SetVector<Value> maybeInitTensor =
findValueInReverseUseDefChain(operand.get(), [](Value val) {
// Continue traversal until this function returns true.
OpResult opResult = val.dyn_cast<OpResult>();
if (!opResult)
return true;
if (getInPlace(opResult) != InPlaceSpec::True)
return true;
// Only equivalent tensors are supported at the moment.
// TODO: Support cases such as extract_slice(init_tensor).
SmallVector<OpOperand *> opOperands =
getAliasingOpOperand(opResult);
if (!llvm::all_of(opOperands, [](OpOperand *operand) {
return bufferRelation(*operand) == BufferRelation::Equivalent;
}))
return true;
return false;
});
// Replace only if the reverse use-def chain ends at exactly one
// InitTensorOp.
if (maybeInitTensor.size() != 1 ||
!maybeInitTensor.front().getDefiningOp<InitTensorOp>())
return WalkResult::skip();
Value initTensor = maybeInitTensor.front();
// Create a replacement for the InitTensorOp.
b.setInsertionPoint(initTensor.getDefiningOp());
Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
if (!replacement)
continue;
// Uses of the InitTensorOp are replaced here, but the op is not deleted.
// InitTensorOps without uses are ignored by the bufferization.
initTensor.replaceAllUsesWith(replacement);
aliasInfo.createAliasInfoEntry(replacement);
// Run analysis on the newly created op.
if (auto opResult = replacement.dyn_cast<OpResult>()) {
if (!skipAnalysis) {
SmallVector<Operation *> ops(1, replacement.getDefiningOp());
if (failed(inPlaceAnalysis(ops, aliasInfo, domInfo)))
return WalkResult::interrupt();
}
}
}
// Advance to the next operation.
return WalkResult::advance();
});
return failure(status.wasInterrupted());
}
/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be
/// eliminated if it is eventually inserted into another tensor (and some other
/// conditions are met).
@ -2178,60 +2250,26 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
///
/// Note that the newly inserted ExtractSliceOp may have to bufferize
/// out-of-place due to RaW conflicts.
static LogicalResult runInitTensorElimination(FuncOp funcOp,
BufferizationAliasInfo &aliasInfo,
DominanceInfo &domInfo) {
OpBuilder b(funcOp->getContext());
WalkResult status = funcOp->walk([&](tensor::InsertSliceOp insertOp) {
// Only inplace bufferized InsertSliceOps are eligible.
if (getInPlace(insertOp->getOpResult(0)) != InPlaceSpec::True)
return WalkResult::skip();
SetVector<Value> maybeInitTensor =
findValueInReverseUseDefChain(insertOp.source(), [](Value val) {
// Continue traversal until this function returns true.
OpResult opResult = val.dyn_cast<OpResult>();
if (!opResult)
return true;
if (getInPlace(opResult) != InPlaceSpec::True)
return true;
// Only equivalent tensors are supported at the moment. E.g., when
// taking a tensor.extract_slice of an init_tensor, we can currently
// not eliminate the init_tensor.
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
if (!llvm::all_of(opOperands, [](OpOperand *operand) {
return bufferRelation(*operand) == BufferRelation::Equivalent;
}))
return true;
LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps(
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo) {
return initTensorElimination(
funcOp, aliasInfo, domInfo,
[](OpOperand &operand) {
auto insertSliceOp = dyn_cast<InsertSliceOp>(operand.getOwner());
if (!insertSliceOp)
return false;
});
// Replace only if the InsertSliceOp source originates from exactly one
// InitTensorOp.
if (maybeInitTensor.size() != 1 ||
!maybeInitTensor.front().getDefiningOp<InitTensorOp>())
return WalkResult::skip();
Value initTensor = maybeInitTensor.front();
b.setInsertionPoint(initTensor.getDefiningOp());
auto extractOp = b.create<tensor::ExtractSliceOp>(
initTensor.getLoc(), insertOp.dest(), insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
// Uses of the InitTensorOp are replaced here, but the op is not deleted.
// InitTensorOps without uses are ignored by the bufferization.
initTensor.replaceAllUsesWith(extractOp.result());
aliasInfo.createAliasInfoEntry(extractOp.result());
// Run analysis on the ExtractSliceOp.
if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(
extractOp->getOpOperand(0), aliasInfo, domInfo)))
return WalkResult::interrupt();
// Advance to the next operation.
return WalkResult::advance();
});
return failure(status.wasInterrupted());
// Only inplace bufferized InsertSliceOps are eligible.
if (getInPlace(insertSliceOp->getOpResult(0)) != InPlaceSpec::True)
return false;
return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
},
[](OpBuilder &b, Location loc, OpOperand &operand) {
auto insertSliceOp = cast<InsertSliceOp>(operand.getOwner());
auto extractOp = b.create<tensor::ExtractSliceOp>(
loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
return extractOp.result();
});
}
void LinalgComprehensiveModuleBufferize::runOnOperation() {
@ -2291,7 +2329,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// Try to eliminate InitTensorOps to avoid new allocations during the
// bufferization phase.
if (failed(runInitTensorElimination(funcOp, aliasInfo, domInfo))) {
if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo,
domInfo))) {
signalPassFailure();
return;
}