forked from OSchip/llvm-project
[SLP] reduce opcode API dependency in reduction cost calc; NFC
The icmp opcode is now hard-coded in the cost model call. This will make it easier to eventually remove all opcode queries for min/max patterns as we transition to intrinsics.
This commit is contained in:
parent
2040c1110b
commit
d1c4e859ce
|
@ -7058,12 +7058,10 @@ private:
|
|||
int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal,
|
||||
unsigned ReduxWidth) {
|
||||
Type *ScalarTy = FirstReducedVal->getType();
|
||||
auto *VecTy = FixedVectorType::get(ScalarTy, ReduxWidth);
|
||||
FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth);
|
||||
|
||||
RecurKind Kind = RdxTreeInst.getKind();
|
||||
unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
|
||||
int SplittingRdxCost;
|
||||
int ScalarReduxCost;
|
||||
int VectorCost, ScalarCost;
|
||||
switch (Kind) {
|
||||
case RecurKind::Add:
|
||||
case RecurKind::Mul:
|
||||
|
@ -7071,22 +7069,24 @@ private:
|
|||
case RecurKind::And:
|
||||
case RecurKind::Xor:
|
||||
case RecurKind::FAdd:
|
||||
case RecurKind::FMul:
|
||||
SplittingRdxCost = TTI->getArithmeticReductionCost(
|
||||
RdxOpcode, VecTy, /*IsPairwiseForm=*/false);
|
||||
ScalarReduxCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
|
||||
case RecurKind::FMul: {
|
||||
unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
|
||||
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
|
||||
/*IsPairwiseForm=*/false);
|
||||
ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
|
||||
break;
|
||||
}
|
||||
case RecurKind::SMax:
|
||||
case RecurKind::SMin:
|
||||
case RecurKind::UMax:
|
||||
case RecurKind::UMin: {
|
||||
auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VecTy));
|
||||
auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy));
|
||||
bool IsUnsigned = Kind == RecurKind::UMax || Kind == RecurKind::UMin;
|
||||
SplittingRdxCost =
|
||||
TTI->getMinMaxReductionCost(VecTy, VecCondTy,
|
||||
VectorCost =
|
||||
TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
|
||||
/*IsPairwiseForm=*/false, IsUnsigned);
|
||||
ScalarReduxCost =
|
||||
TTI->getCmpSelInstrCost(RdxOpcode, ScalarTy) +
|
||||
ScalarCost =
|
||||
TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy) +
|
||||
TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
|
||||
CmpInst::makeCmpResultType(ScalarTy));
|
||||
break;
|
||||
|
@ -7095,12 +7095,12 @@ private:
|
|||
llvm_unreachable("Expected arithmetic or min/max reduction operation");
|
||||
}
|
||||
|
||||
ScalarReduxCost *= (ReduxWidth - 1);
|
||||
LLVM_DEBUG(dbgs() << "SLP: Adding cost "
|
||||
<< SplittingRdxCost - ScalarReduxCost
|
||||
// Scalar cost is repeated for N-1 elements.
|
||||
ScalarCost *= (ReduxWidth - 1);
|
||||
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost
|
||||
<< " for reduction that starts with " << *FirstReducedVal
|
||||
<< " (It is a splitting reduction)\n");
|
||||
return SplittingRdxCost - ScalarReduxCost;
|
||||
return VectorCost - ScalarCost;
|
||||
}
|
||||
|
||||
/// Emit a horizontal reduction of the vectorized value.
|
||||
|
|
Loading…
Reference in New Issue