Revert "[mlir] Fix masked vector transfer ops with broadcasts"

This reverts commit c9087788f7.

Accidentally pushed old version of the commit.
This commit is contained in:
Matthias Springer 2021-05-13 11:55:00 +09:00
parent c9087788f7
commit 6555e53ab0
8 changed files with 103 additions and 293 deletions

View File

@ -17,18 +17,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace vector {
namespace detail {
/// Given the vector type and the permutation map of a vector transfer op,
/// compute the expected mask type.
VectorType transferMaskType(VectorType vecType, AffineMap map);
} // namespace detail
} // namespace vector
} // namespace mlir
/// Include the generated interface declarations.
#include "mlir/Interfaces/VectorInterfaces.h.inc"

View File

@ -156,19 +156,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
return $_op.vector().getType().template dyn_cast<VectorType>();
}]
>,
InterfaceMethod<
/*desc=*/"Return the mask type if the op has a mask.",
/*retTy=*/"Optional<VectorType>",
/*methodName=*/"getMaskType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.mask()
? llvm::Optional<VectorType>(mlir::vector::detail::transferMaskType(
$_op.getVectorType(), $_op.permutation_map()))
: llvm::None;
}]
>,
InterfaceMethod<
/*desc=*/[{ Return the number of dimensions that participate in the
permutation map.}],

View File

@ -79,20 +79,13 @@ static BufferAllocs allocBuffers(OpTy xferOp) {
if (xferOp.mask()) {
auto maskType = MemRefType::get({}, xferOp.mask().getType());
auto maskBuffer = memref_alloca(maskType).value;
memref_store(xferOp.mask(), maskBuffer);
result.maskBuffer = memref_load(maskBuffer);
result.maskBuffer = memref_alloca(maskType).value;
memref_store(xferOp.mask(), result.maskBuffer);
}
return result;
}
template <typename OpTy>
static bool isOutermostDimBroadcast(OpTy xferOp) {
auto map = xferOp.permutation_map();
return map.getResult(0).template isa<AffineConstantExpr>();
}
/// Given a vector transfer op, calculate which dimension of the `source`
/// memref should be unpacked in the next application of TransferOpConversion.
/// A return value of None indicates a broadcast.
@ -102,7 +95,7 @@ static Optional<int64_t> unpackedDim(OpTy xferOp) {
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
return expr.getPosition();
}
assert(isOutermostDimBroadcast(xferOp) &&
assert(map.getResult(0).template isa<AffineConstantExpr>() &&
"Expected AffineDimExpr or AffineConstantExpr");
return None;
}
@ -150,17 +143,14 @@ static void maybeYieldValue(
}
/// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
/// is set to true. Does not return a Value if the transfer op does not have a
/// mask, if the transfer op's mask is not 1D or if the to-be-unpacked dim of
/// the transfer op is a broadcast.
/// is set to true. Does not return a Value if the transfer op is not 1D or
/// if the transfer op does not have a mask.
template <typename OpTy>
static Value maybeGenerateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) {
if (xferOp.getVectorType().getRank() != 1)
return Value();
if (!xferOp.mask())
return Value();
if (xferOp.getMaskType()->getRank() != 1)
return Value();
if (isOutermostDimBroadcast(xferOp))
return Value();
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
return vector_extract_element(xferOp.mask(), ivI32).value;
@ -498,8 +488,8 @@ struct PrepareTransferReadConversion
auto *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
if (xferOp.mask()) {
dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
buffers.maskBuffer);
auto loadedMask = memref_load(buffers.maskBuffer);
dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(loadedMask);
}
memref_store(newXfer->getResult(0), buffers.dataBuffer);
@ -551,8 +541,9 @@ struct PrepareTransferWriteConversion
});
if (xferOp.mask()) {
auto loadedMask = memref_load(buffers.maskBuffer);
rewriter.updateRootInPlace(
xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
xferOp, [&]() { xferOp.maskMutable().assign(loadedMask); });
}
return success();
@ -599,18 +590,8 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
auto maskBuffer = getMaskBuffer(xferOp);
auto maskBufferType =
maskBuffer.getType().template dyn_cast<MemRefType>();
if (isOutermostDimBroadcast(xferOp) ||
xferOp.getMaskType()->getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
// * Mask is 1D, i.e., the mask cannot be further unpacked.
// (That means that all remaining dimensions of the transfer op must
// be broadcasts.)
castedMaskBuffer = maskBuffer;
} else {
auto castedMaskType = unpackOneDim(maskBufferType);
castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
}
auto castedMaskType = unpackOneDim(maskBufferType);
castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
}
// Loop bounds and step.
@ -635,20 +616,13 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv);
// If old transfer op has a mask: Set mask on new transfer op.
// Special case: If the mask of the old transfer op is 1D and the
// unpacked dim is not a broadcast, no mask is needed
// on the new transfer op.
if (xferOp.mask() && (isOutermostDimBroadcast(xferOp) ||
xferOp.getMaskType()->getRank() > 1)) {
if (xferOp.mask()) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(newXfer); // Insert load before newXfer.
SmallVector<Value, 8> loadIndices;
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
// In case of broadcast: Use same indices to load from memref as
// before.
if (!isOutermostDimBroadcast(xferOp))
loadIndices.push_back(iv);
loadIndices.push_back(iv);
auto mask = memref_load(castedMaskBuffer, loadIndices);
rewriter.updateRootInPlace(
@ -687,7 +661,7 @@ static Optional<int64_t> get1dMemrefIndices(
return dim;
}
assert(isOutermostDimBroadcast(xferOp) &&
assert(map.getResult(0).template isa<AffineConstantExpr>() &&
"Expected AffineDimExpr or AffineConstantExpr");
return None;
}

View File

@ -2491,11 +2491,10 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
Attribute mapAttr = result.attributes.get(permutationAttrName);
if (!mapAttr) {
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
mapAttr = AffineMapAttr::get(permMap);
result.attributes.set(permutationAttrName, mapAttr);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
@ -2503,10 +2502,7 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
result.operands))
return failure();
if (hasMask.succeeded()) {
auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
// Instead of adding the mask type as an op type, compute it based on the
// vector type and the permutation map (to keep the type signature small).
auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
if (parser.resolveOperand(maskInfo, maskType, result.operands))
return failure();
}

View File

@ -10,26 +10,6 @@
using namespace mlir;
namespace mlir {
namespace vector {
namespace detail {
VectorType transferMaskType(VectorType vecType, AffineMap map) {
auto i1Type = IntegerType::get(map.getContext(), 1);
SmallVector<int64_t, 8> shape;
for (int64_t i = 0; i < vecType.getRank(); ++i) {
// Only result dims have a corresponding dim in the mask.
if (auto expr = map.getResult(i).template isa<AffineDimExpr>()) {
shape.push_back(vecType.getDimSize(i));
}
}
return VectorType::get(shape, i1Type);
}
} // namespace detail
} // namespace vector
} // namespace mlir
//===----------------------------------------------------------------------===//
// VectorUnroll Interfaces
//===----------------------------------------------------------------------===//

View File

@ -5,14 +5,6 @@
// Test for special cases of 1D vector transfer ops.
memref.global "private" @gv : memref<5x6xf32> =
dense<[[0. , 1. , 2. , 3. , 4. , 5. ],
[10., 11., 12., 13., 14., 15.],
[20., 21., 22., 23., 24., 25.],
[30., 31., 32., 33., 34., 35.],
[40., 41., 42., 43., 44., 45.]]>
// Non-contiguous, strided load
func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base1, %base2], %fm42
@ -22,7 +14,6 @@ func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
return
}
// Broadcast
func @transfer_read_1d_broadcast(
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
@ -33,7 +24,6 @@ func @transfer_read_1d_broadcast(
return
}
// Non-contiguous, strided load
func @transfer_read_1d_in_bounds(
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
@ -44,7 +34,6 @@ func @transfer_read_1d_in_bounds(
return
}
// Non-contiguous, strided load
func @transfer_read_1d_mask(
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
@ -56,7 +45,6 @@ func @transfer_read_1d_mask(
return
}
// Non-contiguous, strided load
func @transfer_read_1d_mask_in_bounds(
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
@ -68,7 +56,6 @@ func @transfer_read_1d_mask_in_bounds(
return
}
// Non-contiguous, strided store
func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fn1 = constant -1.0 : f32
%vf0 = splat %fn1 : vector<7xf32>
@ -78,68 +65,57 @@ func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
return
}
// Non-contiguous, strided store
func @transfer_write_1d_mask(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fn1 = constant -2.0 : f32
%vf0 = splat %fn1 : vector<7xf32>
%mask = constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
vector.transfer_write %vf0, %A[%base1, %base2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0)>}
: vector<7xf32>, memref<?x?xf32>
return
}
func @entry() {
%c0 = constant 0: index
%c1 = constant 1: index
%c2 = constant 2: index
%c3 = constant 3: index
%0 = memref.get_global @gv : memref<5x6xf32>
%A = memref.cast %0 : memref<5x6xf32> to memref<?x?xf32>
%f10 = constant 10.0: f32
// work with dims of 4, not of 3
%first = constant 5: index
%second = constant 6: index
%A = memref.alloc(%first, %second) : memref<?x?xf32>
scf.for %i = %c0 to %first step %c1 {
%i32 = index_cast %i : index to i32
%fi = sitofp %i32 : i32 to f32
%fi10 = mulf %fi, %f10 : f32
scf.for %j = %c0 to %second step %c1 {
%j32 = index_cast %j : index to i32
%fj = sitofp %j32 : i32 to f32
%fres = addf %fi10, %fj : f32
memref.store %fres, %A[%i, %j] : memref<?x?xf32>
}
}
// 1. Read from 2D memref on first dimension. Cannot be lowered to an LLVM
// vector load. Instead, generates scalar loads.
// Read from 2D memref on first dimension. Cannot be lowered to an LLVM
// vector load. Instead, generates scalar loads.
call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
// 2. Write to 2D memref on first dimension. Cannot be lowered to an LLVM
// vector store. Instead, generates scalar stores.
// Write to 2D memref on first dimension. Cannot be lowered to an LLVM
// vector store. Instead, generates scalar stores.
call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
// 3. (Same as 1. To check if 2 works correctly.)
// (Same as above.)
call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
// 4. Read a scalar from a 2D memref and broadcast the value to a 1D vector.
// Generates a loop with vector.insertelement.
// Read a scalar from a 2D memref and broadcast the value to a 1D vector.
// Generates a loop with vector.insertelement.
call @transfer_read_1d_broadcast(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
// 5. Read from 2D memref on first dimension. Accesses are in-bounds, so no
// if-check is generated inside the generated loop.
// Read from 2D memref on first dimension. Accesses are in-bounds, so no
// if-check is generated inside the generated loop.
call @transfer_read_1d_in_bounds(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( 12, 22, -1 )
// 6. Optional mask attribute is specified and, in addition, there may be
// out-of-bounds accesses.
// Optional mask attribute is specified and, in addition, there may be
// out-of-bounds accesses.
call @transfer_read_1d_mask(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
// 7. Same as 6, but accesses are in-bounds.
// Same as above, but accesses are in-bounds.
call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( 12, -42, -1 )
// 8. Write to 2D memref on first dimension with a mask.
call @transfer_write_1d_mask(%A, %c1, %c0)
: (memref<?x?xf32>, index, index) -> ()
// 9. (Same as 1. To check if 8 works correctly.)
call @transfer_read_1d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
// CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 )
return
}
// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
// CHECK: ( 12, 22, -1 )
// CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
// CHECK: ( 12, -42, -1 )

View File

@ -3,11 +3,6 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
[10., 11., 12., 13.],
[20., 21., 22., 23.]]>
// Vector load
func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base1, %base2], %fm42
@ -17,7 +12,6 @@ func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
return
}
// Vector load with mask
func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fm42 = constant -42.0: f32
%mask = constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
@ -31,47 +25,6 @@ func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index)
return
}
// Vector load with mask + transpose
func @transfer_read_2d_mask_transposed(
%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fm42 = constant -42.0: f32
%mask = constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
[1, 1, 1, 1], [0, 1, 1, 0],
[1, 1, 1, 1], [1, 1, 1, 1],
[1, 1, 1, 1], [0, 0, 0, 0],
[1, 1, 1, 1]]> : vector<9x4xi1>
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
memref<?x?xf32>, vector<9x4xf32>
vector.print %f: vector<9x4xf32>
return
}
// Vector load with mask + broadcast
func @transfer_read_2d_mask_broadcast(
%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fm42 = constant -42.0: f32
%mask = constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1>
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
{permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
memref<?x?xf32>, vector<4x9xf32>
vector.print %f: vector<4x9xf32>
return
}
// Transpose + vector load with mask + broadcast
func @transfer_read_2d_mask_transpose_broadcast_last_dim(
%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fm42 = constant -42.0: f32
%mask = constant dense<[1, 0, 1, 1]> : vector<4xi1>
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
{permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
memref<?x?xf32>, vector<4x9xf32>
vector.print %f: vector<4x9xf32>
return
}
// Load + transpose
func @transfer_read_2d_transposed(
%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fm42 = constant -42.0: f32
@ -82,7 +35,6 @@ func @transfer_read_2d_transposed(
return
}
// Load 1D + broadcast to 2D
func @transfer_read_2d_broadcast(
%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fm42 = constant -42.0: f32
@ -93,7 +45,6 @@ func @transfer_read_2d_broadcast(
return
}
// Vector store
func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fn1 = constant -1.0 : f32
%vf0 = splat %fn1 : vector<1x4xf32>
@ -103,79 +54,55 @@ func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
return
}
// Vector store with mask
func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fn1 = constant -2.0 : f32
%mask = constant dense<[[1, 0, 1, 0]]> : vector<1x4xi1>
%vf0 = splat %fn1 : vector<1x4xf32>
vector.transfer_write %vf0, %A[%base1, %base2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
vector<1x4xf32>, memref<?x?xf32>
return
}
func @entry() {
%c0 = constant 0: index
%c1 = constant 1: index
%c2 = constant 2: index
%c3 = constant 3: index
%0 = memref.get_global @gv : memref<3x4xf32>
%A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
// 1. Read 2D vector from 2D memref.
%c4 = constant 4: index
%c5 = constant 5: index
%c8 = constant 5: index
%f10 = constant 10.0: f32
// work with dims of 4, not of 3
%first = constant 3: index
%second = constant 4: index
%A = memref.alloc(%first, %second) : memref<?x?xf32>
scf.for %i = %c0 to %first step %c1 {
%i32 = index_cast %i : index to i32
%fi = sitofp %i32 : i32 to f32
%fi10 = mulf %fi, %f10 : f32
scf.for %j = %c0 to %second step %c1 {
%j32 = index_cast %j : index to i32
%fj = sitofp %j32 : i32 to f32
%fres = addf %fi10, %fj : f32
memref.store %fres, %A[%i, %j] : memref<?x?xf32>
}
}
// On input, memory contains [[ 0, 1, 2, ...], [10, 11, 12, ...], ...]
// Read shifted by 2 and pad with -42:
call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// 2. Read 2D vector from 2D memref at specified location and transpose the
// result.
// Same as above, but transposed
call @transfer_read_2d_transposed(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// 3. Read 2D vector from 2D memref with a 2D mask. In addition, some
// accesses are out-of-bounds.
// Write into memory shifted by 3
call @transfer_write_2d(%A, %c3, %c1) : (memref<?x?xf32>, index, index) -> ()
// Read shifted by 0 and pad with -42:
call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
// Same as above, but apply a mask
call @transfer_read_2d_mask(%A, %c0, %c0)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// 4. Same as 3, but transpose the result.
call @transfer_read_2d_mask_transposed(%A, %c0, %c0)
// Same as above, but without mask and transposed
call @transfer_read_2d_transposed(%A, %c0, %c0)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 0, -42, 20, -42 ), ( -42, -42, 21, -42 ), ( 2, 12, 22, -42 ), ( -42, 13, 23, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ) )
// 5. Read 1D vector from 2D memref at specified location and broadcast the
// result to 2D.
// Second vector dimension is a broadcast
call @transfer_read_2d_broadcast(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// 6. Read 1D vector from 2D memref at specified location with mask and
// broadcast the result to 2D.
call @transfer_read_2d_mask_broadcast(%A, %c2, %c1)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ) )
// 7. Read 1D vector from 2D memref (second dimension) at specified location
// with mask and broadcast the result to 2D. In this test case, mask
// elements must be evaluated before lowering to an (N>1)-D transfer.
call @transfer_read_2d_mask_transpose_broadcast_last_dim(%A, %c0, %c1)
: (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 1, 1, 1, 1, 1, 1, 1, 1, 1 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( 3, 3, 3, 3, 3, 3, 3, 3, 3 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// 8. Write 2D vector into 2D memref at specified location.
call @transfer_write_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
// 9. Read memref to verify step 8.
call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// 10. Write 2D vector into 2D memref at specified location with mask.
call @transfer_write_2d_mask(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
// 11. Read memref to verify step 10.
call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
// CHECK: ( ( 0, 1, -2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
return
}
// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )

View File

@ -1,8 +1,15 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
// Test case is based on test-transfer-read-2d.
func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
%o: index, %a: index, %b: index, %c: index) {
%fm42 = constant -42.0: f32
@ -22,17 +29,6 @@ func @transfer_read_3d_broadcast(%A : memref<?x?x?x?xf32>,
return
}
func @transfer_read_3d_mask_broadcast(
%A : memref<?x?x?x?xf32>, %o: index, %a: index, %b: index, %c: index) {
%fm42 = constant -42.0: f32
%mask = constant dense<[0, 1]> : vector<2xi1>
%f = vector.transfer_read %A[%o, %a, %b, %c], %fm42, %mask
{permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>}
: memref<?x?x?x?xf32>, vector<2x5x3xf32>
vector.print %f: vector<2x5x3xf32>
return
}
func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
%o: index, %a: index, %b: index, %c: index) {
%fm42 = constant -42.0: f32
@ -84,34 +80,20 @@ func @entry() {
}
}
// 1. Read 3D vector from 4D memref.
call @transfer_read_3d(%A, %c0, %c0, %c0, %c0)
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
// 2. Write 3D vector to 4D memref.
call @transfer_write_3d(%A, %c0, %c0, %c1, %c1)
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
// 3. Read memref to verify step 2.
call @transfer_read_3d(%A, %c0, %c0, %c0, %c0)
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
// 4. Read 3D vector from 4D memref and transpose vector.
call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0)
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
// 5. Read 1D vector from 4D memref and broadcast vector to 3D.
call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0)
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )
// 6. Read 1D vector from 4D memref with mask and broadcast vector to 3D.
call @transfer_read_3d_mask_broadcast(%A, %c0, %c0, %c0, %c0)
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
// CHECK: ( ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ), ( ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ) ) )
return
}
// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )