diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 7dcac62a5850..d5e84314357c 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -746,10 +746,15 @@ def Vector_TransferReadOp : let description = [{ The `vector.transfer_read` op performs a blocking read from a slice within - a scalar [MemRef](../LangRef.md#memref-type) supplied as its first operand - into a [vector](../LangRef.md#vector-type) of the same elemental type. The - slice is further defined by a full-rank index within the MemRef, supplied as - the operands `2 .. 1 + rank(memref)`. The permutation_map + a [MemRef](../LangRef.md#memref-type) supplied as its first operand + into a [vector](../LangRef.md#vector-type) of the same base elemental type. + + A vector memref operand must have its vector element type match a suffix + (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). + + The slice is further defined by a full-rank index within the MemRef, + supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The size of the slice is specified by the @@ -854,6 +859,11 @@ def Vector_TransferReadOp : memref<?x?xf32>, vector<128xf32> } } + + // Read from a memref with vector element type. + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 + {permutation_map = (d0, d1)->(d0, d1)} + : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32> ``` }]; @@ -878,10 +888,15 @@ def Vector_TransferWriteOp : let description = [{ The `vector.transfer_write` performs a blocking write from a [vector](../LangRef.md#vector-type), supplied as its first operand, into a - slice within a scalar [MemRef](../LangRef.md#memref-type) of the same - elemental type, supplied as its second operand. The slice is further defined - by a full-rank index within the MemRef, supplied as the operands - `3 .. 2 + rank(memref)`. + slice within a [MemRef](../LangRef.md#memref-type) of the same base + elemental type, supplied as its second operand. + + A vector memref operand must have its vector element type match a suffix + (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). + + The slice is further defined by a full-rank index within the MemRef, + supplied as the operands `3 .. 2 + rank(memref)`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The size of the slice is specified by the @@ -915,6 +930,11 @@ def Vector_TransferWriteOp : {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : vector<16x32x64xf32>, memref<?x?x?x?xf32> }}}} + + // write to a memref with vector element type. + vector.transfer_write %4, %arg1[%c3, %c3] + {permutation_map = (d0, d1)->(d0, d1)} + : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>> ``` }]; @@ -1048,7 +1068,7 @@ def Vector_TupleOp : Note that this operation is used during the vector op unrolling transformation and should be removed before lowering to lower-level dialects. - + Examples: ``` diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 541b5427af91..8a6946792b2d 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1420,6 +1420,59 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap, return success(); } +static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType, + VectorType vectorType, + AffineMap permutationMap) { + auto memrefElementType = memrefType.getElementType(); + if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) { + // Memref has vector element type. + + // Check that 'memrefVectorElementType' and vector element types match. + if (memrefVectorElementType.getElementType() != vectorType.getElementType()) + return op->emitOpError( + "requires memref and vector types of the same elemental type"); + + // Check that memref vector type is a suffix of 'vectorType. + unsigned memrefVecEltRank = memrefVectorElementType.getRank(); + unsigned resultVecRank = vectorType.getRank(); + if (memrefVecEltRank > resultVecRank) + return op->emitOpError( + "requires memref vector element and vector result ranks to match."); + // TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h. + unsigned rankOffset = resultVecRank - memrefVecEltRank; + auto memrefVecEltShape = memrefVectorElementType.getShape(); + auto resultVecShape = vectorType.getShape(); + for (unsigned i = 0; i < memrefVecEltRank; ++i) + if (memrefVecEltShape[i] != resultVecShape[rankOffset + i]) + return op->emitOpError( + "requires memref vector element shape to match suffix of " + "vector result shape."); + // Check that permutation map results match 'rankOffset' of vector type. + if (permutationMap.getNumResults() != rankOffset) + return op->emitOpError("requires a permutation_map with result dims of " + "the same rank as the vector type"); + } else { + // Memref has scalar element type. + + // Check that memref and vector element types match. + if (memrefType.getElementType() != vectorType.getElementType()) + return op->emitOpError( + "requires memref and vector types of the same elemental type"); + + // Check that permutation map results match rank of vector type. + if (permutationMap.getNumResults() != vectorType.getRank()) + return op->emitOpError("requires a permutation_map with result dims of " + "the same rank as the vector type"); + } + + if (permutationMap.getNumSymbols() != 0) + return op->emitOpError("requires permutation_map without symbols"); + if (permutationMap.getNumInputs() != memrefType.getRank()) + return op->emitOpError("requires a permutation_map with input dims of the " + "same rank as the memref type"); + return success(); +} + static void print(OpAsmPrinter &p, TransferReadOp op) { p << op.getOperationName() << " " << op.memref() << "[" << op.indices() << "], " << op.padding() << " "; @@ -1459,26 +1512,35 @@ static LogicalResult verify(TransferReadOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); - if (memrefType.getElementType() != vectorType.getElementType()) - return op.emitOpError( - "requires memref and vector types of the same elemental type"); - auto elementalType = op.padding()->getType(); - if (!VectorType::isValidElementType(elementalType)) - return op.emitOpError("requires valid padding vector elemental type"); - if (elementalType != vectorType.getElementType()) - return op.emitOpError( - "requires formal padding and vector of the same elemental type"); - if (llvm::size(op.indices()) != memrefType.getRank()) - return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + auto paddingType = op.padding()->getType(); auto permutationMap = op.permutation_map(); - if (permutationMap.getNumSymbols() != 0) - return op.emitOpError("requires permutation_map without symbols"); - if (permutationMap.getNumInputs() != memrefType.getRank()) - return op.emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type"); - if (permutationMap.getNumResults() != vectorType.getRank()) - return op.emitOpError("requires a permutation_map with result dims of the " - "same rank as the vector type"); + auto memrefElementType = memrefType.getElementType(); + + if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank()) + return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + + if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, + permutationMap))) + return failure(); + + if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) { + // Memref has vector element type. + // Check that 'memrefVectorElementType' and 'paddingType' types match. + if (memrefVectorElementType != paddingType) + return op.emitOpError( + "requires memref element type and padding type to match."); + + } else { + // Check that 'paddingType' is valid to store in a vector type. + if (!VectorType::isValidElementType(paddingType)) + return op.emitOpError("requires valid padding vector elemental type"); + + // Check that padding type and vector element types match. + if (paddingType != vectorType.getElementType()) + return op.emitOpError( + "requires formal padding and vector of the same elemental type"); + } + return verifyPermutationMap(permutationMap, [&op](Twine t) { return op.emitOpError(t); }); } @@ -1519,24 +1581,15 @@ static LogicalResult verify(TransferWriteOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); - if (memrefType.getElementType() != vectorType.getElementType()) - return op.emitOpError( - "requires memref and vector types of the same elemental type"); + auto permutationMap = op.permutation_map(); + if (llvm::size(op.indices()) != memrefType.getRank()) return op.emitOpError("requires ") << memrefType.getRank() << " indices"; - // Consistency of AffineMap attribute. - auto permutationMap = op.permutation_map(); - if (permutationMap.getNumSymbols() != 0) - return op.emitOpError("requires a symbol-less permutation_map"); - if (permutationMap.getNumInputs() != memrefType.getRank()) - return op.emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type: ") - << permutationMap.getNumInputs() << " vs " << memrefType; - if (permutationMap.getNumResults() != vectorType.getRank()) - return op.emitOpError("requires a permutation_map with result dims of the " - "same rank as the vector type.") - << permutationMap.getNumResults() << " vs " << vectorType; + if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, + permutationMap))) + return failure(); + return verifyPermutationMap(permutationMap, [&op](Twine t) { return op.emitOpError(t); }); } diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index c208c92fc232..9ef39e251448 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -308,6 +308,36 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) { // ----- +func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{requires memref and vector types of the same elemental type}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xi32> +} + +// ----- + +func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{requires memref vector element and vector result ranks to match}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<3xf32> +} + +// ----- + +func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x2x3xf32> +} + +// ----- + func @test_vector.transfer_write(%arg0: memref<?x?xf32>) { %c3 = constant 3 : index %cst = constant dense<3.0> : vector<128 x f32> diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index e1607996cc2d..d99a7df0d2b3 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -1,24 +1,35 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s +// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1) + // CHECK-LABEL: func @vector_transfer_ops( -func @vector_transfer_ops(%arg0: memref<?x?xf32>) { +func @vector_transfer_ops(%arg0: memref<?x?xf32>, + %arg1 : memref<?x?xvector<4x3xf32>>) { + // CHECK: %[[C3:.*]] = constant 3 : index %c3 = constant 3 : index %cst = constant 3.0 : f32 %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // - // CHECK: %0 = vector.transfer_read + // CHECK: vector.transfer_read %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d0)} : memref<?x?xf32>, vector<128xf32> - // CHECK: %1 = vector.transfer_read + // CHECK: vector.transfer_read %1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d1, d0)} : memref<?x?xf32>, vector<3x7xf32> // CHECK: vector.transfer_read %2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d0)} : memref<?x?xf32>, vector<128xf32> // CHECK: vector.transfer_read %3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d1)} : memref<?x?xf32>, vector<128xf32> - // + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32> + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32> // CHECK: vector.transfer_write vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d1, d0)} : vector<3x7xf32>, memref<?x?xf32> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] {permutation_map = #[[MAP0]]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>> + vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>> + return }