forked from OSchip/llvm-project
[mlir][Vector] Add transformation + pattern to split vector.transfer_read into full and partial copies.
This revision adds a transformation and a pattern that rewrites a "maybe masked" `vector.transfer_read %view[...], %pad `into a pattern resembling: ``` %1:3 = scf.if (%inBounds) { scf.yield %view : memref<A...>, index, index } else { %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...> %3 = vector.type_cast %extra_alloc : memref<...> to memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 = memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 : memref<A...>, index, index } %res= vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]} ``` where `extra_alloc` is a top of the function alloca'ed buffer of one vector. This rewrite makes it possible to realize the "always full tile" abstraction where vector.transfer_read operations are guaranteed to read from a padded full buffer. The extra work only occurs on the boundary tiles. Differential Revision: https://reviews.llvm.org/D84631
This commit is contained in:
parent
8aeb212887
commit
35b65be041
|
@ -17,6 +17,11 @@
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class MLIRContext;
|
class MLIRContext;
|
||||||
class OwningRewritePatternList;
|
class OwningRewritePatternList;
|
||||||
|
class VectorTransferOpInterface;
|
||||||
|
|
||||||
|
namespace scf {
|
||||||
|
class IfOp;
|
||||||
|
} // namespace scf
|
||||||
|
|
||||||
/// Collect a set of patterns to convert from the Vector dialect to itself.
|
/// Collect a set of patterns to convert from the Vector dialect to itself.
|
||||||
/// Should be merged with populateVectorToSCFLoweringPattern.
|
/// Should be merged with populateVectorToSCFLoweringPattern.
|
||||||
|
@ -104,6 +109,65 @@ private:
|
||||||
FilterConstraintType filter;
|
FilterConstraintType filter;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
|
||||||
|
/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
|
||||||
|
/// is `success, the `ifOp` points to the newly created conditional upon
|
||||||
|
/// function return. To accomodate for the fact that the original
|
||||||
|
/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
|
||||||
|
/// in the temporary buffer, the scf.if op returns a view and values of type
|
||||||
|
/// index. At this time, only vector.transfer_read is implemented.
|
||||||
|
///
|
||||||
|
/// Example (a 2-D vector.transfer_read):
|
||||||
|
/// ```
|
||||||
|
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
|
||||||
|
/// ```
|
||||||
|
/// is transformed into:
|
||||||
|
/// ```
|
||||||
|
/// %1:3 = scf.if (%inBounds) {
|
||||||
|
/// scf.yield %0 : memref<A...>, index, index
|
||||||
|
/// } else {
|
||||||
|
/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
|
||||||
|
/// %3 = vector.type_cast %extra_alloc : memref<...> to
|
||||||
|
/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
|
||||||
|
/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
|
||||||
|
/// memref<A...>, index, index
|
||||||
|
// }
|
||||||
|
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
|
||||||
|
/// ```
|
||||||
|
/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
|
||||||
|
///
|
||||||
|
/// Preconditions:
|
||||||
|
/// 1. `xferOp.permutation_map()` must be a minor identity map
|
||||||
|
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
|
||||||
|
/// must be equal. This will be relaxed in the future but requires
|
||||||
|
/// rank-reducing subviews.
|
||||||
|
LogicalResult
|
||||||
|
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
|
||||||
|
LogicalResult splitFullAndPartialTransfer(OpBuilder &b,
|
||||||
|
VectorTransferOpInterface xferOp,
|
||||||
|
scf::IfOp *ifOp = nullptr);
|
||||||
|
|
||||||
|
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
|
||||||
|
/// may take an extra filter to perform selection at a finer granularity.
|
||||||
|
struct VectorTransferFullPartialRewriter : public RewritePattern {
|
||||||
|
using FilterConstraintType =
|
||||||
|
std::function<LogicalResult(VectorTransferOpInterface op)>;
|
||||||
|
|
||||||
|
explicit VectorTransferFullPartialRewriter(
|
||||||
|
MLIRContext *context,
|
||||||
|
FilterConstraintType filter =
|
||||||
|
[](VectorTransferOpInterface op) { return success(); },
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
|
||||||
|
|
||||||
|
/// Performs the rewrite.
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
FilterConstraintType filter;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace vector
|
} // namespace vector
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -160,6 +160,19 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
||||||
/*defaultImplementation=*/
|
/*defaultImplementation=*/
|
||||||
"return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
|
"return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
|
||||||
>,
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{ Returns true if at least one of the dimensions is masked.}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"hasMaskedDim",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
for (unsigned idx = 0, e = $_op.getTransferRank(); idx < e; ++idx)
|
||||||
|
if ($_op.isMaskedDim(idx))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
/*desc=*/[{
|
/*desc=*/[{
|
||||||
Helper function to account for the fact that `permutationMap` results and
|
Helper function to account for the fact that `permutationMap` results and
|
||||||
|
|
|
@ -12,9 +12,13 @@
|
||||||
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||||
|
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||||
|
@ -1985,6 +1989,236 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
static Optional<int64_t> extractConstantIndex(Value v) {
|
||||||
|
if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
|
||||||
|
return cstOp.getValue();
|
||||||
|
if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
|
||||||
|
if (affineApplyOp.getAffineMap().isSingleConstant())
|
||||||
|
return affineApplyOp.getAffineMap().getSingleConstantResult();
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Missing foldings of scf.if make it necessary to perform poor man's folding
|
||||||
|
// eagerly, especially in the case of unrolling. In the future, this should go
|
||||||
|
// away once scf.if folds properly.
|
||||||
|
static Value createScopedFoldedSLE(Value v, Value ub) {
|
||||||
|
using namespace edsc::op;
|
||||||
|
auto maybeCstV = extractConstantIndex(v);
|
||||||
|
auto maybeCstUb = extractConstantIndex(ub);
|
||||||
|
if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
|
||||||
|
return Value();
|
||||||
|
return sle(v, ub);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operates under a scoped context to build the condition to ensure that a
|
||||||
|
// particular VectorTransferOpInterface is unmasked.
|
||||||
|
static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
|
||||||
|
assert(xferOp.permutation_map().isMinorIdentity() &&
|
||||||
|
"Expected minor identity map");
|
||||||
|
Value inBoundsCond;
|
||||||
|
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
|
||||||
|
// Zip over the resulting vector shape and memref indices.
|
||||||
|
// If the dimension is known to be unmasked, it does not participate in the
|
||||||
|
// construction of `inBoundsCond`.
|
||||||
|
if (!xferOp.isMaskedDim(resultIdx))
|
||||||
|
return;
|
||||||
|
int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
|
||||||
|
using namespace edsc::op;
|
||||||
|
using namespace edsc::intrinsics;
|
||||||
|
// Fold or create the check that `index + vector_size` <= `memref_size`.
|
||||||
|
Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
|
||||||
|
Value cond =
|
||||||
|
createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
|
||||||
|
if (!cond)
|
||||||
|
return;
|
||||||
|
// Conjunction over all dims for which we are in-bounds.
|
||||||
|
inBoundsCond = inBoundsCond ? inBoundsCond && cond : cond;
|
||||||
|
});
|
||||||
|
return inBoundsCond;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
|
||||||
|
VectorTransferOpInterface xferOp) {
|
||||||
|
// TODO: expand support to these 2 cases.
|
||||||
|
if (!xferOp.permutation_map().isMinorIdentity())
|
||||||
|
return failure();
|
||||||
|
// TODO: relax this precondition. This will require rank-reducing subviews.
|
||||||
|
if (xferOp.getMemRefType().getRank() != xferOp.getTransferRank())
|
||||||
|
return failure();
|
||||||
|
// Must have some masked dimension to be a candidate for splitting.
|
||||||
|
if (!xferOp.hasMaskedDim())
|
||||||
|
return failure();
|
||||||
|
// Don't split transfer operations under IfOp, this avoids applying the
|
||||||
|
// pattern recursively.
|
||||||
|
// TODO: improve the condition to make it more applicable.
|
||||||
|
if (xferOp.getParentOfType<scf::IfOp>())
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
||||||
|
if (MemRefCastOp::areCastCompatible(aT, bT))
|
||||||
|
return aT;
|
||||||
|
if (aT.getRank() != bT.getRank())
|
||||||
|
return MemRefType();
|
||||||
|
int64_t aOffset, bOffset;
|
||||||
|
SmallVector<int64_t, 4> aStrides, bStrides;
|
||||||
|
if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
|
||||||
|
failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
|
||||||
|
aStrides.size() != bStrides.size())
|
||||||
|
return MemRefType();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
|
||||||
|
int64_t resOffset;
|
||||||
|
SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
|
||||||
|
resStrides(bT.getRank(), 0);
|
||||||
|
for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
|
||||||
|
resShape[idx] =
|
||||||
|
(aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
|
||||||
|
resStrides[idx] = (aStrides[idx] == bStrides[idx])
|
||||||
|
? aStrides[idx]
|
||||||
|
: MemRefType::kDynamicStrideOrOffset;
|
||||||
|
}
|
||||||
|
resOffset =
|
||||||
|
(aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
|
||||||
|
return MemRefType::get(
|
||||||
|
resShape, aT.getElementType(),
|
||||||
|
makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
|
||||||
|
/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
|
||||||
|
/// is `success, the `ifOp` points to the newly created conditional upon
|
||||||
|
/// function return. To accomodate for the fact that the original
|
||||||
|
/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
|
||||||
|
/// in the temporary buffer, the scf.if op returns a view and values of type
|
||||||
|
/// index. At this time, only vector.transfer_read is implemented.
|
||||||
|
///
|
||||||
|
/// Example (a 2-D vector.transfer_read):
|
||||||
|
/// ```
|
||||||
|
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
|
||||||
|
/// ```
|
||||||
|
/// is transformed into:
|
||||||
|
/// ```
|
||||||
|
/// %1:3 = scf.if (%inBounds) {
|
||||||
|
/// scf.yield %0 : memref<A...>, index, index
|
||||||
|
/// } else {
|
||||||
|
/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
|
||||||
|
/// %3 = vector.type_cast %extra_alloc : memref<...> to
|
||||||
|
/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
|
||||||
|
/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
|
||||||
|
/// memref<A...>, index, index
|
||||||
|
// }
|
||||||
|
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
|
||||||
|
/// ```
|
||||||
|
/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
|
||||||
|
///
|
||||||
|
/// Preconditions:
|
||||||
|
/// 1. `xferOp.permutation_map()` must be a minor identity map
|
||||||
|
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
|
||||||
|
/// must be equal. This will be relaxed in the future but requires
|
||||||
|
/// rank-reducing subviews.
|
||||||
|
LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||||
|
OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) {
|
||||||
|
using namespace edsc;
|
||||||
|
using namespace edsc::intrinsics;
|
||||||
|
|
||||||
|
assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
|
||||||
|
"Expected splitFullAndPartialTransferPrecondition to hold");
|
||||||
|
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
|
||||||
|
|
||||||
|
// TODO: add support for write case.
|
||||||
|
if (!xferReadOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(b);
|
||||||
|
if (xferOp.memref().getDefiningOp())
|
||||||
|
b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
|
||||||
|
else
|
||||||
|
b.setInsertionPoint(xferOp);
|
||||||
|
ScopedContext scope(b, xferOp.getLoc());
|
||||||
|
Value inBoundsCond = createScopedInBoundsCond(
|
||||||
|
cast<VectorTransferOpInterface>(xferOp.getOperation()));
|
||||||
|
if (!inBoundsCond)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Top of the function `alloc` for transient storage.
|
||||||
|
Value alloc;
|
||||||
|
{
|
||||||
|
FuncOp funcOp = xferOp.getParentOfType<FuncOp>();
|
||||||
|
OpBuilder::InsertionGuard guard(b);
|
||||||
|
b.setInsertionPointToStart(&funcOp.getRegion().front());
|
||||||
|
auto shape = xferOp.getVectorType().getShape();
|
||||||
|
Type elementType = xferOp.getVectorType().getElementType();
|
||||||
|
alloc = std_alloca(MemRefType::get(shape, elementType), ValueRange{},
|
||||||
|
b.getI64IntegerAttr(32));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value memref = xferOp.memref();
|
||||||
|
SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
|
||||||
|
auto unmaskedAttr = b.getBoolArrayAttr(bools);
|
||||||
|
|
||||||
|
MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
|
||||||
|
xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
|
||||||
|
|
||||||
|
// Read case: full fill + partial copy -> unmasked vector.xfer_read.
|
||||||
|
Value zero = std_constant_index(0);
|
||||||
|
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
|
||||||
|
b.getIndexType());
|
||||||
|
returnTypes[0] = compatibleMemRefType;
|
||||||
|
scf::IfOp fullPartialIfOp;
|
||||||
|
conditionBuilder(
|
||||||
|
returnTypes, inBoundsCond,
|
||||||
|
[&]() -> scf::ValueVector {
|
||||||
|
Value res = memref;
|
||||||
|
if (compatibleMemRefType != xferOp.getMemRefType())
|
||||||
|
res = std_memref_cast(memref, compatibleMemRefType);
|
||||||
|
scf::ValueVector viewAndIndices{res};
|
||||||
|
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
|
||||||
|
xferOp.indices().end());
|
||||||
|
return viewAndIndices;
|
||||||
|
},
|
||||||
|
[&]() -> scf::ValueVector {
|
||||||
|
Operation *newXfer =
|
||||||
|
ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
|
||||||
|
Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
|
||||||
|
std_store(vector, vector_type_cast(
|
||||||
|
MemRefType::get({}, vector.getType()), alloc));
|
||||||
|
|
||||||
|
Value casted = std_memref_cast(alloc, compatibleMemRefType);
|
||||||
|
scf::ValueVector viewAndIndices{casted};
|
||||||
|
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
|
||||||
|
zero);
|
||||||
|
|
||||||
|
return viewAndIndices;
|
||||||
|
},
|
||||||
|
&fullPartialIfOp);
|
||||||
|
if (ifOp)
|
||||||
|
*ifOp = fullPartialIfOp;
|
||||||
|
|
||||||
|
// Unmask the existing read op, it always reads from a full buffer.
|
||||||
|
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
|
||||||
|
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
|
||||||
|
xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
|
||||||
|
Operation *op, PatternRewriter &rewriter) const {
|
||||||
|
auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
|
||||||
|
if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
|
||||||
|
failed(filter(xferOp)))
|
||||||
|
return failure();
|
||||||
|
rewriter.startRootUpdate(xferOp);
|
||||||
|
if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) {
|
||||||
|
rewriter.finalizeRootUpdate(xferOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
rewriter.cancelRootUpdate(xferOp);
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
||||||
// TODO: Add this as DRR pattern.
|
// TODO: Add this as DRR pattern.
|
||||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
// RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
|
||||||
|
// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
|
||||||
|
// CHECK-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||||
|
|
||||||
|
// CHECK-LABEL: split_vector_transfer_read_2d(
|
||||||
|
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
|
||||||
|
// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
|
||||||
|
// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
|
||||||
|
func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%f0 = constant 0.0 : f32
|
||||||
|
|
||||||
|
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
|
||||||
|
// CHECK-DAG: %[[c8:.*]] = constant 8 : index
|
||||||
|
// CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// alloca for boundary full tile
|
||||||
|
// CHECK: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
|
||||||
|
// %i + 4 <= dim(%A, 0)
|
||||||
|
// CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
|
||||||
|
// CHECK: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
|
||||||
|
// CHECK: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[d0]] : index
|
||||||
|
// %j + 8 <= dim(%A, 1)
|
||||||
|
// CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
|
||||||
|
// CHECK: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
|
||||||
|
// are both conds true
|
||||||
|
// CHECK: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
|
||||||
|
// CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) {
|
||||||
|
// inBounds, just yield %A
|
||||||
|
// CHECK: scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
|
||||||
|
// CHECK: } else {
|
||||||
|
// slow path, fill tmp alloc and yield a memref_casted version of it
|
||||||
|
// CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst :
|
||||||
|
// CHECK-SAME: memref<?x8xf32>, vector<4x8xf32>
|
||||||
|
// CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] :
|
||||||
|
// CHECK-SAME: memref<4x8xf32> to memref<vector<4x8xf32>>
|
||||||
|
// CHECK: store %[[slow]], %[[cast_alloc]][] : memref<vector<4x8xf32>>
|
||||||
|
// CHECK: %[[yielded:.*]] = memref_cast %[[alloc]] :
|
||||||
|
// CHECK-SAME: memref<4x8xf32> to memref<?x8xf32>
|
||||||
|
// CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
|
||||||
|
// CHECK-SAME: memref<?x8xf32>, index, index
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
|
||||||
|
// CHECK_SAME: {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
|
||||||
|
%1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32>
|
||||||
|
|
||||||
|
// CHECK: return %[[res]] : vector<4x8xf32>
|
||||||
|
return %1: vector<4x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: split_vector_transfer_read_strided_2d(
|
||||||
|
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
|
||||||
|
// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
|
||||||
|
// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
|
||||||
|
func @split_vector_transfer_read_strided_2d(
|
||||||
|
%A: memref<7x8xf32, offset:?, strides:[?, 1]>,
|
||||||
|
%i: index, %j: index) -> vector<4x8xf32> {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%f0 = constant 0.0 : f32
|
||||||
|
|
||||||
|
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
|
||||||
|
// CHECK-DAG: %[[c7:.*]] = constant 7 : index
|
||||||
|
// CHECK-DAG: %[[c8:.*]] = constant 8 : index
|
||||||
|
// CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// alloca for boundary full tile
|
||||||
|
// CHECK: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
|
||||||
|
// %i + 4 <= dim(%A, 0)
|
||||||
|
// CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
|
||||||
|
// CHECK: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[c7]] : index
|
||||||
|
// %j + 8 <= dim(%A, 1)
|
||||||
|
// CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
|
||||||
|
// CHECK: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
|
||||||
|
// are both conds true
|
||||||
|
// CHECK: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
|
||||||
|
// CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index) {
|
||||||
|
// inBounds but not cast-compatible: yield a memref_casted form of %A
|
||||||
|
// CHECK: %[[casted:.*]] = memref_cast %arg0 :
|
||||||
|
// CHECK-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x8xf32, #[[$map_2d_stride_1]]>
|
||||||
|
// CHECK: scf.yield %[[casted]], %[[i]], %[[j]] :
|
||||||
|
// CHECK-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
|
||||||
|
// CHECK: } else {
|
||||||
|
// slow path, fill tmp alloc and yield a memref_casted version of it
|
||||||
|
// CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst :
|
||||||
|
// CHECK-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
|
||||||
|
// CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] :
|
||||||
|
// CHECK-SAME: memref<4x8xf32> to memref<vector<4x8xf32>>
|
||||||
|
// CHECK: store %[[slow]], %[[cast_alloc]][] :
|
||||||
|
// CHECK-SAME: memref<vector<4x8xf32>>
|
||||||
|
// CHECK: %[[yielded:.*]] = memref_cast %[[alloc]] :
|
||||||
|
// CHECK-SAME: memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
|
||||||
|
// CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
|
||||||
|
// CHECK-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
|
||||||
|
// CHECK-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
|
||||||
|
%1 = vector.transfer_read %A[%i, %j], %f0 :
|
||||||
|
memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32>
|
||||||
|
|
||||||
|
// CHECK: return %[[res]] : vector<4x8xf32>
|
||||||
|
return %1 : vector<4x8xf32>
|
||||||
|
}
|
|
@ -122,6 +122,17 @@ struct TestVectorUnrollingPatterns
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TestVectorTransferFullPartialSplitPatterns
|
||||||
|
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
|
||||||
|
FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
MLIRContext *ctx = &getContext();
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<VectorTransferFullPartialRewriter>(ctx);
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -141,5 +152,10 @@ void registerTestVectorConversions() {
|
||||||
PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
|
PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
|
||||||
"test-vector-unrolling-patterns",
|
"test-vector-unrolling-patterns",
|
||||||
"Test conversion patterns to unroll contract ops in the vector dialect");
|
"Test conversion patterns to unroll contract ops in the vector dialect");
|
||||||
|
|
||||||
|
PassRegistration<TestVectorTransferFullPartialSplitPatterns>
|
||||||
|
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
|
||||||
|
"Test conversion patterns to split "
|
||||||
|
"transfer ops via scf.if + linalg ops");
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
Loading…
Reference in New Issue