[mlir] Support masked 1D vector transfer ops in ProgressiveVectorToSCF

Support for masked N-D vector transfer ops will be added in a subsequent commit.

Differential Revision: https://reviews.llvm.org/D101132
This commit is contained in:
Matthias Springer 2021-04-23 18:04:58 +09:00
parent c2297544c0
commit 545f98efc7
2 changed files with 171 additions and 106 deletions

View File

@ -74,9 +74,9 @@ static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
template <typename OpTy>
static Optional<int64_t> unpackedDim(OpTy xferOp) {
auto map = xferOp.permutation_map();
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>())
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
return expr.getPosition();
}
assert(map.getResult(0).template isa<AffineConstantExpr>() &&
"Expected AffineDimExpr or AffineConstantExpr");
return None;
@ -88,7 +88,8 @@ static Optional<int64_t> unpackedDim(OpTy xferOp) {
template <typename OpTy>
static AffineMap unpackedPermutationMap(OpTy xferOp, OpBuilder &builder) {
auto map = xferOp.permutation_map();
return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
return AffineMap::get(
map.getNumDims(), 0, map.getResults().drop_front(),
builder.getContext());
}
@ -114,8 +115,8 @@ static void getXferIndices(OpTy xferOp, Value iv,
}
}
static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
Value value) {
static void maybeYieldValue(
bool hasRetVal, OpBuilder builder, Location loc, Value value) {
if (hasRetVal) {
builder.create<scf::YieldOp>(loc, value);
} else {
@ -123,6 +124,20 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
}
}
/// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
/// is set to true. Does not return a Value if the transfer op is not 1D or
/// if the transfer op does not have a mask.
template <typename OpTy>
static Value maybeGenerateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) {
if (xferOp.getVectorType().getRank() != 1)
return Value();
if (!xferOp.mask())
return Value();
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
return vector_extract_element(xferOp.mask(), ivI32).value;
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
/// specified dimension `dim` with the loop iteration variable `iv`.
@ -140,6 +155,10 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
/// (out-of-bounds case)
/// }
/// ```
///
/// If the transfer is 1D and has a mask, this function generates a more complex
/// check also accounts for potentially masked out elements.
///
/// This function variant returns the value returned by `inBoundsCase` or
/// `outOfBoundsCase`. The MLIR type of the return value must be specified in
/// `resultTypes`.
@ -150,24 +169,36 @@ static Value generateInBoundsCheck(
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
bool hasRetVal = !resultTypes.empty();
Value cond; // Condition to be built...
// Condition check 1: Access in-bounds?
bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
if (!xferOp.isDimInBounds(0) && !isBroadcast) {
auto memrefDim =
memref_dim(xferOp.source(), std_constant_index(dim.getValue()));
using edsc::op::operator+;
auto memrefIdx = xferOp.indices()[dim.getValue()] + iv;
auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
}
// Condition check 2: Masked in?
if (auto maskCond = maybeGenerateMaskCheck(builder, xferOp, iv)) {
if (cond) {
cond = builder.create<AndOp>(xferOp.getLoc(), cond, maskCond);
} else {
cond = maskCond;
}
}
// If the condition is non-empty, generate an SCF::IfOp.
if (cond) {
auto check = builder.create<scf::IfOp>(
xferOp.getLoc(), resultTypes, cond,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
/*thenBuilder=*/[&](OpBuilder &builder, Location loc) {
maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
}, /*elseBuilder=*/[&](OpBuilder &builder, Location loc) {
if (outOfBoundsCase) {
maybeYieldValue(hasRetVal, builder, loc,
outOfBoundsCase(builder, loc));
maybeYieldValue(hasRetVal, builder, loc, outOfBoundsCase(builder, loc));
} else {
builder.create<scf::YieldOp>(loc);
}
@ -176,7 +207,7 @@ static Value generateInBoundsCheck(
return hasRetVal ? check.getResult(0) : Value();
}
// No runtime check needed if dim is guaranteed to be in-bounds.
// Condition is empty, no need for an SCF::IfOp.
return inBoundsCase(builder, xferOp.getLoc());
}
@ -189,13 +220,11 @@ static void generateInBoundsCheck(
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
generateInBoundsCheck(
xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(),
/*inBoundsCase=*/
[&](OpBuilder &builder, Location loc) {
/*inBoundsCase=*/[&](OpBuilder &builder, Location loc) {
inBoundsCase(builder, loc);
return Value();
},
/*outOfBoundsCase=*/
[&](OpBuilder &builder, Location loc) {
/*outOfBoundsCase=*/[&](OpBuilder &builder, Location loc) {
if (outOfBoundsCase)
outOfBoundsCase(builder, loc);
return Value();
@ -271,8 +300,8 @@ struct Strategy<TransferReadOp> {
///
/// Note: The loop and type cast are generated in TransferOpConversion.
/// The original TransferReadOp and store op are deleted in `cleanup`.
static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp, Value buffer,
Value iv) {
static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
Value buffer, Value iv) {
SmallVector<Value, 8> storeIndices;
getStoreIndices(xferOp, storeIndices);
storeIndices.push_back(iv);
@ -283,12 +312,10 @@ struct Strategy<TransferReadOp> {
auto bufferType = buffer.getType().dyn_cast<ShapedType>();
auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
auto newXfer =
vector_transfer_read(
auto newXfer = vector_transfer_read(
vecType, xferOp.source(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
xferOp.padding(), Value(), inBoundsAttr)
.value;
xferOp.padding(), Value(), inBoundsAttr).value;
if (vecType.getRank() > kTargetRank)
newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr());
@ -298,8 +325,8 @@ struct Strategy<TransferReadOp> {
/// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
/// padding value to the temporary buffer.
static void handleOutOfBoundsDim(OpBuilder & /*builder*/,
TransferReadOp xferOp, Value buffer,
static void handleOutOfBoundsDim(
OpBuilder &/*builder*/, TransferReadOp xferOp, Value buffer,
Value iv) {
SmallVector<Value, 8> storeIndices;
getStoreIndices(xferOp, storeIndices);
@ -365,16 +392,17 @@ struct Strategy<TransferWriteOp> {
auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
auto newXfer = vector_transfer_write(
Type(), vec, xferOp.source(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(),
inBoundsAttr);
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
Value(), inBoundsAttr);
if (vecType.getRank() > kTargetRank)
newXfer.op->setAttr(kPassLabel, builder.getUnitAttr());
}
/// Handle out-of-bounds accesses on the to-be-unpacked dimension.
static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp,
Value buffer, Value iv) {}
static void handleOutOfBoundsDim(
OpBuilder &builder, TransferWriteOp xferOp, Value buffer,
Value iv) {}
/// Cleanup after rewriting the op.
static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
@ -522,16 +550,14 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
// Generate for loop.
rewriter.create<scf::ForOp>(
xferOp.getLoc(), lb, ub, step, ValueRange(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
[&](OpBuilder &b, Location loc, Value iv,
ValueRange /*loopState*/) {
ScopedContext scope(b, loc);
generateInBoundsCheck(
xferOp, iv, b, unpackedDim(xferOp),
/*inBoundsCase=*/
[&](OpBuilder &b, Location /*loc*/) {
/*inBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) {
Strategy<OpTy>::rewriteOp(b, xferOp, casted, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder &b, Location /*loc*/) {
}, /*outOfBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) {
Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, casted, iv);
});
b.create<scf::YieldOp>(loc);
@ -546,9 +572,8 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
/// part of TransferOp1dConversion. Return the memref dimension on which
/// the transfer is operating. A return value of None indicates a broadcast.
template <typename OpTy>
static Optional<int64_t>
get1dMemrefIndices(OpTy xferOp, Value iv,
SmallVector<Value, 8> &memrefIndices) {
static Optional<int64_t> get1dMemrefIndices(
OpTy xferOp, Value iv, SmallVector<Value, 8> &memrefIndices) {
auto indices = xferOp.indices();
auto map = xferOp.permutation_map();
@ -575,25 +600,25 @@ struct Strategy1d;
/// Codegen strategy for TransferReadOp.
template <>
struct Strategy1d<TransferReadOp> {
static void generateForLoopBody(OpBuilder &builder, Location loc,
TransferReadOp xferOp, Value iv,
static void generateForLoopBody(
OpBuilder &builder, Location loc, TransferReadOp xferOp, Value iv,
ValueRange loopState) {
SmallVector<Value, 8> indices;
auto dim = get1dMemrefIndices(xferOp, iv, indices);
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
auto ivI32 = std_index_cast(
IntegerType::get(builder.getContext(), 32), iv);
auto vec = loopState[0];
// In case of out-of-bounds access, leave `vec` as is (was initialized with
// padding value).
auto nextVec = generateInBoundsCheck(
xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()),
/*inBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) {
/*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
auto val = memref_load(xferOp.source(), indices);
return vector_insert_element(val, vec, ivI32.value).value;
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
}, /*outOfBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
return vec;
});
builder.create<scf::YieldOp>(loc, nextVec);
}
@ -606,12 +631,13 @@ struct Strategy1d<TransferReadOp> {
/// Codegen strategy for TransferWriteOp.
template <>
struct Strategy1d<TransferWriteOp> {
static void generateForLoopBody(OpBuilder &builder, Location loc,
TransferWriteOp xferOp, Value iv,
static void generateForLoopBody(
OpBuilder &builder, Location loc, TransferWriteOp xferOp, Value iv,
ValueRange /*loopState*/) {
SmallVector<Value, 8> indices;
auto dim = get1dMemrefIndices(xferOp, iv, indices);
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
auto ivI32 = std_index_cast(
IntegerType::get(builder.getContext(), 32), iv);
// Nothing to do in case of out-of-bounds access.
generateInBoundsCheck(
@ -623,7 +649,9 @@ struct Strategy1d<TransferWriteOp> {
builder.create<scf::YieldOp>(loc);
}
static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
static Value initialLoopState(TransferWriteOp xferOp) {
return Value();
}
};
/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
@ -670,8 +698,6 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
return failure();
if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
return failure();
if (xferOp.mask())
return failure();
// Loop bounds, step, state...
auto vecType = xferOp.getVectorType();
@ -685,8 +711,8 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
[&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
ScopedContext nestedScope(builder, loc);
Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv,
loopState);
Strategy1d<OpTy>::generateForLoopBody(
builder, loc, xferOp, iv, loopState);
});
return success();
@ -699,7 +725,8 @@ namespace mlir {
void populateProgressiveVectorToSCFConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
patterns.add<PrepareTransferReadConversion,
PrepareTransferWriteConversion,
TransferOpConversion<TransferReadOp>,
TransferOpConversion<TransferWriteOp>>(patterns.getContext());
@ -725,4 +752,3 @@ std::unique_ptr<Pass>
mlir::createProgressiveConvertVectorToSCFPass() {
return std::make_unique<ConvertProgressiveVectorToSCFPass>();
}

View File

@ -1,8 +1,3 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
@ -10,15 +5,6 @@
// Test for special cases of 1D vector transfer ops.
func @transfer_read_2d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base1, %base2], %fm42
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
: memref<?x?xf32>, vector<5x6xf32>
vector.print %f: vector<5x6xf32>
return
}
func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base1, %base2], %fm42
@ -38,6 +24,38 @@ func @transfer_read_1d_broadcast(
return
}
func @transfer_read_1d_in_bounds(
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base1, %base2], %fm42
{permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
: memref<?x?xf32>, vector<3xf32>
vector.print %f: vector<3xf32>
return
}
func @transfer_read_1d_mask(
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
%mask = constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1>
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
{permutation_map = affine_map<(d0, d1) -> (d0)>}
: memref<?x?xf32>, vector<9xf32>
vector.print %f: vector<9xf32>
return
}
func @transfer_read_1d_mask_in_bounds(
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fm42 = constant -42.0: f32
%mask = constant dense<[1, 0, 1]> : vector<3xi1>
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
{permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
: memref<?x?xf32>, vector<3xf32>
vector.print %f: vector<3xf32>
return
}
func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fn1 = constant -1.0 : f32
%vf0 = splat %fn1 : vector<7xf32>
@ -69,14 +87,35 @@ func @entry() {
}
}
// Read from 2D memref on first dimension. Cannot be lowered to an LLVM
// vector load. Instead, generates scalar loads.
call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
// Write to 2D memref on first dimension. Cannot be lowered to an LLVM
// vector store. Instead, generates scalar stores.
call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
// (Same as above.)
call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
// Read a scalar from a 2D memref and broadcast the value to a 1D vector.
// Generates a loop with vector.insertelement.
call @transfer_read_1d_broadcast(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// Read from 2D memref on first dimension. Accesses are in-bounds, so no
// if-check is generated inside the generated loop.
call @transfer_read_1d_in_bounds(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// Optional mask attribute is specified and, in addition, there may be
// out-of-bounds accesses.
call @transfer_read_1d_mask(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
// Same as above, but accesses are in-bounds.
call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2)
: (memref<?x?xf32>, index, index) -> ()
return
}
// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
// CHECK: ( 12, 22, -1 )
// CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
// CHECK: ( 12, -42, -1 )