[mlir][tensor][bufferize] Fix deallocation of GenerateOp/FromElementsOp

Both ops allocate a buffer. There were cases in which the buffer was not deallocated.

Differential Revision: https://reviews.llvm.org/D130469
This commit is contained in:
Matthias Springer 2022-07-25 12:24:24 +02:00
parent 333ee218ce
commit 664ffa46bb
5 changed files with 76 additions and 26 deletions

View File

@ -511,6 +511,12 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}
/// Return `true` if the buffer of given OpResult should be deallocated. This
/// function should be called during `BufferizableOpInterface::bufferize`
/// implementations that allocate a new buffer for the given OpResult.
bool shouldDeallocateOpResult(OpResult opResult,
const BufferizationOptions &options);
/// Return a MemRefType to which the type of the given value can be bufferized.
///
/// If possible, op bufferization implementations should not use this function

View File

@ -206,6 +206,29 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
return success();
}
bool bufferization::shouldDeallocateOpResult(
OpResult opResult, const BufferizationOptions &options) {
Operation *op = opResult.getOwner();
assert(options.dynCastBufferizableOp(op).bufferizesToAllocation(opResult) &&
"expected that op allocates");
AnalysisState analysisState(options);
if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
// AllocTensorOp has one result.
ArrayAttr escapeAttr =
op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
return !escapeAttr[0].cast<BoolAttr>().getValue();
}
// No "escape" annotation found.
if (options.createDeallocs) {
// Perform an ad-hoc analysis.
return !analysisState.isTensorYielded(opResult);
}
return false;
}
//===----------------------------------------------------------------------===//
// OpFilter
//===----------------------------------------------------------------------===//

View File

@ -204,22 +204,8 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
}
// Should the buffer be deallocated?
AnalysisState analysisState(options);
bool dealloc;
if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
// AllocTensorOp has one result.
ArrayAttr escapeAttr =
op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
dealloc = !escapeAttr[0].cast<BoolAttr>().getValue();
} else {
// No "escape" annotation found.
if (options.createDeallocs) {
// Perform an ad-hoc analysis.
dealloc = !analysisState.isTensorYielded(getResult());
} else {
dealloc = false;
}
}
bool dealloc =
shouldDeallocateOpResult(getResult().cast<OpResult>(), options);
// Replace op.
replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);

View File

@ -363,9 +363,17 @@ static void createStores(RewriterBase &rewriter, Location loc, int dim,
struct FromElementsOpInterface
: public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
tensor::FromElementsOp> {
bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
// Should the buffer be deallocated?
bool dealloc = shouldDeallocateOpResult(
fromElementsOp.getResult().cast<OpResult>(), options);
// TODO: Implement memory space for this op.
if (options.defaultMemorySpace != static_cast<unsigned>(0))
@ -376,11 +384,10 @@ struct FromElementsOpInterface
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, fromElementsOp.getResult(),
analysisState.isTensorYielded(fromElementsOp.getResult()), options,
/*copy=*/false);
FailureOr<Value> tensorAlloc =
allocateTensorForShapedValue(rewriter, loc, fromElementsOp.getResult(),
/*escape=*/!dealloc, options,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
auto memrefType =
@ -416,6 +423,7 @@ struct FromElementsOpInterface
indices);
replaceOpWithBufferizedValues(rewriter, op, buffer);
return success();
}
};
@ -424,9 +432,17 @@ struct FromElementsOpInterface
struct GenerateOpInterface
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
tensor::GenerateOp> {
bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto generateOp = cast<tensor::GenerateOp>(op);
// Should the buffer be deallocated?
bool dealloc = shouldDeallocateOpResult(
generateOp.getResult().cast<OpResult>(), options);
// TODO: Implement memory space for this op.
if (options.defaultMemorySpace != static_cast<unsigned>(0))
@ -436,11 +452,10 @@ struct GenerateOpInterface
// Allocate memory.
Location loc = op->getLoc();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, generateOp.getResult(),
analysisState.isTensorYielded(generateOp.getResult()), options,
/*copy=*/false);
FailureOr<Value> tensorAlloc =
allocateTensorForShapedValue(rewriter, loc, generateOp.getResult(),
/*escape=*/!dealloc, options,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
auto memrefType =
@ -484,6 +499,7 @@ struct GenerateOpInterface
parallelBody->getArguments());
replaceOpWithBufferizedValues(rewriter, op, buffer);
return success();
}
};

View File

@ -217,3 +217,22 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
// CHECK: }
return
}
// -----
// CHECK-LABEL: func @dealloc_generate_buffer
func.func @dealloc_generate_buffer(%arg: tensor<*xf32>, %sz: index, %idx: index)
-> index
{
// CHECK: memref.alloc
// CHECK: scf.parallel
// CHECK: memref.load
// CHECK: memref.dealloc
%0 = tensor.generate %sz {
^bb0(%i : index):
%elem = tensor.dim %arg, %i : tensor<*xf32>
tensor.yield %elem : index
} : tensor<?xindex>
%r = tensor.extract %0[%idx] : tensor<?xindex>
return %r : index
}