forked from OSchip/llvm-project
[mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps
spirv::getElementPtr can return null (for memrefs with affine map) but patterns didn't handle this. Differential Revision: https://reviews.llvm.org/D106988
This commit is contained in:
parent
8eaa05d061
commit
1e9799e204
|
@ -147,6 +147,7 @@ Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
|
|||
|
||||
/// Performs the index computation to get to the element at `indices` of the
|
||||
/// memory pointed to by `basePtr`, using the layout map of `baseType`.
|
||||
/// Returns null if index computation cannot be performed.
|
||||
|
||||
// TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
|
||||
// that has static strides. Extend to handle dynamic strides.
|
||||
|
|
|
@ -268,6 +268,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
|
|||
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
|
||||
loadOperands.indices(), loc, rewriter);
|
||||
|
||||
if (!accessChainOp)
|
||||
return failure();
|
||||
|
||||
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
|
||||
bool isBool = srcBits == 1;
|
||||
if (isBool)
|
||||
|
@ -358,6 +361,10 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
|||
auto loadPtr = spirv::getElementPtr(
|
||||
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
|
||||
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
|
||||
|
||||
if (!loadPtr)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
|
||||
return success();
|
||||
}
|
||||
|
@ -376,6 +383,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|||
spirv::AccessChainOp accessChainOp =
|
||||
spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
|
||||
storeOperands.indices(), loc, rewriter);
|
||||
|
||||
if (!accessChainOp)
|
||||
return failure();
|
||||
|
||||
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
|
||||
|
||||
bool isBool = srcBits == 1;
|
||||
|
@ -467,6 +478,10 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|||
spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
|
||||
storeOperands.memref(), storeOperands.indices(),
|
||||
storeOp.getLoc(), rewriter);
|
||||
|
||||
if (!storePtr)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
|
||||
storeOperands.value());
|
||||
return success();
|
||||
|
|
Loading…
Reference in New Issue