forked from OSchip/llvm-project
Support affine.load/store ops in fold-memref-subview-ops pass
Support affine.load/store ops in fold-memref-subview ops pass. The existing pass just "inlines" the subview operation on load/stores by inserting affine.apply ops in front of the memref load/store ops: this is by design always consistent with the semantics on affine.load/store ops and the same would work even more naturally/intuitively with the latter. Differential Revision: https://reviews.llvm.org/D118565
This commit is contained in:
parent
73cfa982ba
commit
f8a2cd67b9
|
@ -90,12 +90,13 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helpers to access the memref operand for each op.
|
/// Helpers to access the memref operand for each op.
|
||||||
static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
|
template <typename LoadOrStoreOpTy>
|
||||||
|
static Value getMemRefOperand(LoadOrStoreOpTy op) {
|
||||||
|
return op.memref();
|
||||||
|
}
|
||||||
|
|
||||||
static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
|
static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
|
||||||
|
|
||||||
static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
|
|
||||||
|
|
||||||
static Value getMemRefOperand(vector::TransferWriteOp op) {
|
static Value getMemRefOperand(vector::TransferWriteOp op) {
|
||||||
return op.source();
|
return op.source();
|
||||||
}
|
}
|
||||||
|
@ -154,12 +155,12 @@ private:
|
||||||
PatternRewriter &rewriter) const;
|
PatternRewriter &rewriter) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename LoadOpTy>
|
||||||
void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
|
void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
|
||||||
memref::LoadOp loadOp, memref::SubViewOp subViewOp,
|
LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
|
||||||
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
|
PatternRewriter &rewriter) const {
|
||||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
|
rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(),
|
||||||
sourceIndices);
|
sourceIndices);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -178,12 +179,12 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
|
||||||
/*mask=*/Value(), transferReadOp.in_boundsAttr());
|
/*mask=*/Value(), transferReadOp.in_boundsAttr());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <typename StoreOpTy>
|
||||||
void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
|
void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
|
||||||
memref::StoreOp storeOp, memref::SubViewOp subViewOp,
|
StoreOpTy storeOp, memref::SubViewOp subViewOp,
|
||||||
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
|
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
|
||||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(),
|
||||||
storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
|
subViewOp.source(), sourceIndices);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -239,8 +240,10 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
|
||||||
}
|
}
|
||||||
|
|
||||||
void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
|
void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
|
||||||
patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
|
patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
|
||||||
|
LoadOpOfSubViewFolder<memref::LoadOp>,
|
||||||
LoadOpOfSubViewFolder<vector::TransferReadOp>,
|
LoadOpOfSubViewFolder<vector::TransferReadOp>,
|
||||||
|
StoreOpOfSubViewFolder<AffineStoreOp>,
|
||||||
StoreOpOfSubViewFolder<memref::StoreOp>,
|
StoreOpOfSubViewFolder<memref::StoreOp>,
|
||||||
StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
|
StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
|
|
|
@ -251,3 +251,24 @@ func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
|
||||||
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG7]])[%[[ARG3]]]
|
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG7]])[%[[ARG3]]]
|
||||||
// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]]
|
// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]]
|
||||||
// CHECK-SAME: {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, memref<?x?x?xf32
|
// CHECK-SAME: {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, memref<?x?x?xf32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test with affine.load/store ops. We only do a basic test here since the
|
||||||
|
// logic is identical to that with memref.load/store ops. The same affine.apply
|
||||||
|
// ops would be generated.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_static_stride_subview_with_affine_load_store
|
||||||
|
func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
|
||||||
|
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
|
||||||
|
%1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
|
||||||
|
// CHECK-NEXT: affine.apply
|
||||||
|
// CHECK-NEXT: affine.apply
|
||||||
|
// CHECK-NEXT: affine.load
|
||||||
|
affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
|
||||||
|
// CHECK-NEXT: affine.apply
|
||||||
|
// CHECK-NEXT: affine.apply
|
||||||
|
// CHECK-NEXT: affine.store
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
return %1 : f32
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue