[X86] Try to avoid casts around logical vector ops recursively.

Currently PromoteMaskArithemtic only looks at a single operation to
skip casts. This means we miss cases where we combine multiple masks.

This patch updates PromoteMaskArithemtic to try to recursively promote
AND/XOR/AND nodes that terminate in truncates of the right size or
constant vectors.

Reviewers: craig.topper, RKSimon, spatel

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D72524
This commit is contained in:
Florian Hahn 2020-01-19 17:11:43 -08:00
parent 886d2c2ca7
commit 0ee1db2d1d
2 changed files with 272 additions and 605 deletions

View File

@ -39898,6 +39898,65 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
}
// Try to widen AND, OR and XOR nodes to VT in order to remove casts around
// logical operations, like in the example below.
// or (and (truncate x, truncate y)),
// (xor (truncate z, build_vector (constants)))
// Given a target type \p VT, we generate
// or (and x, y), (xor z, zext(build_vector (constants)))
// given x, y and z are of type \p VT. We can do so, if operands are either
// truncates from VT types, the second operand is a vector of constants or can
// be recursively promoted.
static SDValue PromoteMaskArithmetic(SDNode *N, EVT VT, SelectionDAG &DAG,
unsigned Depth) {
// Limit recursion to avoid excessive compile times.
if (Depth >= SelectionDAG::MaxRecursionDepth)
return SDValue();
if (N->getOpcode() != ISD::XOR && N->getOpcode() != ISD::AND &&
N->getOpcode() != ISD::OR)
return SDValue();
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDLoc DL(N);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isOperationLegalOrPromote(N->getOpcode(), VT))
return SDValue();
if (SDValue NN0 = PromoteMaskArithmetic(N0.getNode(), VT, DAG, Depth + 1))
N0 = NN0;
else {
// The Left side has to be a trunc.
if (N0.getOpcode() != ISD::TRUNCATE)
return SDValue();
// The type of the truncated inputs.
if (N0.getOperand(0).getValueType() != VT)
return SDValue();
N0 = N0.getOperand(0);
}
if (SDValue NN1 = PromoteMaskArithmetic(N1.getNode(), VT, DAG, Depth + 1))
N1 = NN1;
else {
// The right side has to be a 'trunc' or a constant vector.
bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
N1.getOperand(0).getValueType() == VT;
if (!RHSTrunc && !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
return SDValue();
if (RHSTrunc)
N1 = N1.getOperand(0);
else
N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
}
return DAG.getNode(N->getOpcode(), DL, VT, N0, N1);
}
// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
// register. In most cases we actually compare or select YMM-sized registers
// and mixing the two types creates horrible code. This method optimizes
@ -39909,6 +39968,7 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
EVT VT = N->getValueType(0);
assert(VT.isVector() && "Expected vector type");
SDLoc DL(N);
assert((N->getOpcode() == ISD::ANY_EXTEND ||
N->getOpcode() == ISD::ZERO_EXTEND ||
N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
@ -39916,46 +39976,11 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
SDValue Narrow = N->getOperand(0);
EVT NarrowVT = Narrow.getValueType();
if (Narrow->getOpcode() != ISD::XOR &&
Narrow->getOpcode() != ISD::AND &&
Narrow->getOpcode() != ISD::OR)
return SDValue();
SDValue N0 = Narrow->getOperand(0);
SDValue N1 = Narrow->getOperand(1);
SDLoc DL(Narrow);
// The Left side has to be a trunc.
if (N0.getOpcode() != ISD::TRUNCATE)
return SDValue();
// The type of the truncated inputs.
if (N0.getOperand(0).getValueType() != VT)
return SDValue();
// The right side has to be a 'trunc' or a constant vector.
bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
N1.getOperand(0).getValueType() == VT;
if (!RHSTrunc &&
!ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
return SDValue();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT))
return SDValue();
// Set N0 and N1 to hold the inputs to the new wide operation.
N0 = N0.getOperand(0);
if (RHSTrunc)
N1 = N1.getOperand(0);
else
N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
// Generate the wide operation.
SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1);
unsigned Opcode = N->getOpcode();
switch (Opcode) {
SDValue Op = PromoteMaskArithmetic(Narrow.getNode(), VT, DAG, 0);
if (!Op)
return SDValue();
switch (N->getOpcode()) {
default: llvm_unreachable("Unexpected opcode");
case ISD::ANY_EXTEND:
return Op;

File diff suppressed because it is too large Load Diff