[SLP] reduce opcode API dependency in reduction cost calc; NFC

The icmp opcode is now hard-coded in the cost model call.
This will make it easier to eventually remove all opcode
queries for min/max patterns as we transition to intrinsics.
This commit is contained in:
Sanjay Patel 2021-01-18 08:57:09 -05:00
parent 2040c1110b
commit d1c4e859ce
1 changed files with 17 additions and 17 deletions

View File

@ -7058,12 +7058,10 @@ private:
int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal,
unsigned ReduxWidth) {
Type *ScalarTy = FirstReducedVal->getType();
auto *VecTy = FixedVectorType::get(ScalarTy, ReduxWidth);
FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth);
RecurKind Kind = RdxTreeInst.getKind();
unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
int SplittingRdxCost;
int ScalarReduxCost;
int VectorCost, ScalarCost;
switch (Kind) {
case RecurKind::Add:
case RecurKind::Mul:
@ -7071,22 +7069,24 @@ private:
case RecurKind::And:
case RecurKind::Xor:
case RecurKind::FAdd:
case RecurKind::FMul:
SplittingRdxCost = TTI->getArithmeticReductionCost(
RdxOpcode, VecTy, /*IsPairwiseForm=*/false);
ScalarReduxCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
case RecurKind::FMul: {
unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
/*IsPairwiseForm=*/false);
ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
break;
}
case RecurKind::SMax:
case RecurKind::SMin:
case RecurKind::UMax:
case RecurKind::UMin: {
auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VecTy));
auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy));
bool IsUnsigned = Kind == RecurKind::UMax || Kind == RecurKind::UMin;
SplittingRdxCost =
TTI->getMinMaxReductionCost(VecTy, VecCondTy,
VectorCost =
TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
/*IsPairwiseForm=*/false, IsUnsigned);
ScalarReduxCost =
TTI->getCmpSelInstrCost(RdxOpcode, ScalarTy) +
ScalarCost =
TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy) +
TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
CmpInst::makeCmpResultType(ScalarTy));
break;
@ -7095,12 +7095,12 @@ private:
llvm_unreachable("Expected arithmetic or min/max reduction operation");
}
ScalarReduxCost *= (ReduxWidth - 1);
LLVM_DEBUG(dbgs() << "SLP: Adding cost "
<< SplittingRdxCost - ScalarReduxCost
// Scalar cost is repeated for N-1 elements.
ScalarCost *= (ReduxWidth - 1);
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost
<< " for reduction that starts with " << *FirstReducedVal
<< " (It is a splitting reduction)\n");
return SplittingRdxCost - ScalarReduxCost;
return VectorCost - ScalarCost;
}
/// Emit a horizontal reduction of the vectorized value.