diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 297cc8e55fb4..b7bcc34b3b3b 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -876,6 +876,10 @@ public: static ReductionKind matchVectorSplittingReduction( const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty); + static ReductionKind matchVectorReduction(const ExtractElementInst *ReduxRoot, + unsigned &Opcode, VectorType *&Ty, + bool &IsPairwise); + /// Additional information about an operand's possible values. enum OperandValueKind { OK_AnyValue, // Operand can have any value. diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index ebd1beb6e39e..22708d073a1b 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -1004,41 +1004,23 @@ public: if (CI) Idx = CI->getZExtValue(); - // Try to match a reduction sequence (series of shufflevector and - // vector adds followed by a extractelement). - unsigned ReduxOpCode; - VectorType *ReduxType; - - switch (TTI::matchVectorSplittingReduction(EEI, ReduxOpCode, - ReduxType)) { + // Try to match a reduction (a series of shufflevector and vector ops + // followed by an extractelement). + unsigned RdxOpcode; + VectorType *RdxType; + bool IsPairwise; + switch (TTI::matchVectorReduction(EEI, RdxOpcode, RdxType, IsPairwise)) { case TTI::RK_Arithmetic: - return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, - /*IsPairwiseForm=*/false, - CostKind); + return TargetTTI->getArithmeticReductionCost(RdxOpcode, RdxType, + IsPairwise, CostKind); case TTI::RK_MinMax: return TargetTTI->getMinMaxReductionCost( - ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), - /*IsPairwiseForm=*/false, /*IsUnsigned=*/false, CostKind); + RdxType, cast(CmpInst::makeCmpResultType(RdxType)), + IsPairwise, /*IsUnsigned=*/false, CostKind); case TTI::RK_UnsignedMinMax: return TargetTTI->getMinMaxReductionCost( - ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), - /*IsPairwiseForm=*/false, /*IsUnsigned=*/true, CostKind); - case TTI::RK_None: - break; - } - - switch (TTI::matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { - case TTI::RK_Arithmetic: - return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, - /*IsPairwiseForm=*/true, CostKind); - case TTI::RK_MinMax: - return TargetTTI->getMinMaxReductionCost( - ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), - /*IsPairwiseForm=*/true, /*IsUnsigned=*/false, CostKind); - case TTI::RK_UnsignedMinMax: - return TargetTTI->getMinMaxReductionCost( - ReduxType, cast(CmpInst::makeCmpResultType(ReduxType)), - /*IsPairwiseForm=*/true, /*IsUnsigned=*/true, CostKind); + RdxType, cast(CmpInst::makeCmpResultType(RdxType)), + IsPairwise, /*IsUnsigned=*/true, CostKind); case TTI::RK_None: break; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index efecb4501853..4836d80ddb2d 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1308,6 +1308,18 @@ TTI::ReductionKind TTI::matchVectorSplittingReduction( return RD->Kind; } +TTI::ReductionKind +TTI::matchVectorReduction(const ExtractElementInst *Root, unsigned &Opcode, + VectorType *&Ty, bool &IsPairwise) { + TTI::ReductionKind RdxKind = matchVectorSplittingReduction(Root, Opcode, Ty); + if (RdxKind != TTI::ReductionKind::RK_None) { + IsPairwise = false; + return RdxKind; + } + IsPairwise = true; + return matchPairwiseReduction(Root, Opcode, Ty); +} + int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;