forked from OSchip/llvm-project
[mlir][bufferize] Improve resolveConflicts for ExtractSliceOp
It is sometimes better to make a copy of the OpResult instead of making a copy of the OpOperand. E.g., when bufferizing tensor.extract_slice. This implementation will eventually make parts of extract_slice's `bufferize` implementation obsolete (and simplify it). It will only need to handle in-place OpOperands. Differential Revision: https://reviews.llvm.org/D126819
This commit is contained in:
parent
72a049d778
commit
87b46776c4
|
@ -44,7 +44,12 @@ constexpr const ::llvm::StringLiteral
|
|||
|
||||
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
RewriterBase &rewriter, const AnalysisState &state) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
Operation *op = getOperation();
|
||||
SmallVector<OpOperand *> outOfPlaceOpOperands;
|
||||
SmallVector<OpResult> outOfPlaceOpResults;
|
||||
|
||||
// Find all out-of-place OpOperands.
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
Type operandType = opOperand.get().getType();
|
||||
if (!operandType.isa<TensorType>())
|
||||
|
@ -53,17 +58,52 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
|||
continue;
|
||||
if (operandType.isa<UnrankedTensorType>())
|
||||
return op->emitError("copies of unranked tensors are not supported");
|
||||
auto tensorType = operandType.dyn_cast<RankedTensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
|
||||
SmallVector<OpResult> aliasingOpResults =
|
||||
state.getAliasingOpResult(opOperand);
|
||||
if (aliasingOpResults.size() == 1 &&
|
||||
!state.bufferizesToMemoryWrite(opOperand) &&
|
||||
state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
|
||||
// The op itself does not write but may create exactly one alias. Instead
|
||||
// of copying the OpOperand, copy the OpResult. The OpResult can sometimes
|
||||
// be smaller than the OpOperand (e.g., in the case of an extract_slice,
|
||||
// where the result is usually a smaller part of the source).
|
||||
outOfPlaceOpResults.push_back(aliasingOpResults.front());
|
||||
} else {
|
||||
// In all other cases, make a copy of the OpOperand.
|
||||
outOfPlaceOpOperands.push_back(&opOperand);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert copies of OpOperands.
|
||||
rewriter.setInsertionPoint(op);
|
||||
for (OpOperand *opOperand : outOfPlaceOpOperands) {
|
||||
auto tensorType = opOperand->get().getType().cast<RankedTensorType>();
|
||||
SmallVector<OpResult> aliasingOpResults =
|
||||
state.getAliasingOpResult(*opOperand);
|
||||
bool escape = llvm::any_of(
|
||||
aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
|
||||
Value copy = rewriter.create<AllocTensorOp>(
|
||||
op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
|
||||
rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); });
|
||||
op->getLoc(), tensorType, ValueRange(), opOperand->get(), escape);
|
||||
rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
|
||||
}
|
||||
|
||||
// Insert copies of OpResults.
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
for (OpResult opResult : outOfPlaceOpResults) {
|
||||
auto tensorType = opResult.getType().cast<RankedTensorType>();
|
||||
bool escape = state.isTensorYielded(opResult);
|
||||
Value copy = rewriter.create<AllocTensorOp>(op->getLoc(), tensorType,
|
||||
ValueRange(), opResult, escape);
|
||||
SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
|
||||
opResult.getUses(), [](OpOperand &use) { return &use; }));
|
||||
for (OpOperand *use : uses) {
|
||||
// Do not update the alloc_tensor op that we just created.
|
||||
if (use->getOwner() != copy.getDefiningOp())
|
||||
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
// RUN: mlir-opt %s -tensor-copy-insertion -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC
|
||||
|
||||
// CHECK-LABEL: func @extract_slice(
|
||||
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
|
||||
// CHECK-FUNC-LABEL: func @extract_slice(
|
||||
func.func @extract_slice(%t: tensor<?xf32>, %idx: index, %f: f32)
|
||||
-> (tensor<5xf32>, tensor<?xf32>)
|
||||
{
|
||||
// CHECK: %[[extract_slice:.*]] = tensor.extract_slice %[[t]][10] [5] [1]
|
||||
%0 = tensor.extract_slice %t[10][5][1] : tensor<?xf32> to tensor<5xf32>
|
||||
// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[extract_slice]]) {escape = false} : tensor<5xf32>
|
||||
// CHECK-FUNC: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<5xf32>
|
||||
// CHECK: %[[insert:.*]] = tensor.insert %{{.*}} into %[[alloc]]
|
||||
%1 = tensor.insert %f into %0[%idx] : tensor<5xf32>
|
||||
// CHECK: return %[[insert]], %[[t]]
|
||||
return %1, %t : tensor<5xf32>, tensor<?xf32>
|
||||
}
|
Loading…
Reference in New Issue