forked from OSchip/llvm-project
[X86] Try to avoid casts around logical vector ops recursively.
Currently PromoteMaskArithemtic only looks at a single operation to skip casts. This means we miss cases where we combine multiple masks. This patch updates PromoteMaskArithemtic to try to recursively promote AND/XOR/AND nodes that terminate in truncates of the right size or constant vectors. Reviewers: craig.topper, RKSimon, spatel Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D72524
This commit is contained in:
parent
886d2c2ca7
commit
0ee1db2d1d
|
@ -39898,6 +39898,65 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
|
||||||
return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
|
return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to widen AND, OR and XOR nodes to VT in order to remove casts around
|
||||||
|
// logical operations, like in the example below.
|
||||||
|
// or (and (truncate x, truncate y)),
|
||||||
|
// (xor (truncate z, build_vector (constants)))
|
||||||
|
// Given a target type \p VT, we generate
|
||||||
|
// or (and x, y), (xor z, zext(build_vector (constants)))
|
||||||
|
// given x, y and z are of type \p VT. We can do so, if operands are either
|
||||||
|
// truncates from VT types, the second operand is a vector of constants or can
|
||||||
|
// be recursively promoted.
|
||||||
|
static SDValue PromoteMaskArithmetic(SDNode *N, EVT VT, SelectionDAG &DAG,
|
||||||
|
unsigned Depth) {
|
||||||
|
// Limit recursion to avoid excessive compile times.
|
||||||
|
if (Depth >= SelectionDAG::MaxRecursionDepth)
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
if (N->getOpcode() != ISD::XOR && N->getOpcode() != ISD::AND &&
|
||||||
|
N->getOpcode() != ISD::OR)
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
SDValue N0 = N->getOperand(0);
|
||||||
|
SDValue N1 = N->getOperand(1);
|
||||||
|
SDLoc DL(N);
|
||||||
|
|
||||||
|
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
||||||
|
if (!TLI.isOperationLegalOrPromote(N->getOpcode(), VT))
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
if (SDValue NN0 = PromoteMaskArithmetic(N0.getNode(), VT, DAG, Depth + 1))
|
||||||
|
N0 = NN0;
|
||||||
|
else {
|
||||||
|
// The Left side has to be a trunc.
|
||||||
|
if (N0.getOpcode() != ISD::TRUNCATE)
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
// The type of the truncated inputs.
|
||||||
|
if (N0.getOperand(0).getValueType() != VT)
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
N0 = N0.getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (SDValue NN1 = PromoteMaskArithmetic(N1.getNode(), VT, DAG, Depth + 1))
|
||||||
|
N1 = NN1;
|
||||||
|
else {
|
||||||
|
// The right side has to be a 'trunc' or a constant vector.
|
||||||
|
bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
|
||||||
|
N1.getOperand(0).getValueType() == VT;
|
||||||
|
if (!RHSTrunc && !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
|
||||||
|
return SDValue();
|
||||||
|
|
||||||
|
if (RHSTrunc)
|
||||||
|
N1 = N1.getOperand(0);
|
||||||
|
else
|
||||||
|
N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return DAG.getNode(N->getOpcode(), DL, VT, N0, N1);
|
||||||
|
}
|
||||||
|
|
||||||
// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
|
// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
|
||||||
// register. In most cases we actually compare or select YMM-sized registers
|
// register. In most cases we actually compare or select YMM-sized registers
|
||||||
// and mixing the two types creates horrible code. This method optimizes
|
// and mixing the two types creates horrible code. This method optimizes
|
||||||
|
@ -39909,6 +39968,7 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
|
||||||
EVT VT = N->getValueType(0);
|
EVT VT = N->getValueType(0);
|
||||||
assert(VT.isVector() && "Expected vector type");
|
assert(VT.isVector() && "Expected vector type");
|
||||||
|
|
||||||
|
SDLoc DL(N);
|
||||||
assert((N->getOpcode() == ISD::ANY_EXTEND ||
|
assert((N->getOpcode() == ISD::ANY_EXTEND ||
|
||||||
N->getOpcode() == ISD::ZERO_EXTEND ||
|
N->getOpcode() == ISD::ZERO_EXTEND ||
|
||||||
N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
|
N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
|
||||||
|
@ -39916,46 +39976,11 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
|
||||||
SDValue Narrow = N->getOperand(0);
|
SDValue Narrow = N->getOperand(0);
|
||||||
EVT NarrowVT = Narrow.getValueType();
|
EVT NarrowVT = Narrow.getValueType();
|
||||||
|
|
||||||
if (Narrow->getOpcode() != ISD::XOR &&
|
|
||||||
Narrow->getOpcode() != ISD::AND &&
|
|
||||||
Narrow->getOpcode() != ISD::OR)
|
|
||||||
return SDValue();
|
|
||||||
|
|
||||||
SDValue N0 = Narrow->getOperand(0);
|
|
||||||
SDValue N1 = Narrow->getOperand(1);
|
|
||||||
SDLoc DL(Narrow);
|
|
||||||
|
|
||||||
// The Left side has to be a trunc.
|
|
||||||
if (N0.getOpcode() != ISD::TRUNCATE)
|
|
||||||
return SDValue();
|
|
||||||
|
|
||||||
// The type of the truncated inputs.
|
|
||||||
if (N0.getOperand(0).getValueType() != VT)
|
|
||||||
return SDValue();
|
|
||||||
|
|
||||||
// The right side has to be a 'trunc' or a constant vector.
|
|
||||||
bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
|
|
||||||
N1.getOperand(0).getValueType() == VT;
|
|
||||||
if (!RHSTrunc &&
|
|
||||||
!ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
|
|
||||||
return SDValue();
|
|
||||||
|
|
||||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
|
||||||
|
|
||||||
if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT))
|
|
||||||
return SDValue();
|
|
||||||
|
|
||||||
// Set N0 and N1 to hold the inputs to the new wide operation.
|
|
||||||
N0 = N0.getOperand(0);
|
|
||||||
if (RHSTrunc)
|
|
||||||
N1 = N1.getOperand(0);
|
|
||||||
else
|
|
||||||
N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
|
|
||||||
|
|
||||||
// Generate the wide operation.
|
// Generate the wide operation.
|
||||||
SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1);
|
SDValue Op = PromoteMaskArithmetic(Narrow.getNode(), VT, DAG, 0);
|
||||||
unsigned Opcode = N->getOpcode();
|
if (!Op)
|
||||||
switch (Opcode) {
|
return SDValue();
|
||||||
|
switch (N->getOpcode()) {
|
||||||
default: llvm_unreachable("Unexpected opcode");
|
default: llvm_unreachable("Unexpected opcode");
|
||||||
case ISD::ANY_EXTEND:
|
case ISD::ANY_EXTEND:
|
||||||
return Op;
|
return Op;
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue