[mlir][linalg][bufferize] Remove buffer equivalence from bufferize

Remove all function calls related to buffer equivalence from bufferize implementations.

Add a new PostAnalysisStep for scf.for that ensures that yielded values are equivalent to the corresponding BBArgs. (This was previously checked in `bufferize`.) This will be relaxed in a subsequent commit.

Note: This commit changes two test cases. These were broken by design
and should not have passed. With the new scf.for PostAnalysisStep, this
bug was fixed.

Differential Revision: https://reviews.llvm.org/D114927
This commit is contained in:
Matthias Springer 2021-12-06 17:40:08 +09:00
parent a96d828510
commit e9fb4dc9e9
9 changed files with 49 additions and 60 deletions

View File

@ -19,6 +19,13 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace scf_ext {
/// Equivalence analysis for scf.for. Raise an error if iter_args are not
/// equivalent to their corresponding loop yield values.
struct AssertDestinationPassingStyle : public PostAnalysisStep {
LogicalResult run(FuncOp funcOp, BufferizationState &state,
SmallVector<Operation *> &newOps) override;
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace scf_ext

View File

@ -37,7 +37,6 @@ struct ConstantOpInterface
auto globalMemref = globalCreator.getGlobalFor(constantOp);
Value memref = b.create<memref::GetGlobalOp>(
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
state.mapBuffer(constantOp, memref);
return success();

View File

@ -141,22 +141,7 @@ void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
/// Return `true` if a value was marked as in-place bufferized.
bool BufferizationAliasInfo::isInPlace(OpResult opResult) const {
bool inplace = inplaceBufferized.contains(opResult);
#ifndef NDEBUG
if (inplace) {
auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opResult.getDefiningOp());
assert(bufferizableOp &&
"expected that in-place bufferized op is bufferizable");
SmallVector<OpOperand *> operands =
bufferizableOp.getAliasingOpOperand(opResult);
for (OpOperand *operand : operands)
assert(areAliasingBufferizedValues(operand->get(), opResult) &&
"expected that in-place bufferized OpResult aliases with "
"aliasing OpOperand");
}
#endif // NDEBUG
return inplace;
return inplaceBufferized.contains(opResult);
}
/// Set the inPlace bufferization spec to true.
@ -593,7 +578,6 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
// 2. Create memory deallocation.

View File

@ -253,8 +253,6 @@ struct TiledLoopOpInterface
return failure();
// Insert mapping and aliasing info.
state.aliasInfo.createAliasInfoEntry(resultBuffer);
state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
state.mapBuffer(opResult, resultBuffer);
// Insert new operand and bbArg.
@ -263,9 +261,6 @@ struct TiledLoopOpInterface
body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
// Insert mapping and aliasing info.
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
newBufferBBArg);
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Set operand of `linalg.yield` to the bbArg so it just canonicalizes
@ -303,9 +298,6 @@ struct TiledLoopOpInterface
BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
// Insert mapping and aliasing info.
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
newBufferBBArg);
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Increment indices.

View File

@ -223,7 +223,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
BufferizationState &state) {
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// If nothing to do then we are done.
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
@ -321,15 +320,12 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
auto castOp = b.create<memref::CastOp>(
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
toMemrefOp.memref().replaceAllUsesWith(castOp);
aliasInfo.insertNewBufferEquivalence(castOp.dest(),
toMemrefOp.memref());
}
}
// Replace all remaining uses by a to_tensor.
if (!bbArg.use_empty()) {
auto toTensorOp =
b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
bbArg.replaceAllUsesWith(toTensorOp);
}
frontBlock.eraseArgument(0);
@ -562,7 +558,6 @@ struct CallOpInterface
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural
// info.
state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
state.mapBuffer(oldRes, buffer);
// Add a ToTensorOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
@ -572,7 +567,6 @@ struct CallOpInterface
b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(toTensorOp);
// Add new op equivalence info.
state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer);
state.mapBuffer(toTensorOp, buffer);
continue;
}
@ -615,7 +609,6 @@ struct CallOpInterface
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
// Add new op equivalence info.
state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
@ -663,7 +656,6 @@ struct ReturnOpInterface
Value returnTensor = b.create<bufferization::ToTensorOp>(
returnOp.getLoc(), v);
operand.set(returnTensor);
state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
state.mapBuffer(returnTensor, v);
}
return success();
@ -690,7 +682,6 @@ struct FuncOpInterface
: getContiguousOrUnrankedMemRefType(tensorType);
Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
memRefType, bbArg);
state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
state.mapBuffer(bbArg, bufferCast);
}

View File

@ -147,7 +147,6 @@ struct IfOpInterface
if (!resultBuffer)
return failure();
state.aliasInfo.createAliasInfoEntry(resultBuffer);
state.mapBuffer(opResult, resultBuffer);
}
@ -237,8 +236,6 @@ struct ForOpInterface
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
state.aliasInfo.createAliasInfoEntry(resultBuffer);
state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
state.mapBuffer(bbArg, resultBuffer);
state.mapBuffer(opResult, resultBuffer);
}
@ -257,15 +254,6 @@ struct ForOpInterface
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
return yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to an equivalent buffer to the matching"
<< " enclosing scf::for operand";
}
// Buffers are equivalent so the work is already done and we just yield
// the bbArg so that it later canonicalizes away.
@ -275,6 +263,41 @@ struct ForOpInterface
}
};
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
SmallVector<Operation *> &newOps) {
LogicalResult status = success();
funcOp->walk([&](scf::YieldOp yieldOp) {
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp)
return WalkResult::advance();
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
status =
yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to an equivalent buffer to the matching"
<< " enclosing scf::for operand";
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
return status;
}
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
scf::YieldOp> {

View File

@ -80,7 +80,6 @@ struct CastOpInterface
castOp.getResult().getType(), layout, memorySpace);
Value res =
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
state.mapBuffer(castOp.getResult(), res);
return success();
}
@ -233,7 +232,6 @@ struct InsertOpInterface
b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
insertOp.indices());
state.mapBuffer(insertOp, destMemref);
state.aliasInfo.insertNewBufferAlias(insertOp, destMemref);
return success();
}
@ -421,8 +419,6 @@ struct InsertSliceOpInterface
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Insert new alias.
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
// Copy tensor.
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),

View File

@ -96,6 +96,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// TODO: Find a way to enable this step automatically when bufferizing tensor
// dialect ops.
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);

View File

@ -1113,7 +1113,7 @@ func @reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
// Read from %t1 via alias %e.
%v2 = vector.transfer_read %e[%s], %cst : tensor<?xf32>, vector<5xf32>
scf.yield %e, %v2 : tensor<?xf32>, vector<5xf32>
scf.yield %t2, %v2 : tensor<?xf32>, vector<5xf32>
}
// CHECK: __inplace_results_attr__ = ["true", "false"]
@ -1154,14 +1154,10 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
// This loop does not read from %t1. It only writes to it.
// CHECK: scf.for
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
// CHECK: tensor.extract_slice
// CHECK-SAME: __inplace_results_attr__ = ["true"]
%e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
// Write to %t1 via alias. (Overwrite %t3.)
// Write to %t1 via %t2. (Overwrite %t3.)
// CHECK: linalg.generic
// CHECK-SAME: __inplace_results_attr__ = ["true"]
%o2 = linalg.generic #trait outs (%e : tensor<?xf32>) {
%o2 = linalg.generic #trait outs (%t2 : tensor<?xf32>) {
^bb(%0: f32) :
linalg.yield %cst : f32
} -> (tensor<?xf32>)
@ -1172,8 +1168,8 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
}
// Use %t3 in some way without reading it, so that it does not get DCE'd.
// CHECK: linalg.generic
// CHECK-SAME: __inplace_results_attr__ = ["true"]
// CHECK: linalg.generic
// CHECK-SAME: __inplace_results_attr__ = ["true"]
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
^bb(%0: f32) :
linalg.yield %cst : f32