[InstCombine] Add more complex folds for extractelement + stepvector

I have updated cheapToScalarize to also consider the case when
extracting lanes from a stepvector intrinsic. This required removing
the existing 'bool IsConstantExtractIndex' and passing in the actual
index as a Value instead. We do this because we need to know if the
index is <= known minimum number of elements returned by the stepvector
intrinsic. Effectively, when extracting lane X from a stepvector we
know the value returned is also X.

New tests added here:

  Transforms/InstCombine/vscale_extractelement.ll

Differential Revision: https://reviews.llvm.org/D106358
This commit is contained in:
David Sherwood 2021-07-19 15:18:05 +01:00
parent 0831f8bf79
commit ce394161cb
2 changed files with 50 additions and 13 deletions

View File

@ -52,20 +52,29 @@ STATISTIC(NumAggregateReconstructionsSimplified,
"original aggregate");
/// Return true if the value is cheaper to scalarize than it is to leave as a
/// vector operation. IsConstantExtractIndex indicates whether we are extracting
/// one known element from a vector constant.
/// vector operation. If the extract index \p EI is a constant integer then
/// some operations may be cheap to scalarize.
///
/// FIXME: It's possible to create more instructions than previously existed.
static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) {
static bool cheapToScalarize(Value *V, Value *EI) {
ConstantInt *CEI = dyn_cast<ConstantInt>(EI);
// If we can pick a scalar constant value out of a vector, that is free.
if (auto *C = dyn_cast<Constant>(V))
return IsConstantExtractIndex || C->getSplatValue();
return CEI || C->getSplatValue();
if (CEI && match(V, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
ElementCount EC = cast<VectorType>(V->getType())->getElementCount();
// Index needs to be lower than the minimum size of the vector, because
// for scalable vector, the vector size is known at run time.
return CEI->getValue().ult(EC.getKnownMinValue());
}
// An insertelement to the same constant index as our extract will simplify
// to the scalar inserted element. An insertelement to a different constant
// index is irrelevant to our extract.
if (match(V, m_InsertElt(m_Value(), m_Value(), m_ConstantInt())))
return IsConstantExtractIndex;
return CEI;
if (match(V, m_OneUse(m_Load(m_Value()))))
return true;
@ -75,14 +84,12 @@ static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) {
Value *V0, *V1;
if (match(V, m_OneUse(m_BinOp(m_Value(V0), m_Value(V1)))))
if (cheapToScalarize(V0, IsConstantExtractIndex) ||
cheapToScalarize(V1, IsConstantExtractIndex))
if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI))
return true;
CmpInst::Predicate UnusedPred;
if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1)))))
if (cheapToScalarize(V0, IsConstantExtractIndex) ||
cheapToScalarize(V1, IsConstantExtractIndex))
if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI))
return true;
return false;
@ -119,7 +126,8 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
// and that it is a binary operation which is cheap to scalarize.
// otherwise return nullptr.
if (!PHIUser->hasOneUse() || !(PHIUser->user_back() == PN) ||
!(isa<BinaryOperator>(PHIUser)) || !cheapToScalarize(PHIUser, true))
!(isa<BinaryOperator>(PHIUser)) ||
!cheapToScalarize(PHIUser, EI.getIndexOperand()))
return nullptr;
// Create a scalar PHI node that will replace the vector PHI node
@ -415,7 +423,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
// TODO come up with a n-ary matcher that subsumes both unary and
// binary matchers.
UnaryOperator *UO;
if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, IndexC)) {
if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, Index)) {
// extelt (unop X), Index --> unop (extelt X, Index)
Value *X = UO->getOperand(0);
Value *E = Builder.CreateExtractElement(X, Index);
@ -423,7 +431,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
}
BinaryOperator *BO;
if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, IndexC)) {
if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, Index)) {
// extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index)
Value *X = BO->getOperand(0), *Y = BO->getOperand(1);
Value *E0 = Builder.CreateExtractElement(X, Index);
@ -434,7 +442,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
Value *X, *Y;
CmpInst::Predicate Pred;
if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) &&
cheapToScalarize(SrcVec, IndexC)) {
cheapToScalarize(SrcVec, Index)) {
// extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index)
Value *E0 = Builder.CreateExtractElement(X, Index);
Value *E1 = Builder.CreateExtractElement(Y, Index);

View File

@ -243,6 +243,35 @@ entry:
ret i8 %1
}
; Check that we can extract more complex cases where the stepvector is
; involved in a binary operation prior to the lane being extracted.
define i64 @ext_lane0_from_add_with_stepvec(i64 %i) {
; CHECK-LABEL: @ext_lane0_from_add_with_stepvec(
; CHECK-NEXT: ret i64 [[I:%.*]]
;
%tmp = insertelement <vscale x 2 x i64> poison, i64 %i, i32 0
%splatofi = shufflevector <vscale x 2 x i64> %tmp, <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
%stepvec = call <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
%add = add <vscale x 2 x i64> %splatofi, %stepvec
%res = extractelement <vscale x 2 x i64> %add, i32 0
ret i64 %res
}
define i1 @ext_lane1_from_cmp_with_stepvec(i64 %i) {
; CHECK-LABEL: @ext_lane1_from_cmp_with_stepvec(
; CHECK-NEXT: [[RES:%.*]] = icmp eq i64 [[I:%.*]], 1
; CHECK-NEXT: ret i1 [[RES]]
;
%tmp = insertelement <vscale x 2 x i64> poison, i64 %i, i32 0
%splatofi = shufflevector <vscale x 2 x i64> %tmp, <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
%stepvec = call <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
%cmp = icmp eq <vscale x 2 x i64> %splatofi, %stepvec
%res = extractelement <vscale x 2 x i1> %cmp, i32 1
ret i1 %res
}
declare <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
declare <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
declare <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
declare <vscale x 512 x i8> @llvm.experimental.stepvector.nxv512i8()