forked from OSchip/llvm-project
[MLIR] Replace a not of a comparison with appropriate comparison
Differential Revision: https://reviews.llvm.org/D101710
This commit is contained in:
parent
54bff1522f
commit
93297e4bac
|
@ -2292,6 +2292,7 @@ def XOrOp : IntBinaryOp<"xor", [Commutative]> {
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -3011,6 +3011,80 @@ OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
|
||||||
[](APInt a, APInt b) { return a ^ b; });
|
[](APInt a, APInt b) { return a ^ b; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Replace a not of a comparison operation, for example: not(cmp eq A, B) =>
|
||||||
|
/// cmp ne A, B. Note that a logical not is implemented as xor 1, val
|
||||||
|
struct NotICmp : public OpRewritePattern<XOrOp> {
|
||||||
|
using OpRewritePattern<XOrOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(XOrOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
APInt constValue;
|
||||||
|
if (!matchPattern(op.getOperand(1), m_ConstantInt(&constValue)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (constValue != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto prev = op.getOperand(0).getDefiningOp<CmpIOp>();
|
||||||
|
if (!prev)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
switch (prev.predicate()) {
|
||||||
|
case CmpIPredicate::eq:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ne, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
case CmpIPredicate::ne:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::eq, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
|
||||||
|
case CmpIPredicate::slt:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sge, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
case CmpIPredicate::sle:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sgt, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
case CmpIPredicate::sgt:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sle, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
case CmpIPredicate::sge:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
|
||||||
|
case CmpIPredicate::ult:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::uge, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
case CmpIPredicate::ule:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ugt, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
case CmpIPredicate::ugt:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ule, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
case CmpIPredicate::uge:
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ult, prev.lhs(),
|
||||||
|
prev.rhs());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
results.insert<NotICmp>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ZeroExtendIOp
|
// ZeroExtendIOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -538,3 +538,113 @@ func @tripleSubSub3(%arg0: index) -> index {
|
||||||
%add2 = subi %add1, %c42 : index
|
%add2 = subi %add1, %c42 : index
|
||||||
return %add2 : index
|
return %add2 : index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpEQ
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpEQ(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "eq", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpEQ2
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpEQ2(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "eq", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %true, %cmp : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpNE
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi eq, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpNE(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "ne", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpSLT
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi sge, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpSLT(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "slt", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpSLE
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi sgt, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpSLE(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "sle", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpSGT
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi sle, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpSGT(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "sgt", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpSGE
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi slt, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpSGE(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "sge", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpULT
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi uge, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpULT(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "ult", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpULE
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi ugt, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpULE(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "ule", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpUGT
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi ule, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpUGT(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "ugt", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @notCmpUGE
|
||||||
|
// CHECK: %[[cres:.+]] = cmpi ult, %arg0, %arg1 : i8
|
||||||
|
// CHECK: return %[[cres]]
|
||||||
|
func @notCmpUGE(%arg0: i8, %arg1: i8) -> i1 {
|
||||||
|
%true = constant true
|
||||||
|
%cmp = cmpi "uge", %arg0, %arg1 : i8
|
||||||
|
%ncmp = xor %cmp, %true : i1
|
||||||
|
return %ncmp : i1
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue