forked from OSchip/llvm-project
[mlir][std] Fold comparisons when the operands are equal
For equal operands, comparisons can be decided statically. Differential Revision: https://reviews.llvm.org/D91856
This commit is contained in:
parent
0caa82e2ac
commit
cb778c3423
|
@ -916,17 +916,41 @@ bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
||||||
llvm_unreachable("unknown comparison predicate");
|
llvm_unreachable("unknown comparison predicate");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns true if the predicate is true for two equal operands.
|
||||||
|
static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
|
||||||
|
switch (predicate) {
|
||||||
|
case CmpIPredicate::eq:
|
||||||
|
case CmpIPredicate::sle:
|
||||||
|
case CmpIPredicate::sge:
|
||||||
|
case CmpIPredicate::ule:
|
||||||
|
case CmpIPredicate::uge:
|
||||||
|
return true;
|
||||||
|
case CmpIPredicate::ne:
|
||||||
|
case CmpIPredicate::slt:
|
||||||
|
case CmpIPredicate::sgt:
|
||||||
|
case CmpIPredicate::ult:
|
||||||
|
case CmpIPredicate::ugt:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
llvm_unreachable("unknown comparison predicate");
|
||||||
|
}
|
||||||
|
|
||||||
// Constant folding hook for comparisons.
|
// Constant folding hook for comparisons.
|
||||||
OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
|
||||||
assert(operands.size() == 2 && "cmpi takes two arguments");
|
assert(operands.size() == 2 && "cmpi takes two arguments");
|
||||||
|
|
||||||
|
if (lhs() == rhs()) {
|
||||||
|
auto val = applyCmpPredicateToEqualOperands(getPredicate());
|
||||||
|
return BoolAttr::get(val, getContext());
|
||||||
|
}
|
||||||
|
|
||||||
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
||||||
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
||||||
if (!lhs || !rhs)
|
if (!lhs || !rhs)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
||||||
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
|
return BoolAttr::get(val, getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -59,3 +59,25 @@ func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index {
|
||||||
%1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
|
%1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
|
||||||
return %1 : index
|
return %1 : index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test case: Folding of comparisons with equal operands.
|
||||||
|
// CHECK-LABEL: @cmpi_equal_operands
|
||||||
|
// CHECK-DAG: %[[T:.*]] = constant true
|
||||||
|
// CHECK-DAG: %[[F:.*]] = constant false
|
||||||
|
// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
|
||||||
|
// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
|
||||||
|
func @cmpi_equal_operands(%arg0: i64)
|
||||||
|
-> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
|
||||||
|
%0 = cmpi "eq", %arg0, %arg0 : i64
|
||||||
|
%1 = cmpi "sle", %arg0, %arg0 : i64
|
||||||
|
%2 = cmpi "sge", %arg0, %arg0 : i64
|
||||||
|
%3 = cmpi "ule", %arg0, %arg0 : i64
|
||||||
|
%4 = cmpi "uge", %arg0, %arg0 : i64
|
||||||
|
%5 = cmpi "ne", %arg0, %arg0 : i64
|
||||||
|
%6 = cmpi "slt", %arg0, %arg0 : i64
|
||||||
|
%7 = cmpi "sgt", %arg0, %arg0 : i64
|
||||||
|
%8 = cmpi "ult", %arg0, %arg0 : i64
|
||||||
|
%9 = cmpi "ugt", %arg0, %arg0 : i64
|
||||||
|
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
|
||||||
|
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue