[mlir][SCF] Fix scf.while bufferization

Before this fix, the bufferization implementation made the incorrect assumption that the values yielded from the "before" region must match with the values yielded from the "after" region.

Differential Revision: https://reviews.llvm.org/D125835
This commit is contained in:
Matthias Springer 2022-05-17 22:58:54 +02:00
parent 79ca4ed3e7
commit 996834e681
5 changed files with 98 additions and 36 deletions

View File

@ -543,11 +543,11 @@ struct BufferizationState {
Optional<ForceInPlacability> overrideInPlace = None,
Optional<Operation *> customCopyInsertionPoint = None);
/// Return the buffer type for a given OpOperand (tensor) after bufferization.
/// Return the buffer type for a given Value (tensor) after bufferization.
///
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
/// This function should only be used if `getBuffer` cannot be used.
BaseMemRefType getBufferType(OpOperand &opOperand) const;
BaseMemRefType getBufferType(Value value) const;
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const {

View File

@ -333,13 +333,12 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
return resultBuffer;
}
/// Return the buffer type for a given OpOperand (tensor) after bufferization.
BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const {
Value tensor = opOperand.get();
auto tensorType = tensor.getType().dyn_cast<TensorType>();
/// Return the buffer type for a given Value (tensor) after bufferization.
BaseMemRefType BufferizationState::getBufferType(Value value) const {
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.memref().getType().cast<BaseMemRefType>();
return getMemRefType(tensorType, getOptions());

View File

@ -276,14 +276,14 @@ static DenseSet<int64_t> getTensorIndices(ValueRange values) {
DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
ValueRange yieldedValues,
const AnalysisState &state) {
unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
DenseSet<int64_t> result;
int64_t counter = 0;
for (const auto &it : llvm::zip(bbArgs, yieldedValues)) {
if (!std::get<0>(it).getType().isa<TensorType>())
for (unsigned int i = 0; i < minSize; ++i) {
if (!bbArgs[i].getType().isa<TensorType>() ||
!yieldedValues[i].getType().isa<TensorType>())
continue;
if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it)))
result.insert(counter);
counter++;
if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
result.insert(i);
}
return result;
}
@ -486,8 +486,6 @@ struct ForOpInterface
// The new memref init_args of the loop.
SmallVector<Value> initArgs =
getBuffers(rewriter, forOp.getIterOpOperands(), state);
if (initArgs.size() != indices.size())
return failure();
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = rewriter.create<scf::ForOp>(
@ -578,7 +576,16 @@ struct WhileOpInterface
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto whileOp = cast<scf::WhileOp>(op);
return {whileOp->getResult(opOperand.getOperandNumber())};
unsigned int idx = opOperand.getOperandNumber();
// The OpResults and OpOperands may not match. They may not even have the
// same type. The number of OpResults and OpOperands can also differ.
if (idx >= op->getNumResults() ||
opOperand.get().getType() != op->getResult(idx).getType())
return {};
// The only aliasing OpResult may be the one at the same index.
return {whileOp->getResult(idx)};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
@ -589,6 +596,13 @@ struct WhileOpInterface
unsigned int resultNumber = opResult.getResultNumber();
auto whileOp = cast<scf::WhileOp>(op);
// The "before" region bbArgs and the OpResults may not match.
if (resultNumber >= whileOp.getBeforeArguments().size())
return BufferRelation::None;
if (opResult.getType() !=
whileOp.getBeforeArguments()[resultNumber].getType())
return BufferRelation::None;
auto conditionOp = whileOp.getConditionOp();
BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
Value conditionOperand = conditionOp.getArgs()[resultNumber];
@ -627,9 +641,12 @@ struct WhileOpInterface
"regions with multiple blocks not supported");
Block *afterBody = &whileOp.getAfter().front();
// Indices of all iter_args that have tensor type. These are the ones that
// are bufferized.
DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits());
// Indices of all bbArgs that have tensor type. These are the ones that
// are bufferized. The "before" and "after" regions may have different args.
DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
DenseSet<int64_t> indicesAfter =
getTensorIndices(whileOp.getAfterArguments());
// For every yielded value, is the value equivalent to its corresponding
// bbArg?
DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
@ -642,51 +659,64 @@ struct WhileOpInterface
// The new memref init_args of the loop.
SmallVector<Value> initArgs =
getBuffers(rewriter, whileOp->getOpOperands(), state);
if (initArgs.size() != indices.size())
return failure();
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
return state.getBufferType(bbArg).cast<Type>();
}));
// Construct a new scf.while op with memref instead of tensor values.
ValueRange argsRange(initArgs);
TypeRange argsTypes(argsRange);
auto newWhileOp =
rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs);
ValueRange argsRangeBefore(initArgs);
TypeRange argsTypesBefore(argsRangeBefore);
auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(),
argsTypesAfter, initArgs);
// Add before/after regions to the new op.
SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc());
SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc());
SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
whileOp.getLoc());
Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs);
newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs);
newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
// Set up new iter_args and move the loop condition block to the new op.
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
// in ToTensorOps.
rewriter.setInsertionPointToStart(newBeforeBody);
SmallVector<Value> newBeforeArgs = getBbArgReplacements(
rewriter, newWhileOp.getBeforeArguments(), indices);
rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
// Update scf.condition of new loop.
auto newConditionOp = newWhileOp.getConditionOp();
rewriter.setInsertionPoint(newConditionOp);
// Only equivalent buffers or new buffer allocations may be yielded to the
// "after" region.
// TODO: This could be relaxed for better bufferization results.
SmallVector<Value> newConditionArgs =
getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices,
equivalentYieldsBefore, state);
getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
indicesAfter, equivalentYieldsBefore, state);
newConditionOp.getArgsMutable().assign(newConditionArgs);
// Set up new iter_args and move the loop body block to the new op.
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
// in ToTensorOps.
rewriter.setInsertionPointToStart(newAfterBody);
SmallVector<Value> newAfterArgs =
getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices);
SmallVector<Value> newAfterArgs = getBbArgReplacements(
rewriter, newWhileOp.getAfterArguments(), indicesAfter);
rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
// Update scf.yield of the new loop.
auto newYieldOp = newWhileOp.getYieldOp();
rewriter.setInsertionPoint(newYieldOp);
// Only equivalent buffers or new buffer allocations may be yielded to the
// "before" region.
// TODO: This could be relaxed for better bufferization results.
SmallVector<Value> newYieldValues =
getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices,
equivalentYieldsAfter, state);
getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
indicesBefore, equivalentYieldsAfter, state);
newYieldOp.getResultsMutable().assign(newYieldValues);
// Replace loop results.

View File

@ -111,7 +111,7 @@ struct CollapseShapeOpInterface
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
auto bufferType = state.getBufferType(srcOperand.get()).cast<MemRefType>();
if (tensorResultType.getRank() == 0) {
// 0-d collapses must go through a different op builder.

View File

@ -449,3 +449,36 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
// CHECK: return %[[loop]]#0, %[[loop]]#1
return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
}
// -----
// CHECK-LABEL: func @scf_while_iter_arg_result_mismatch(
// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}>
// CHECK: %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: scf.while (%[[arg3:.*]] = %[[arg1]]) : (memref<5xi1, #{{.*}}) -> () {
// CHECK: %[[load:.*]] = memref.load %[[arg0]]
// CHECK: scf.condition(%[[load]])
// CHECK: } do {
// CHECK: memref.copy %[[arg0]], %[[alloc2]]
// CHECK: memref.store %{{.*}}, %[[alloc2]]
// CHECK: memref.copy %[[alloc2]], %[[alloc1]]
// CHECK: %[[casted:.*]] = memref.cast %[[alloc1]] : memref<5xi1> to memref<5xi1, #{{.*}}>
// CHECK: scf.yield %[[casted]]
// CHECK: }
// CHECK-DAG: memref.dealloc %[[alloc1]]
// CHECK-DAG: memref.dealloc %[[alloc2]]
func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
%arg1: tensor<5xi1>,
%arg2: index) {
scf.while (%arg3 = %arg1) : (tensor<5xi1>) -> () {
%0 = tensor.extract %arg0[%arg2] : tensor<5xi1>
scf.condition(%0)
} do {
%0 = "dummy.some_op"() : () -> index
%1 = "dummy.another_op"() : () -> i1
%2 = tensor.insert %1 into %arg0[%0] : tensor<5xi1>
scf.yield %2 : tensor<5xi1>
}
return
}