forked from OSchip/llvm-project
[mlir][linalg][bufferize] Remove buffer equivalence from bufferize
Remove all function calls related to buffer equivalence from bufferize implementations. Add a new PostAnalysisStep for scf.for that ensures that yielded values are equivalent to the corresponding BBArgs. (This was previously checked in `bufferize`.) This will be relaxed in a subsequent commit. Note: This commit changes two test cases. These were broken by design and should not have passed. With the new scf.for PostAnalysisStep, this bug was fixed. Differential Revision: https://reviews.llvm.org/D114927
This commit is contained in:
parent
a96d828510
commit
e9fb4dc9e9
|
@ -19,6 +19,13 @@ namespace linalg {
|
|||
namespace comprehensive_bufferize {
|
||||
namespace scf_ext {
|
||||
|
||||
/// Equivalence analysis for scf.for. Raise an error if iter_args are not
|
||||
/// equivalent to their corresponding loop yield values.
|
||||
struct AssertDestinationPassingStyle : public PostAnalysisStep {
|
||||
LogicalResult run(FuncOp funcOp, BufferizationState &state,
|
||||
SmallVector<Operation *> &newOps) override;
|
||||
};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace scf_ext
|
||||
|
|
|
@ -37,7 +37,6 @@ struct ConstantOpInterface
|
|||
auto globalMemref = globalCreator.getGlobalFor(constantOp);
|
||||
Value memref = b.create<memref::GetGlobalOp>(
|
||||
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
|
||||
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
|
||||
state.mapBuffer(constantOp, memref);
|
||||
|
||||
return success();
|
||||
|
|
|
@ -141,22 +141,7 @@ void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
|
|||
|
||||
/// Return `true` if a value was marked as in-place bufferized.
|
||||
bool BufferizationAliasInfo::isInPlace(OpResult opResult) const {
|
||||
bool inplace = inplaceBufferized.contains(opResult);
|
||||
#ifndef NDEBUG
|
||||
if (inplace) {
|
||||
auto bufferizableOp =
|
||||
dyn_cast<BufferizableOpInterface>(opResult.getDefiningOp());
|
||||
assert(bufferizableOp &&
|
||||
"expected that in-place bufferized op is bufferizable");
|
||||
SmallVector<OpOperand *> operands =
|
||||
bufferizableOp.getAliasingOpOperand(opResult);
|
||||
for (OpOperand *operand : operands)
|
||||
assert(areAliasingBufferizedValues(operand->get(), opResult) &&
|
||||
"expected that in-place bufferized OpResult aliases with "
|
||||
"aliasing OpOperand");
|
||||
}
|
||||
#endif // NDEBUG
|
||||
return inplace;
|
||||
return inplaceBufferized.contains(opResult);
|
||||
}
|
||||
|
||||
/// Set the inPlace bufferization spec to true.
|
||||
|
@ -593,7 +578,6 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
Value casted = allocated.getValue();
|
||||
if (memRefType && memRefType != allocMemRefType) {
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
|
||||
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
|
||||
}
|
||||
|
||||
// 2. Create memory deallocation.
|
||||
|
|
|
@ -253,8 +253,6 @@ struct TiledLoopOpInterface
|
|||
return failure();
|
||||
|
||||
// Insert mapping and aliasing info.
|
||||
state.aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
|
||||
state.mapBuffer(opResult, resultBuffer);
|
||||
|
||||
// Insert new operand and bbArg.
|
||||
|
@ -263,9 +261,6 @@ struct TiledLoopOpInterface
|
|||
body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
|
||||
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
|
||||
// Insert mapping and aliasing info.
|
||||
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
|
||||
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
|
||||
newBufferBBArg);
|
||||
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
|
||||
|
||||
// Set operand of `linalg.yield` to the bbArg so it just canonicalizes
|
||||
|
@ -303,9 +298,6 @@ struct TiledLoopOpInterface
|
|||
BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
|
||||
|
||||
// Insert mapping and aliasing info.
|
||||
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
|
||||
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
|
||||
newBufferBBArg);
|
||||
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
|
||||
|
||||
// Increment indices.
|
||||
|
|
|
@ -223,7 +223,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
BufferizationState &state) {
|
||||
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||
|
||||
// If nothing to do then we are done.
|
||||
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
|
||||
|
@ -321,15 +320,12 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
auto castOp = b.create<memref::CastOp>(
|
||||
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
|
||||
toMemrefOp.memref().replaceAllUsesWith(castOp);
|
||||
aliasInfo.insertNewBufferEquivalence(castOp.dest(),
|
||||
toMemrefOp.memref());
|
||||
}
|
||||
}
|
||||
// Replace all remaining uses by a to_tensor.
|
||||
if (!bbArg.use_empty()) {
|
||||
auto toTensorOp =
|
||||
b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
|
||||
aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
|
||||
bbArg.replaceAllUsesWith(toTensorOp);
|
||||
}
|
||||
frontBlock.eraseArgument(0);
|
||||
|
@ -562,7 +558,6 @@ struct CallOpInterface
|
|||
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
|
||||
// Add CallOp operand/result equivalence: this is interprocedural
|
||||
// info.
|
||||
state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
|
||||
state.mapBuffer(oldRes, buffer);
|
||||
// Add a ToTensorOp to kill all uses of the CallOp return.
|
||||
// Replace all uses of the CallOp results so we can erase the CallOp.
|
||||
|
@ -572,7 +567,6 @@ struct CallOpInterface
|
|||
b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
|
||||
oldRes.replaceAllUsesWith(toTensorOp);
|
||||
// Add new op equivalence info.
|
||||
state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer);
|
||||
state.mapBuffer(toTensorOp, buffer);
|
||||
continue;
|
||||
}
|
||||
|
@ -615,7 +609,6 @@ struct CallOpInterface
|
|||
Value castBuffer =
|
||||
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
|
||||
// Add new op equivalence info.
|
||||
state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
|
||||
state.mapBuffer(tensorOperand, castBuffer);
|
||||
buffer = castBuffer;
|
||||
}
|
||||
|
@ -663,7 +656,6 @@ struct ReturnOpInterface
|
|||
Value returnTensor = b.create<bufferization::ToTensorOp>(
|
||||
returnOp.getLoc(), v);
|
||||
operand.set(returnTensor);
|
||||
state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
|
||||
state.mapBuffer(returnTensor, v);
|
||||
}
|
||||
return success();
|
||||
|
@ -690,7 +682,6 @@ struct FuncOpInterface
|
|||
: getContiguousOrUnrankedMemRefType(tensorType);
|
||||
Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
|
||||
memRefType, bbArg);
|
||||
state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
|
||||
state.mapBuffer(bbArg, bufferCast);
|
||||
}
|
||||
|
||||
|
|
|
@ -147,7 +147,6 @@ struct IfOpInterface
|
|||
if (!resultBuffer)
|
||||
return failure();
|
||||
|
||||
state.aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
state.mapBuffer(opResult, resultBuffer);
|
||||
}
|
||||
|
||||
|
@ -237,8 +236,6 @@ struct ForOpInterface
|
|||
|
||||
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
|
||||
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
|
||||
state.aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
|
||||
state.mapBuffer(bbArg, resultBuffer);
|
||||
state.mapBuffer(opResult, resultBuffer);
|
||||
}
|
||||
|
@ -257,15 +254,6 @@ struct ForOpInterface
|
|||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
|
||||
bbArg)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
return yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to an equivalent buffer to the matching"
|
||||
<< " enclosing scf::for operand";
|
||||
}
|
||||
|
||||
// Buffers are equivalent so the work is already done and we just yield
|
||||
// the bbArg so that it later canonicalizes away.
|
||||
|
@ -275,6 +263,41 @@ struct ForOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
|
||||
AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
LogicalResult status = success();
|
||||
funcOp->walk([&](scf::YieldOp yieldOp) {
|
||||
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
|
||||
if (!forOp)
|
||||
return WalkResult::advance();
|
||||
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
|
||||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
|
||||
bbArg)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
status =
|
||||
yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to an equivalent buffer to the matching"
|
||||
<< " enclosing scf::for operand";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return status;
|
||||
}
|
||||
|
||||
struct YieldOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
|
||||
scf::YieldOp> {
|
||||
|
|
|
@ -80,7 +80,6 @@ struct CastOpInterface
|
|||
castOp.getResult().getType(), layout, memorySpace);
|
||||
Value res =
|
||||
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
|
||||
state.mapBuffer(castOp.getResult(), res);
|
||||
return success();
|
||||
}
|
||||
|
@ -233,7 +232,6 @@ struct InsertOpInterface
|
|||
b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
|
||||
insertOp.indices());
|
||||
state.mapBuffer(insertOp, destMemref);
|
||||
state.aliasInfo.insertNewBufferAlias(insertOp, destMemref);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -421,8 +419,6 @@ struct InsertSliceOpInterface
|
|||
Value subView = b.create<memref::SubViewOp>(
|
||||
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
|
||||
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||
// Insert new alias.
|
||||
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
|
||||
// Copy tensor.
|
||||
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
|
||||
state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),
|
||||
|
|
|
@ -96,6 +96,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
|||
// TODO: Find a way to enable this step automatically when bufferizing tensor
|
||||
// dialect ops.
|
||||
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
|
||||
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
|
||||
|
||||
ModuleOp moduleOp = getOperation();
|
||||
applyEnablingTransformations(moduleOp);
|
||||
|
|
|
@ -1113,7 +1113,7 @@ func @reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
|
||||
// Read from %t1 via alias %e.
|
||||
%v2 = vector.transfer_read %e[%s], %cst : tensor<?xf32>, vector<5xf32>
|
||||
scf.yield %e, %v2 : tensor<?xf32>, vector<5xf32>
|
||||
scf.yield %t2, %v2 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
// CHECK: __inplace_results_attr__ = ["true", "false"]
|
||||
|
||||
|
@ -1154,14 +1154,10 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
// This loop does not read from %t1. It only writes to it.
|
||||
// CHECK: scf.for
|
||||
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
|
||||
|
||||
// Write to %t1 via alias. (Overwrite %t3.)
|
||||
// Write to %t1 via %t2. (Overwrite %t3.)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%o2 = linalg.generic #trait outs (%e : tensor<?xf32>) {
|
||||
%o2 = linalg.generic #trait outs (%t2 : tensor<?xf32>) {
|
||||
^bb(%0: f32) :
|
||||
linalg.yield %cst : f32
|
||||
} -> (tensor<?xf32>)
|
||||
|
@ -1172,8 +1168,8 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
}
|
||||
|
||||
// Use %t3 in some way without reading it, so that it does not get DCE'd.
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
|
||||
^bb(%0: f32) :
|
||||
linalg.yield %cst : f32
|
||||
|
|
Loading…
Reference in New Issue