[mlir][std] Fold comparisons when the operands are equal

For equal operands, comparisons can be decided statically.

Differential Revision: https://reviews.llvm.org/D91856
This commit is contained in:
Stephan Herhut 2020-11-20 11:46:22 +01:00
parent 0caa82e2ac
commit cb778c3423
2 changed files with 47 additions and 1 deletions

View File

@ -916,17 +916,41 @@ bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
llvm_unreachable("unknown comparison predicate"); llvm_unreachable("unknown comparison predicate");
} }
// Returns true if the predicate is true for two equal operands.
static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
switch (predicate) {
case CmpIPredicate::eq:
case CmpIPredicate::sle:
case CmpIPredicate::sge:
case CmpIPredicate::ule:
case CmpIPredicate::uge:
return true;
case CmpIPredicate::ne:
case CmpIPredicate::slt:
case CmpIPredicate::sgt:
case CmpIPredicate::ult:
case CmpIPredicate::ugt:
return false;
}
llvm_unreachable("unknown comparison predicate");
}
// Constant folding hook for comparisons. // Constant folding hook for comparisons.
OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) { OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two arguments"); assert(operands.size() == 2 && "cmpi takes two arguments");
if (lhs() == rhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
return BoolAttr::get(val, getContext());
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs) if (!lhs || !rhs)
return {}; return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); return BoolAttr::get(val, getContext());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -59,3 +59,25 @@ func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index {
%1 = dim %0, %c3 : tensor<2x?x4x?x5xindex> %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
return %1 : index return %1 : index
} }
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = constant true
// CHECK-DAG: %[[F:.*]] = constant false
// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
func @cmpi_equal_operands(%arg0: i64)
-> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
%0 = cmpi "eq", %arg0, %arg0 : i64
%1 = cmpi "sle", %arg0, %arg0 : i64
%2 = cmpi "sge", %arg0, %arg0 : i64
%3 = cmpi "ule", %arg0, %arg0 : i64
%4 = cmpi "uge", %arg0, %arg0 : i64
%5 = cmpi "ne", %arg0, %arg0 : i64
%6 = cmpi "slt", %arg0, %arg0 : i64
%7 = cmpi "sgt", %arg0, %arg0 : i64
%8 = cmpi "ult", %arg0, %arg0 : i64
%9 = cmpi "ugt", %arg0, %arg0 : i64
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}