[mlir][arith] cmpi: move constant to the right side

Convert arith.cmpi to the canonical form with constants on the right side
to simplify further optimizations and open more opportunities for CSE.


Differential Revision: https://reviews.llvm.org/D129929
This commit is contained in:
Ivan Butygin 2022-07-16 12:05:03 +02:00
parent 8de1f04c77
commit 917e4519bc
7 changed files with 75 additions and 13 deletions

View File

@ -93,7 +93,7 @@ func.func @_QPtest_proc_dummy_other(%arg0: !fir.boxproc<() -> ()>) {
// CHECK: %[[VAL_4:.*]] = load { ptr, i64 }, ptr %[[VAL_3]], align 8
// CHECK: %[[VAL_5:.*]] = extractvalue { ptr, i64 } %[[VAL_4]], 0
// CHECK: %[[VAL_6:.*]] = extractvalue { ptr, i64 } %[[VAL_4]], 1
// CHECK: %[[VAL_8:.*]] = icmp slt i64 10, %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = icmp sgt i64 %[[VAL_6]], 10
// CHECK: %[[VAL_9:.*]] = select i1 %[[VAL_8]], i64 10, i64 %[[VAL_6]]
// CHECK: call void @llvm.memmove.p0.p0.i64(ptr %[[VAL_0]], ptr %[[VAL_5]], i64 %[[VAL_9]], i1 false)
// CHECK: %[[VAL_10:.*]] = sub i64 10, %[[VAL_9]]
@ -129,7 +129,7 @@ func.func @_QPtest_proc_dummy_other(%arg0: !fir.boxproc<() -> ()>) {
// CHECK: %[[VAL_27:.*]] = load [1 x i8], ptr %[[VAL_26]], align 1
// CHECK: %[[VAL_29:.*]] = getelementptr [1 x i8], ptr %[[VAL_14]], i64 %[[VAL_18]]
// CHECK: store [1 x i8] %[[VAL_27]], ptr %[[VAL_29]], align 1
// CHECK: %[[VAL_30:.*]] = icmp slt i64 40, %[[VAL_13]]
// CHECK: %[[VAL_30:.*]] = icmp sgt i64 %[[VAL_13]], 40
// CHECK: %[[VAL_31:.*]] = select i1 %[[VAL_30]], i64 40, i64 %[[VAL_13]]
// CHECK: call void @llvm.memmove.p0.p0.i64(ptr %[[VAL_0]], ptr %[[VAL_14]], i64 %[[VAL_31]], i1 false)
// CHECK: %[[VAL_32:.*]] = sub i64 40, %[[VAL_31]]

View File

@ -22,7 +22,7 @@ subroutine issue(c1, c2)
! CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_7]] : index
! CHECK: %[[VAL_17:.*]] = fir.array_coor %[[VAL_11]](%[[VAL_12]]) %[[VAL_16]] typeparams %[[VAL_10]]#1 : (!fir.ref<!fir.array<3x!fir.char<1,?>>>, !fir.shape<1>, index, index) -> !fir.ref<!fir.char<1,?>>
! CHECK: %[[VAL_18:.*]] = fir.array_coor %[[VAL_9]](%[[VAL_12]]) %[[VAL_16]] : (!fir.ref<!fir.array<3x!fir.char<1,4>>>, !fir.shape<1>, index) -> !fir.ref<!fir.char<1,4>>
! CHECK: %[[VAL_19:.*]] = arith.cmpi slt, %[[VAL_5]], %[[VAL_10]]#1 : index
! CHECK: %[[VAL_19:.*]] = arith.cmpi sgt, %[[VAL_10]]#1, %[[VAL_5]] : index
! CHECK: %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_5]], %[[VAL_10]]#1 : index
! CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (index) -> i64
! CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_18]] : (!fir.ref<!fir.char<1,4>>) -> !fir.ref<i8>

View File

@ -540,7 +540,7 @@ end subroutine test_proc_dummy_other
! CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_9]] : !fir.ref<!fir.boxchar<1>>
! CHECK: %[[VAL_11:.*]]:2 = fir.unboxchar %[[VAL_10]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
! CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<!fir.char<1,?>>
! CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_4]], %[[VAL_11]]#1 : index
! CHECK: %[[VAL_13:.*]] = arith.cmpi sgt, %[[VAL_11]]#1, %[[VAL_4]] : index
! CHECK: %[[VAL_14:.*]] = arith.select %[[VAL_13]], %[[VAL_4]], %[[VAL_11]]#1 : index
! CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (index) -> i64
! CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_12]] : (!fir.ref<!fir.char<1,?>>) -> !fir.ref<i8>
@ -607,7 +607,7 @@ end subroutine test_proc_dummy_other
! CHECK: %[[VAL_34:.*]] = arith.subi %[[VAL_25]], %[[VAL_6]] : index
! CHECK: br ^bb1(%[[VAL_33]], %[[VAL_34]] : index, index)
! CHECK: ^bb3:
! CHECK: %[[VAL_35:.*]] = arith.cmpi slt, %[[VAL_3]], %[[VAL_19]] : index
! CHECK: %[[VAL_35:.*]] = arith.cmpi sgt, %[[VAL_19]], %[[VAL_3]] : index
! CHECK: %[[VAL_36:.*]] = arith.select %[[VAL_35]], %[[VAL_3]], %[[VAL_19]] : index
! CHECK: %[[VAL_37:.*]] = fir.convert %[[VAL_36]] : (index) -> i64
! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_9]] : (!fir.ref<!fir.char<1,?>>) -> !fir.ref<i8>

View File

@ -1332,11 +1332,38 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
}
}
// Move constant to the right side.
if (operands[0] && !operands[1]) {
// Do not use invertPredicate, as it will change eq to ne and vice versa.
using Pred = CmpIPredicate;
const std::pair<Pred, Pred> invPreds[] = {
{Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
{Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
{Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
{Pred::ne, Pred::ne},
};
Pred origPred = getPredicate();
for (auto pred : invPreds) {
if (origPred == pred.first) {
setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
Value lhs = getLhs();
Value rhs = getRhs();
getLhsMutable().assign(rhs);
getRhsMutable().assign(lhs);
return getResult();
}
}
llvm_unreachable("unknown cmpi predicate kind");
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
if (!lhs)
return {};
// We are moving constants to the right side; So if lhs is constant rhs is
// guaranteed to be a constant.
auto rhs = operands.back().cast<IntegerAttr>();
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return BoolAttr::get(getContext(), val);
}

View File

@ -127,6 +127,41 @@ func.func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
// -----
// Test case: Move constant to the right side.
// CHECK-LABEL: @cmpi_const_right(
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[C:.*]] = arith.constant 1 : i64
// CHECK: %[[R0:.*]] = arith.cmpi eq, %[[ARG]], %[[C]] : i64
// CHECK: %[[R1:.*]] = arith.cmpi sge, %[[ARG]], %[[C]] : i64
// CHECK: %[[R2:.*]] = arith.cmpi sle, %[[ARG]], %[[C]] : i64
// CHECK: %[[R3:.*]] = arith.cmpi uge, %[[ARG]], %[[C]] : i64
// CHECK: %[[R4:.*]] = arith.cmpi ule, %[[ARG]], %[[C]] : i64
// CHECK: %[[R5:.*]] = arith.cmpi ne, %[[ARG]], %[[C]] : i64
// CHECK: %[[R6:.*]] = arith.cmpi sgt, %[[ARG]], %[[C]] : i64
// CHECK: %[[R7:.*]] = arith.cmpi slt, %[[ARG]], %[[C]] : i64
// CHECK: %[[R8:.*]] = arith.cmpi ugt, %[[ARG]], %[[C]] : i64
// CHECK: %[[R9:.*]] = arith.cmpi ult, %[[ARG]], %[[C]] : i64
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]],
// CHECK-SAME: %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]]
func.func @cmpi_const_right(%arg0: i64)
-> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
%c1 = arith.constant 1 : i64
%0 = arith.cmpi eq, %c1, %arg0 : i64
%1 = arith.cmpi sle, %c1, %arg0 : i64
%2 = arith.cmpi sge, %c1, %arg0 : i64
%3 = arith.cmpi ule, %c1, %arg0 : i64
%4 = arith.cmpi uge, %c1, %arg0 : i64
%5 = arith.cmpi ne, %c1, %arg0 : i64
%6 = arith.cmpi slt, %c1, %arg0 : i64
%7 = arith.cmpi sgt, %c1, %arg0 : i64
%8 = arith.cmpi ult, %c1, %arg0 : i64
%9 = arith.cmpi ugt, %c1, %arg0 : i64
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}
// -----
// CHECK-LABEL: @cmpOfExtSI
// CHECK-NEXT: return %arg0
func.func @cmpOfExtSI(%arg0: i1) -> i1 {

View File

@ -819,10 +819,10 @@ func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index
// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index
// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
// CHECK: return %[[T6]] : vector<2x3xi1>
@ -842,13 +842,13 @@ func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[B]] : index
// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index
// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index
// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
// CHECK: %[[T7:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index
// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
// CHECK: return %[[T9]] : vector<2x1x7xi1>

View File

@ -141,7 +141,7 @@ func.func @loop_region_branch_terminator_op(%arg1 : i32) {
%c2_i32 = arith.constant 2 : i32
%0 = scf.while (%arg2 = %c2_i32) : (i32) -> (i32) {
%1 = arith.cmpi slt, %arg2, %arg1 : i32
%1 = arith.cmpi sgt, %arg1, %arg2 : i32
scf.condition(%1) %arg2 : i32
} do {
^bb0(%arg2: i32):