[RISCV] Refactor the common combines for SELECT_CC and BR_CC into a helper function.

The only difference between the combines were the calls to getNode
that include the true/false values for SELECT_CC or the chain
and branch target for BR_CC.

Wrap the rest of the code into a helper that reads LHS, RHS, and
CC and outputs new values and a bool if a new node needs to be
created.
This commit is contained in:
Craig Topper 2022-07-20 21:07:07 -07:00
parent bba1f26f2e
commit 7dda6c71b1
1 changed files with 56 additions and 88 deletions

View File

@ -8712,6 +8712,52 @@ static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
DAG.getConstant(32 - ShAmt, DL, MVT::i64));
}
// Perform common combines for BR_CC and SELECT_CC condtions.
static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
ISD::CondCode CCVal = cast<CondCodeSDNode>(CC)->get();
if (!ISD::isIntEqualitySetCC(CCVal))
return false;
// Fold ((setlt X, Y), 0, ne) -> (X, Y, lt)
// Sometimes the setcc is introduced after br_cc/select_cc has been formed.
if (LHS.getOpcode() == ISD::SETCC && isNullConstant(RHS) &&
LHS.getOperand(0).getValueType() == Subtarget.getXLenVT()) {
// If we're looking for eq 0 instead of ne 0, we need to invert the
// condition.
bool Invert = CCVal == ISD::SETEQ;
CCVal = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
if (Invert)
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
CC = DAG.getCondCode(CCVal);
return true;
}
// Fold ((xor X, Y), 0, eq/ne) -> (X, Y, eq/ne)
if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) {
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
return true;
}
// (X, 1, setne) -> // (X, 0, seteq) if we can prove X is 0/1.
// This can occur when legalizing some floating point comparisons.
APInt Mask = APInt::getBitsSetFrom(LHS.getValueSizeInBits(), 1);
if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) {
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
CC = DAG.getCondCode(CCVal);
RHS = DAG.getConstant(0, DL, LHS.getValueType());
return true;
}
return false;
}
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@ -8956,110 +9002,32 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
// Transform
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
SDValue CC = N->getOperand(2);
SDValue TrueV = N->getOperand(3);
SDValue FalseV = N->getOperand(4);
SDLoc DL(N);
// If the True and False values are the same, we don't need a select_cc.
if (TrueV == FalseV)
return TrueV;
ISD::CondCode CCVal = cast<CondCodeSDNode>(N->getOperand(2))->get();
if (!ISD::isIntEqualitySetCC(CCVal))
break;
// Fold (select_cc (setlt X, Y), 0, ne, trueV, falseV) ->
// (select_cc X, Y, lt, trueV, falseV)
// Sometimes the setcc is introduced after select_cc has been formed.
if (LHS.getOpcode() == ISD::SETCC && isNullConstant(RHS) &&
LHS.getOperand(0).getValueType() == Subtarget.getXLenVT()) {
// If we're looking for eq 0 instead of ne 0, we need to invert the
// condition.
bool Invert = CCVal == ISD::SETEQ;
CCVal = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
if (Invert)
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
SDLoc DL(N);
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
SDValue TargetCC = DAG.getCondCode(CCVal);
if (combine_CC(LHS, RHS, CC, DL, DAG, Subtarget))
return DAG.getNode(RISCVISD::SELECT_CC, DL, N->getValueType(0),
{LHS, RHS, TargetCC, TrueV, FalseV});
}
{LHS, RHS, CC, TrueV, FalseV});
// Fold (select_cc (xor X, Y), 0, eq/ne, trueV, falseV) ->
// (select_cc X, Y, eq/ne, trueV, falseV)
if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS))
return DAG.getNode(RISCVISD::SELECT_CC, SDLoc(N), N->getValueType(0),
{LHS.getOperand(0), LHS.getOperand(1),
N->getOperand(2), TrueV, FalseV});
// (select_cc X, 1, setne, trueV, falseV) ->
// (select_cc X, 0, seteq, trueV, falseV) if we can prove X is 0/1.
// This can occur when legalizing some floating point comparisons.
APInt Mask = APInt::getBitsSetFrom(LHS.getValueSizeInBits(), 1);
if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) {
SDLoc DL(N);
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
SDValue TargetCC = DAG.getCondCode(CCVal);
RHS = DAG.getConstant(0, DL, LHS.getValueType());
return DAG.getNode(RISCVISD::SELECT_CC, DL, N->getValueType(0),
{LHS, RHS, TargetCC, TrueV, FalseV});
}
break;
return SDValue();
}
case RISCVISD::BR_CC: {
SDValue LHS = N->getOperand(1);
SDValue RHS = N->getOperand(2);
ISD::CondCode CCVal = cast<CondCodeSDNode>(N->getOperand(3))->get();
if (!ISD::isIntEqualitySetCC(CCVal))
break;
// Fold (br_cc (setlt X, Y), 0, ne, dest) ->
// (br_cc X, Y, lt, dest)
// Sometimes the setcc is introduced after br_cc has been formed.
if (LHS.getOpcode() == ISD::SETCC && isNullConstant(RHS) &&
LHS.getOperand(0).getValueType() == Subtarget.getXLenVT()) {
// If we're looking for eq 0 instead of ne 0, we need to invert the
// condition.
bool Invert = CCVal == ISD::SETEQ;
CCVal = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
if (Invert)
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
SDLoc DL(N);
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
SDValue CC = N->getOperand(3);
SDLoc DL(N);
if (combine_CC(LHS, RHS, CC, DL, DAG, Subtarget))
return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0),
N->getOperand(0), LHS, RHS, DAG.getCondCode(CCVal),
N->getOperand(4));
}
N->getOperand(0), LHS, RHS, CC, N->getOperand(4));
// Fold (br_cc (xor X, Y), 0, eq/ne, dest) ->
// (br_cc X, Y, eq/ne, trueV, falseV)
if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS))
return DAG.getNode(RISCVISD::BR_CC, SDLoc(N), N->getValueType(0),
N->getOperand(0), LHS.getOperand(0), LHS.getOperand(1),
N->getOperand(3), N->getOperand(4));
// (br_cc X, 1, setne, br_cc) ->
// (br_cc X, 0, seteq, br_cc) if we can prove X is 0/1.
// This can occur when legalizing some floating point comparisons.
APInt Mask = APInt::getBitsSetFrom(LHS.getValueSizeInBits(), 1);
if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) {
SDLoc DL(N);
CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
SDValue TargetCC = DAG.getCondCode(CCVal);
RHS = DAG.getConstant(0, DL, LHS.getValueType());
return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0),
N->getOperand(0), LHS, RHS, TargetCC,
N->getOperand(4));
}
break;
return SDValue();
}
case ISD::BITREVERSE:
return performBITREVERSECombine(N, DAG, Subtarget);