diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 95208ad231c9..0f428f887d12 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -16,6 +16,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/EDSC/Builders.h" #include "mlir/Dialect/SCF/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -203,7 +204,10 @@ Value NDTransferOpHelper::emitInBoundsCondition( Value inBoundsCondition; majorIvsPlusOffsets.reserve(majorIvs.size()); unsigned idx = 0; - for (auto it : llvm::zip(majorIvs, majorOffsets, memrefBounds.getUbs())) { + SmallVector bounds = + linalg::applyMapToValues(rewriter, xferOp.getLoc(), + xferOp.permutation_map(), memrefBounds.getUbs()); + for (auto it : llvm::zip(majorIvs, majorOffsets, bounds)) { Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it); using namespace mlir::edsc::op; majorIvsPlusOffsets.push_back(iv + off); diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir index 213877cd36af..5c2da799d861 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir @@ -400,3 +400,60 @@ func @transfer_read_simple(%A : memref<2x2xf32>) -> vector<2x2xf32> { %0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<2x2xf32>, vector<2x2xf32> return %0 : vector<2x2xf32> } + +func @transfer_read_minor_identity(%A : memref) -> vector<3x3xf32> { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + %0 = vector.transfer_read %A[%c0, %c0, %c0, %c0], %f0 + { permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)> } + : memref, vector<3x3xf32> + return %0 : vector<3x3xf32> +} + +// CHECK-LABEL: transfer_read_minor_identity( +// CHECK-SAME: %[[A:.*]]: memref) -> vector<3x3xf32> +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK: %[[cst:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[c2:.*]] = constant 2 : index +// CHECK: %[[cst0:.*]] = constant dense<0.000000e+00> : vector<3xf32> +// CHECK: %[[m:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<3xf32>> +// CHECK: %[[d:.*]] = dim %[[A]], %[[c2]] : memref +// CHECK: affine.for %[[arg1:.*]] = 0 to 3 { +// CHECK: %[[cmp:.*]] = cmpi "slt", %[[arg1]], %[[d]] : index +// CHECK: scf.if %[[cmp]] { +// CHECK: %[[tr:.*]] = vector.transfer_read %[[A]][%[[c0]], %[[c0]], %[[arg1]], %[[c0]]], %[[cst]] : memref, vector<3xf32> +// CHECK: store %[[tr]], %[[m]][%[[arg1]]] : memref<3xvector<3xf32>> +// CHECK: } else { +// CHECK: store %[[cst0]], %[[m]][%[[arg1]]] : memref<3xvector<3xf32>> +// CHECK: } +// CHECK: } +// CHECK: %[[cast:.*]] = vector.type_cast %[[m]] : memref<3xvector<3xf32>> to memref> +// CHECK: %[[ret:.*]] = load %[[cast]][] : memref> +// CHECK: return %[[ret]] : vector<3x3xf32> + +func @transfer_write_minor_identity(%A : vector<3x3xf32>, %B : memref) { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + vector.transfer_write %A, %B[%c0, %c0, %c0, %c0] + { permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)> } + : vector<3x3xf32>, memref + return +} + +// CHECK-LABEL: transfer_write_minor_identity( +// CHECK-SAME: %[[A:.*]]: vector<3x3xf32>, +// CHECK-SAME: %[[B:.*]]: memref) +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK: %[[c2:.*]] = constant 2 : index +// CHECK: %[[m:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<3xf32>> +// CHECK: %[[cast:.*]] = vector.type_cast %[[m]] : memref<3xvector<3xf32>> to memref> +// CHECK: store %[[A]], %[[cast]][] : memref> +// CHECK: %[[d:.*]] = dim %[[B]], %[[c2]] : memref +// CHECK: affine.for %[[arg2:.*]] = 0 to 3 { +// CHECK: %[[cmp:.*]] = cmpi "slt", %[[arg2]], %[[d]] : index +// CHECK: scf.if %[[cmp]] { +// CHECK: %[[tmp:.*]] = load %[[m]][%[[arg2]]] : memref<3xvector<3xf32>> +// CHECK: vector.transfer_write %[[tmp]], %[[B]][%[[c0]], %[[c0]], %[[arg2]], %[[c0]]] : vector<3xf32>, memref +// CHECK: } +// CHECK: } +// CHECK: return