[mlir][VectorOps] Fail fast when a strided memref is passed to vector_transfer

Otherwise we'll silently miscompile things.

Differential Revision: https://reviews.llvm.org/D86951
This commit is contained in:
Benjamin Kramer 2020-09-01 17:21:27 +02:00
parent 21d02dc595
commit 2bf491c729
1 changed files with 25 additions and 15 deletions

View File

@ -1025,6 +1025,25 @@ public:
bool hasBoundedRewriteRecursion() const final { return true; }
};
/// Returns true if the memory underlying `memRefType` has a contiguous layout.
/// Strides are written to `strides`.
static bool isContiguous(MemRefType memRefType,
SmallVectorImpl<int64_t> &strides) {
int64_t offset;
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
bool isContiguous = (strides.back() == 1);
if (isContiguous) {
auto sizes = memRefType.getShape();
for (int index = 0, e = strides.size() - 2; index < e; ++index) {
if (strides[index] != strides[index + 1] * sizes[index + 1]) {
isContiguous = false;
break;
}
}
}
return succeeded(successStrides) && isContiguous;
}
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorTypeCastOpConversion(MLIRContext *context,
@ -1058,22 +1077,9 @@ public:
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
getStridesAndOffset(sourceMemRefType, strides, offset);
bool isContiguous = (strides.back() == 1);
if (isContiguous) {
auto sizes = sourceMemRefType.getShape();
for (int index = 0, e = strides.size() - 2; index < e; ++index) {
if (strides[index] != strides[index + 1] * sizes[index + 1]) {
isContiguous = false;
break;
}
}
}
// Only contiguous source tensors supported atm.
if (failed(successStrides) || !isContiguous)
SmallVector<int64_t, 4> strides;
if (!isContiguous(sourceMemRefType, strides))
return failure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
@ -1141,6 +1147,10 @@ public:
xferOp.getVectorType().getRank(),
op->getContext()))
return failure();
// Only contiguous source tensors supported atm.
SmallVector<int64_t, 4> strides;
if (!isContiguous(xferOp.getMemRefType(), strides))
return failure();
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };