[DAGCombiner] Fold step_vector with add/mul/shl

This patch implements some DAG combines for STEP_VECTOR:
add step_vector(C1), step_vector(C2) -> step_vector(C1+C2)
add (add X step_vector(C1)), step_vector(C2) -> add X step_vector(C1+C2)
mul step_vector(C1), C2 -> step_vector(C1*C2)
shl step_vector(C1), C2 -> step_vector(C1<<C2)

TestPlan: check-llvm

Differential Revision: https://reviews.llvm.org/D100088
This commit is contained in:
Jun Ma 2021-04-08 13:09:24 +08:00
parent ea14df695e
commit 7e1422c1e4
2 changed files with 100 additions and 0 deletions

View File

@ -2503,6 +2503,31 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
}
// Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
if (N0.getOpcode() == ISD::STEP_VECTOR &&
N1.getOpcode() == ISD::STEP_VECTOR) {
const APInt &C0 = N0->getConstantOperandAPInt(0);
const APInt &C1 = N1->getConstantOperandAPInt(0);
EVT SVT = N0.getOperand(0).getValueType();
SDValue NewStep = DAG.getConstant(C0 + C1, DL, SVT);
return DAG.getStepVector(DL, VT, NewStep);
}
// Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
if ((N0.getOpcode() == ISD::ADD) &&
(N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR) &&
(N1.getOpcode() == ISD::STEP_VECTOR)) {
const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
const APInt &SV1 = N1->getConstantOperandAPInt(0);
EVT SVT = N1.getOperand(0).getValueType();
assert(N1.getOperand(0).getValueType() ==
N0.getOperand(1)->getOperand(0).getValueType() &&
"Different operand types of STEP_VECTOR.");
SDValue NewStep = DAG.getConstant(SV0 + SV1, DL, SVT);
SDValue SV = DAG.getStepVector(DL, VT, NewStep);
return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
}
return SDValue();
}
@ -3893,6 +3918,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
return DAG.getVScale(SDLoc(N), VT, C0 * C1);
}
// Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
APInt MulVal;
if (N0.getOpcode() == ISD::STEP_VECTOR)
if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
const APInt &C0 = N0.getConstantOperandAPInt(0);
EVT SVT = N0.getOperand(0).getValueType();
SDValue NewStep = DAG.getConstant(
C0 * MulVal.sextOrTrunc(SVT.getSizeInBits()), SDLoc(N), SVT);
return DAG.getStepVector(SDLoc(N), VT, NewStep);
}
// Fold ((mul x, 0/undef) -> 0,
// (mul x, 1) -> x) -> x)
// -> and(x, mask)
@ -8381,6 +8417,17 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
return DAG.getVScale(SDLoc(N), VT, C0 << C1);
}
// Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
APInt ShlVal;
if (N0.getOpcode() == ISD::STEP_VECTOR)
if (ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
const APInt &C0 = N0.getConstantOperandAPInt(0);
EVT SVT = N0.getOperand(0).getValueType();
SDValue NewStep = DAG.getConstant(
C0 << ShlVal.sextOrTrunc(SVT.getSizeInBits()), SDLoc(N), SVT);
return DAG.getStepVector(SDLoc(N), VT, NewStep);
}
return SDValue();
}

View File

@ -105,6 +105,59 @@ entry:
ret <vscale x 8 x i8> %0
}
define <vscale x 8 x i8> @add_stepvector_nxv8i8() {
; CHECK-LABEL: add_stepvector_nxv8i8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: index z0.h, #0, #2
; CHECK-NEXT: ret
entry:
%0 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
%1 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
%2 = add <vscale x 8 x i8> %0, %1
ret <vscale x 8 x i8> %2
}
define <vscale x 8 x i8> @add_stepvector_nxv8i8_1(<vscale x 8 x i8> %p) {
; CHECK-LABEL: add_stepvector_nxv8i8_1:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: index z1.h, #0, #2
; CHECK-NEXT: add z0.h, z0.h, z1.h
; CHECK-NEXT: ret
entry:
%0 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
%1 = add <vscale x 8 x i8> %p, %0
%2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
%3 = add <vscale x 8 x i8> %1, %2
ret <vscale x 8 x i8> %3
}
define <vscale x 8 x i8> @mul_stepvector_nxv8i8() {
; CHECK-LABEL: mul_stepvector_nxv8i8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: index z0.h, #0, #2
; CHECK-NEXT: ret
entry:
%0 = insertelement <vscale x 8 x i8> poison, i8 2, i32 0
%1 = shufflevector <vscale x 8 x i8> %0, <vscale x 8 x i8> poison, <vscale x 8 x i32> zeroinitializer
%2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
%3 = mul <vscale x 8 x i8> %2, %1
ret <vscale x 8 x i8> %3
}
define <vscale x 8 x i8> @shl_stepvector_nxv8i8() {
; CHECK-LABEL: shl_stepvector_nxv8i8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: index z0.h, #0, #4
; CHECK-NEXT: ret
entry:
%0 = insertelement <vscale x 8 x i8> poison, i8 2, i32 0
%1 = shufflevector <vscale x 8 x i8> %0, <vscale x 8 x i8> poison, <vscale x 8 x i32> zeroinitializer
%2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
%3 = shl <vscale x 8 x i8> %2, %1
ret <vscale x 8 x i8> %3
}
declare <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
declare <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
declare <vscale x 8 x i16> @llvm.experimental.stepvector.nxv8i16()