From 93284120f28c82503138f3e594358349ed0ab37f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 22 Nov 2021 12:28:39 -0500 Subject: [PATCH] [mlir][vector] Fix TransferOpReduceRank for 0-D tensors We cannot unconditionally generate memref.load ops for such cases; need to check the source's type. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114376 --- ...rTransferPermutationMapRewritePatterns.cpp | 12 ++++++++--- .../vector-transfer-to-vector-load-store.mlir | 20 +++++++++++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp index 3f5c3127a286..a27ebfc8e5c6 100644 --- a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp @@ -224,9 +224,15 @@ struct TransferOpReduceRank : public OpRewritePattern { // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. // In the meantime, lower these to a scalar load when they pop up. if (reducedShapeRank == 0) { - Value newRead = rewriter.create( - op.getLoc(), originalVecType.getElementType(), op.source(), - op.indices()); + Value newRead; + if (op.getShapedType().isa()) { + newRead = rewriter.create(op.getLoc(), op.source(), + op.indices()); + } else { + newRead = rewriter.create( + op.getLoc(), originalVecType.getElementType(), op.source(), + op.indices()); + } rewriter.replaceOpWithNewOp(op, originalVecType, newRead); return success(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 866d791c7c19..a5c0cb584b11 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s -// CHECK-LABEL: func @vector_transfer_ops_0d( +// CHECK-LABEL: func @vector_transfer_ops_0d_memref( // CHECK-SAME: %[[MEM:.*]]: memref // CHECK-SAME: %[[VV:.*]]: vector<1x1x1xf32> -func @vector_transfer_ops_0d(%M: memref, %v: vector<1x1x1xf32>) { +func @vector_transfer_ops_0d_memref(%M: memref, %v: vector<1x1x1xf32>) { %f0 = arith.constant 0.0 : f32 // CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref @@ -23,6 +23,22 @@ func @vector_transfer_ops_0d(%M: memref, %v: vector<1x1x1xf32>) { // ----- +// CHECK-LABEL: func @vector_transfer_ops_0d_tensor( +// CHECK-SAME: %[[SOURCE:.*]]: tensor +func @vector_transfer_ops_0d_tensor(%M: tensor) -> vector<1xf32> { + %f0 = arith.constant 0.0 : f32 + +// CHECK-NEXT: %[[S:.*]] = tensor.extract %[[SOURCE]][] : tensor +// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<1xf32> + %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} : + tensor, vector<1xf32> + +// CHECK-NEXT: return %[[V]] + return %0: vector<1xf32> +} + +// ----- + // transfer_read/write are lowered to vector.load/store // CHECK-LABEL: func @transfer_to_load( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,