forked from OSchip/llvm-project
[X86] Recognize a splat of negate in isFNEG
Summary: Expand isFNEG so that we generate the appropriate F(N)M(ADD|SUB) instructions in more cases. For example, the following sequence a = _mm256_broadcast_ss(f) d = _mm256_fnmadd_ps(a, b, c) generates an fsub and fma without this patch and an fnma with this change. Reviewers: craig.topper Subscribers: llvm-commits, davidxl, wmi Differential Revision: https://reviews.llvm.org/D48467 llvm-svn: 339043
This commit is contained in:
parent
96e6ed645d
commit
10fd92dd94
|
@ -5633,6 +5633,24 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
|
|||
}
|
||||
return CastBitData(UndefSrcElts, SrcEltBits);
|
||||
}
|
||||
if (ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) {
|
||||
unsigned SrcEltSizeInBits = VT.getScalarSizeInBits();
|
||||
unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits;
|
||||
|
||||
APInt UndefSrcElts(NumSrcElts, 0);
|
||||
SmallVector<APInt, 64> SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0));
|
||||
for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
|
||||
const SDValue &Src = Op.getOperand(i);
|
||||
if (Src.isUndef()) {
|
||||
UndefSrcElts.setBit(i);
|
||||
continue;
|
||||
}
|
||||
auto *Cst = cast<ConstantFPSDNode>(Src);
|
||||
APInt RawBits = Cst->getValueAPF().bitcastToAPInt();
|
||||
SrcEltBits[i] = RawBits.zextOrTrunc(SrcEltSizeInBits);
|
||||
}
|
||||
return CastBitData(UndefSrcElts, SrcEltBits);
|
||||
}
|
||||
|
||||
// Extract constant bits from constant pool vector.
|
||||
if (auto *Cst = getTargetConstantFromNode(Op)) {
|
||||
|
@ -36971,29 +36989,72 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
|
|||
|
||||
/// Returns the negated value if the node \p N flips sign of FP value.
|
||||
///
|
||||
/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000).
|
||||
/// FP-negation node may have different forms: FNEG(x), FXOR (x, 0x80000000)
|
||||
/// or FSUB(0, x)
|
||||
/// AVX512F does not have FXOR, so FNEG is lowered as
|
||||
/// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))).
|
||||
/// In this case we go though all bitcasts.
|
||||
static SDValue isFNEG(SDNode *N) {
|
||||
/// This also recognizes splat of a negated value and returns the splat of that
|
||||
/// value.
|
||||
static SDValue isFNEG(SelectionDAG &DAG, SDNode *N) {
|
||||
if (N->getOpcode() == ISD::FNEG)
|
||||
return N->getOperand(0);
|
||||
|
||||
SDValue Op = peekThroughBitcasts(SDValue(N, 0));
|
||||
if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR)
|
||||
auto VT = Op->getValueType(0);
|
||||
if (auto SVOp = dyn_cast<ShuffleVectorSDNode>(Op.getNode())) {
|
||||
// For a VECTOR_SHUFFLE(VEC1, VEC2), if the VEC2 is undef, then the negate
|
||||
// of this is VECTOR_SHUFFLE(-VEC1, UNDEF). The mask can be anything here.
|
||||
if (!SVOp->getOperand(1).isUndef())
|
||||
return SDValue();
|
||||
if (SDValue NegOp0 = isFNEG(DAG, SVOp->getOperand(0).getNode()))
|
||||
return DAG.getVectorShuffle(VT, SDLoc(SVOp), NegOp0, DAG.getUNDEF(VT),
|
||||
SVOp->getMask());
|
||||
return SDValue();
|
||||
}
|
||||
unsigned Opc = Op.getOpcode();
|
||||
if (Opc == ISD::INSERT_VECTOR_ELT) {
|
||||
// Negate of INSERT_VECTOR_ELT(UNDEF, V, INDEX) is INSERT_VECTOR_ELT(UNDEF,
|
||||
// -V, INDEX).
|
||||
SDValue InsVector = Op.getOperand(0);
|
||||
SDValue InsVal = Op.getOperand(1);
|
||||
if (!InsVector.isUndef())
|
||||
return SDValue();
|
||||
if (SDValue NegInsVal = isFNEG(DAG, InsVal.getNode()))
|
||||
return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Op), VT, InsVector,
|
||||
NegInsVal, Op.getOperand(2));
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
if (Opc != X86ISD::FXOR && Opc != ISD::XOR && Opc != ISD::FSUB)
|
||||
return SDValue();
|
||||
|
||||
SDValue Op1 = peekThroughBitcasts(Op.getOperand(1));
|
||||
if (!Op1.getValueType().isFloatingPoint())
|
||||
return SDValue();
|
||||
|
||||
// Extract constant bits and see if they are all sign bit masks.
|
||||
SDValue Op0 = peekThroughBitcasts(Op.getOperand(0));
|
||||
|
||||
// For XOR and FXOR, we want to check if constant bits of Op1 are sign bit
|
||||
// masks. For FSUB, we have to check if constant bits of Op0 are sign bit
|
||||
// masks and hence we swap the operands.
|
||||
if (Opc == ISD::FSUB)
|
||||
std::swap(Op0, Op1);
|
||||
|
||||
APInt UndefElts;
|
||||
SmallVector<APInt, 16> EltBits;
|
||||
// Extract constant bits and see if they are all sign bit masks. Ignore the
|
||||
// undef elements.
|
||||
if (getTargetConstantBitsFromNode(Op1, Op1.getScalarValueSizeInBits(),
|
||||
UndefElts, EltBits, false, false))
|
||||
if (llvm::all_of(EltBits, [](APInt &I) { return I.isSignMask(); }))
|
||||
return peekThroughBitcasts(Op.getOperand(0));
|
||||
UndefElts, EltBits,
|
||||
/* AllowWholeUndefs */ true,
|
||||
/* AllowPartialUndefs */ false)) {
|
||||
for (unsigned I = 0, E = EltBits.size(); I < E; I++)
|
||||
if (!UndefElts[I] && !EltBits[I].isSignMask())
|
||||
return SDValue();
|
||||
|
||||
return peekThroughBitcasts(Op0);
|
||||
}
|
||||
|
||||
return SDValue();
|
||||
}
|
||||
|
@ -37002,8 +37063,9 @@ static SDValue isFNEG(SDNode *N) {
|
|||
static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
|
||||
const X86Subtarget &Subtarget) {
|
||||
EVT OrigVT = N->getValueType(0);
|
||||
SDValue Arg = isFNEG(N);
|
||||
assert(Arg.getNode() && "N is expected to be an FNEG node");
|
||||
SDValue Arg = isFNEG(DAG, N);
|
||||
if (!Arg)
|
||||
return SDValue();
|
||||
|
||||
EVT VT = Arg.getValueType();
|
||||
EVT SVT = VT.getScalarType();
|
||||
|
@ -37118,9 +37180,7 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
|
|||
if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget))
|
||||
return FPLogic;
|
||||
|
||||
if (isFNEG(N))
|
||||
return combineFneg(N, DAG, Subtarget);
|
||||
return SDValue();
|
||||
return combineFneg(N, DAG, Subtarget);
|
||||
}
|
||||
|
||||
static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG,
|
||||
|
@ -37253,9 +37313,8 @@ static SDValue combineFOr(SDNode *N, SelectionDAG &DAG,
|
|||
if (isNullFPScalarOrVectorConst(N->getOperand(1)))
|
||||
return N->getOperand(0);
|
||||
|
||||
if (isFNEG(N))
|
||||
if (SDValue NewVal = combineFneg(N, DAG, Subtarget))
|
||||
return NewVal;
|
||||
if (SDValue NewVal = combineFneg(N, DAG, Subtarget))
|
||||
return NewVal;
|
||||
|
||||
return lowerX86FPLogicOp(N, DAG, Subtarget);
|
||||
}
|
||||
|
@ -37940,7 +37999,7 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
|
|||
SDValue C = N->getOperand(2);
|
||||
|
||||
auto invertIfNegative = [&DAG](SDValue &V) {
|
||||
if (SDValue NegVal = isFNEG(V.getNode())) {
|
||||
if (SDValue NegVal = isFNEG(DAG, V.getNode())) {
|
||||
V = DAG.getBitcast(V.getValueType(), NegVal);
|
||||
return true;
|
||||
}
|
||||
|
@ -37948,7 +38007,7 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
|
|||
// new extract from the FNEG input.
|
||||
if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
|
||||
isNullConstant(V.getOperand(1))) {
|
||||
if (SDValue NegVal = isFNEG(V.getOperand(0).getNode())) {
|
||||
if (SDValue NegVal = isFNEG(DAG, V.getOperand(0).getNode())) {
|
||||
NegVal = DAG.getBitcast(V.getOperand(0).getValueType(), NegVal);
|
||||
V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(),
|
||||
NegVal, V.getOperand(1));
|
||||
|
@ -37981,7 +38040,7 @@ static SDValue combineFMADDSUB(SDNode *N, SelectionDAG &DAG,
|
|||
SDLoc dl(N);
|
||||
EVT VT = N->getValueType(0);
|
||||
|
||||
SDValue NegVal = isFNEG(N->getOperand(2).getNode());
|
||||
SDValue NegVal = isFNEG(DAG, N->getOperand(2).getNode());
|
||||
if (!NegVal)
|
||||
return SDValue();
|
||||
|
||||
|
|
|
@ -118,20 +118,14 @@ declare <2 x double> @llvm.x86.fma.vfmadd.pd(<2 x double> %a, <2 x double> %b, <
|
|||
define <8 x float> @test7(float %a, <8 x float> %b, <8 x float> %c) {
|
||||
; X32-LABEL: test7:
|
||||
; X32: # %bb.0: # %entry
|
||||
; X32-NEXT: vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero
|
||||
; X32-NEXT: vmovss {{.*#+}} xmm3 = mem[0],zero,zero,zero
|
||||
; X32-NEXT: vsubps %ymm2, %ymm3, %ymm2
|
||||
; X32-NEXT: vbroadcastss %xmm2, %ymm2
|
||||
; X32-NEXT: vfmadd213ps {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1
|
||||
; X32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %ymm2
|
||||
; X32-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm2 * ymm0) + ymm1
|
||||
; X32-NEXT: retl
|
||||
;
|
||||
; X64-LABEL: test7:
|
||||
; X64: # %bb.0: # %entry
|
||||
; X64-NEXT: # kill: def $xmm0 killed $xmm0 def $ymm0
|
||||
; X64-NEXT: vmovss {{.*#+}} xmm3 = mem[0],zero,zero,zero
|
||||
; X64-NEXT: vsubps %ymm0, %ymm3, %ymm0
|
||||
; X64-NEXT: vbroadcastss %xmm0, %ymm0
|
||||
; X64-NEXT: vfmadd213ps {{.*#+}} ymm0 = (ymm1 * ymm0) + ymm2
|
||||
; X64-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm2
|
||||
; X64-NEXT: retq
|
||||
entry:
|
||||
%0 = insertelement <8 x float> undef, float %a, i32 0
|
||||
|
@ -145,19 +139,14 @@ entry:
|
|||
define <8 x float> @test8(float %a, <8 x float> %b, <8 x float> %c) {
|
||||
; X32-LABEL: test8:
|
||||
; X32: # %bb.0: # %entry
|
||||
; X32-NEXT: vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero
|
||||
; X32-NEXT: vbroadcastss {{.*#+}} xmm3 = [-0,-0,-0,-0]
|
||||
; X32-NEXT: vxorps %xmm3, %xmm2, %xmm2
|
||||
; X32-NEXT: vbroadcastss %xmm2, %ymm2
|
||||
; X32-NEXT: vfmadd213ps {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1
|
||||
; X32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %ymm2
|
||||
; X32-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm2 * ymm0) + ymm1
|
||||
; X32-NEXT: retl
|
||||
;
|
||||
; X64-LABEL: test8:
|
||||
; X64: # %bb.0: # %entry
|
||||
; X64-NEXT: vbroadcastss {{.*#+}} xmm3 = [-0,-0,-0,-0]
|
||||
; X64-NEXT: vxorps %xmm3, %xmm0, %xmm0
|
||||
; X64-NEXT: vbroadcastss %xmm0, %ymm0
|
||||
; X64-NEXT: vfmadd213ps {{.*#+}} ymm0 = (ymm1 * ymm0) + ymm2
|
||||
; X64-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm2
|
||||
; X64-NEXT: retq
|
||||
entry:
|
||||
%0 = fsub float -0.0, %a
|
||||
|
|
Loading…
Reference in New Issue