[mlir] simplify affine maps and operands in affine.min/max

Affine dialect already has a map+operand simplification infrastructure in
place. Plug the recently added affine.min/max operations into this
infrastructure and add a simple test. More complex behavior of the simplifier
is already tested by other ops.

Addresses https://bugs.llvm.org/show_bug.cgi?id=45008.

Differential Revision: https://reviews.llvm.org/D75058
This commit is contained in:
Alex Zinenko 2020-02-24 17:43:01 +01:00
parent 3a1b34ff69
commit 5f9b543e8e
3 changed files with 48 additions and 6 deletions

View File

@ -238,13 +238,25 @@ class AffineMinMaxOpBase<string mnemonic, list<OpTrait> traits = []> :
Op<Affine_Dialect, mnemonic, traits> {
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
let results = (outs Index);
let builders = [
OpBuilder<"Builder *builder, OperationState &result, AffineMap affineMap, "
"ValueRange mapOperands",
[{
build(builder, result, builder->getIndexType(), affineMap, mapOperands);
}]>
];
let extraClassDeclaration = [{
static StringRef getMapAttrName() { return "map"; }
AffineMap getAffineMap() { return map(); }
ValueRange getMapOperands() { return operands(); }
}];
let verifier = [{ return ::verifyAffineMinMaxOp(*this); }];
let printer = [{ return ::printAffineMinMaxOp(p, *this); }];
let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def AffineMinOp : AffineMinMaxOpBase<"min", [NoSideEffect]> {

View File

@ -763,7 +763,9 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
std::is_same<AffineOpTy, AffinePrefetchOp>::value ||
std::is_same<AffineOpTy, AffineStoreOp>::value ||
std::is_same<AffineOpTy, AffineApplyOp>::value,
std::is_same<AffineOpTy, AffineApplyOp>::value ||
std::is_same<AffineOpTy, AffineMinOp>::value ||
std::is_same<AffineOpTy, AffineMaxOp>::value,
"affine load/store/apply op expected");
auto map = affineOp.getAffineMap();
AffineMap oldMap = map;
@ -804,11 +806,13 @@ void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
rewriter.replaceOpWithNewOp<AffineStoreOp>(
store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
}
template <>
void SimplifyAffineOp<AffineApplyOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map,
// Generic version for ops that don't have extra operands.
template <typename AffineOpTy>
void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, mapOperands);
rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
}
} // end anonymous namespace.
@ -2016,6 +2020,11 @@ OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
return results[minIndex];
}
void AffineMinOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<SimplifyAffineOp<AffineMinOp>>(context);
}
//===----------------------------------------------------------------------===//
// AffineMaxOp
//===----------------------------------------------------------------------===//
@ -2046,6 +2055,11 @@ OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
return results[maxIndex];
}
void AffineMaxOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context);
}
//===----------------------------------------------------------------------===//
// AffinePrefetchOp
//===----------------------------------------------------------------------===//

View File

@ -552,3 +552,19 @@ func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK-NEXT: return
return
}
// -----
// CHECK: #[[map:.*]] = affine_map<(d0, d1) -> (d0, d1 - 2)>
func @affine_min(%arg0: index) {
affine.for %i = 0 to %arg0 {
affine.for %j = 0 to %arg0 {
%c2 = constant 2 : index
// CHECK: affine.min #[[map]]
%0 = affine.min affine_map<(d0,d1,d2)->(d0, d1 - d2)>(%i, %j, %c2)
"consumer"(%0) : (index) -> ()
}
}
return
}