forked from OSchip/llvm-project
[mlir][linalg][bufferize] Generalize InitTensorOp elimination
This allows for external users of Comprehensive Bufferize to specify their own InitTensorOp elimination procedures. Differential Revision: https://reviews.llvm.org/D112686
This commit is contained in:
parent
4ae8c83104
commit
bb83520dce
|
@ -195,6 +195,29 @@ bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
|
|||
|
||||
/// Register external models implemented for the `BufferizableOpInterface`.
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
/// Try to eliminate InitTensorOps inside `funcOp`.
|
||||
///
|
||||
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
|
||||
/// * Only InitTensorOps that are anchored on a matching OpOperand as per
|
||||
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path on
|
||||
/// the reverse SSA use-def chain, starting from the OpOperand and always
|
||||
/// following the aliasing OpOperand, that eventually ends at a single
|
||||
/// InitTensorOp.
|
||||
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
|
||||
/// This analysis can be skipped with `skipAnalysis`.
|
||||
LogicalResult initTensorElimination(
|
||||
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
|
||||
std::function<bool(OpOperand &)> anchorMatchFunc,
|
||||
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
|
||||
bool skipAnalysis = false);
|
||||
|
||||
/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
|
||||
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
|
||||
/// (and some other conditions are met).
|
||||
LogicalResult eliminateInsertSliceAnchoredInitTensorOps(
|
||||
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -2150,6 +2150,78 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced
|
||||
/// with the the result of `rewriteFunc` if it is anchored on a matching
|
||||
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
|
||||
/// chain, starting from the OpOperand and always following the aliasing
|
||||
/// OpOperand, that eventually ends at a single InitTensorOp.
|
||||
LogicalResult mlir::linalg::initTensorElimination(
|
||||
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
|
||||
std::function<bool(OpOperand &)> anchorMatchFunc,
|
||||
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
|
||||
bool skipAnalysis) {
|
||||
OpBuilder b(funcOp->getContext());
|
||||
|
||||
WalkResult status = funcOp->walk([&](Operation *op) {
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
// Is this a matching OpOperand?
|
||||
if (!anchorMatchFunc(operand))
|
||||
continue;
|
||||
|
||||
SetVector<Value> maybeInitTensor =
|
||||
findValueInReverseUseDefChain(operand.get(), [](Value val) {
|
||||
// Continue traversal until this function returns true.
|
||||
OpResult opResult = val.dyn_cast<OpResult>();
|
||||
if (!opResult)
|
||||
return true;
|
||||
if (getInPlace(opResult) != InPlaceSpec::True)
|
||||
return true;
|
||||
// Only equivalent tensors are supported at the moment.
|
||||
// TODO: Support cases such as extract_slice(init_tensor).
|
||||
SmallVector<OpOperand *> opOperands =
|
||||
getAliasingOpOperand(opResult);
|
||||
if (!llvm::all_of(opOperands, [](OpOperand *operand) {
|
||||
return bufferRelation(*operand) == BufferRelation::Equivalent;
|
||||
}))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
// Replace only if the reverse use-def chain ends at exactly one
|
||||
// InitTensorOp.
|
||||
if (maybeInitTensor.size() != 1 ||
|
||||
!maybeInitTensor.front().getDefiningOp<InitTensorOp>())
|
||||
return WalkResult::skip();
|
||||
Value initTensor = maybeInitTensor.front();
|
||||
|
||||
// Create a replacement for the InitTensorOp.
|
||||
b.setInsertionPoint(initTensor.getDefiningOp());
|
||||
Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
|
||||
if (!replacement)
|
||||
continue;
|
||||
|
||||
// Uses of the InitTensorOp are replaced here, but the op is not deleted.
|
||||
// InitTensorOps without uses are ignored by the bufferization.
|
||||
initTensor.replaceAllUsesWith(replacement);
|
||||
aliasInfo.createAliasInfoEntry(replacement);
|
||||
|
||||
// Run analysis on the newly created op.
|
||||
if (auto opResult = replacement.dyn_cast<OpResult>()) {
|
||||
if (!skipAnalysis) {
|
||||
SmallVector<Operation *> ops(1, replacement.getDefiningOp());
|
||||
if (failed(inPlaceAnalysis(ops, aliasInfo, domInfo)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Advance to the next operation.
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
return failure(status.wasInterrupted());
|
||||
}
|
||||
|
||||
/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be
|
||||
/// eliminated if it is eventually inserted into another tensor (and some other
|
||||
/// conditions are met).
|
||||
|
@ -2178,60 +2250,26 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
|
|||
///
|
||||
/// Note that the newly inserted ExtractSliceOp may have to bufferize
|
||||
/// out-of-place due to RaW conflicts.
|
||||
static LogicalResult runInitTensorElimination(FuncOp funcOp,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
DominanceInfo &domInfo) {
|
||||
OpBuilder b(funcOp->getContext());
|
||||
|
||||
WalkResult status = funcOp->walk([&](tensor::InsertSliceOp insertOp) {
|
||||
// Only inplace bufferized InsertSliceOps are eligible.
|
||||
if (getInPlace(insertOp->getOpResult(0)) != InPlaceSpec::True)
|
||||
return WalkResult::skip();
|
||||
|
||||
SetVector<Value> maybeInitTensor =
|
||||
findValueInReverseUseDefChain(insertOp.source(), [](Value val) {
|
||||
// Continue traversal until this function returns true.
|
||||
OpResult opResult = val.dyn_cast<OpResult>();
|
||||
if (!opResult)
|
||||
return true;
|
||||
if (getInPlace(opResult) != InPlaceSpec::True)
|
||||
return true;
|
||||
// Only equivalent tensors are supported at the moment. E.g., when
|
||||
// taking a tensor.extract_slice of an init_tensor, we can currently
|
||||
// not eliminate the init_tensor.
|
||||
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
|
||||
if (!llvm::all_of(opOperands, [](OpOperand *operand) {
|
||||
return bufferRelation(*operand) == BufferRelation::Equivalent;
|
||||
}))
|
||||
return true;
|
||||
LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps(
|
||||
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo) {
|
||||
return initTensorElimination(
|
||||
funcOp, aliasInfo, domInfo,
|
||||
[](OpOperand &operand) {
|
||||
auto insertSliceOp = dyn_cast<InsertSliceOp>(operand.getOwner());
|
||||
if (!insertSliceOp)
|
||||
return false;
|
||||
});
|
||||
// Replace only if the InsertSliceOp source originates from exactly one
|
||||
// InitTensorOp.
|
||||
if (maybeInitTensor.size() != 1 ||
|
||||
!maybeInitTensor.front().getDefiningOp<InitTensorOp>())
|
||||
return WalkResult::skip();
|
||||
Value initTensor = maybeInitTensor.front();
|
||||
|
||||
b.setInsertionPoint(initTensor.getDefiningOp());
|
||||
auto extractOp = b.create<tensor::ExtractSliceOp>(
|
||||
initTensor.getLoc(), insertOp.dest(), insertOp.getMixedOffsets(),
|
||||
insertOp.getMixedSizes(), insertOp.getMixedStrides());
|
||||
// Uses of the InitTensorOp are replaced here, but the op is not deleted.
|
||||
// InitTensorOps without uses are ignored by the bufferization.
|
||||
initTensor.replaceAllUsesWith(extractOp.result());
|
||||
aliasInfo.createAliasInfoEntry(extractOp.result());
|
||||
|
||||
// Run analysis on the ExtractSliceOp.
|
||||
if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(
|
||||
extractOp->getOpOperand(0), aliasInfo, domInfo)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
// Advance to the next operation.
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
return failure(status.wasInterrupted());
|
||||
// Only inplace bufferized InsertSliceOps are eligible.
|
||||
if (getInPlace(insertSliceOp->getOpResult(0)) != InPlaceSpec::True)
|
||||
return false;
|
||||
return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
|
||||
},
|
||||
[](OpBuilder &b, Location loc, OpOperand &operand) {
|
||||
auto insertSliceOp = cast<InsertSliceOp>(operand.getOwner());
|
||||
auto extractOp = b.create<tensor::ExtractSliceOp>(
|
||||
loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
|
||||
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||
return extractOp.result();
|
||||
});
|
||||
}
|
||||
|
||||
void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
||||
|
@ -2291,7 +2329,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
|||
|
||||
// Try to eliminate InitTensorOps to avoid new allocations during the
|
||||
// bufferization phase.
|
||||
if (failed(runInitTensorElimination(funcOp, aliasInfo, domInfo))) {
|
||||
if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo,
|
||||
domInfo))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue