[MLIR] Replace a not of a comparison with appropriate comparison

Differential Revision: https://reviews.llvm.org/D101710
This commit is contained in:
William S. Moses 2021-05-02 00:16:41 -04:00
parent 54bff1522f
commit 93297e4bac
3 changed files with 185 additions and 0 deletions

View File

@ -2292,6 +2292,7 @@ def XOrOp : IntBinaryOp<"xor", [Commutative]> {
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -3011,6 +3011,80 @@ OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
[](APInt a, APInt b) { return a ^ b; });
}
namespace {
/// Replace a not of a comparison operation, for example: not(cmp eq A, B) =>
/// cmp ne A, B. Note that a logical not is implemented as xor 1, val
struct NotICmp : public OpRewritePattern<XOrOp> {
using OpRewritePattern<XOrOp>::OpRewritePattern;
LogicalResult matchAndRewrite(XOrOp op,
PatternRewriter &rewriter) const override {
APInt constValue;
if (!matchPattern(op.getOperand(1), m_ConstantInt(&constValue)))
return failure();
if (constValue != 1)
return failure();
auto prev = op.getOperand(0).getDefiningOp<CmpIOp>();
if (!prev)
return failure();
switch (prev.predicate()) {
case CmpIPredicate::eq:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ne, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::ne:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::eq, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::slt:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sge, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::sle:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sgt, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::sgt:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sle, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::sge:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::ult:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::uge, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::ule:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ugt, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::ugt:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ule, prev.lhs(),
prev.rhs());
return success();
case CmpIPredicate::uge:
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ult, prev.lhs(),
prev.rhs());
return success();
}
return failure();
}
};
} // namespace
void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<NotICmp>(context);
}
//===----------------------------------------------------------------------===//
// ZeroExtendIOp
//===----------------------------------------------------------------------===//

View File

@ -538,3 +538,113 @@ func @tripleSubSub3(%arg0: index) -> index {
%add2 = subi %add1, %c42 : index
return %add2 : index
}
// CHECK-LABEL: @notCmpEQ
// CHECK: %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpEQ(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "eq", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpEQ2
// CHECK: %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpEQ2(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "eq", %arg0, %arg1 : i8
%ncmp = xor %true, %cmp : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpNE
// CHECK: %[[cres:.+]] = cmpi eq, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpNE(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "ne", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpSLT
// CHECK: %[[cres:.+]] = cmpi sge, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpSLT(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "slt", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpSLE
// CHECK: %[[cres:.+]] = cmpi sgt, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpSLE(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "sle", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpSGT
// CHECK: %[[cres:.+]] = cmpi sle, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpSGT(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "sgt", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpSGE
// CHECK: %[[cres:.+]] = cmpi slt, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpSGE(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "sge", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpULT
// CHECK: %[[cres:.+]] = cmpi uge, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpULT(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "ult", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpULE
// CHECK: %[[cres:.+]] = cmpi ugt, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpULE(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "ule", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpUGT
// CHECK: %[[cres:.+]] = cmpi ule, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpUGT(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "ugt", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}
// CHECK-LABEL: @notCmpUGE
// CHECK: %[[cres:.+]] = cmpi ult, %arg0, %arg1 : i8
// CHECK: return %[[cres]]
func @notCmpUGE(%arg0: i8, %arg1: i8) -> i1 {
%true = constant true
%cmp = cmpi "uge", %arg0, %arg1 : i8
%ncmp = xor %cmp, %true : i1
return %ncmp : i1
}