[mlir][affine] Add single result affine.min/max -> affine.apply canonicalization.

Differential Revision: https://reviews.llvm.org/D106014
This commit is contained in:
Nicolas Vasilache 2021-07-14 20:33:29 +00:00
parent 7e496c29e2
commit df538fdaa9
2 changed files with 36 additions and 3 deletions

View File

@ -2538,6 +2538,20 @@ struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
}
};
template <typename T>
struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T affineOp,
PatternRewriter &rewriter) const override {
if (affineOp.map().getNumResults() != 1)
return failure();
rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.map(),
affineOp.getOperands());
return success();
}
};
//===----------------------------------------------------------------------===//
// AffineMinOp
//===----------------------------------------------------------------------===//
@ -2551,7 +2565,8 @@ OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<DeduplicateAffineMinMaxExpressions<AffineMinOp>,
patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
DeduplicateAffineMinMaxExpressions<AffineMinOp>,
MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>>(
context);
}
@ -2569,7 +2584,8 @@ OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>>(
context);
}

View File

@ -870,7 +870,6 @@ func @dont_merge_affine_max_if_not_single_dim(%i0: index, %i1: index, %i2: index
return %1: index
}
// -----
// CHECK-LABEL: func @dont_merge_affine_max_if_not_single_sym
@ -936,3 +935,21 @@ func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
return
}
// -----
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 + 16)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK: func @canonicalize_single_min_max
// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index)
func @canonicalize_single_min_max(%i0: index, %i1: index) -> (index, index) {
// CHECK-NOT: affine.min
// CHECK-NEXT: affine.apply #[[$MAP0]]()[%[[I0]]]
%0 = affine.min affine_map<()[s0] -> (s0 + 16)> ()[%i0]
// CHECK-NOT: affine.max
// CHECK-NEXT: affine.apply #[[$MAP1]]()[%[[I1]]]
%1 = affine.min affine_map<()[s0] -> (s0 * 4)> ()[%i1]
return %0, %1: index, index
}