[mlir][vector] Refactor TransferReadToVectorLoadLowering

* TransferReadToVectorLoadLowering no longer generates memref.load ops.
* Add new pattern VectorLoadToMemrefLoadLowering that lowers scalar vector.loads to memref.loads.
* Add vector::BroadcastOp canonicalization pattern that folds broadcast chains.

Differential Revision: https://reviews.llvm.org/D106117
This commit is contained in:
Matthias Springer 2021-07-17 13:52:20 +09:00
parent f4ec30d808
commit 4a3defa629
4 changed files with 68 additions and 19 deletions

View File

@ -1346,11 +1346,25 @@ public:
}
};
// Fold broadcast1(broadcast2(x)) into broadcast1(x).
struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
if (!srcBroadcast)
return failure();
rewriter.replaceOpWithNewOp<BroadcastOp>(
broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
return success();
}
};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<BroadcastToShapeCast>(context);
results.add<BroadcastToShapeCast, BroadcastFolder>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -2510,32 +2510,39 @@ struct TransferReadToVectorLoadLowering
return failure();
if (read.mask())
return failure();
Operation *loadOp;
if (!broadcastedDims.empty() &&
unbroadcastedVectorType.getNumElements() == 1) {
// If broadcasting is required and the number of loaded elements is 1 then
// we can create `memref.load` instead of `vector.load`.
loadOp = rewriter.create<memref::LoadOp>(read.getLoc(), read.source(),
read.indices());
} else {
// Otherwise create `vector.load`.
loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
unbroadcastedVectorType,
read.source(), read.indices());
}
auto loadOp = rewriter.create<vector::LoadOp>(
read.getLoc(), unbroadcastedVectorType, read.source(), read.indices());
// Insert a broadcasting op if required.
if (!broadcastedDims.empty()) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
read, read.getVectorType(), loadOp->getResult(0));
read, read.getVectorType(), loadOp.result());
} else {
rewriter.replaceOp(read, loadOp->getResult(0));
rewriter.replaceOp(read, loadOp.result());
}
return success();
}
};
/// Replace a scalar vector.load with a memref.load.
struct VectorLoadToMemrefLoadLowering
: public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
auto vecType = loadOp.getVectorType();
if (vecType.getNumElements() != 1)
return failure();
auto memrefLoad = rewriter.create<memref::LoadOp>(
loadOp.getLoc(), loadOp.base(), loadOp.indices());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad);
return success();
}
};
/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store` if all of the following hold:
/// - The op writes to a memref with the default layout.
@ -3674,8 +3681,9 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext());
patterns
.add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
VectorLoadToMemrefLoadLowering>(patterns.getContext());
populateVectorTransferPermutationMapLoweringPatterns(patterns);
}

View File

@ -613,6 +613,18 @@ func @broadcast_folding2() -> vector<4x16xi32> {
// -----
// CHECK-LABEL: @fold_consecutive_broadcasts(
// CHECK-SAME: %[[ARG0:.*]]: i32
// CHECK: %[[RESULT:.*]] = vector.broadcast %[[ARG0]] : i32 to vector<4x16xi32>
// CHECK: return %[[RESULT]]
func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
%1 = vector.broadcast %a : i32 to vector<16xi32>
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
return %2 : vector<4x16xi32>
}
// -----
// CHECK-LABEL: shape_cast_constant
// CHECK-DAG: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
// transfer_read/write are lowered to vector.load/store
// CHECK-LABEL: func @transfer_to_load(
@ -174,6 +174,21 @@ func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
// -----
// CHECK-LABEL: func @transfer_scalar(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> {
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
// CHECK-NEXT: return %[[RES]] : vector<1xf32>
// CHECK-NEXT: }
func @transfer_scalar(%mem : memref<?x?xf32>, %i : index) -> vector<1xf32> {
%cf0 = constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<1xf32>
return %res : vector<1xf32>
}
// -----
// An example with two broadcasted dimensions.
// CHECK-LABEL: func @transfer_broadcasting_2D(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,