From 4a876b13fbba8ad9ad7375bd1fe22c7d71c4ad05 Mon Sep 17 00:00:00 2001 From: harsh Date: Tue, 8 Feb 2022 19:50:00 +0000 Subject: [PATCH] Add case to handle 0-D vectors in FlattenContiguousRowMajorTransferWritePattern and FlattenContiguousRowMajorTransferReadPattern. For 0-D as well as 1-D vectors, both these patterns should return a failure as there is no need to collapse the shape of the source. Currently, only 1-D vectors were handled. This patch handles the 0-D case as well. Reviewed By: Benoit, ThomasRaoux Differential Revision: https://reviews.llvm.org/D119202 --- .../Transforms/VectorTransferOpTransforms.cpp | 8 +++--- .../Vector/vector-transfer-flatten.mlir | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 40897ef53b05..2092d93f3ece 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -373,8 +373,8 @@ class FlattenContiguousRowMajorTransferReadPattern // Contiguity check is valid on tensors only. if (!sourceType) return failure(); - if (vectorType.getRank() == 1 && sourceType.getRank() == 1) - // Already 1D, nothing to do. + if (vectorType.getRank() <= 1) + // Already 0D/1D, nothing to do. return failure(); if (!isStaticShapeAndContiguousRowMajor(sourceType)) return failure(); @@ -425,8 +425,8 @@ class FlattenContiguousRowMajorTransferWritePattern // Contiguity check is valid on tensors only. if (!sourceType) return failure(); - if (vectorType.getRank() == 1 && sourceType.getRank() == 1) - // Already 1D, nothing to do. + if (vectorType.getRank() <= 1) + // Already 0D/1D, nothing to do. return failure(); if (!isStaticShapeAndContiguousRowMajor(sourceType)) return failure(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 68a6779461d6..65025b4c9b9c 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -33,3 +33,29 @@ func @transfer_write_flattenable_with_offset( // C-HECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> // C-HECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// ----- + +func @transfer_write_0d(%arg : memref, %vec : vector) { + vector.transfer_write %vec, %arg[] : vector, memref + return +} + +// CHECK-LABEL: func @transfer_write_0d +// CHECK-SAME: %[[ARG:.+]]: memref +// CHECK-SAME: %[[VEC:.+]]: vector +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector, memref +// CHECK: return + +// ----- + +func @transfer_read_0d(%arg : memref) -> vector { + %cst = arith.constant 0 : i8 + %0 = vector.transfer_read %arg[], %cst : memref, vector + return %0 : vector +} + +// CHECK-LABEL: func @transfer_read_0d +// CHECK-SAME: %[[ARG:.+]]: memref +// CHECK: %[[CST:.+]] = arith.constant 0 : i8 +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref +// CHECK: return %[[READ]]