forked from OSchip/llvm-project
[mlir] Fail early if AnalysisState::getBuffer() returns failure
This patch updates calls to AnalysisState::getBuffer() so that we return early with a failure if the call does not succeed. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D125251
This commit is contained in:
parent
93a8225da1
commit
53ff0daa7e
|
@ -115,7 +115,9 @@ struct CollapseShapeOpInterface
|
||||||
|
|
||||||
if (tensorResultType.getRank() == 0) {
|
if (tensorResultType.getRank() == 0) {
|
||||||
// 0-d collapses must go through a different op builder.
|
// 0-d collapses must go through a different op builder.
|
||||||
Value buffer = *state.getBuffer(rewriter, srcOperand);
|
auto buffer = state.getBuffer(rewriter, srcOperand);
|
||||||
|
if (failed(buffer))
|
||||||
|
return failure();
|
||||||
MemRefType resultType;
|
MemRefType resultType;
|
||||||
|
|
||||||
if (bufferType.getLayout().isIdentity()) {
|
if (bufferType.getLayout().isIdentity()) {
|
||||||
|
@ -138,7 +140,7 @@ struct CollapseShapeOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
|
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
|
||||||
rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
|
rewriter, op, resultType, *buffer, collapseShapeOp.reassociation());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,11 +154,13 @@ struct CollapseShapeOpInterface
|
||||||
? None
|
? None
|
||||||
: Optional<BufferizationState::ForceInPlacability>(
|
: Optional<BufferizationState::ForceInPlacability>(
|
||||||
BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
|
BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
|
||||||
Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
|
auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace);
|
||||||
|
if (failed(buffer))
|
||||||
|
return failure();
|
||||||
|
|
||||||
// Result type is inferred by the builder.
|
// Result type is inferred by the builder.
|
||||||
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
|
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
|
||||||
rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
|
rewriter, op, *buffer, collapseShapeOp.getReassociationIndices());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -183,8 +187,11 @@ struct DimOpInterface
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
BufferizationState &state) const {
|
||||||
auto dimOp = cast<tensor::DimOp>(op);
|
auto dimOp = cast<tensor::DimOp>(op);
|
||||||
Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
|
auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
|
||||||
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
|
if (failed(v))
|
||||||
|
return failure();
|
||||||
|
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
|
||||||
|
dimOp.index());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -219,13 +226,15 @@ struct ExpandShapeOpInterface
|
||||||
BufferizationState &state) const {
|
BufferizationState &state) const {
|
||||||
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
|
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
|
||||||
auto tensorResultType = expandShapeOp.getResultType();
|
auto tensorResultType = expandShapeOp.getResultType();
|
||||||
Value buffer =
|
auto buffer =
|
||||||
*state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
|
state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
|
||||||
|
if (failed(buffer))
|
||||||
|
return failure();
|
||||||
|
|
||||||
// Memref result type is inferred by the builder based on reassociation
|
// Memref result type is inferred by the builder based on reassociation
|
||||||
// indices and result shape.
|
// indices and result shape.
|
||||||
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
|
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
|
||||||
rewriter, op, tensorResultType.getShape(), buffer,
|
rewriter, op, tensorResultType.getShape(), *buffer,
|
||||||
expandShapeOp.getReassociationIndices());
|
expandShapeOp.getReassociationIndices());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -264,10 +273,12 @@ struct ExtractSliceOpInterface
|
||||||
|
|
||||||
// Even if this op was decided to bufferize out-of-place, do not insert the
|
// Even if this op was decided to bufferize out-of-place, do not insert the
|
||||||
// buffer copy yet. This is done later in this function.
|
// buffer copy yet. This is done later in this function.
|
||||||
Value srcMemref =
|
auto srcMemref =
|
||||||
*state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
|
state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
|
||||||
BufferizationState::ForceInPlacability::FORCE_INPLACE);
|
BufferizationState::ForceInPlacability::FORCE_INPLACE);
|
||||||
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
if (failed(srcMemref))
|
||||||
|
return failure();
|
||||||
|
auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
|
||||||
auto dstTensorType =
|
auto dstTensorType =
|
||||||
extractSliceOp.result().getType().cast<RankedTensorType>();
|
extractSliceOp.result().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
@ -289,7 +300,7 @@ struct ExtractSliceOpInterface
|
||||||
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
||||||
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
||||||
OffsetSizeAndStrideOpInterface::expandToRank(
|
OffsetSizeAndStrideOpInterface::expandToRank(
|
||||||
srcMemref, mixedOffsets, mixedSizes, mixedStrides,
|
*srcMemref, mixedOffsets, mixedSizes, mixedStrides,
|
||||||
[&](Value target, int64_t dim) -> OpFoldResult {
|
[&](Value target, int64_t dim) -> OpFoldResult {
|
||||||
auto shapedType = target.getType().cast<ShapedType>();
|
auto shapedType = target.getType().cast<ShapedType>();
|
||||||
if (shapedType.isDynamicDim(dim))
|
if (shapedType.isDynamicDim(dim))
|
||||||
|
@ -302,7 +313,7 @@ struct ExtractSliceOpInterface
|
||||||
mixedOffsets, mixedSizes, mixedStrides)
|
mixedOffsets, mixedSizes, mixedStrides)
|
||||||
.cast<MemRefType>();
|
.cast<MemRefType>();
|
||||||
Value subView = rewriter.create<memref::SubViewOp>(
|
Value subView = rewriter.create<memref::SubViewOp>(
|
||||||
loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
|
loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
|
||||||
mixedStrides);
|
mixedStrides);
|
||||||
|
|
||||||
// If not inplaceable, copy.
|
// If not inplaceable, copy.
|
||||||
|
@ -342,9 +353,11 @@ struct ExtractOpInterface
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
BufferizationState &state) const {
|
||||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||||
Value srcMemref =
|
auto srcMemref =
|
||||||
*state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
|
state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
|
||||||
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
|
if (failed(srcMemref))
|
||||||
|
return failure();
|
||||||
|
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
|
||||||
extractOp.indices());
|
extractOp.indices());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -703,10 +716,10 @@ struct InsertSliceOpInterface
|
||||||
|
|
||||||
// Copy tensor. If this tensor.insert_slice has a matching
|
// Copy tensor. If this tensor.insert_slice has a matching
|
||||||
// tensor.extract_slice, the copy operation will eventually fold away.
|
// tensor.extract_slice, the copy operation will eventually fold away.
|
||||||
Value srcMemref =
|
auto srcMemref =
|
||||||
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
|
state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
|
||||||
if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
|
if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref,
|
||||||
state.getOptions())))
|
subView, state.getOptions())))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
|
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
|
||||||
|
@ -736,9 +749,11 @@ struct RankOpInterface
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
BufferizationState &state) const {
|
||||||
auto rankOp = cast<tensor::RankOp>(op);
|
auto rankOp = cast<tensor::RankOp>(op);
|
||||||
Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
|
auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
|
||||||
|
if (failed(v))
|
||||||
|
return failure();
|
||||||
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
|
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
|
||||||
v);
|
*v);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue