forked from OSchip/llvm-project
[mlir] Fix masked vector transfer ops with broadcasts
Broadcast dimensions of a vector transfer op have no corresponding dimension in the mask vector. E.g., a 2-D TransferReadOp, where one dimension is a broadcast, can have a 1-D `mask` attribute. This commit also adds a few additional transfer op integration tests for various combinations of broadcasts, masking, dim transposes, etc. Differential Revision: https://reviews.llvm.org/D101745
This commit is contained in:
parent
a0ca4c46ca
commit
c52cbe63e4
|
@ -17,6 +17,18 @@
|
|||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace vector {
|
||||
namespace detail {
|
||||
|
||||
/// Given the vector type and the permutation map of a vector transfer op,
|
||||
/// compute the expected mask type.
|
||||
VectorType transferMaskType(VectorType vecType, AffineMap map);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
||||
/// Include the generated interface declarations.
|
||||
#include "mlir/Interfaces/VectorInterfaces.h.inc"
|
||||
|
||||
|
|
|
@ -114,6 +114,18 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||
/*methodBody=*/"return $_op.permutation_map();"
|
||||
/*defaultImplementation=*/
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{ Returns true if the specified dimension is a broadcast. }],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isBroadcastDim",
|
||||
/*args=*/(ins "unsigned":$idx),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto expr = $_op.permutation_map().getResult(idx);
|
||||
return expr.template isa<AffineConstantExpr>() &&
|
||||
expr.template dyn_cast<AffineConstantExpr>().getValue() == 0;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{ Returns true if at least one of the dimensions in the
|
||||
permutation map is a broadcast.}],
|
||||
|
@ -122,11 +134,11 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return llvm::any_of(
|
||||
$_op.permutation_map().getResults(),
|
||||
[](AffineExpr e) {
|
||||
return e.isa<AffineConstantExpr>() &&
|
||||
e.dyn_cast<AffineConstantExpr>().getValue() == 0; });
|
||||
for (unsigned i = 0; i < $_op.permutation_map().getNumResults(); ++i) {
|
||||
if ($_op.isBroadcastDim(i))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
|
@ -156,6 +168,19 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||
return $_op.vector().getType().template dyn_cast<VectorType>();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return the mask type if the op has a mask.",
|
||||
/*retTy=*/"VectorType",
|
||||
/*methodName=*/"getMaskType",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.mask()
|
||||
? mlir::vector::detail::transferMaskType(
|
||||
$_op.getVectorType(), $_op.permutation_map())
|
||||
: VectorType();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{ Return the number of dimensions that participate in the
|
||||
permutation map.}],
|
||||
|
|
|
@ -79,8 +79,9 @@ static BufferAllocs allocBuffers(OpTy xferOp) {
|
|||
|
||||
if (xferOp.mask()) {
|
||||
auto maskType = MemRefType::get({}, xferOp.mask().getType());
|
||||
result.maskBuffer = memref_alloca(maskType).value;
|
||||
memref_store(xferOp.mask(), result.maskBuffer);
|
||||
Value maskBuffer = memref_alloca(maskType);
|
||||
memref_store(xferOp.mask(), maskBuffer);
|
||||
result.maskBuffer = memref_load(maskBuffer);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -95,7 +96,7 @@ static Optional<int64_t> unpackedDim(OpTy xferOp) {
|
|||
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
|
||||
return expr.getPosition();
|
||||
}
|
||||
assert(map.getResult(0).template isa<AffineConstantExpr>() &&
|
||||
assert(xferOp.isBroadcastDim(0) &&
|
||||
"Expected AffineDimExpr or AffineConstantExpr");
|
||||
return None;
|
||||
}
|
||||
|
@ -143,14 +144,19 @@ static void maybeYieldValue(
|
|||
}
|
||||
|
||||
/// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
|
||||
/// is set to true. Does not return a Value if the transfer op is not 1D or
|
||||
/// if the transfer op does not have a mask.
|
||||
/// is set to true. No such check is generated under following circumstances:
|
||||
/// * xferOp does not have a mask.
|
||||
/// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
|
||||
/// computed and attached to the new transfer op in the pattern.)
|
||||
/// * The to-be-unpacked dim of xferOp is a broadcast.
|
||||
template <typename OpTy>
|
||||
static Value maybeGenerateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) {
|
||||
if (xferOp.getVectorType().getRank() != 1)
|
||||
return Value();
|
||||
static Value generateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) {
|
||||
if (!xferOp.mask())
|
||||
return Value();
|
||||
if (xferOp.getMaskType().getRank() != 1)
|
||||
return Value();
|
||||
if (xferOp.isBroadcastDim(0))
|
||||
return Value();
|
||||
|
||||
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
|
||||
return vector_extract_element(xferOp.mask(), ivI32).value;
|
||||
|
@ -200,7 +206,7 @@ static Value generateInBoundsCheck(
|
|||
}
|
||||
|
||||
// Condition check 2: Masked in?
|
||||
if (auto maskCond = maybeGenerateMaskCheck(builder, xferOp, iv)) {
|
||||
if (auto maskCond = generateMaskCheck(builder, xferOp, iv)) {
|
||||
if (cond) {
|
||||
cond = builder.create<AndOp>(xferOp.getLoc(), cond, maskCond);
|
||||
} else {
|
||||
|
@ -488,8 +494,8 @@ struct PrepareTransferReadConversion
|
|||
auto *newXfer = rewriter.clone(*xferOp.getOperation());
|
||||
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
|
||||
if (xferOp.mask()) {
|
||||
auto loadedMask = memref_load(buffers.maskBuffer);
|
||||
dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(loadedMask);
|
||||
dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
|
||||
buffers.maskBuffer);
|
||||
}
|
||||
|
||||
memref_store(newXfer->getResult(0), buffers.dataBuffer);
|
||||
|
@ -541,9 +547,8 @@ struct PrepareTransferWriteConversion
|
|||
});
|
||||
|
||||
if (xferOp.mask()) {
|
||||
auto loadedMask = memref_load(buffers.maskBuffer);
|
||||
rewriter.updateRootInPlace(
|
||||
xferOp, [&]() { xferOp.maskMutable().assign(loadedMask); });
|
||||
xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -590,8 +595,17 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
|
|||
auto maskBuffer = getMaskBuffer(xferOp);
|
||||
auto maskBufferType =
|
||||
maskBuffer.getType().template dyn_cast<MemRefType>();
|
||||
auto castedMaskType = unpackOneDim(maskBufferType);
|
||||
castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
|
||||
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
|
||||
// Do not unpack a dimension of the mask, if:
|
||||
// * To-be-unpacked transfer op dimension is a broadcast.
|
||||
// * Mask is 1D, i.e., the mask cannot be further unpacked.
|
||||
// (That means that all remaining dimensions of the transfer op must
|
||||
// be broadcasted.)
|
||||
castedMaskBuffer = maskBuffer;
|
||||
} else {
|
||||
auto castedMaskType = unpackOneDim(maskBufferType);
|
||||
castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
// Loop bounds and step.
|
||||
|
@ -616,13 +630,20 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
|
|||
Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv);
|
||||
|
||||
// If old transfer op has a mask: Set mask on new transfer op.
|
||||
if (xferOp.mask()) {
|
||||
// Special case: If the mask of the old transfer op is 1D and the
|
||||
// unpacked dim is not a broadcast, no mask is needed
|
||||
// on the new transfer op.
|
||||
if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
|
||||
xferOp.getMaskType().getRank() > 1)) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPoint(newXfer); // Insert load before newXfer.
|
||||
|
||||
SmallVector<Value, 8> loadIndices;
|
||||
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
|
||||
loadIndices.push_back(iv);
|
||||
// In case of broadcast: Use same indices to load from memref as
|
||||
// before.
|
||||
if (!xferOp.isBroadcastDim(0))
|
||||
loadIndices.push_back(iv);
|
||||
|
||||
auto mask = memref_load(castedMaskBuffer, loadIndices);
|
||||
rewriter.updateRootInPlace(
|
||||
|
@ -661,7 +682,7 @@ static Optional<int64_t> get1dMemrefIndices(
|
|||
return dim;
|
||||
}
|
||||
|
||||
assert(map.getResult(0).template isa<AffineConstantExpr>() &&
|
||||
assert(xferOp.isBroadcastDim(0) &&
|
||||
"Expected AffineDimExpr or AffineConstantExpr");
|
||||
return None;
|
||||
}
|
||||
|
|
|
@ -2306,6 +2306,7 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
|||
|
||||
static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
|
||||
VectorType vectorType,
|
||||
VectorType maskType,
|
||||
AffineMap permutationMap,
|
||||
ArrayAttr inBounds) {
|
||||
if (op->hasAttr("masked")) {
|
||||
|
@ -2341,6 +2342,9 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
|
|||
if (permutationMap.getNumResults() != rankOffset)
|
||||
return op->emitOpError("requires a permutation_map with result dims of "
|
||||
"the same rank as the vector type");
|
||||
|
||||
if (maskType)
|
||||
return op->emitOpError("does not support masks with vector element type");
|
||||
} else {
|
||||
// Memref or tensor has scalar element type.
|
||||
unsigned resultVecSize =
|
||||
|
@ -2355,6 +2359,13 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
|
|||
if (permutationMap.getNumResults() != vectorType.getRank())
|
||||
return op->emitOpError("requires a permutation_map with result dims of "
|
||||
"the same rank as the vector type");
|
||||
|
||||
VectorType expectedMaskType =
|
||||
vector::detail::transferMaskType(vectorType, permutationMap);
|
||||
if (maskType && expectedMaskType != maskType)
|
||||
return op->emitOpError("expects mask type consistent with permutation "
|
||||
"map: ")
|
||||
<< maskType;
|
||||
}
|
||||
|
||||
if (permutationMap.getNumSymbols() != 0)
|
||||
|
@ -2491,10 +2502,11 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
|
|||
if (!vectorType)
|
||||
return parser.emitError(typesLoc, "requires vector type");
|
||||
auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
|
||||
auto attr = result.attributes.get(permutationAttrName);
|
||||
if (!attr) {
|
||||
Attribute mapAttr = result.attributes.get(permutationAttrName);
|
||||
if (!mapAttr) {
|
||||
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
|
||||
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
|
||||
mapAttr = AffineMapAttr::get(permMap);
|
||||
result.attributes.set(permutationAttrName, mapAttr);
|
||||
}
|
||||
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
|
||||
|
@ -2502,7 +2514,13 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
|
|||
result.operands))
|
||||
return failure();
|
||||
if (hasMask.succeeded()) {
|
||||
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
|
||||
if (shapedType.getElementType().dyn_cast<VectorType>())
|
||||
return parser.emitError(
|
||||
maskInfo.location, "does not support masks with vector element type");
|
||||
auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
|
||||
// Instead of adding the mask type as an op type, compute it based on the
|
||||
// vector type and the permutation map (to keep the type signature small).
|
||||
auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
|
||||
if (parser.resolveOperand(maskInfo, maskType, result.operands))
|
||||
return failure();
|
||||
}
|
||||
|
@ -2517,6 +2535,7 @@ static LogicalResult verify(TransferReadOp op) {
|
|||
// Consistency of elemental types in source and vector.
|
||||
ShapedType shapedType = op.getShapedType();
|
||||
VectorType vectorType = op.getVectorType();
|
||||
VectorType maskType = op.getMaskType();
|
||||
auto paddingType = op.padding().getType();
|
||||
auto permutationMap = op.permutation_map();
|
||||
auto sourceElementType = shapedType.getElementType();
|
||||
|
@ -2525,7 +2544,7 @@ static LogicalResult verify(TransferReadOp op) {
|
|||
return op.emitOpError("requires ") << shapedType.getRank() << " indices";
|
||||
|
||||
if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
|
||||
permutationMap,
|
||||
maskType, permutationMap,
|
||||
op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
|
||||
return failure();
|
||||
|
||||
|
@ -2768,6 +2787,9 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
|||
parser.resolveOperands(indexInfo, indexType, result.operands))
|
||||
return failure();
|
||||
if (hasMask.succeeded()) {
|
||||
if (shapedType.getElementType().dyn_cast<VectorType>())
|
||||
return parser.emitError(
|
||||
maskInfo.location, "does not support masks with vector element type");
|
||||
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
|
||||
if (parser.resolveOperand(maskInfo, maskType, result.operands))
|
||||
return failure();
|
||||
|
@ -2793,6 +2815,7 @@ static LogicalResult verify(TransferWriteOp op) {
|
|||
// Consistency of elemental types in shape and vector.
|
||||
ShapedType shapedType = op.getShapedType();
|
||||
VectorType vectorType = op.getVectorType();
|
||||
VectorType maskType = op.getMaskType();
|
||||
auto permutationMap = op.permutation_map();
|
||||
|
||||
if (llvm::size(op.indices()) != shapedType.getRank())
|
||||
|
@ -2804,7 +2827,7 @@ static LogicalResult verify(TransferWriteOp op) {
|
|||
return op.emitOpError("should not have broadcast dimensions");
|
||||
|
||||
if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
|
||||
permutationMap,
|
||||
maskType, permutationMap,
|
||||
op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -10,6 +10,19 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
|
||||
AffineMap map) {
|
||||
auto i1Type = IntegerType::get(map.getContext(), 1);
|
||||
SmallVector<int64_t, 8> shape;
|
||||
for (int64_t i = 0; i < vecType.getRank(); ++i) {
|
||||
// Only result dims have a corresponding dim in the mask.
|
||||
if (auto expr = map.getResult(i).template isa<AffineDimExpr>()) {
|
||||
shape.push_back(vecType.getDimSize(i));
|
||||
}
|
||||
}
|
||||
return shape.empty() ? VectorType() : VectorType::get(shape, i1Type);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorUnroll Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -339,6 +339,18 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
|
||||
%c1 = constant 1 : i1
|
||||
%c3 = constant 3 : index
|
||||
%cst = constant 3.0 : f32
|
||||
// expected-note@+1 {{prior use here}}
|
||||
%mask = splat %c1 : vector<3x8x7xi1>
|
||||
// expected-error@+1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
|
||||
%c3 = constant 3 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
|
@ -369,6 +381,17 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
|
||||
%c3 = constant 3 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
%vf0 = splat %f0 : vector<2x3xf32>
|
||||
%mask = splat %c1 : vector<2x3xi1>
|
||||
// expected-error@+1 {{does not support masks with vector element type}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
|
||||
%c3 = constant 3 : index
|
||||
%cst = constant 3.0 : f32
|
||||
|
|
|
@ -4,19 +4,21 @@
|
|||
func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
||||
%arg1 : memref<?x?xvector<4x3xf32>>,
|
||||
%arg2 : memref<?x?xvector<4x3xi32>>,
|
||||
%arg3 : memref<?x?xvector<4x3xindex>>) {
|
||||
%arg3 : memref<?x?xvector<4x3xindex>>,
|
||||
%arg4 : memref<?x?x?xf32>) {
|
||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
||||
%c3 = constant 3 : index
|
||||
%cst = constant 3.0 : f32
|
||||
%f0 = constant 0.0 : f32
|
||||
%c0 = constant 0 : i32
|
||||
%i0 = constant 0 : index
|
||||
%i1 = constant 1 : i1
|
||||
|
||||
%vf0 = splat %f0 : vector<4x3xf32>
|
||||
%v0 = splat %c0 : vector<4x3xi32>
|
||||
%vi0 = splat %i0 : vector<4x3xindex>
|
||||
%m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
|
||||
|
||||
%m2 = splat %i1 : vector<5x4xi1>
|
||||
//
|
||||
// CHECK: vector.transfer_read
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>
|
||||
|
@ -36,6 +38,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
|||
%7 = vector.transfer_read %arg3[%c3, %c3], %vi0 : memref<?x?xvector<4x3xindex>>, vector<5x48xi8>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<5xf32>
|
||||
%8 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref<?x?xf32>, vector<5xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?x?xf32>, vector<5x4x8xf32>
|
||||
%9 = vector.transfer_read %arg4[%c3, %c3, %c3], %f0, %m2 {permutation_map = affine_map<(d0, d1, d2)->(d1, d0, 0)>} : memref<?x?x?xf32>, vector<5x4x8xf32>
|
||||
|
||||
// CHECK: vector.transfer_write
|
||||
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
|
||||
|
|
|
@ -5,6 +5,14 @@
|
|||
|
||||
// Test for special cases of 1D vector transfer ops.
|
||||
|
||||
memref.global "private" @gv : memref<5x6xf32> =
|
||||
dense<[[0. , 1. , 2. , 3. , 4. , 5. ],
|
||||
[10., 11., 12., 13., 14., 15.],
|
||||
[20., 21., 22., 23., 24., 25.],
|
||||
[30., 31., 32., 33., 34., 35.],
|
||||
[40., 41., 42., 43., 44., 45.]]>
|
||||
|
||||
// Non-contiguous, strided load.
|
||||
func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%f = vector.transfer_read %A[%base1, %base2], %fm42
|
||||
|
@ -14,6 +22,7 @@ func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
|||
return
|
||||
}
|
||||
|
||||
// Broadcast.
|
||||
func @transfer_read_1d_broadcast(
|
||||
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
|
@ -24,6 +33,7 @@ func @transfer_read_1d_broadcast(
|
|||
return
|
||||
}
|
||||
|
||||
// Non-contiguous, strided load.
|
||||
func @transfer_read_1d_in_bounds(
|
||||
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
|
@ -34,6 +44,7 @@ func @transfer_read_1d_in_bounds(
|
|||
return
|
||||
}
|
||||
|
||||
// Non-contiguous, strided load.
|
||||
func @transfer_read_1d_mask(
|
||||
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
|
@ -45,6 +56,7 @@ func @transfer_read_1d_mask(
|
|||
return
|
||||
}
|
||||
|
||||
// Non-contiguous, strided load.
|
||||
func @transfer_read_1d_mask_in_bounds(
|
||||
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
|
@ -56,6 +68,7 @@ func @transfer_read_1d_mask_in_bounds(
|
|||
return
|
||||
}
|
||||
|
||||
// Non-contiguous, strided store.
|
||||
func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
||||
%fn1 = constant -1.0 : f32
|
||||
%vf0 = splat %fn1 : vector<7xf32>
|
||||
|
@ -65,57 +78,68 @@ func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
|||
return
|
||||
}
|
||||
|
||||
// Non-contiguous, strided store.
|
||||
func @transfer_write_1d_mask(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
|
||||
%fn1 = constant -2.0 : f32
|
||||
%vf0 = splat %fn1 : vector<7xf32>
|
||||
%mask = constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
|
||||
vector.transfer_write %vf0, %A[%base1, %base2], %mask
|
||||
{permutation_map = affine_map<(d0, d1) -> (d0)>}
|
||||
: vector<7xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @entry() {
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c2 = constant 2: index
|
||||
%c3 = constant 3: index
|
||||
%f10 = constant 10.0: f32
|
||||
// work with dims of 4, not of 3
|
||||
%first = constant 5: index
|
||||
%second = constant 6: index
|
||||
%A = memref.alloc(%first, %second) : memref<?x?xf32>
|
||||
scf.for %i = %c0 to %first step %c1 {
|
||||
%i32 = index_cast %i : index to i32
|
||||
%fi = sitofp %i32 : i32 to f32
|
||||
%fi10 = mulf %fi, %f10 : f32
|
||||
scf.for %j = %c0 to %second step %c1 {
|
||||
%j32 = index_cast %j : index to i32
|
||||
%fj = sitofp %j32 : i32 to f32
|
||||
%fres = addf %fi10, %fj : f32
|
||||
memref.store %fres, %A[%i, %j] : memref<?x?xf32>
|
||||
}
|
||||
}
|
||||
%0 = memref.get_global @gv : memref<5x6xf32>
|
||||
%A = memref.cast %0 : memref<5x6xf32> to memref<?x?xf32>
|
||||
|
||||
// Read from 2D memref on first dimension. Cannot be lowered to an LLVM
|
||||
// vector load. Instead, generates scalar loads.
|
||||
// 1. Read from 2D memref on first dimension. Cannot be lowered to an LLVM
|
||||
// vector load. Instead, generates scalar loads.
|
||||
call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
|
||||
// Write to 2D memref on first dimension. Cannot be lowered to an LLVM
|
||||
// vector store. Instead, generates scalar stores.
|
||||
// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
|
||||
|
||||
// 2. Write to 2D memref on first dimension. Cannot be lowered to an LLVM
|
||||
// vector store. Instead, generates scalar stores.
|
||||
call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
|
||||
// (Same as above.)
|
||||
|
||||
// 3. (Same as 1. To check if 2 works correctly.)
|
||||
call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
|
||||
// Read a scalar from a 2D memref and broadcast the value to a 1D vector.
|
||||
// Generates a loop with vector.insertelement.
|
||||
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
|
||||
|
||||
// 4. Read a scalar from a 2D memref and broadcast the value to a 1D vector.
|
||||
// Generates a loop with vector.insertelement.
|
||||
call @transfer_read_1d_broadcast(%A, %c1, %c2)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// Read from 2D memref on first dimension. Accesses are in-bounds, so no
|
||||
// if-check is generated inside the generated loop.
|
||||
// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
|
||||
|
||||
// 5. Read from 2D memref on first dimension. Accesses are in-bounds, so no
|
||||
// if-check is generated inside the generated loop.
|
||||
call @transfer_read_1d_in_bounds(%A, %c1, %c2)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// Optional mask attribute is specified and, in addition, there may be
|
||||
// out-of-bounds accesses.
|
||||
// CHECK: ( 12, 22, -1 )
|
||||
|
||||
// 6. Optional mask attribute is specified and, in addition, there may be
|
||||
// out-of-bounds accesses.
|
||||
call @transfer_read_1d_mask(%A, %c1, %c2)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// Same as above, but accesses are in-bounds.
|
||||
// CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
|
||||
|
||||
// 7. Same as 6, but accesses are in-bounds.
|
||||
call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// CHECK: ( 12, -42, -1 )
|
||||
|
||||
// 8. Write to 2D memref on first dimension with a mask.
|
||||
call @transfer_write_1d_mask(%A, %c1, %c0)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
|
||||
// 9. (Same as 1. To check if 8 works correctly.)
|
||||
call @transfer_read_1d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
// CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 )
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
|
||||
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
|
||||
// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
|
||||
// CHECK: ( 12, 22, -1 )
|
||||
// CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
|
||||
// CHECK: ( 12, -42, -1 )
|
||||
|
|
|
@ -3,6 +3,11 @@
|
|||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
|
||||
[10., 11., 12., 13.],
|
||||
[20., 21., 22., 23.]]>
|
||||
|
||||
// Vector load.
|
||||
func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%f = vector.transfer_read %A[%base1, %base2], %fm42
|
||||
|
@ -12,6 +17,7 @@ func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
|||
return
|
||||
}
|
||||
|
||||
// Vector load with mask.
|
||||
func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%mask = constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
|
||||
|
@ -25,6 +31,47 @@ func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index)
|
|||
return
|
||||
}
|
||||
|
||||
// Vector load with mask + transpose.
|
||||
func @transfer_read_2d_mask_transposed(
|
||||
%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%mask = constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
|
||||
[1, 1, 1, 1], [0, 1, 1, 0],
|
||||
[1, 1, 1, 1], [1, 1, 1, 1],
|
||||
[1, 1, 1, 1], [0, 0, 0, 0],
|
||||
[1, 1, 1, 1]]> : vector<9x4xi1>
|
||||
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
|
||||
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
|
||||
memref<?x?xf32>, vector<9x4xf32>
|
||||
vector.print %f: vector<9x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// Vector load with mask + broadcast.
|
||||
func @transfer_read_2d_mask_broadcast(
|
||||
%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%mask = constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1>
|
||||
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
|
||||
{permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
|
||||
memref<?x?xf32>, vector<4x9xf32>
|
||||
vector.print %f: vector<4x9xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// Transpose + vector load with mask + broadcast.
|
||||
func @transfer_read_2d_mask_transpose_broadcast_last_dim(
|
||||
%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%mask = constant dense<[1, 0, 1, 1]> : vector<4xi1>
|
||||
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
|
||||
{permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
|
||||
memref<?x?xf32>, vector<4x9xf32>
|
||||
vector.print %f: vector<4x9xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// Load + transpose.
|
||||
func @transfer_read_2d_transposed(
|
||||
%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
|
@ -35,6 +82,7 @@ func @transfer_read_2d_transposed(
|
|||
return
|
||||
}
|
||||
|
||||
// Load 1D + broadcast to 2D.
|
||||
func @transfer_read_2d_broadcast(
|
||||
%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
|
@ -45,6 +93,7 @@ func @transfer_read_2d_broadcast(
|
|||
return
|
||||
}
|
||||
|
||||
// Vector store.
|
||||
func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fn1 = constant -1.0 : f32
|
||||
%vf0 = splat %fn1 : vector<1x4xf32>
|
||||
|
@ -54,55 +103,79 @@ func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
|||
return
|
||||
}
|
||||
|
||||
// Vector store with mask.
|
||||
func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
|
||||
%fn1 = constant -2.0 : f32
|
||||
%mask = constant dense<[[1, 0, 1, 0]]> : vector<1x4xi1>
|
||||
%vf0 = splat %fn1 : vector<1x4xf32>
|
||||
vector.transfer_write %vf0, %A[%base1, %base2], %mask
|
||||
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
|
||||
vector<1x4xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @entry() {
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c2 = constant 2: index
|
||||
%c3 = constant 3: index
|
||||
%c4 = constant 4: index
|
||||
%c5 = constant 5: index
|
||||
%c8 = constant 5: index
|
||||
%f10 = constant 10.0: f32
|
||||
// work with dims of 4, not of 3
|
||||
%first = constant 3: index
|
||||
%second = constant 4: index
|
||||
%A = memref.alloc(%first, %second) : memref<?x?xf32>
|
||||
scf.for %i = %c0 to %first step %c1 {
|
||||
%i32 = index_cast %i : index to i32
|
||||
%fi = sitofp %i32 : i32 to f32
|
||||
%fi10 = mulf %fi, %f10 : f32
|
||||
scf.for %j = %c0 to %second step %c1 {
|
||||
%j32 = index_cast %j : index to i32
|
||||
%fj = sitofp %j32 : i32 to f32
|
||||
%fres = addf %fi10, %fj : f32
|
||||
memref.store %fres, %A[%i, %j] : memref<?x?xf32>
|
||||
}
|
||||
}
|
||||
// On input, memory contains [[ 0, 1, 2, ...], [10, 11, 12, ...], ...]
|
||||
// Read shifted by 2 and pad with -42:
|
||||
%0 = memref.get_global @gv : memref<3x4xf32>
|
||||
%A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
|
||||
|
||||
// 1. Read 2D vector from 2D memref.
|
||||
call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
|
||||
// Same as above, but transposed
|
||||
// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
||||
// 2. Read 2D vector from 2D memref at specified location and transpose the
|
||||
// result.
|
||||
call @transfer_read_2d_transposed(%A, %c1, %c2)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// Write into memory shifted by 3
|
||||
call @transfer_write_2d(%A, %c3, %c1) : (memref<?x?xf32>, index, index) -> ()
|
||||
// Read shifted by 0 and pad with -42:
|
||||
call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
// Same as above, but apply a mask
|
||||
// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
||||
// 3. Read 2D vector from 2D memref with a 2D mask. In addition, some
|
||||
// accesses are out-of-bounds.
|
||||
call @transfer_read_2d_mask(%A, %c0, %c0)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// Same as above, but without mask and transposed
|
||||
call @transfer_read_2d_transposed(%A, %c0, %c0)
|
||||
// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
||||
// 4. Same as 3, but transpose the result.
|
||||
call @transfer_read_2d_mask_transposed(%A, %c0, %c0)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// Second vector dimension is a broadcast
|
||||
// CHECK: ( ( 0, -42, 20, -42 ), ( -42, -42, 21, -42 ), ( 2, 12, 22, -42 ), ( -42, 13, 23, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ) )
|
||||
|
||||
// 5. Read 1D vector from 2D memref at specified location and broadcast the
|
||||
// result to 2D.
|
||||
call @transfer_read_2d_broadcast(%A, %c1, %c2)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
||||
// 6. Read 1D vector from 2D memref at specified location with mask and
|
||||
// broadcast the result to 2D.
|
||||
call @transfer_read_2d_mask_broadcast(%A, %c2, %c1)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// CHECK: ( ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
||||
// 7. Read 1D vector from 2D memref (second dimension) at specified location
|
||||
// with mask and broadcast the result to 2D. In this test case, mask
|
||||
// elements must be evaluated before lowering to an (N>1)-D transfer.
|
||||
call @transfer_read_2d_mask_transpose_broadcast_last_dim(%A, %c0, %c1)
|
||||
: (memref<?x?xf32>, index, index) -> ()
|
||||
// CHECK: ( ( 1, 1, 1, 1, 1, 1, 1, 1, 1 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( 3, 3, 3, 3, 3, 3, 3, 3, 3 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
||||
// 8. Write 2D vector into 2D memref at specified location.
|
||||
call @transfer_write_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
|
||||
|
||||
// 9. Read memref to verify step 8.
|
||||
call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
||||
// 10. Write 2D vector into 2D memref at specified location with mask.
|
||||
call @transfer_write_2d_mask(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
|
||||
|
||||
// 11. Read memref to verify step 10.
|
||||
call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
|
||||
// CHECK: ( ( 0, 1, -2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
|
||||
// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
|
||||
|
|
|
@ -1,15 +1,8 @@
|
|||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
// Test case is based on test-transfer-read-2d.
|
||||
|
||||
func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
|
||||
%o: index, %a: index, %b: index, %c: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
|
@ -29,6 +22,17 @@ func @transfer_read_3d_broadcast(%A : memref<?x?x?x?xf32>,
|
|||
return
|
||||
}
|
||||
|
||||
func @transfer_read_3d_mask_broadcast(
|
||||
%A : memref<?x?x?x?xf32>, %o: index, %a: index, %b: index, %c: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%mask = constant dense<[0, 1]> : vector<2xi1>
|
||||
%f = vector.transfer_read %A[%o, %a, %b, %c], %fm42, %mask
|
||||
{permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>}
|
||||
: memref<?x?x?x?xf32>, vector<2x5x3xf32>
|
||||
vector.print %f: vector<2x5x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
|
||||
%o: index, %a: index, %b: index, %c: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
|
@ -80,20 +84,34 @@ func @entry() {
|
|||
}
|
||||
}
|
||||
|
||||
// 1. Read 3D vector from 4D memref.
|
||||
call @transfer_read_3d(%A, %c0, %c0, %c0, %c0)
|
||||
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
|
||||
// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
|
||||
|
||||
// 2. Write 3D vector to 4D memref.
|
||||
call @transfer_write_3d(%A, %c0, %c0, %c1, %c1)
|
||||
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
|
||||
|
||||
// 3. Read memref to verify step 2.
|
||||
call @transfer_read_3d(%A, %c0, %c0, %c0, %c0)
|
||||
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
|
||||
// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
|
||||
|
||||
// 4. Read 3D vector from 4D memref and transpose vector.
|
||||
call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0)
|
||||
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
|
||||
// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
|
||||
|
||||
// 5. Read 1D vector from 4D memref and broadcast vector to 3D.
|
||||
call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0)
|
||||
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
|
||||
// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )
|
||||
|
||||
// 6. Read 1D vector from 4D memref with mask and broadcast vector to 3D.
|
||||
call @transfer_read_3d_mask_broadcast(%A, %c0, %c0, %c0, %c0)
|
||||
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
|
||||
// CHECK: ( ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ), ( ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ) ) )
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
|
||||
// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
|
||||
// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
|
||||
// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )
|
||||
|
|
Loading…
Reference in New Issue