forked from OSchip/llvm-project
[mlir][vector] Fix TransferOpReduceRank for 0-D tensors
We cannot unconditionally generate memref.load ops for such cases; need to check the source's type. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114376
This commit is contained in:
parent
9c5982ef8e
commit
93284120f2
|
@ -224,9 +224,15 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
|
|||
// https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
|
||||
// In the meantime, lower these to a scalar load when they pop up.
|
||||
if (reducedShapeRank == 0) {
|
||||
Value newRead = rewriter.create<memref::LoadOp>(
|
||||
Value newRead;
|
||||
if (op.getShapedType().isa<TensorType>()) {
|
||||
newRead = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.source(),
|
||||
op.indices());
|
||||
} else {
|
||||
newRead = rewriter.create<memref::LoadOp>(
|
||||
op.getLoc(), originalVecType.getElementType(), op.source(),
|
||||
op.indices());
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
|
||||
newRead);
|
||||
return success();
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @vector_transfer_ops_0d(
|
||||
// CHECK-LABEL: func @vector_transfer_ops_0d_memref(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<f32>
|
||||
// CHECK-SAME: %[[VV:.*]]: vector<1x1x1xf32>
|
||||
func @vector_transfer_ops_0d(%M: memref<f32>, %v: vector<1x1x1xf32>) {
|
||||
func @vector_transfer_ops_0d_memref(%M: memref<f32>, %v: vector<1x1x1xf32>) {
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
|
||||
// CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
|
||||
|
@ -23,6 +23,22 @@ func @vector_transfer_ops_0d(%M: memref<f32>, %v: vector<1x1x1xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @vector_transfer_ops_0d_tensor(
|
||||
// CHECK-SAME: %[[SOURCE:.*]]: tensor<f32>
|
||||
func @vector_transfer_ops_0d_tensor(%M: tensor<f32>) -> vector<1xf32> {
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
|
||||
// CHECK-NEXT: %[[S:.*]] = tensor.extract %[[SOURCE]][] : tensor<f32>
|
||||
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<1xf32>
|
||||
%0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
|
||||
tensor<f32>, vector<1xf32>
|
||||
|
||||
// CHECK-NEXT: return %[[V]]
|
||||
return %0: vector<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// transfer_read/write are lowered to vector.load/store
|
||||
// CHECK-LABEL: func @transfer_to_load(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
|
|
Loading…
Reference in New Issue