diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 40897ef53b05..2092d93f3ece 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -373,8 +373,8 @@ class FlattenContiguousRowMajorTransferReadPattern // Contiguity check is valid on tensors only. if (!sourceType) return failure(); - if (vectorType.getRank() == 1 && sourceType.getRank() == 1) - // Already 1D, nothing to do. + if (vectorType.getRank() <= 1) + // Already 0D/1D, nothing to do. return failure(); if (!isStaticShapeAndContiguousRowMajor(sourceType)) return failure(); @@ -425,8 +425,8 @@ class FlattenContiguousRowMajorTransferWritePattern // Contiguity check is valid on tensors only. if (!sourceType) return failure(); - if (vectorType.getRank() == 1 && sourceType.getRank() == 1) - // Already 1D, nothing to do. + if (vectorType.getRank() <= 1) + // Already 0D/1D, nothing to do. return failure(); if (!isStaticShapeAndContiguousRowMajor(sourceType)) return failure(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 68a6779461d6..65025b4c9b9c 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -33,3 +33,29 @@ func @transfer_write_flattenable_with_offset( // C-HECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> // C-HECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// ----- + +func @transfer_write_0d(%arg : memref, %vec : vector) { + vector.transfer_write %vec, %arg[] : vector, memref + return +} + +// CHECK-LABEL: func @transfer_write_0d +// CHECK-SAME: %[[ARG:.+]]: memref +// CHECK-SAME: %[[VEC:.+]]: vector +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector, memref +// CHECK: return + +// ----- + +func @transfer_read_0d(%arg : memref) -> vector { + %cst = arith.constant 0 : i8 + %0 = vector.transfer_read %arg[], %cst : memref, vector + return %0 : vector +} + +// CHECK-LABEL: func @transfer_read_0d +// CHECK-SAME: %[[ARG:.+]]: memref +// CHECK: %[[CST:.+]] = arith.constant 0 : i8 +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref +// CHECK: return %[[READ]]