forked from OSchip/llvm-project
[mlir][Bufferize] NFC - Introduce areCastCompatible assertions to catch misformed CastOp early
Differential Revision: https://reviews.llvm.org/D116893
This commit is contained in:
parent
1ce01b7dfe
commit
9ba25ec92d
|
@ -549,6 +549,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
|||
return failure();
|
||||
Value casted = allocated.getValue();
|
||||
if (memRefType && memRefType != allocMemRefType) {
|
||||
assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
|
||||
memRefType) &&
|
||||
"createAlloc: cast incompatible");
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
|
||||
}
|
||||
|
||||
|
|
|
@ -77,9 +77,13 @@ struct ToMemrefOpInterface
|
|||
|
||||
// Insert cast in case to_memref(to_tensor(x))'s type is different from
|
||||
// x's type.
|
||||
if (toTensorOp.memref().getType() != toMemrefOp.getType())
|
||||
if (toTensorOp.memref().getType() != toMemrefOp.getType()) {
|
||||
assert(memref::CastOp::areCastCompatible(buffer.getType(),
|
||||
toMemrefOp.getType()) &&
|
||||
"ToMemrefOp::bufferize : cast incompatible");
|
||||
buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer,
|
||||
toMemrefOp.getType());
|
||||
}
|
||||
replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -386,7 +386,10 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
// Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
|
||||
for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
|
||||
if (auto toMemrefOp =
|
||||
dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
|
||||
dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
|
||||
assert(memref::CastOp::areCastCompatible(
|
||||
memref.getType(), toMemrefOp.memref().getType()) &&
|
||||
"bufferizeFuncOpBoundary: cast incompatible");
|
||||
auto castOp = b.create<memref::CastOp>(
|
||||
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
|
||||
toMemrefOp.memref().replaceAllUsesWith(castOp);
|
||||
|
@ -525,6 +528,8 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
|
|||
bbArg.setType(desiredMemrefType);
|
||||
OpBuilder b(bbArg.getContext());
|
||||
b.setInsertionPointToStart(bbArg.getOwner());
|
||||
assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) &&
|
||||
"layoutPostProcessing: cast incompatible");
|
||||
// Cast back to the original memrefType and let it canonicalize.
|
||||
Value cast =
|
||||
b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
|
||||
|
@ -537,6 +542,10 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
|
|||
// such cases.
|
||||
auto castArg = [&](Operation *caller) {
|
||||
OpBuilder b(caller);
|
||||
assert(
|
||||
memref::CastOp::areCastCompatible(
|
||||
caller->getOperand(argNumber).getType(), desiredMemrefType) &&
|
||||
"layoutPostProcessing.2: cast incompatible");
|
||||
Value newOperand = b.create<memref::CastOp>(
|
||||
funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
|
||||
operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
|
||||
|
@ -703,6 +712,9 @@ struct CallOpInterface
|
|||
// that will either canonicalize away or fail compilation until we can do
|
||||
// something better.
|
||||
if (buffer.getType() != memRefType) {
|
||||
assert(
|
||||
memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
|
||||
"CallOp::bufferize: cast incompatible");
|
||||
Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
|
||||
memRefType, buffer);
|
||||
buffer = castBuffer;
|
||||
|
|
|
@ -77,6 +77,9 @@ struct CastOpInterface
|
|||
}
|
||||
|
||||
// Replace the op with a memref.cast.
|
||||
assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
|
||||
resultMemRefType) &&
|
||||
"CallOp::bufferize: cast incompatible");
|
||||
replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
|
||||
*resultBuffer);
|
||||
|
||||
|
|
Loading…
Reference in New Issue