forked from OSchip/llvm-project
[mlir][VectorOps] Fail fast when a strided memref is passed to vector_transfer
Otherwise we'll silently miscompile things. Differential Revision: https://reviews.llvm.org/D86951
This commit is contained in:
parent
21d02dc595
commit
2bf491c729
|
@ -1025,6 +1025,25 @@ public:
|
|||
bool hasBoundedRewriteRecursion() const final { return true; }
|
||||
};
|
||||
|
||||
/// Returns true if the memory underlying `memRefType` has a contiguous layout.
|
||||
/// Strides are written to `strides`.
|
||||
static bool isContiguous(MemRefType memRefType,
|
||||
SmallVectorImpl<int64_t> &strides) {
|
||||
int64_t offset;
|
||||
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
|
||||
bool isContiguous = (strides.back() == 1);
|
||||
if (isContiguous) {
|
||||
auto sizes = memRefType.getShape();
|
||||
for (int index = 0, e = strides.size() - 2; index < e; ++index) {
|
||||
if (strides[index] != strides[index + 1] * sizes[index + 1]) {
|
||||
isContiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return succeeded(successStrides) && isContiguous;
|
||||
}
|
||||
|
||||
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorTypeCastOpConversion(MLIRContext *context,
|
||||
|
@ -1058,22 +1077,9 @@ public:
|
|||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return failure();
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides =
|
||||
getStridesAndOffset(sourceMemRefType, strides, offset);
|
||||
bool isContiguous = (strides.back() == 1);
|
||||
if (isContiguous) {
|
||||
auto sizes = sourceMemRefType.getShape();
|
||||
for (int index = 0, e = strides.size() - 2; index < e; ++index) {
|
||||
if (strides[index] != strides[index + 1] * sizes[index + 1]) {
|
||||
isContiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Only contiguous source tensors supported atm.
|
||||
if (failed(successStrides) || !isContiguous)
|
||||
SmallVector<int64_t, 4> strides;
|
||||
if (!isContiguous(sourceMemRefType, strides))
|
||||
return failure();
|
||||
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
|
||||
|
@ -1141,6 +1147,10 @@ public:
|
|||
xferOp.getVectorType().getRank(),
|
||||
op->getContext()))
|
||||
return failure();
|
||||
// Only contiguous source tensors supported atm.
|
||||
SmallVector<int64_t, 4> strides;
|
||||
if (!isContiguous(xferOp.getMemRefType(), strides))
|
||||
return failure();
|
||||
|
||||
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
|
||||
|
||||
|
|
Loading…
Reference in New Issue