forked from OSchip/llvm-project
[mlir][Vector] Add linalg.copy-based pattern for splitting vector.transfer_read into full and partial copies.
This revision adds a transformation and a pattern that rewrites a "maybe masked" `vector.transfer_read %view[...], %pad `into a pattern resembling: ``` %1:3 = scf.if (%inBounds) { scf.yield %view : memref<A...>, index, index } else { %2 = linalg.fill(%extra_alloc, %pad) %3 = subview %view [...][...][...] linalg.copy(%3, %alloc) memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 : memref<A...>, index, index } %res= vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]} ``` where `extra_alloc` is a top of the function alloca'ed buffer of one vector. This rewrite makes it possible to realize the "always full tile" abstraction where vector.transfer_read operations are guaranteed to read from a padded full buffer. The extra work only occurs on the boundary tiles.
This commit is contained in:
parent
04e45ae1c6
commit
1a4263d394
|
@ -56,22 +56,48 @@ enum class VectorContractLowering {
|
|||
};
|
||||
/// Enum to control the lowering of `vector.transpose` operations.
|
||||
enum class VectorTransposeLowering {
|
||||
// Lower transpose into element-wise extract and inserts.
|
||||
/// Lower transpose into element-wise extract and inserts.
|
||||
EltWise = 0,
|
||||
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
|
||||
/// intrinsics.
|
||||
Flat = 1,
|
||||
};
|
||||
/// Enum to control the splitting of `vector.transfer` operations into masked
|
||||
/// and unmasked variants.
|
||||
enum class VectorTransferSplit {
|
||||
/// Do not split vector transfer operations.
|
||||
None = 0,
|
||||
/// Split using masked + unmasked vector.transfer operations.
|
||||
VectorTransfer = 1,
|
||||
/// Split using a unmasked vector.transfer + linalg.fill + linalg.copy
|
||||
/// operations.
|
||||
LinalgCopy = 2,
|
||||
/// Do not split vector transfer operation but instead mark it as "unmasked".
|
||||
ForceUnmasked = 3
|
||||
};
|
||||
/// Structure to control the behavior of vector transform patterns.
|
||||
struct VectorTransformsOptions {
|
||||
/// Option to control the lowering of vector.contract.
|
||||
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
|
||||
VectorTransposeLowering vectorTransposeLowering =
|
||||
VectorTransposeLowering::EltWise;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransformsOptions(VectorContractLowering opt) {
|
||||
vectorContractLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.transpose.
|
||||
VectorTransposeLowering vectorTransposeLowering =
|
||||
VectorTransposeLowering::EltWise;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransposeLowering(VectorTransposeLowering opt) {
|
||||
vectorTransposeLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the splitting of vector transfers.
|
||||
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
|
||||
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
|
||||
vectorTransferSplit = opt;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// Collect a set of transformation patterns that are related to contracting
|
||||
|
|
|
@ -109,13 +109,13 @@ private:
|
|||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
|
||||
/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
|
||||
/// is `success, the `ifOp` points to the newly created conditional upon
|
||||
/// function return. To accomodate for the fact that the original
|
||||
/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
|
||||
/// in the temporary buffer, the scf.if op returns a view and values of type
|
||||
/// index. At this time, only vector.transfer_read is implemented.
|
||||
/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
|
||||
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
|
||||
/// newly created conditional upon function return.
|
||||
/// To accomodate for the fact that the original vector.transfer indexing may be
|
||||
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
|
||||
/// scf.if op returns a view and values of type index.
|
||||
/// At this time, only vector.transfer_read case is implemented.
|
||||
///
|
||||
/// Example (a 2-D vector.transfer_read):
|
||||
/// ```
|
||||
|
@ -124,17 +124,17 @@ private:
|
|||
/// is transformed into:
|
||||
/// ```
|
||||
/// %1:3 = scf.if (%inBounds) {
|
||||
/// scf.yield %0 : memref<A...>, index, index
|
||||
/// } else {
|
||||
/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
|
||||
/// %3 = vector.type_cast %extra_alloc : memref<...> to
|
||||
/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
|
||||
/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
|
||||
/// memref<A...>, index, index
|
||||
/// // fastpath, direct cast
|
||||
/// memref_cast %A: memref<A...> to compatibleMemRefType
|
||||
/// scf.yield %view : compatibleMemRefType, index, index
|
||||
/// } else {
|
||||
/// // slowpath, masked vector.transfer or linalg.copy.
|
||||
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4 : compatibleMemRefType, index, index
|
||||
// }
|
||||
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
|
||||
/// ```
|
||||
/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
|
||||
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
|
||||
///
|
||||
/// Preconditions:
|
||||
/// 1. `xferOp.permutation_map()` must be a minor identity map
|
||||
|
@ -143,9 +143,10 @@ private:
|
|||
/// rank-reducing subviews.
|
||||
LogicalResult
|
||||
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
|
||||
LogicalResult splitFullAndPartialTransfer(OpBuilder &b,
|
||||
VectorTransferOpInterface xferOp,
|
||||
scf::IfOp *ifOp = nullptr);
|
||||
LogicalResult splitFullAndPartialTransfer(
|
||||
OpBuilder &b, VectorTransferOpInterface xferOp,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
scf::IfOp *ifOp = nullptr);
|
||||
|
||||
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
|
||||
/// may take an extra filter to perform selection at a finer granularity.
|
||||
|
@ -155,16 +156,19 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
|
|||
|
||||
explicit VectorTransferFullPartialRewriter(
|
||||
MLIRContext *context,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
FilterConstraintType filter =
|
||||
[](VectorTransferOpInterface op) { return success(); },
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
|
||||
: RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
|
||||
filter(filter) {}
|
||||
|
||||
/// Performs the rewrite.
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
VectorTransformsOptions options;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVector
|
|||
MLIRIR
|
||||
MLIRStandardOps
|
||||
MLIRAffineOps
|
||||
MLIRLinalgOps
|
||||
MLIRSCF
|
||||
MLIRLoopAnalysis
|
||||
MLIRSideEffectInterfaces
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
@ -2056,7 +2057,16 @@ LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
|
|||
return success();
|
||||
}
|
||||
|
||||
MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
||||
/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
|
||||
/// be cast. If the MemRefTypes don't have the same rank or are not strided,
|
||||
/// return null; otherwise:
|
||||
/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
|
||||
/// 2. else return a new MemRefType obtained by iterating over the shape and
|
||||
/// strides and:
|
||||
/// a. keeping the ones that are static and equal across `aT` and `bT`.
|
||||
/// b. using a dynamic shape and/or stride for the dimeniosns that don't
|
||||
/// agree.
|
||||
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
||||
if (MemRefCastOp::areCastCompatible(aT, bT))
|
||||
return aT;
|
||||
if (aT.getRank() != bT.getRank())
|
||||
|
@ -2086,13 +2096,154 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
|||
makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
|
||||
}
|
||||
|
||||
/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
|
||||
/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
|
||||
/// is `success, the `ifOp` points to the newly created conditional upon
|
||||
/// function return. To accomodate for the fact that the original
|
||||
/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
|
||||
/// in the temporary buffer, the scf.if op returns a view and values of type
|
||||
/// index. At this time, only vector.transfer_read is implemented.
|
||||
/// Operates under a scoped context to build the intersection between the
|
||||
/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
|
||||
// TODO: view intersection/union/differences should be a proper std op.
|
||||
static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
|
||||
Value alloc) {
|
||||
using namespace edsc::intrinsics;
|
||||
int64_t memrefRank = xferOp.getMemRefType().getRank();
|
||||
// TODO: relax this precondition, will require rank-reducing subviews.
|
||||
assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
|
||||
"Expected memref rank to match the alloc rank");
|
||||
Value one = std_constant_index(1);
|
||||
ValueRange leadingIndices =
|
||||
xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
|
||||
SmallVector<Value, 4> sizes;
|
||||
sizes.append(leadingIndices.begin(), leadingIndices.end());
|
||||
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
|
||||
Value dimAlloc = std_dim(alloc, resultIdx);
|
||||
Value index = xferOp.indices()[indicesIdx];
|
||||
AffineExpr i, j, k;
|
||||
bindDims(xferOp.getContext(), i, j, k);
|
||||
SmallVector<AffineMap, 4> maps =
|
||||
AffineMap::inferFromExprList(MapList{{i - j, k}});
|
||||
// affine_min(%dimMemRef - %index, %dimAlloc)
|
||||
Value affineMin = affine_min(index.getType(), maps[0],
|
||||
ValueRange{dimMemRef, index, dimAlloc});
|
||||
sizes.push_back(affineMin);
|
||||
});
|
||||
return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
|
||||
SmallVector<Value, 4>(memrefRank, one));
|
||||
}
|
||||
|
||||
/// Given an `xferOp` for which:
|
||||
/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
|
||||
/// 2. a memref of single vector `alloc` has been allocated.
|
||||
/// Produce IR resembling:
|
||||
/// ```
|
||||
/// %1:3 = scf.if (%inBounds) {
|
||||
/// memref_cast %A: memref<A...> to compatibleMemRefType
|
||||
/// scf.yield %view, ... : compatibleMemRefType, index, index
|
||||
/// } else {
|
||||
/// %2 = linalg.fill(%alloc, %pad)
|
||||
/// %3 = subview %view [...][...][...]
|
||||
/// linalg.copy(%3, %alloc)
|
||||
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4, ... : compatibleMemRefType, index, index
|
||||
/// }
|
||||
/// ```
|
||||
/// Return the produced scf::IfOp.
|
||||
static scf::IfOp createScopedFullPartialLinalgCopy(
|
||||
vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
|
||||
MemRefType compatibleMemRefType, Value alloc) {
|
||||
using namespace edsc;
|
||||
using namespace edsc::intrinsics;
|
||||
scf::IfOp fullPartialIfOp;
|
||||
Value zero = std_constant_index(0);
|
||||
Value memref = xferOp.memref();
|
||||
conditionBuilder(
|
||||
returnTypes, inBoundsCond,
|
||||
[&]() -> scf::ValueVector {
|
||||
Value res = memref;
|
||||
if (compatibleMemRefType != xferOp.getMemRefType())
|
||||
res = std_memref_cast(memref, compatibleMemRefType);
|
||||
scf::ValueVector viewAndIndices{res};
|
||||
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
|
||||
xferOp.indices().end());
|
||||
return viewAndIndices;
|
||||
},
|
||||
[&]() -> scf::ValueVector {
|
||||
linalg_fill(alloc, xferOp.padding());
|
||||
// Take partial subview of memref which guarantees no dimension
|
||||
// overflows.
|
||||
Value memRefSubView = createScopedSubViewIntersection(
|
||||
cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
|
||||
linalg_copy(memRefSubView, alloc);
|
||||
Value casted = std_memref_cast(alloc, compatibleMemRefType);
|
||||
scf::ValueVector viewAndIndices{casted};
|
||||
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
|
||||
zero);
|
||||
return viewAndIndices;
|
||||
},
|
||||
&fullPartialIfOp);
|
||||
return fullPartialIfOp;
|
||||
}
|
||||
|
||||
/// Given an `xferOp` for which:
|
||||
/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
|
||||
/// 2. a memref of single vector `alloc` has been allocated.
|
||||
/// Produce IR resembling:
|
||||
/// ```
|
||||
/// %1:3 = scf.if (%inBounds) {
|
||||
/// memref_cast %A: memref<A...> to compatibleMemRefType
|
||||
/// scf.yield %view, ... : compatibleMemRefType, index, index
|
||||
/// } else {
|
||||
/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
|
||||
/// %3 = vector.type_cast %extra_alloc :
|
||||
/// memref<...> to memref<vector<...>>
|
||||
/// store %2, %3[] : memref<vector<...>>
|
||||
/// %4 = memref_cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4, ... : compatibleMemRefType, index, index
|
||||
/// }
|
||||
/// ```
|
||||
/// Return the produced scf::IfOp.
|
||||
static scf::IfOp createScopedFullPartialVectorTransferRead(
|
||||
vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
|
||||
MemRefType compatibleMemRefType, Value alloc) {
|
||||
using namespace edsc;
|
||||
using namespace edsc::intrinsics;
|
||||
scf::IfOp fullPartialIfOp;
|
||||
Value zero = std_constant_index(0);
|
||||
Value memref = xferOp.memref();
|
||||
conditionBuilder(
|
||||
returnTypes, inBoundsCond,
|
||||
[&]() -> scf::ValueVector {
|
||||
Value res = memref;
|
||||
if (compatibleMemRefType != xferOp.getMemRefType())
|
||||
res = std_memref_cast(memref, compatibleMemRefType);
|
||||
scf::ValueVector viewAndIndices{res};
|
||||
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
|
||||
xferOp.indices().end());
|
||||
return viewAndIndices;
|
||||
},
|
||||
[&]() -> scf::ValueVector {
|
||||
Operation *newXfer =
|
||||
ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
|
||||
Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
|
||||
std_store(vector, vector_type_cast(
|
||||
MemRefType::get({}, vector.getType()), alloc));
|
||||
|
||||
Value casted = std_memref_cast(alloc, compatibleMemRefType);
|
||||
scf::ValueVector viewAndIndices{casted};
|
||||
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
|
||||
zero);
|
||||
|
||||
return viewAndIndices;
|
||||
},
|
||||
&fullPartialIfOp);
|
||||
return fullPartialIfOp;
|
||||
}
|
||||
|
||||
/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
|
||||
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
|
||||
/// newly created conditional upon function return.
|
||||
/// To accomodate for the fact that the original vector.transfer indexing may be
|
||||
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
|
||||
/// scf.if op returns a view and values of type index.
|
||||
/// At this time, only vector.transfer_read case is implemented.
|
||||
///
|
||||
/// Example (a 2-D vector.transfer_read):
|
||||
/// ```
|
||||
|
@ -2101,17 +2252,17 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
|||
/// is transformed into:
|
||||
/// ```
|
||||
/// %1:3 = scf.if (%inBounds) {
|
||||
/// scf.yield %0 : memref<A...>, index, index
|
||||
/// } else {
|
||||
/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
|
||||
/// %3 = vector.type_cast %extra_alloc : memref<...> to
|
||||
/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
|
||||
/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
|
||||
/// memref<A...>, index, index
|
||||
/// // fastpath, direct cast
|
||||
/// memref_cast %A: memref<A...> to compatibleMemRefType
|
||||
/// scf.yield %view : compatibleMemRefType, index, index
|
||||
/// } else {
|
||||
/// // slowpath, masked vector.transfer or linalg.copy.
|
||||
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4 : compatibleMemRefType, index, index
|
||||
// }
|
||||
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
|
||||
/// ```
|
||||
/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
|
||||
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
|
||||
///
|
||||
/// Preconditions:
|
||||
/// 1. `xferOp.permutation_map()` must be a minor identity map
|
||||
|
@ -2119,10 +2270,21 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
|||
/// must be equal. This will be relaxed in the future but requires
|
||||
/// rank-reducing subviews.
|
||||
LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||
OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) {
|
||||
OpBuilder &b, VectorTransferOpInterface xferOp,
|
||||
VectorTransformsOptions options, scf::IfOp *ifOp) {
|
||||
using namespace edsc;
|
||||
using namespace edsc::intrinsics;
|
||||
|
||||
if (options.vectorTransferSplit == VectorTransferSplit::None)
|
||||
return failure();
|
||||
|
||||
SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
|
||||
auto unmaskedAttr = b.getBoolArrayAttr(bools);
|
||||
if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) {
|
||||
xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
|
||||
"Expected splitFullAndPartialTransferPrecondition to hold");
|
||||
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
|
||||
|
@ -2154,45 +2316,21 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
|||
b.getI64IntegerAttr(32));
|
||||
}
|
||||
|
||||
Value memref = xferOp.memref();
|
||||
SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
|
||||
auto unmaskedAttr = b.getBoolArrayAttr(bools);
|
||||
|
||||
MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
|
||||
xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
|
||||
|
||||
// Read case: full fill + partial copy -> unmasked vector.xfer_read.
|
||||
Value zero = std_constant_index(0);
|
||||
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
|
||||
b.getIndexType());
|
||||
returnTypes[0] = compatibleMemRefType;
|
||||
scf::IfOp fullPartialIfOp;
|
||||
conditionBuilder(
|
||||
returnTypes, inBoundsCond,
|
||||
[&]() -> scf::ValueVector {
|
||||
Value res = memref;
|
||||
if (compatibleMemRefType != xferOp.getMemRefType())
|
||||
res = std_memref_cast(memref, compatibleMemRefType);
|
||||
scf::ValueVector viewAndIndices{res};
|
||||
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
|
||||
xferOp.indices().end());
|
||||
return viewAndIndices;
|
||||
},
|
||||
[&]() -> scf::ValueVector {
|
||||
Operation *newXfer =
|
||||
ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
|
||||
Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
|
||||
std_store(vector, vector_type_cast(
|
||||
MemRefType::get({}, vector.getType()), alloc));
|
||||
|
||||
Value casted = std_memref_cast(alloc, compatibleMemRefType);
|
||||
scf::ValueVector viewAndIndices{casted};
|
||||
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
|
||||
zero);
|
||||
|
||||
return viewAndIndices;
|
||||
},
|
||||
&fullPartialIfOp);
|
||||
scf::IfOp fullPartialIfOp =
|
||||
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
|
||||
? createScopedFullPartialVectorTransferRead(
|
||||
xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
|
||||
alloc)
|
||||
: createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
|
||||
inBoundsCond,
|
||||
compatibleMemRefType, alloc);
|
||||
if (ifOp)
|
||||
*ifOp = fullPartialIfOp;
|
||||
|
||||
|
@ -2211,7 +2349,7 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
|
|||
failed(filter(xferOp)))
|
||||
return failure();
|
||||
rewriter.startRootUpdate(xferOp);
|
||||
if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) {
|
||||
if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
|
||||
rewriter.finalizeRootUpdate(xferOp);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1,13 +1,26 @@
|
|||
// RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-transfer-full-partial-split=use-linalg-copy | FileCheck %s --check-prefix=LINALG
|
||||
|
||||
// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
|
||||
// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
|
||||
// CHECK-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
|
||||
// LINALG-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
|
||||
// LINALG-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
|
||||
// LINALG-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
// LINALG-DAG: #[[$map_2d_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
|
||||
// LINALG-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
|
||||
// LINALG-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
|
||||
|
||||
// CHECK-LABEL: split_vector_transfer_read_2d(
|
||||
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
|
||||
// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
|
||||
// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
|
||||
|
||||
// LINALG-LABEL: split_vector_transfer_read_2d(
|
||||
// LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref
|
||||
// LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index
|
||||
// LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index
|
||||
func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%f0 = constant 0.0 : f32
|
||||
|
@ -43,9 +56,45 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
|
|||
// CHECK: }
|
||||
// CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
|
||||
// CHECK_SAME: {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
|
||||
|
||||
// LINALG-DAG: %[[c0:.*]] = constant 0 : index
|
||||
// LINALG-DAG: %[[c1:.*]] = constant 1 : index
|
||||
// LINALG-DAG: %[[c4:.*]] = constant 4 : index
|
||||
// LINALG-DAG: %[[c8:.*]] = constant 8 : index
|
||||
// LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
|
||||
// alloca for boundary full tile
|
||||
// LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
|
||||
// %i + 4 <= dim(%A, 0)
|
||||
// LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
|
||||
// LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
|
||||
// LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[d0]] : index
|
||||
// %j + 8 <= dim(%A, 1)
|
||||
// LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
|
||||
// LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
|
||||
// are both conds true
|
||||
// LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
|
||||
// LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) {
|
||||
// inBounds, just yield %A
|
||||
// LINALG: scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
|
||||
// LINALG: } else {
|
||||
// slow path, fill tmp alloc and yield a memref_casted version of it
|
||||
// LINALG: linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32
|
||||
// LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
|
||||
// LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]])
|
||||
// LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
|
||||
// LINALG: %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]]
|
||||
// LINALG-SAME: memref<?x8xf32> to memref<?x?xf32, #[[$map_2d_dynamic]]>
|
||||
// LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_dynamic]]>, memref<4x8xf32>
|
||||
// LINALG: %[[yielded:.*]] = memref_cast %[[alloc]] :
|
||||
// LINALG-SAME: memref<4x8xf32> to memref<?x8xf32>
|
||||
// LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
|
||||
// LINALG-SAME: memref<?x8xf32>, index, index
|
||||
// LINALG: }
|
||||
// LINALG: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
|
||||
// LINALG_SAME: {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
|
||||
%1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32>
|
||||
|
||||
// CHECK: return %[[res]] : vector<4x8xf32>
|
||||
// LINALG: return %[[res]] : vector<4x8xf32>
|
||||
return %1: vector<4x8xf32>
|
||||
}
|
||||
|
||||
|
@ -53,6 +102,11 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
|
|||
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
|
||||
// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
|
||||
// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
|
||||
|
||||
// LINALG-LABEL: split_vector_transfer_read_strided_2d(
|
||||
// LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref
|
||||
// LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index
|
||||
// LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index
|
||||
func @split_vector_transfer_read_strided_2d(
|
||||
%A: memref<7x8xf32, offset:?, strides:[?, 1]>,
|
||||
%i: index, %j: index) -> vector<4x8xf32> {
|
||||
|
@ -94,6 +148,44 @@ func @split_vector_transfer_read_strided_2d(
|
|||
// CHECK: }
|
||||
// CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
|
||||
// CHECK-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
|
||||
|
||||
// LINALG-DAG: %[[c0:.*]] = constant 0 : index
|
||||
// LINALG-DAG: %[[c1:.*]] = constant 1 : index
|
||||
// LINALG-DAG: %[[c4:.*]] = constant 4 : index
|
||||
// LINALG-DAG: %[[c7:.*]] = constant 7 : index
|
||||
// LINALG-DAG: %[[c8:.*]] = constant 8 : index
|
||||
// LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
|
||||
// alloca for boundary full tile
|
||||
// LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
|
||||
// %i + 4 <= dim(%A, 0)
|
||||
// LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
|
||||
// LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[c7]] : index
|
||||
// %j + 8 <= dim(%A, 1)
|
||||
// LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
|
||||
// LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
|
||||
// are both conds true
|
||||
// LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
|
||||
// LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index) {
|
||||
// inBounds but not cast-compatible: yield a memref_casted form of %A
|
||||
// LINALG: %[[casted:.*]] = memref_cast %arg0 :
|
||||
// LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x8xf32, #[[$map_2d_stride_1]]>
|
||||
// LINALG: scf.yield %[[casted]], %[[i]], %[[j]] :
|
||||
// LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
|
||||
// LINALG: } else {
|
||||
// slow path, fill tmp alloc and yield a memref_casted version of it
|
||||
// LINALG: linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32
|
||||
// LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]])
|
||||
// LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
|
||||
// LINALG: %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]]
|
||||
// LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x?xf32, #[[$map_2d_dynamic]]>
|
||||
// LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_dynamic]]>, memref<4x8xf32>
|
||||
// LINALG: %[[yielded:.*]] = memref_cast %[[alloc]] :
|
||||
// LINALG-SAME: memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
|
||||
// LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
|
||||
// LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
|
||||
// LINALG: }
|
||||
// LINALG: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
|
||||
// LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
|
||||
%1 = vector.transfer_read %A[%i, %j], %f0 :
|
||||
memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32>
|
||||
|
||||
|
|
|
@ -125,10 +125,23 @@ struct TestVectorUnrollingPatterns
|
|||
struct TestVectorTransferFullPartialSplitPatterns
|
||||
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
|
||||
FunctionPass> {
|
||||
TestVectorTransferFullPartialSplitPatterns() = default;
|
||||
TestVectorTransferFullPartialSplitPatterns(
|
||||
const TestVectorTransferFullPartialSplitPatterns &pass) {}
|
||||
Option<bool> useLinalgOps{
|
||||
*this, "use-linalg-copy",
|
||||
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
|
||||
"linalg.copy operations."),
|
||||
llvm::cl::init(false)};
|
||||
void runOnFunction() override {
|
||||
MLIRContext *ctx = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<VectorTransferFullPartialRewriter>(ctx);
|
||||
VectorTransformsOptions options;
|
||||
if (useLinalgOps)
|
||||
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
|
||||
else
|
||||
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
|
||||
patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue