[X86] Remove vXi1 select optimization from LowerSELECT. Move it to DAG combine.

This commit is contained in:
Craig Topper 2020-02-18 23:38:36 -08:00
parent 80b2e3cc53
commit f69a29da5a
1 changed files with 33 additions and 48 deletions

View File

@ -8777,20 +8777,6 @@ static SDValue buildFromShuffleMostly(SDValue Op, SelectionDAG &DAG) {
return NV;
}
static SDValue ConvertI1VectorToInteger(SDValue Op, SelectionDAG &DAG) {
assert(ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) &&
Op.getScalarValueSizeInBits() == 1 &&
"Can not convert non-constant vector");
uint64_t Immediate = 0;
for (unsigned idx = 0, e = Op.getNumOperands(); idx < e; ++idx) {
SDValue In = Op.getOperand(idx);
if (!In.isUndef())
Immediate |= (cast<ConstantSDNode>(In)->getZExtValue() & 0x1) << idx;
}
SDLoc dl(Op);
MVT VT = MVT::getIntegerVT(std::max((int)Op.getValueSizeInBits(), 8));
return DAG.getConstant(Immediate, dl, VT);
}
// Lower BUILD_VECTOR operation for v8i1 and v16i1 types.
static SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
@ -22467,40 +22453,6 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(X86ISD::SELECTS, DL, VT, Cmp, Op1, Op2);
}
// For v64i1 without 64-bit support we need to split and rejoin.
if (VT == MVT::v64i1 && !Subtarget.is64Bit()) {
assert(Subtarget.hasBWI() && "Expected BWI to be legal");
SDValue Op1Lo = extractSubVector(Op1, 0, DAG, DL, 32);
SDValue Op2Lo = extractSubVector(Op2, 0, DAG, DL, 32);
SDValue Op1Hi = extractSubVector(Op1, 32, DAG, DL, 32);
SDValue Op2Hi = extractSubVector(Op2, 32, DAG, DL, 32);
SDValue Lo = DAG.getSelect(DL, MVT::v32i1, Cond, Op1Lo, Op2Lo);
SDValue Hi = DAG.getSelect(DL, MVT::v32i1, Cond, Op1Hi, Op2Hi);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
}
if (VT.isVector() && VT.getVectorElementType() == MVT::i1) {
SDValue Op1Scalar;
if (ISD::isBuildVectorOfConstantSDNodes(Op1.getNode()))
Op1Scalar = ConvertI1VectorToInteger(Op1, DAG);
else if (Op1.getOpcode() == ISD::BITCAST && Op1.getOperand(0))
Op1Scalar = Op1.getOperand(0);
SDValue Op2Scalar;
if (ISD::isBuildVectorOfConstantSDNodes(Op2.getNode()))
Op2Scalar = ConvertI1VectorToInteger(Op2, DAG);
else if (Op2.getOpcode() == ISD::BITCAST && Op2.getOperand(0))
Op2Scalar = Op2.getOperand(0);
if (Op1Scalar.getNode() && Op2Scalar.getNode()) {
SDValue newSelect = DAG.getSelect(DL, Op1Scalar.getValueType(), Cond,
Op1Scalar, Op2Scalar);
if (newSelect.getValueSizeInBits() == VT.getSizeInBits())
return DAG.getBitcast(VT, newSelect);
SDValue ExtVec = DAG.getBitcast(MVT::v8i1, newSelect);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtVec,
DAG.getIntPtrConstant(0, DL));
}
}
if (Cond.getOpcode() == ISD::SETCC) {
if (SDValue NewCond = LowerSETCC(Cond, DAG)) {
Cond = NewCond;
@ -38968,6 +38920,39 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
return DAG.getBitcast(VT, newSelect);
}
// Try to optimize vXi1 selects if both operands are either all constants or
// bitcasts from scalar integer type. In that case we can convert the operands
// to integer and use an integer select which will be converted to a CMOV.
// We need to take a little bit of care to avoid creating an i64 type after
// type legalization.
if (N->getOpcode() == ISD::SELECT && VT.isVector() &&
VT.getVectorElementType() == MVT::i1 &&
(DCI.isBeforeLegalize() || (VT != MVT::v64i1 || Subtarget.is64Bit()))) {
MVT IntVT = MVT::getIntegerVT(VT.getVectorNumElements());
bool LHSIsConst = ISD::isBuildVectorOfConstantSDNodes(LHS.getNode());
bool RHSIsConst = ISD::isBuildVectorOfConstantSDNodes(RHS.getNode());
if ((LHSIsConst ||
(LHS.getOpcode() == ISD::BITCAST &&
LHS.getOperand(0).getValueType() == IntVT)) &&
(RHSIsConst ||
(RHS.getOpcode() == ISD::BITCAST &&
RHS.getOperand(0).getValueType() == IntVT))) {
if (LHSIsConst)
LHS = combinevXi1ConstantToInteger(LHS, DAG);
else
LHS = LHS.getOperand(0);
if (RHSIsConst)
RHS = combinevXi1ConstantToInteger(RHS, DAG);
else
RHS = RHS.getOperand(0);
SDValue Select = DAG.getSelect(DL, IntVT, Cond, LHS, RHS);
return DAG.getBitcast(VT, Select);
}
}
return SDValue();
}