forked from OSchip/llvm-project
[x86] invert a vector select IR canonicalization with a binop identity constant
This is an intentionally limited/different form of D90113. That patch bravely tries to generalize folds where we pull a binop into the arms of a select: N0 + (Cond ? 0 : FVal) --> Cond ? N0 : (N0 + FVal) ...but it is not universally profitable. This is the inverse of IR canonicalization as discussed in D113442. We know that this transform is not entirely profitable even within x86, so we only handle x86 vector fadd/fsub as a 1st step. The intent is to prevent AVX512 regressions as mentioned in D113442. The plan is to port this to DAGCombiner (so it will eventually look more like D90113) and add more types/cases in pieces with many more tests to verify that we are seeing improvements. Differential Revision: https://reviews.llvm.org/D118644
This commit is contained in:
parent
ccf02cdf17
commit
6592bcecd4
|
@ -48942,6 +48942,83 @@ static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG,
|
|||
return DAG.getBitcast(VT, CFmul);
|
||||
}
|
||||
|
||||
/// This inverts a canonicalization in IR that replaces a variable select arm
|
||||
/// with an identity constant. Codegen improves if we re-use the variable
|
||||
/// operand rather than load a constant. This can also be converted into a
|
||||
/// masked vector operation if the target supports it.
|
||||
static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
|
||||
bool ShouldCommuteOperands) {
|
||||
// Match a select as operand 1. The identity constant that we are looking for
|
||||
// is only valid as operand 1 of a non-commutative binop.
|
||||
SDValue N0 = N->getOperand(0);
|
||||
SDValue N1 = N->getOperand(1);
|
||||
if (ShouldCommuteOperands)
|
||||
std::swap(N0, N1);
|
||||
|
||||
// TODO: Should this apply to scalar select too?
|
||||
if (!N1.hasOneUse() || N1.getOpcode() != ISD::VSELECT)
|
||||
return SDValue();
|
||||
|
||||
unsigned Opcode = N->getOpcode();
|
||||
EVT VT = N->getValueType(0);
|
||||
SDValue Cond = N1.getOperand(0);
|
||||
SDValue TVal = N1.getOperand(1);
|
||||
SDValue FVal = N1.getOperand(2);
|
||||
|
||||
// TODO: This (and possibly the entire function) belongs in a
|
||||
// target-independent location with target hooks.
|
||||
// TODO: The cases should match with IR's ConstantExpr::getBinOpIdentity().
|
||||
// TODO: With fast-math (NSZ), allow the opposite-sign form of zero?
|
||||
auto isIdentityConstantForOpcode = [](unsigned Opcode, SDValue V) {
|
||||
if (ConstantFPSDNode *C = isConstOrConstSplatFP(V)) {
|
||||
switch (Opcode) {
|
||||
case ISD::FADD: // X + -0.0 --> X
|
||||
return C->isZero() && C->isNegative();
|
||||
case ISD::FSUB: // X - 0.0 --> X
|
||||
return C->isZero() && !C->isNegative();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// This transform increases uses of N0, so freeze it to be safe.
|
||||
// binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
|
||||
if (isIdentityConstantForOpcode(Opcode, TVal)) {
|
||||
SDValue F0 = DAG.getFreeze(N0);
|
||||
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
|
||||
return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
|
||||
}
|
||||
// binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
|
||||
if (isIdentityConstantForOpcode(Opcode, FVal)) {
|
||||
SDValue F0 = DAG.getFreeze(N0);
|
||||
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
|
||||
return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
|
||||
}
|
||||
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
static SDValue combineBinopWithSelect(SDNode *N, SelectionDAG &DAG,
|
||||
const X86Subtarget &Subtarget) {
|
||||
// TODO: This is too general. There are cases where pre-AVX512 codegen would
|
||||
// benefit. The transform may also be profitable for scalar code.
|
||||
if (!Subtarget.hasAVX512())
|
||||
return SDValue();
|
||||
|
||||
if (!Subtarget.hasVLX() && !N->getValueType(0).is512BitVector())
|
||||
return SDValue();
|
||||
|
||||
if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, false))
|
||||
return Sel;
|
||||
|
||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||
if (TLI.isCommutativeBinOp(N->getOpcode()))
|
||||
if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, true))
|
||||
return Sel;
|
||||
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
/// Do target-specific dag combines on floating-point adds/subs.
|
||||
static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
|
||||
const X86Subtarget &Subtarget) {
|
||||
|
@ -48951,6 +49028,9 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
|
|||
if (SDValue COp = combineFaddCFmul(N, DAG, Subtarget))
|
||||
return COp;
|
||||
|
||||
if (SDValue Sel = combineBinopWithSelect(N, DAG, Subtarget))
|
||||
return Sel;
|
||||
|
||||
return SDValue();
|
||||
}
|
||||
|
||||
|
|
|
@ -83,8 +83,8 @@ define <32 x half> @test_int_x86_avx512fp16_maskz_sub_ph_512(<32 x half> %src, <
|
|||
; CHECK: # %bb.0:
|
||||
; CHECK-NEXT: kmovd %edi, %k1
|
||||
; CHECK-NEXT: vsubph %zmm2, %zmm1, %zmm0 {%k1} {z}
|
||||
; CHECK-NEXT: vsubph (%rsi), %zmm1, %zmm1 {%k1} {z}
|
||||
; CHECK-NEXT: vsubph %zmm1, %zmm0, %zmm0
|
||||
; CHECK-NEXT: vsubph (%rsi), %zmm1, %zmm1
|
||||
; CHECK-NEXT: vsubph %zmm1, %zmm0, %zmm0 {%k1}
|
||||
; CHECK-NEXT: retq
|
||||
%mask = bitcast i32 %msk to <32 x i1>
|
||||
%val = load <32 x half>, <32 x half>* %ptr
|
||||
|
|
|
@ -27,9 +27,8 @@ define <4 x float> @fadd_v4f32(<4 x i1> %b, <4 x float> noundef %x, <4 x float>
|
|||
; AVX512VL: # %bb.0:
|
||||
; AVX512VL-NEXT: vpslld $31, %xmm0, %xmm0
|
||||
; AVX512VL-NEXT: vptestmd %xmm0, %xmm0, %k1
|
||||
; AVX512VL-NEXT: vbroadcastss {{.*#+}} xmm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
|
||||
; AVX512VL-NEXT: vmovaps %xmm2, %xmm0 {%k1}
|
||||
; AVX512VL-NEXT: vaddps %xmm0, %xmm1, %xmm0
|
||||
; AVX512VL-NEXT: vaddps %xmm2, %xmm1, %xmm1 {%k1}
|
||||
; AVX512VL-NEXT: vmovaps %xmm1, %xmm0
|
||||
; AVX512VL-NEXT: retq
|
||||
%s = select <4 x i1> %b, <4 x float> %y, <4 x float> <float -0.0, float -0.0, float -0.0, float -0.0>
|
||||
%r = fadd <4 x float> %x, %s
|
||||
|
@ -62,9 +61,8 @@ define <8 x float> @fadd_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x
|
|||
; AVX512VL-NEXT: vpmovsxwd %xmm0, %ymm0
|
||||
; AVX512VL-NEXT: vpslld $31, %ymm0, %ymm0
|
||||
; AVX512VL-NEXT: vptestmd %ymm0, %ymm0, %k1
|
||||
; AVX512VL-NEXT: vbroadcastss {{.*#+}} ymm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
|
||||
; AVX512VL-NEXT: vmovaps %ymm2, %ymm0 {%k1}
|
||||
; AVX512VL-NEXT: vaddps %ymm1, %ymm0, %ymm0
|
||||
; AVX512VL-NEXT: vaddps %ymm2, %ymm1, %ymm1 {%k1}
|
||||
; AVX512VL-NEXT: vmovaps %ymm1, %ymm0
|
||||
; AVX512VL-NEXT: retq
|
||||
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>
|
||||
%r = fadd <8 x float> %s, %x
|
||||
|
@ -92,8 +90,8 @@ define <16 x float> @fadd_v16f32_swap(<16 x i1> %b, <16 x float> noundef %x, <16
|
|||
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
|
||||
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
|
||||
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
|
||||
; AVX512-NEXT: vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1}
|
||||
; AVX512-NEXT: vaddps %zmm2, %zmm1, %zmm0
|
||||
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
|
||||
; AVX512-NEXT: retq
|
||||
%s = select <16 x i1> %b, <16 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>, <16 x float> %y
|
||||
%r = fadd <16 x float> %x, %s
|
||||
|
@ -121,8 +119,8 @@ define <16 x float> @fadd_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef
|
|||
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
|
||||
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
|
||||
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
|
||||
; AVX512-NEXT: vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1}
|
||||
; AVX512-NEXT: vaddps %zmm1, %zmm2, %zmm0
|
||||
; AVX512-NEXT: vaddps %zmm2, %zmm1, %zmm0
|
||||
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
|
||||
; AVX512-NEXT: retq
|
||||
%s = select <16 x i1> %b, <16 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>, <16 x float> %y
|
||||
%r = fadd <16 x float> %s, %x
|
||||
|
@ -152,14 +150,16 @@ define <4 x float> @fsub_v4f32(<4 x i1> %b, <4 x float> noundef %x, <4 x float>
|
|||
; AVX512VL: # %bb.0:
|
||||
; AVX512VL-NEXT: vpslld $31, %xmm0, %xmm0
|
||||
; AVX512VL-NEXT: vptestmd %xmm0, %xmm0, %k1
|
||||
; AVX512VL-NEXT: vmovaps %xmm2, %xmm0 {%k1} {z}
|
||||
; AVX512VL-NEXT: vsubps %xmm0, %xmm1, %xmm0
|
||||
; AVX512VL-NEXT: vsubps %xmm2, %xmm1, %xmm1 {%k1}
|
||||
; AVX512VL-NEXT: vmovaps %xmm1, %xmm0
|
||||
; AVX512VL-NEXT: retq
|
||||
%s = select <4 x i1> %b, <4 x float> %y, <4 x float> zeroinitializer
|
||||
%r = fsub <4 x float> %x, %s
|
||||
ret <4 x float> %r
|
||||
}
|
||||
|
||||
; negative test - fsub is not commutative; there is no identity constant for operand 0
|
||||
|
||||
define <8 x float> @fsub_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x float> noundef %y) {
|
||||
; AVX2-LABEL: fsub_v8f32_commute:
|
||||
; AVX2: # %bb.0:
|
||||
|
@ -214,15 +214,17 @@ define <16 x float> @fsub_v16f32_swap(<16 x i1> %b, <16 x float> noundef %x, <16
|
|||
; AVX512: # %bb.0:
|
||||
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
|
||||
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
|
||||
; AVX512-NEXT: vptestnmd %zmm0, %zmm0, %k1
|
||||
; AVX512-NEXT: vmovaps %zmm2, %zmm0 {%k1} {z}
|
||||
; AVX512-NEXT: vsubps %zmm0, %zmm1, %zmm0
|
||||
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
|
||||
; AVX512-NEXT: vsubps %zmm2, %zmm1, %zmm0
|
||||
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
|
||||
; AVX512-NEXT: retq
|
||||
%s = select <16 x i1> %b, <16 x float> zeroinitializer, <16 x float> %y
|
||||
%r = fsub <16 x float> %x, %s
|
||||
ret <16 x float> %r
|
||||
}
|
||||
|
||||
; negative test - fsub is not commutative; there is no identity constant for operand 0
|
||||
|
||||
define <16 x float> @fsub_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef %x, <16 x float> noundef %y) {
|
||||
; AVX2-LABEL: fsub_v16f32_commute_swap:
|
||||
; AVX2: # %bb.0:
|
||||
|
@ -570,9 +572,7 @@ define <8 x float> @fadd_v8f32_cast_cond(i8 noundef zeroext %pb, <8 x float> nou
|
|||
; AVX512VL-LABEL: fadd_v8f32_cast_cond:
|
||||
; AVX512VL: # %bb.0:
|
||||
; AVX512VL-NEXT: kmovw %edi, %k1
|
||||
; AVX512VL-NEXT: vbroadcastss {{.*#+}} ymm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
|
||||
; AVX512VL-NEXT: vmovaps %ymm1, %ymm2 {%k1}
|
||||
; AVX512VL-NEXT: vaddps %ymm2, %ymm0, %ymm0
|
||||
; AVX512VL-NEXT: vaddps %ymm1, %ymm0, %ymm0 {%k1}
|
||||
; AVX512VL-NEXT: retq
|
||||
%b = bitcast i8 %pb to <8 x i1>
|
||||
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>
|
||||
|
@ -636,9 +636,7 @@ define <8 x double> @fadd_v8f64_cast_cond(i8 noundef zeroext %pb, <8 x double> n
|
|||
; AVX512-LABEL: fadd_v8f64_cast_cond:
|
||||
; AVX512: # %bb.0:
|
||||
; AVX512-NEXT: kmovw %edi, %k1
|
||||
; AVX512-NEXT: vbroadcastsd {{.*#+}} zmm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
|
||||
; AVX512-NEXT: vmovapd %zmm1, %zmm2 {%k1}
|
||||
; AVX512-NEXT: vaddpd %zmm2, %zmm0, %zmm0
|
||||
; AVX512-NEXT: vaddpd %zmm1, %zmm0, %zmm0 {%k1}
|
||||
; AVX512-NEXT: retq
|
||||
%b = bitcast i8 %pb to <8 x i1>
|
||||
%s = select <8 x i1> %b, <8 x double> %y, <8 x double> <double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0>
|
||||
|
@ -709,8 +707,7 @@ define <8 x float> @fsub_v8f32_cast_cond(i8 noundef zeroext %pb, <8 x float> nou
|
|||
; AVX512VL-LABEL: fsub_v8f32_cast_cond:
|
||||
; AVX512VL: # %bb.0:
|
||||
; AVX512VL-NEXT: kmovw %edi, %k1
|
||||
; AVX512VL-NEXT: vmovaps %ymm1, %ymm1 {%k1} {z}
|
||||
; AVX512VL-NEXT: vsubps %ymm1, %ymm0, %ymm0
|
||||
; AVX512VL-NEXT: vsubps %ymm1, %ymm0, %ymm0 {%k1}
|
||||
; AVX512VL-NEXT: retq
|
||||
%b = bitcast i8 %pb to <8 x i1>
|
||||
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> zeroinitializer
|
||||
|
@ -775,8 +772,7 @@ define <8 x double> @fsub_v8f64_cast_cond(i8 noundef zeroext %pb, <8 x double> n
|
|||
; AVX512-LABEL: fsub_v8f64_cast_cond:
|
||||
; AVX512: # %bb.0:
|
||||
; AVX512-NEXT: kmovw %edi, %k1
|
||||
; AVX512-NEXT: vmovapd %zmm1, %zmm1 {%k1} {z}
|
||||
; AVX512-NEXT: vsubpd %zmm1, %zmm0, %zmm0
|
||||
; AVX512-NEXT: vsubpd %zmm1, %zmm0, %zmm0 {%k1}
|
||||
; AVX512-NEXT: retq
|
||||
%b = bitcast i8 %pb to <8 x i1>
|
||||
%s = select <8 x i1> %b, <8 x double> %y, <8 x double> zeroinitializer
|
||||
|
|
Loading…
Reference in New Issue