forked from OSchip/llvm-project
[mlir][vector] Refactor TransferReadToVectorLoadLowering
* TransferReadToVectorLoadLowering no longer generates memref.load ops. * Add new pattern VectorLoadToMemrefLoadLowering that lowers scalar vector.loads to memref.loads. * Add vector::BroadcastOp canonicalization pattern that folds broadcast chains. Differential Revision: https://reviews.llvm.org/D106117
This commit is contained in:
parent
f4ec30d808
commit
4a3defa629
|
@ -1346,11 +1346,25 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Fold broadcast1(broadcast2(x)) into broadcast1(x).
|
||||
struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
|
||||
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
|
||||
if (!srcBroadcast)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<BroadcastOp>(
|
||||
broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<BroadcastToShapeCast>(context);
|
||||
results.add<BroadcastToShapeCast, BroadcastFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2510,32 +2510,39 @@ struct TransferReadToVectorLoadLowering
|
|||
return failure();
|
||||
if (read.mask())
|
||||
return failure();
|
||||
Operation *loadOp;
|
||||
if (!broadcastedDims.empty() &&
|
||||
unbroadcastedVectorType.getNumElements() == 1) {
|
||||
// If broadcasting is required and the number of loaded elements is 1 then
|
||||
// we can create `memref.load` instead of `vector.load`.
|
||||
loadOp = rewriter.create<memref::LoadOp>(read.getLoc(), read.source(),
|
||||
read.indices());
|
||||
} else {
|
||||
// Otherwise create `vector.load`.
|
||||
loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
|
||||
unbroadcastedVectorType,
|
||||
read.source(), read.indices());
|
||||
}
|
||||
|
||||
auto loadOp = rewriter.create<vector::LoadOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.source(), read.indices());
|
||||
// Insert a broadcasting op if required.
|
||||
if (!broadcastedDims.empty()) {
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
read, read.getVectorType(), loadOp->getResult(0));
|
||||
read, read.getVectorType(), loadOp.result());
|
||||
} else {
|
||||
rewriter.replaceOp(read, loadOp->getResult(0));
|
||||
rewriter.replaceOp(read, loadOp.result());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace a scalar vector.load with a memref.load.
|
||||
struct VectorLoadToMemrefLoadLowering
|
||||
: public OpRewritePattern<vector::LoadOp> {
|
||||
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vecType = loadOp.getVectorType();
|
||||
if (vecType.getNumElements() != 1)
|
||||
return failure();
|
||||
auto memrefLoad = rewriter.create<memref::LoadOp>(
|
||||
loadOp.getLoc(), loadOp.base(), loadOp.indices());
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of transfer_write. This pattern supports lowering of
|
||||
/// `vector.transfer_write` to `vector.store` if all of the following hold:
|
||||
/// - The op writes to a memref with the default layout.
|
||||
|
@ -3674,8 +3681,9 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
|
|||
|
||||
void mlir::vector::populateVectorTransferLoweringPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<TransferReadToVectorLoadLowering,
|
||||
TransferWriteToVectorStoreLowering>(patterns.getContext());
|
||||
patterns
|
||||
.add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
|
||||
VectorLoadToMemrefLoadLowering>(patterns.getContext());
|
||||
populateVectorTransferPermutationMapLoweringPatterns(patterns);
|
||||
}
|
||||
|
||||
|
|
|
@ -613,6 +613,18 @@ func @broadcast_folding2() -> vector<4x16xi32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_consecutive_broadcasts(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i32
|
||||
// CHECK: %[[RESULT:.*]] = vector.broadcast %[[ARG0]] : i32 to vector<4x16xi32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
|
||||
%1 = vector.broadcast %a : i32 to vector<16xi32>
|
||||
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
|
||||
return %2 : vector<4x16xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: shape_cast_constant
|
||||
// CHECK-DAG: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
|
||||
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
|
||||
|
||||
// transfer_read/write are lowered to vector.load/store
|
||||
// CHECK-LABEL: func @transfer_to_load(
|
||||
|
@ -174,6 +174,21 @@ func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_scalar(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<1xf32>
|
||||
// CHECK-NEXT: }
|
||||
func @transfer_scalar(%mem : memref<?x?xf32>, %i : index) -> vector<1xf32> {
|
||||
%cf0 = constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<1xf32>
|
||||
return %res : vector<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// An example with two broadcasted dimensions.
|
||||
// CHECK-LABEL: func @transfer_broadcasting_2D(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
|
|
Loading…
Reference in New Issue