[mlir] Fail early if AnalysisState::getBuffer() returns failure

This patch updates calls to AnalysisState::getBuffer() so that we return
early with a failure if the call does not succeed.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D125251
This commit is contained in:
Ashay Rane 2022-05-09 11:22:43 -07:00
parent 93a8225da1
commit 53ff0daa7e
No known key found for this signature in database
GPG Key ID: 0DF50B3E307F5706
1 changed files with 39 additions and 24 deletions

View File

@ -115,7 +115,9 @@ struct CollapseShapeOpInterface
if (tensorResultType.getRank() == 0) {
// 0-d collapses must go through a different op builder.
Value buffer = *state.getBuffer(rewriter, srcOperand);
auto buffer = state.getBuffer(rewriter, srcOperand);
if (failed(buffer))
return failure();
MemRefType resultType;
if (bufferType.getLayout().isIdentity()) {
@ -138,7 +140,7 @@ struct CollapseShapeOpInterface
}
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
rewriter, op, resultType, *buffer, collapseShapeOp.reassociation());
return success();
}
@ -152,11 +154,13 @@ struct CollapseShapeOpInterface
? None
: Optional<BufferizationState::ForceInPlacability>(
BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace);
if (failed(buffer))
return failure();
// Result type is inferred by the builder.
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
rewriter, op, *buffer, collapseShapeOp.getReassociationIndices());
return success();
}
};
@ -183,8 +187,11 @@ struct DimOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
if (failed(v))
return failure();
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
dimOp.index());
return success();
}
};
@ -219,13 +226,15 @@ struct ExpandShapeOpInterface
BufferizationState &state) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto tensorResultType = expandShapeOp.getResultType();
Value buffer =
*state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
auto buffer =
state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
if (failed(buffer))
return failure();
// Memref result type is inferred by the builder based on reassociation
// indices and result shape.
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
rewriter, op, tensorResultType.getShape(), buffer,
rewriter, op, tensorResultType.getShape(), *buffer,
expandShapeOp.getReassociationIndices());
return success();
}
@ -264,10 +273,12 @@ struct ExtractSliceOpInterface
// Even if this op was decided to bufferize out-of-place, do not insert the
// buffer copy yet. This is done later in this function.
Value srcMemref =
*state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
BufferizationState::ForceInPlacability::FORCE_INPLACE);
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto srcMemref =
state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
BufferizationState::ForceInPlacability::FORCE_INPLACE);
if (failed(srcMemref))
return failure();
auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
@ -289,7 +300,7 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
OffsetSizeAndStrideOpInterface::expandToRank(
srcMemref, mixedOffsets, mixedSizes, mixedStrides,
*srcMemref, mixedOffsets, mixedSizes, mixedStrides,
[&](Value target, int64_t dim) -> OpFoldResult {
auto shapedType = target.getType().cast<ShapedType>();
if (shapedType.isDynamicDim(dim))
@ -302,7 +313,7 @@ struct ExtractSliceOpInterface
mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
mixedStrides);
// If not inplaceable, copy.
@ -342,9 +353,11 @@ struct ExtractOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref =
*state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
auto srcMemref =
state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
if (failed(srcMemref))
return failure();
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
extractOp.indices());
return success();
}
@ -703,10 +716,10 @@ struct InsertSliceOpInterface
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
Value srcMemref =
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
state.getOptions())))
auto srcMemref =
state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref,
subView, state.getOptions())))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
@ -736,9 +749,11 @@ struct RankOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto rankOp = cast<tensor::RankOp>(op);
Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
if (failed(v))
return failure();
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
v);
*v);
return success();
}
};