diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 556465ba670d..218ba047a2ea 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -33867,25 +33867,40 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG, return SDValue(); } +// Try to match (and (xor X, -1), Y) logic pattern for (andnp X, Y) combines. +static bool matchANDXORWithAllOnesAsANDNP(SDNode *N, SDValue &X, SDValue &Y) { + if (N->getOpcode() != ISD::AND) + return false; + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + if (N0.getOpcode() == ISD::XOR && + ISD::isBuildVectorAllOnes(N0.getOperand(1).getNode())) { + X = N0.getOperand(0); + Y = N1; + return true; + } + if (N1.getOpcode() == ISD::XOR && + ISD::isBuildVectorAllOnes(N1.getOperand(1).getNode())) { + X = N1.getOperand(0); + Y = N0; + return true; + } + + return false; +} + /// Try to fold: (and (xor X, -1), Y) -> (andnp X, Y). static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { assert(N->getOpcode() == ISD::AND); EVT VT = N->getValueType(0); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDLoc DL(N); - if (VT != MVT::v2i64 && VT != MVT::v4i64 && VT != MVT::v8i64) return SDValue(); - if (N0.getOpcode() == ISD::XOR && - ISD::isBuildVectorAllOnes(N0.getOperand(1).getNode())) - return DAG.getNode(X86ISD::ANDNP, DL, VT, N0.getOperand(0), N1); - - if (N1.getOpcode() == ISD::XOR && - ISD::isBuildVectorAllOnes(N1.getOperand(1).getNode())) - return DAG.getNode(X86ISD::ANDNP, DL, VT, N1.getOperand(0), N0); + SDValue X, Y; + if (matchANDXORWithAllOnesAsANDNP(N, X, Y)) + return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y); return SDValue(); }