[MLIR][Arith] Canonicalize and/or with ext

Replace and(ext(a),ext(b)) with ext(and(a,b)). This both reduces one instruction, and results in the computation (and/or) being done on a smaller type.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D116519
This commit is contained in:
William S. Moses 2022-01-03 00:38:41 -05:00
parent 78389de4d3
commit 834cf3be22
4 changed files with 94 additions and 0 deletions

View File

@ -437,6 +437,7 @@ def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative, Idempotent]> {
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@ -465,6 +466,7 @@ def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative, Idempotent]> {
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -136,4 +136,32 @@ def BitcastOfBitcast :
def ExtSIOfExtUI :
Pat<(Arith_ExtSIOp (Arith_ExtUIOp $x)), (Arith_ExtUIOp $x)>;
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
// and extui(x), extui(y) -> extui(and(x,y))
def AndOfExtUI :
Pat<(Arith_AndIOp (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)), (Arith_ExtUIOp (Arith_AndIOp $x, $y)),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
// and extsi(x), extsi(y) -> extsi(and(x,y))
def AndOfExtSI :
Pat<(Arith_AndIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_AndIOp $x, $y)),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//
// or extui(x), extui(y) -> extui(or(x,y))
def OrOfExtUI :
Pat<(Arith_OrIOp (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)), (Arith_ExtUIOp (Arith_OrIOp $x, $y)),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
// or extsi(x), extsi(y) -> extsi(or(x,y))
def OrOfExtSI :
Pat<(Arith_OrIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_OrIOp $x, $y)),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
#endif // ARITHMETIC_PATTERNS

View File

@ -901,6 +901,24 @@ bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
void arith::AndIOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<AndOfExtUI, AndOfExtSI>(context);
}
//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//
void arith::OrIOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<OrOfExtUI, OrOfExtSI>(context);
}
//===----------------------------------------------------------------------===//
// Verifiers for casts between integers and floats.
//===----------------------------------------------------------------------===//

View File

@ -99,6 +99,52 @@ func @extSIOfExtSI(%arg0: i1) -> i64 {
// -----
// CHECK-LABEL: @andOfExtSI
// CHECK: %[[comb:.+]] = arith.andi %arg0, %arg1 : i8
// CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64
// CHECK: return %[[ext]]
func @andOfExtSI(%arg0: i8, %arg1: i8) -> i64 {
%ext0 = arith.extsi %arg0 : i8 to i64
%ext1 = arith.extsi %arg1 : i8 to i64
%res = arith.andi %ext0, %ext1 : i64
return %res : i64
}
// CHECK-LABEL: @andOfExtUI
// CHECK: %[[comb:.+]] = arith.andi %arg0, %arg1 : i8
// CHECK: %[[ext:.+]] = arith.extui %[[comb]] : i8 to i64
// CHECK: return %[[ext]]
func @andOfExtUI(%arg0: i8, %arg1: i8) -> i64 {
%ext0 = arith.extui %arg0 : i8 to i64
%ext1 = arith.extui %arg1 : i8 to i64
%res = arith.andi %ext0, %ext1 : i64
return %res : i64
}
// CHECK-LABEL: @orOfExtSI
// CHECK: %[[comb:.+]] = arith.ori %arg0, %arg1 : i8
// CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64
// CHECK: return %[[ext]]
func @orOfExtSI(%arg0: i8, %arg1: i8) -> i64 {
%ext0 = arith.extsi %arg0 : i8 to i64
%ext1 = arith.extsi %arg1 : i8 to i64
%res = arith.ori %ext0, %ext1 : i64
return %res : i64
}
// CHECK-LABEL: @orOfExtUI
// CHECK: %[[comb:.+]] = arith.ori %arg0, %arg1 : i8
// CHECK: %[[ext:.+]] = arith.extui %[[comb]] : i8 to i64
// CHECK: return %[[ext]]
func @orOfExtUI(%arg0: i8, %arg1: i8) -> i64 {
%ext0 = arith.extui %arg0 : i8 to i64
%ext1 = arith.extui %arg1 : i8 to i64
%res = arith.ori %ext0, %ext1 : i64
return %res : i64
}
// -----
// CHECK-LABEL: @indexCastOfSignExtend
// CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index
// CHECK: return %[[res]]