forked from OSchip/llvm-project
[mlir] Fail early if AnalysisState::getBuffer() returns failure
This patch updates calls to AnalysisState::getBuffer() so that we return early with a failure if the call does not succeed. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D125251
This commit is contained in:
parent
93a8225da1
commit
53ff0daa7e
|
@ -115,7 +115,9 @@ struct CollapseShapeOpInterface
|
|||
|
||||
if (tensorResultType.getRank() == 0) {
|
||||
// 0-d collapses must go through a different op builder.
|
||||
Value buffer = *state.getBuffer(rewriter, srcOperand);
|
||||
auto buffer = state.getBuffer(rewriter, srcOperand);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
MemRefType resultType;
|
||||
|
||||
if (bufferType.getLayout().isIdentity()) {
|
||||
|
@ -138,7 +140,7 @@ struct CollapseShapeOpInterface
|
|||
}
|
||||
|
||||
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
|
||||
rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
|
||||
rewriter, op, resultType, *buffer, collapseShapeOp.reassociation());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -152,11 +154,13 @@ struct CollapseShapeOpInterface
|
|||
? None
|
||||
: Optional<BufferizationState::ForceInPlacability>(
|
||||
BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
|
||||
Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
|
||||
auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
|
||||
// Result type is inferred by the builder.
|
||||
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
|
||||
rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
|
||||
rewriter, op, *buffer, collapseShapeOp.getReassociationIndices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -183,8 +187,11 @@ struct DimOpInterface
|
|||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
|
||||
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
|
||||
auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
|
||||
if (failed(v))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
|
||||
dimOp.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -219,13 +226,15 @@ struct ExpandShapeOpInterface
|
|||
BufferizationState &state) const {
|
||||
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
|
||||
auto tensorResultType = expandShapeOp.getResultType();
|
||||
Value buffer =
|
||||
*state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
|
||||
auto buffer =
|
||||
state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
|
||||
// Memref result type is inferred by the builder based on reassociation
|
||||
// indices and result shape.
|
||||
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
|
||||
rewriter, op, tensorResultType.getShape(), buffer,
|
||||
rewriter, op, tensorResultType.getShape(), *buffer,
|
||||
expandShapeOp.getReassociationIndices());
|
||||
return success();
|
||||
}
|
||||
|
@ -264,10 +273,12 @@ struct ExtractSliceOpInterface
|
|||
|
||||
// Even if this op was decided to bufferize out-of-place, do not insert the
|
||||
// buffer copy yet. This is done later in this function.
|
||||
Value srcMemref =
|
||||
*state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
|
||||
BufferizationState::ForceInPlacability::FORCE_INPLACE);
|
||||
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
||||
auto srcMemref =
|
||||
state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
|
||||
BufferizationState::ForceInPlacability::FORCE_INPLACE);
|
||||
if (failed(srcMemref))
|
||||
return failure();
|
||||
auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
|
||||
auto dstTensorType =
|
||||
extractSliceOp.result().getType().cast<RankedTensorType>();
|
||||
|
||||
|
@ -289,7 +300,7 @@ struct ExtractSliceOpInterface
|
|||
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
||||
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
||||
OffsetSizeAndStrideOpInterface::expandToRank(
|
||||
srcMemref, mixedOffsets, mixedSizes, mixedStrides,
|
||||
*srcMemref, mixedOffsets, mixedSizes, mixedStrides,
|
||||
[&](Value target, int64_t dim) -> OpFoldResult {
|
||||
auto shapedType = target.getType().cast<ShapedType>();
|
||||
if (shapedType.isDynamicDim(dim))
|
||||
|
@ -302,7 +313,7 @@ struct ExtractSliceOpInterface
|
|||
mixedOffsets, mixedSizes, mixedStrides)
|
||||
.cast<MemRefType>();
|
||||
Value subView = rewriter.create<memref::SubViewOp>(
|
||||
loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
|
||||
loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
|
||||
mixedStrides);
|
||||
|
||||
// If not inplaceable, copy.
|
||||
|
@ -342,9 +353,11 @@ struct ExtractOpInterface
|
|||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
Value srcMemref =
|
||||
*state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
|
||||
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
|
||||
auto srcMemref =
|
||||
state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
|
||||
if (failed(srcMemref))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
|
||||
extractOp.indices());
|
||||
return success();
|
||||
}
|
||||
|
@ -703,10 +716,10 @@ struct InsertSliceOpInterface
|
|||
|
||||
// Copy tensor. If this tensor.insert_slice has a matching
|
||||
// tensor.extract_slice, the copy operation will eventually fold away.
|
||||
Value srcMemref =
|
||||
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
|
||||
if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
|
||||
state.getOptions())))
|
||||
auto srcMemref =
|
||||
state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
|
||||
if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref,
|
||||
subView, state.getOptions())))
|
||||
return failure();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
|
||||
|
@ -736,9 +749,11 @@ struct RankOpInterface
|
|||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
auto rankOp = cast<tensor::RankOp>(op);
|
||||
Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
|
||||
auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
|
||||
if (failed(v))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
|
||||
v);
|
||||
*v);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue