[mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps

spirv::getElementPtr can return null (for memrefs with affine map) but patterns didn't handle this.

Differential Revision: https://reviews.llvm.org/D106988
This commit is contained in:
Butygin 2021-07-28 22:31:26 +03:00
parent 8eaa05d061
commit 1e9799e204
2 changed files with 16 additions and 0 deletions

View File

@ -147,6 +147,7 @@ Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
/// Performs the index computation to get to the element at `indices` of the
/// memory pointed to by `basePtr`, using the layout map of `baseType`.
/// Returns null if index computation cannot be performed.
// TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
// that has static strides. Extend to handle dynamic strides.

View File

@ -268,6 +268,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
loadOperands.indices(), loc, rewriter);
if (!accessChainOp)
return failure();
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
if (isBool)
@ -358,6 +361,10 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
auto loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
if (!loadPtr)
return failure();
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
return success();
}
@ -376,6 +383,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
spirv::AccessChainOp accessChainOp =
spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
storeOperands.indices(), loc, rewriter);
if (!accessChainOp)
return failure();
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
bool isBool = srcBits == 1;
@ -467,6 +478,10 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
storeOperands.memref(), storeOperands.indices(),
storeOp.getLoc(), rewriter);
if (!storePtr)
return failure();
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
storeOperands.value());
return success();