forked from OSchip/llvm-project
[mlir][Vector] Lowering of transfer_read/write to vector.load/store
This patch introduces progressive lowering patterns for rewriting vector.transfer_read/write to vector.load/store and vector.broadcast in certain supported cases. Reviewed By: dcaballe, nicolasvasilache Differential Revision: https://reviews.llvm.org/D97822
This commit is contained in:
parent
5eaeb0fa67
commit
fd2b08969b
|
@ -85,6 +85,13 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
|
|||
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context);
|
||||
|
||||
/// Collect a set of transfer read/write lowering patterns.
|
||||
///
|
||||
/// These patterns lower transfer ops to simpler ops like `vector.load`,
|
||||
/// `vector.store` and `vector.broadcast`.
|
||||
void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context);
|
||||
|
||||
/// An attribute that specifies the combining function for `vector.contract`,
|
||||
/// and `vector.reduction`.
|
||||
class CombiningKindAttr
|
||||
|
|
|
@ -104,6 +104,15 @@ public:
|
|||
/// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
|
||||
bool isMinorIdentity() const;
|
||||
|
||||
/// Returns true if this affine map is a minor identity up to broadcasted
|
||||
/// dimensions which are indicated by value 0 in the result. If
|
||||
/// `broadcastedDims` is not null, it will be populated with the indices of
|
||||
/// the broadcasted dimensions in the result array.
|
||||
/// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d2, 0, d4)>
|
||||
/// (`broadcastedDims` will contain [0, 2])
|
||||
bool isMinorIdentityWithBroadcasting(
|
||||
SmallVectorImpl<unsigned> *broadcastedDims = nullptr) const;
|
||||
|
||||
/// Returns true if this affine map is an empty map, i.e., () -> ().
|
||||
bool isEmpty() const;
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -2729,6 +2730,116 @@ struct TransferWriteInsertPattern
|
|||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of transfer_read. This pattern supports lowering of
|
||||
/// `vector.transfer_read` to a combination of `vector.load` and
|
||||
/// `vector.broadcast` if all of the following hold:
|
||||
/// - The op reads from a memref with the default layout.
|
||||
/// - Masking is not required.
|
||||
/// - If the memref's element type is a vector type then it coincides with the
|
||||
/// result type.
|
||||
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
|
||||
struct TransferReadToVectorLoadLowering
|
||||
: public OpRewritePattern<vector::TransferReadOp> {
|
||||
TransferReadToVectorLoadLowering(MLIRContext *context)
|
||||
: OpRewritePattern<vector::TransferReadOp>(context) {}
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp read,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<unsigned, 4> broadcastedDims;
|
||||
// TODO: Support permutations.
|
||||
if (!read.permutation_map().isMinorIdentityWithBroadcasting(
|
||||
&broadcastedDims))
|
||||
return failure();
|
||||
auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
|
||||
// If there is broadcasting involved then we first load the unbroadcasted
|
||||
// vector, and then broadcast it with `vector.broadcast`.
|
||||
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
|
||||
SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
|
||||
vectorShape.end());
|
||||
for (unsigned i : broadcastedDims)
|
||||
unbroadcastedVectorShape[i] = 1;
|
||||
VectorType unbroadcastedVectorType = VectorType::get(
|
||||
unbroadcastedVectorShape, read.getVectorType().getElementType());
|
||||
|
||||
// `vector.load` supports vector types as memref's elements only when the
|
||||
// resulting vector type is the same as the element type.
|
||||
if (memRefType.getElementType().isa<VectorType>() &&
|
||||
memRefType.getElementType() != unbroadcastedVectorType)
|
||||
return failure();
|
||||
// Only the default layout is supported by `vector.load`.
|
||||
// TODO: Support non-default layouts.
|
||||
if (!memRefType.getAffineMaps().empty())
|
||||
return failure();
|
||||
// TODO: When masking is required, we can create a MaskedLoadOp
|
||||
if (read.hasMaskedDim())
|
||||
return failure();
|
||||
|
||||
Operation *loadOp;
|
||||
if (!broadcastedDims.empty() &&
|
||||
unbroadcastedVectorType.getNumElements() == 1) {
|
||||
// If broadcasting is required and the number of loaded elements is 1 then
|
||||
// we can create `std.load` instead of `vector.load`.
|
||||
loadOp = rewriter.create<mlir::LoadOp>(read.getLoc(), read.source(),
|
||||
read.indices());
|
||||
} else {
|
||||
// Otherwise create `vector.load`.
|
||||
loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
|
||||
unbroadcastedVectorType,
|
||||
read.source(), read.indices());
|
||||
}
|
||||
|
||||
// Insert a broadcasting op if required.
|
||||
if (!broadcastedDims.empty()) {
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
read, read.getVectorType(), loadOp->getResult(0));
|
||||
} else {
|
||||
rewriter.replaceOp(read, loadOp->getResult(0));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of transfer_write. This pattern supports lowering of
|
||||
/// `vector.transfer_write` to `vector.store` if all of the following hold:
|
||||
/// - The op writes to a memref with the default layout.
|
||||
/// - Masking is not required.
|
||||
/// - If the memref's element type is a vector type then it coincides with the
|
||||
/// type of the written value.
|
||||
/// - The permutation map is the minor identity map (neither permutation nor
|
||||
/// broadcasting is allowed).
|
||||
struct TransferWriteToVectorStoreLowering
|
||||
: public OpRewritePattern<vector::TransferWriteOp> {
|
||||
TransferWriteToVectorStoreLowering(MLIRContext *context)
|
||||
: OpRewritePattern<vector::TransferWriteOp>(context) {}
|
||||
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO: Support non-minor-identity maps
|
||||
if (!write.permutation_map().isMinorIdentity())
|
||||
return failure();
|
||||
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
// `vector.store` supports vector types as memref's elements only when the
|
||||
// type of the vector value being written is the same as the element type.
|
||||
if (memRefType.getElementType().isa<VectorType>() &&
|
||||
memRefType.getElementType() != write.getVectorType())
|
||||
return failure();
|
||||
// Only the default layout is supported by `vector.store`.
|
||||
// TODO: Support non-default layouts.
|
||||
if (!memRefType.getAffineMaps().empty())
|
||||
return failure();
|
||||
// TODO: When masking is required, we can create a MaskedStoreOp
|
||||
if (write.hasMaskedDim())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<vector::StoreOp>(
|
||||
write, write.vector(), write.source(), write.indices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Trims leading one dimensions from `oldType` and returns the result type.
|
||||
// Returns `vector<1xT>` if `oldType` only has one element.
|
||||
static VectorType trimLeadingOneDims(VectorType oldType) {
|
||||
|
@ -3201,3 +3312,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
|||
ContractionOpToOuterProductOpLowering>(parameters, context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorTransferLoweringPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<TransferReadToVectorLoadLowering,
|
||||
TransferWriteToVectorStoreLowering>(context);
|
||||
}
|
||||
|
|
|
@ -110,6 +110,35 @@ bool AffineMap::isMinorIdentity() const {
|
|||
getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
|
||||
}
|
||||
|
||||
/// Returns true if this affine map is a minor identity up to broadcasted
|
||||
/// dimensions which are indicated by value 0 in the result.
|
||||
bool AffineMap::isMinorIdentityWithBroadcasting(
|
||||
SmallVectorImpl<unsigned> *broadcastedDims) const {
|
||||
if (broadcastedDims)
|
||||
broadcastedDims->clear();
|
||||
if (getNumDims() < getNumResults())
|
||||
return false;
|
||||
unsigned suffixStart = getNumDims() - getNumResults();
|
||||
for (auto idxAndExpr : llvm::enumerate(getResults())) {
|
||||
unsigned resIdx = idxAndExpr.index();
|
||||
AffineExpr expr = idxAndExpr.value();
|
||||
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
|
||||
// Each result may be either a constant 0 (broadcasted dimension).
|
||||
if (constExpr.getValue() != 0)
|
||||
return false;
|
||||
if (broadcastedDims)
|
||||
broadcastedDims->push_back(resIdx);
|
||||
} else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
|
||||
// Or it may be the input dimension corresponding to this result position.
|
||||
if (dimExpr.getPosition() != suffixStart + resIdx)
|
||||
return false;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns an AffineMap representing a permutation.
|
||||
AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
|
||||
MLIRContext *context) {
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s
|
||||
|
||||
// transfer_read/write are lowered to vector.load/store
|
||||
// CHECK-LABEL: func @transfer_to_load(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
|
||||
// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32>, vector<4xf32>
|
||||
vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// n-D results are also supported.
|
||||
// CHECK-LABEL: func @transfer_2D(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
|
||||
// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @transfer_2D(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false]} : memref<8x8xf32>, vector<2x4xf32>
|
||||
vector.transfer_write %res, %mem[%i, %i] {masked = [false, false]} : vector<2x4xf32>, memref<8x8xf32>
|
||||
return %res : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Vector element types are supported when the result has the same type.
|
||||
// CHECK-LABEL: func @transfer_vector_element(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
|
||||
// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<2x4xf32> {
|
||||
%cf0 = constant dense<0.0> : vector<2x4xf32>
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
|
||||
vector.transfer_write %res, %mem[%i, %i] : vector<2x4xf32>, memref<8x8xvector<2x4xf32>>
|
||||
return %res : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Vector element types are not supported yet when the result has a
|
||||
// different type.
|
||||
// CHECK-LABEL: func @transfer_vector_element_different_types(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1x2x4xf32> {
|
||||
// CHECK-NEXT: %[[CF0:.*]] = constant dense<0.000000e+00> : vector<2x4xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<1x2x4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @transfer_vector_element_different_types(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<1x2x4xf32> {
|
||||
%cf0 = constant dense<0.0> : vector<2x4xf32>
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32>
|
||||
vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>>
|
||||
return %res : vector<1x2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: transfer_read/write cannot be lowered because there is an unmasked
|
||||
// dimension.
|
||||
// CHECK-LABEL: func @transfer_2D_masked(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
|
||||
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @transfer_2D_masked(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32>
|
||||
vector.transfer_write %res, %mem[%i, %i] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32>
|
||||
return %res : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: transfer_read/write cannot be lowered because they are masked.
|
||||
// CHECK-LABEL: func @transfer_masked(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
|
||||
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] : memref<8x8xf32>, vector<4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : vector<4xf32>, memref<8x8xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @transfer_masked(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xf32>, vector<4xf32>
|
||||
vector.transfer_write %res, %mem[%i, %i] : vector<4xf32>, memref<8x8xf32>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: transfer_read/write cannot be lowered to vector.load/store because the
|
||||
// memref has a non-default layout.
|
||||
// CHECK-LABEL: func @transfer_nondefault_layout(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
|
||||
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xf32, #{{.*}}>, vector<4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #{{.*}}>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
#layout = affine_map<(d0, d1) -> (d0*16 + d1)>
|
||||
func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : index) -> vector<4xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32, #layout>, vector<4xf32>
|
||||
vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #layout>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: transfer_read/write cannot be lowered to vector.load/store yet when the
|
||||
// permutation map is not the minor identity map (up to broadcasting).
|
||||
// CHECK-LABEL: func @transfer_perm_map(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
|
||||
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32>
|
||||
vector.transfer_write %res, %mem[%i, %i] {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Lowering of transfer_read with broadcasting is supported (note that a `load`
|
||||
// is generated instead of a `vector.load`).
|
||||
// CHECK-LABEL: func @transfer_broadcasting(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
#broadcast = affine_map<(d0, d1) -> (0)>
|
||||
func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4xf32>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// An example with two broadcasted dimensions.
|
||||
// CHECK-LABEL: func @transfer_broadcasting_2D(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
#broadcast = affine_map<(d0, d1) -> (0, 0)>
|
||||
func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vector<4x4xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4x4xf32>
|
||||
return %res : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// More complex broadcasting case (here a `vector.load` is generated).
|
||||
// CHECK-LABEL: func @transfer_broadcasting_complex(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
#broadcast = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)>
|
||||
func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
|
||||
return %res : vector<3x2x4x5xf32>
|
||||
}
|
|
@ -361,6 +361,15 @@ struct TestVectorTransferOpt
|
|||
void runOnFunction() override { transferOpflowOpt(getFunction()); }
|
||||
};
|
||||
|
||||
struct TestVectorTransferLoweringPatterns
|
||||
: public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
populateVectorTransferLoweringPatterns(patterns, &getContext());
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
|
@ -403,6 +412,10 @@ void registerTestVectorConversions() {
|
|||
PassRegistration<TestVectorTransferOpt> transferOpOpt(
|
||||
"test-vector-transferop-opt",
|
||||
"Test optimization transformations for transfer ops");
|
||||
|
||||
PassRegistration<TestVectorTransferLoweringPatterns> transferOpLoweringPass(
|
||||
"test-vector-transfer-lowering-patterns",
|
||||
"Test conversion patterns to lower transfer ops to other vector ops");
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue