[mlir][Vector] Lowering of transfer_read/write to vector.load/store

This patch introduces progressive lowering patterns for rewriting
vector.transfer_read/write to vector.load/store and vector.broadcast
in certain supported cases.

Reviewed By: dcaballe, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D97822
This commit is contained in:
Sergei Grechanik 2021-03-11 18:07:07 -08:00
parent 5eaeb0fa67
commit fd2b08969b
6 changed files with 383 additions and 0 deletions

View File

@ -85,6 +85,13 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
MLIRContext *context);
/// Collect a set of transfer read/write lowering patterns.
///
/// These patterns lower transfer ops to simpler ops like `vector.load`,
/// `vector.store` and `vector.broadcast`.
void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns,
MLIRContext *context);
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr

View File

@ -104,6 +104,15 @@ public:
/// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const;
/// Returns true if this affine map is a minor identity up to broadcasted
/// dimensions which are indicated by value 0 in the result. If
/// `broadcastedDims` is not null, it will be populated with the indices of
/// the broadcasted dimensions in the result array.
/// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d2, 0, d4)>
/// (`broadcastedDims` will contain [0, 2])
bool isMinorIdentityWithBroadcasting(
SmallVectorImpl<unsigned> *broadcastedDims = nullptr) const;
/// Returns true if this affine map is an empty map, i.e., () -> ().
bool isEmpty() const;

View File

@ -37,6 +37,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@ -2729,6 +2730,116 @@ struct TransferWriteInsertPattern
}
};
/// Progressive lowering of transfer_read. This pattern supports lowering of
/// `vector.transfer_read` to a combination of `vector.load` and
/// `vector.broadcast` if all of the following hold:
/// - The op reads from a memref with the default layout.
/// - Masking is not required.
/// - If the memref's element type is a vector type then it coincides with the
/// result type.
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context)
: OpRewritePattern<vector::TransferReadOp>(context) {}
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
SmallVector<unsigned, 4> broadcastedDims;
// TODO: Support permutations.
if (!read.permutation_map().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return failure();
auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
// If there is broadcasting involved then we first load the unbroadcasted
// vector, and then broadcast it with `vector.broadcast`.
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
vectorShape.end());
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
VectorType unbroadcastedVectorType = VectorType::get(
unbroadcastedVectorShape, read.getVectorType().getElementType());
// `vector.load` supports vector types as memref's elements only when the
// resulting vector type is the same as the element type.
if (memRefType.getElementType().isa<VectorType>() &&
memRefType.getElementType() != unbroadcastedVectorType)
return failure();
// Only the default layout is supported by `vector.load`.
// TODO: Support non-default layouts.
if (!memRefType.getAffineMaps().empty())
return failure();
// TODO: When masking is required, we can create a MaskedLoadOp
if (read.hasMaskedDim())
return failure();
Operation *loadOp;
if (!broadcastedDims.empty() &&
unbroadcastedVectorType.getNumElements() == 1) {
// If broadcasting is required and the number of loaded elements is 1 then
// we can create `std.load` instead of `vector.load`.
loadOp = rewriter.create<mlir::LoadOp>(read.getLoc(), read.source(),
read.indices());
} else {
// Otherwise create `vector.load`.
loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
unbroadcastedVectorType,
read.source(), read.indices());
}
// Insert a broadcasting op if required.
if (!broadcastedDims.empty()) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
read, read.getVectorType(), loadOp->getResult(0));
} else {
rewriter.replaceOp(read, loadOp->getResult(0));
}
return success();
}
};
/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store` if all of the following hold:
/// - The op writes to a memref with the default layout.
/// - Masking is not required.
/// - If the memref's element type is a vector type then it coincides with the
/// type of the written value.
/// - The permutation map is the minor identity map (neither permutation nor
/// broadcasting is allowed).
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context)
: OpRewritePattern<vector::TransferWriteOp>(context) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
// TODO: Support non-minor-identity maps
if (!write.permutation_map().isMinorIdentity())
return failure();
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
// `vector.store` supports vector types as memref's elements only when the
// type of the vector value being written is the same as the element type.
if (memRefType.getElementType().isa<VectorType>() &&
memRefType.getElementType() != write.getVectorType())
return failure();
// Only the default layout is supported by `vector.store`.
// TODO: Support non-default layouts.
if (!memRefType.getAffineMaps().empty())
return failure();
// TODO: When masking is required, we can create a MaskedStoreOp
if (write.hasMaskedDim())
return failure();
rewriter.replaceOpWithNewOp<vector::StoreOp>(
write, write.vector(), write.source(), write.indices());
return success();
}
};
// Trims leading one dimensions from `oldType` and returns the result type.
// Returns `vector<1xT>` if `oldType` only has one element.
static VectorType trimLeadingOneDims(VectorType oldType) {
@ -3201,3 +3312,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
ContractionOpToOuterProductOpLowering>(parameters, context);
// clang-format on
}
void mlir::vector::populateVectorTransferLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(context);
}

View File

@ -110,6 +110,35 @@ bool AffineMap::isMinorIdentity() const {
getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
}
/// Returns true if this affine map is a minor identity up to broadcasted
/// dimensions which are indicated by value 0 in the result.
bool AffineMap::isMinorIdentityWithBroadcasting(
SmallVectorImpl<unsigned> *broadcastedDims) const {
if (broadcastedDims)
broadcastedDims->clear();
if (getNumDims() < getNumResults())
return false;
unsigned suffixStart = getNumDims() - getNumResults();
for (auto idxAndExpr : llvm::enumerate(getResults())) {
unsigned resIdx = idxAndExpr.index();
AffineExpr expr = idxAndExpr.value();
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
// Each result may be either a constant 0 (broadcasted dimension).
if (constExpr.getValue() != 0)
return false;
if (broadcastedDims)
broadcastedDims->push_back(resIdx);
} else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
// Or it may be the input dimension corresponding to this result position.
if (dimExpr.getPosition() != suffixStart + resIdx)
return false;
} else {
return false;
}
}
return true;
}
/// Returns an AffineMap representing a permutation.
AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context) {

View File

@ -0,0 +1,208 @@
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s
// transfer_read/write are lowered to vector.load/store
// CHECK-LABEL: func @transfer_to_load(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
// CHECK-NEXT: }
func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32>, vector<4xf32>
vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32>
return %res : vector<4xf32>
}
// -----
// n-D results are also supported.
// CHECK-LABEL: func @transfer_2D(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
// CHECK-NEXT: }
func @transfer_2D(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false]} : memref<8x8xf32>, vector<2x4xf32>
vector.transfer_write %res, %mem[%i, %i] {masked = [false, false]} : vector<2x4xf32>, memref<8x8xf32>
return %res : vector<2x4xf32>
}
// -----
// Vector element types are supported when the result has the same type.
// CHECK-LABEL: func @transfer_vector_element(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
// CHECK-NEXT: }
func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<2x4xf32> {
%cf0 = constant dense<0.0> : vector<2x4xf32>
%res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
vector.transfer_write %res, %mem[%i, %i] : vector<2x4xf32>, memref<8x8xvector<2x4xf32>>
return %res : vector<2x4xf32>
}
// -----
// TODO: Vector element types are not supported yet when the result has a
// different type.
// CHECK-LABEL: func @transfer_vector_element_different_types(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1x2x4xf32> {
// CHECK-NEXT: %[[CF0:.*]] = constant dense<0.000000e+00> : vector<2x4xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32>
// CHECK-NEXT: vector.transfer_write %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>>
// CHECK-NEXT: return %[[RES]] : vector<1x2x4xf32>
// CHECK-NEXT: }
func @transfer_vector_element_different_types(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<1x2x4xf32> {
%cf0 = constant dense<0.0> : vector<2x4xf32>
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32>
vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>>
return %res : vector<1x2x4xf32>
}
// -----
// TODO: transfer_read/write cannot be lowered because there is an unmasked
// dimension.
// CHECK-LABEL: func @transfer_2D_masked(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32>
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32>
// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
// CHECK-NEXT: }
func @transfer_2D_masked(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32>
vector.transfer_write %res, %mem[%i, %i] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32>
return %res : vector<2x4xf32>
}
// -----
// TODO: transfer_read/write cannot be lowered because they are masked.
// CHECK-LABEL: func @transfer_masked(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] : memref<8x8xf32>, vector<4xf32>
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : vector<4xf32>, memref<8x8xf32>
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
// CHECK-NEXT: }
func @transfer_masked(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xf32>, vector<4xf32>
vector.transfer_write %res, %mem[%i, %i] : vector<4xf32>, memref<8x8xf32>
return %res : vector<4xf32>
}
// -----
// TODO: transfer_read/write cannot be lowered to vector.load/store because the
// memref has a non-default layout.
// CHECK-LABEL: func @transfer_nondefault_layout(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xf32, #{{.*}}>, vector<4xf32>
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #{{.*}}>
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
// CHECK-NEXT: }
#layout = affine_map<(d0, d1) -> (d0*16 + d1)>
func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : index) -> vector<4xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32, #layout>, vector<4xf32>
vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #layout>
return %res : vector<4xf32>
}
// -----
// TODO: transfer_read/write cannot be lowered to vector.load/store yet when the
// permutation map is not the minor identity map (up to broadcasting).
// CHECK-LABEL: func @transfer_perm_map(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32>
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32>
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
// CHECK-NEXT: }
func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32>
vector.transfer_write %res, %mem[%i, %i] {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32>
return %res : vector<4xf32>
}
// -----
// Lowering of transfer_read with broadcasting is supported (note that a `load`
// is generated instead of a `vector.load`).
// CHECK-LABEL: func @transfer_broadcasting(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
// CHECK-NEXT: %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
// CHECK-NEXT: }
#broadcast = affine_map<(d0, d1) -> (0)>
func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4xf32>
return %res : vector<4xf32>
}
// -----
// An example with two broadcasted dimensions.
// CHECK-LABEL: func @transfer_broadcasting_2D(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> {
// CHECK-NEXT: %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
// CHECK-NEXT: }
#broadcast = affine_map<(d0, d1) -> (0, 0)>
func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vector<4x4xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4x4xf32>
return %res : vector<4x4xf32>
}
// -----
// More complex broadcasting case (here a `vector.load` is generated).
// CHECK-LABEL: func @transfer_broadcasting_complex(
// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> {
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32>
// CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32>
// CHECK-NEXT: }
#broadcast = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)>
func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
return %res : vector<3x2x4x5xf32>
}

View File

@ -361,6 +361,15 @@ struct TestVectorTransferOpt
void runOnFunction() override { transferOpflowOpt(getFunction()); }
};
struct TestVectorTransferLoweringPatterns
: public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
populateVectorTransferLoweringPatterns(patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
} // end anonymous namespace
namespace mlir {
@ -403,6 +412,10 @@ void registerTestVectorConversions() {
PassRegistration<TestVectorTransferOpt> transferOpOpt(
"test-vector-transferop-opt",
"Test optimization transformations for transfer ops");
PassRegistration<TestVectorTransferLoweringPatterns> transferOpLoweringPass(
"test-vector-transfer-lowering-patterns",
"Test conversion patterns to lower transfer ops to other vector ops");
}
} // namespace test
} // namespace mlir