[LoopUtils] reduce FMF and min/max complexity when forming reductions

I don't know if there's some way this changes what the vectorizers
may produce for reductions, but I have added test coverage with
3567908 and 5ced712 to show that both passes already have bugs in
this area. Hopefully this does not make things worse before we can
really fix it.
This commit is contained in:
Sanjay Patel 2020-12-30 15:22:26 -05:00
parent 5ced712e98
commit 8ca60db40b
4 changed files with 85 additions and 87 deletions

View File

@ -365,24 +365,21 @@ Value *getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op,
/// Create a target reduction of the given vector. The reduction operation
/// is described by the \p Opcode parameter. min/max reductions require
/// additional information supplied in \p Flags.
/// additional information supplied in \p MinMaxKind.
/// The target is queried to determine if intrinsics or shuffle sequences are
/// required to implement the reduction.
/// Fast-math-flags are propagated using the IRBuilder's setting.
Value *createSimpleTargetReduction(IRBuilderBase &B,
const TargetTransformInfo *TTI,
unsigned Opcode, Value *Src,
TargetTransformInfo::ReductionFlags Flags =
TargetTransformInfo::ReductionFlags(),
ArrayRef<Value *> RedOps = None);
Value *createSimpleTargetReduction(
IRBuilderBase &B, const TargetTransformInfo *TTI, unsigned Opcode,
Value *Src, RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
ArrayRef<Value *> RedOps = None);
/// Create a generic target reduction using a recurrence descriptor \p Desc
/// The target is queried to determine if intrinsics or shuffle sequences are
/// required to implement the reduction.
/// Fast-math-flags are propagated using the RecurrenceDescriptor.
Value *createTargetReduction(IRBuilderBase &B, const TargetTransformInfo *TTI,
RecurrenceDescriptor &Desc, Value *Src,
bool NoNaN = false);
RecurrenceDescriptor &Desc, Value *Src);
/// Get the intersection (logical and) of all of the potential IR flags
/// of each scalar operation (VL) that will be converted into a vector (I).

View File

@ -985,14 +985,12 @@ llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op,
/// flags (if generating min/max reductions).
Value *llvm::createSimpleTargetReduction(
IRBuilderBase &Builder, const TargetTransformInfo *TTI, unsigned Opcode,
Value *Src, TargetTransformInfo::ReductionFlags Flags,
Value *Src, RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
ArrayRef<Value *> RedOps) {
auto *SrcVTy = cast<VectorType>(Src->getType());
std::function<Value *()> BuildFunc;
using RD = RecurrenceDescriptor;
RD::MinMaxRecurrenceKind MinMaxKind = RD::MRK_Invalid;
switch (Opcode) {
case Instruction::Add:
BuildFunc = [&]() { return Builder.CreateAddReduce(Src); };
@ -1024,33 +1022,42 @@ Value *llvm::createSimpleTargetReduction(
};
break;
case Instruction::ICmp:
if (Flags.IsMaxOp) {
MinMaxKind = Flags.IsSigned ? RD::MRK_SIntMax : RD::MRK_UIntMax;
BuildFunc = [&]() {
return Builder.CreateIntMaxReduce(Src, Flags.IsSigned);
};
} else {
MinMaxKind = Flags.IsSigned ? RD::MRK_SIntMin : RD::MRK_UIntMin;
BuildFunc = [&]() {
return Builder.CreateIntMinReduce(Src, Flags.IsSigned);
};
switch (MinMaxKind) {
case RD::MRK_SIntMax:
BuildFunc = [&]() { return Builder.CreateIntMaxReduce(Src, true); };
break;
case RD::MRK_SIntMin:
BuildFunc = [&]() { return Builder.CreateIntMinReduce(Src, true); };
break;
case RD::MRK_UIntMax:
BuildFunc = [&]() { return Builder.CreateIntMaxReduce(Src, false); };
break;
case RD::MRK_UIntMin:
BuildFunc = [&]() { return Builder.CreateIntMinReduce(Src, false); };
break;
default:
llvm_unreachable("Unexpected min/max reduction type");
}
break;
case Instruction::FCmp:
if (Flags.IsMaxOp) {
MinMaxKind = RD::MRK_FloatMax;
assert((MinMaxKind == RD::MRK_FloatMax || MinMaxKind == RD::MRK_FloatMin) &&
"Unexpected min/max reduction type");
if (MinMaxKind == RD::MRK_FloatMax)
BuildFunc = [&]() { return Builder.CreateFPMaxReduce(Src); };
} else {
MinMaxKind = RD::MRK_FloatMin;
else
BuildFunc = [&]() { return Builder.CreateFPMinReduce(Src); };
}
break;
default:
llvm_unreachable("Unhandled opcode");
break;
}
TargetTransformInfo::ReductionFlags RdxFlags;
RdxFlags.IsMaxOp = MinMaxKind == RD::MRK_SIntMax ||
MinMaxKind == RD::MRK_UIntMax ||
MinMaxKind == RD::MRK_FloatMax;
RdxFlags.IsSigned =
MinMaxKind == RD::MRK_SIntMax || MinMaxKind == RD::MRK_SIntMin;
if (ForceReductionIntrinsic ||
TTI->useReductionIntrinsic(Opcode, Src->getType(), Flags))
TTI->useReductionIntrinsic(Opcode, Src->getType(), RdxFlags))
return BuildFunc();
return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, RedOps);
}
@ -1058,12 +1065,9 @@ Value *llvm::createSimpleTargetReduction(
/// Create a vector reduction using a given recurrence descriptor.
Value *llvm::createTargetReduction(IRBuilderBase &B,
const TargetTransformInfo *TTI,
RecurrenceDescriptor &Desc, Value *Src,
bool NoNaN) {
RecurrenceDescriptor &Desc, Value *Src) {
// TODO: Support in-order reductions based on the recurrence descriptor.
using RD = RecurrenceDescriptor;
TargetTransformInfo::ReductionFlags Flags;
Flags.NoNaN = NoNaN;
// All ops in the reduction inherit fast-math-flags from the recurrence
// descriptor.
@ -1071,11 +1075,8 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
B.setFastMathFlags(Desc.getFastMathFlags());
RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind();
Flags.IsMaxOp = MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_UIntMax ||
MMKind == RD::MRK_FloatMax;
Flags.IsSigned = MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin;
return createSimpleTargetReduction(B, TTI, Desc.getRecurrenceBinOp(), Src,
Flags);
MMKind);
}
void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) {

View File

@ -4325,9 +4325,8 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) {
// Create the reduction after the loop. Note that inloop reductions create the
// target reduction in the loop using a Reduction recipe.
if (VF.isVector() && !IsInLoopReductionPhi) {
bool NoNaN = Legal->hasFunNoNaNAttr();
ReducedPartRdx =
createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, NoNaN);
createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx);
// If the reduction can be performed in a smaller type, we need to extend
// the reduction to the wider type before we branch to the original loop.
if (Phi->getType() != RdxDesc.getRecurrenceType())
@ -8783,7 +8782,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
NewVecOp = Select;
}
Value *NewRed =
createTargetReduction(State.Builder, TTI, *RdxDesc, NewVecOp, NoNaN);
createTargetReduction(State.Builder, TTI, *RdxDesc, NewVecOp);
Value *PrevInChain = State.get(getChainOp(), Part);
Value *NextInChain;
if (Kind == RecurrenceDescriptor::RK_IntegerMinMax ||

View File

@ -6455,7 +6455,7 @@ class HorizontalReduction {
/// Kind of the reduction operation.
RD::RecurrenceKind Kind = RD::RK_NoRecurrence;
TargetTransformInfo::ReductionFlags RdxFlags;
RD::MinMaxRecurrenceKind MMKind = RD::MRK_Invalid;
/// Checks if the reduction operation can be vectorized.
bool isVectorizable() const {
@ -6499,10 +6499,13 @@ class HorizontalReduction {
case RD::RK_IntegerMinMax: {
assert(Opcode == Instruction::ICmp && "Expected integer types.");
ICmpInst::Predicate Pred;
if (RdxFlags.IsMaxOp)
Pred = RdxFlags.IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
else
Pred = RdxFlags.IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
switch (MMKind) {
case RD::MRK_SIntMax: Pred = ICmpInst::ICMP_SGT; break;
case RD::MRK_SIntMin: Pred = ICmpInst::ICMP_SLT; break;
case RD::MRK_UIntMax: Pred = ICmpInst::ICMP_UGT; break;
case RD::MRK_UIntMin: Pred = ICmpInst::ICMP_ULT; break;
default: llvm_unreachable("Unexpected min/max value");
}
Value *Cmp = Builder.CreateICmp(Pred, LHS, RHS, Name);
return Builder.CreateSelect(Cmp, LHS, RHS, Name);
}
@ -6521,9 +6524,9 @@ class HorizontalReduction {
}
/// Constructor for reduction operations with opcode and type.
OperationData(unsigned Opcode, RD::RecurrenceKind Kind,
TargetTransformInfo::ReductionFlags Flags)
: Opcode(Opcode), Kind(Kind), RdxFlags(Flags) {
OperationData(unsigned Opcode, RD::RecurrenceKind RdxKind,
RD::MinMaxRecurrenceKind MinMaxKind)
: Opcode(Opcode), Kind(RdxKind), MMKind(MinMaxKind) {
assert(Kind != RD::RK_NoRecurrence && "Expected reduction operation.");
}
@ -6640,6 +6643,7 @@ class HorizontalReduction {
/// Get kind of reduction data.
RD::RecurrenceKind getKind() const { return Kind; }
RD::MinMaxRecurrenceKind getMinMaxKind() const { return MMKind; }
Value *getLHS(Instruction *I) const {
if (Kind == RD::RK_NoRecurrence)
return nullptr;
@ -6706,8 +6710,6 @@ class HorizontalReduction {
llvm_unreachable("Unknown reduction operation.");
}
}
TargetTransformInfo::ReductionFlags getFlags() const { return RdxFlags; }
};
WeakTrackingVH ReductionRoot;
@ -6749,28 +6751,32 @@ class HorizontalReduction {
TargetTransformInfo::ReductionFlags RdxFlags;
if (match(I, m_Add(m_Value(), m_Value())))
return OperationData(I->getOpcode(), RD::RK_IntegerAdd, RdxFlags);
return OperationData(I->getOpcode(), RD::RK_IntegerAdd, RD::MRK_Invalid);
if (match(I, m_Mul(m_Value(), m_Value())))
return OperationData(I->getOpcode(), RD::RK_IntegerMult, RdxFlags);
return OperationData(I->getOpcode(), RD::RK_IntegerMult, RD::MRK_Invalid);
if (match(I, m_And(m_Value(), m_Value())))
return OperationData(I->getOpcode(), RD::RK_IntegerAnd, RdxFlags);
return OperationData(I->getOpcode(), RD::RK_IntegerAnd, RD::MRK_Invalid);
if (match(I, m_Or(m_Value(), m_Value())))
return OperationData(I->getOpcode(), RD::RK_IntegerOr, RdxFlags);
return OperationData(I->getOpcode(), RD::RK_IntegerOr, RD::MRK_Invalid);
if (match(I, m_Xor(m_Value(), m_Value())))
return OperationData(I->getOpcode(), RD::RK_IntegerXor, RdxFlags);
return OperationData(I->getOpcode(), RD::RK_IntegerXor, RD::MRK_Invalid);
if (match(I, m_FAdd(m_Value(), m_Value())))
return OperationData(I->getOpcode(), RD::RK_FloatAdd, RdxFlags);
return OperationData(I->getOpcode(), RD::RK_FloatAdd, RD::MRK_Invalid);
if (match(I, m_FMul(m_Value(), m_Value())))
return OperationData(I->getOpcode(), RD::RK_FloatMult, RdxFlags);
if (match(I, m_MaxOrMin(m_Value(), m_Value()))) {
RdxFlags.IsMaxOp = match(I, m_UMax(m_Value(), m_Value())) ||
match(I, m_SMax(m_Value(), m_Value()));
RdxFlags.IsSigned = match(I, m_SMin(m_Value(), m_Value())) ||
match(I, m_SMax(m_Value(), m_Value()));
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax, RdxFlags);
}
return OperationData(I->getOpcode(), RD::RK_FloatMult, RD::MRK_Invalid);
if (match(I, m_SMax(m_Value(), m_Value())))
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
RD::MRK_SIntMax);
if (match(I, m_SMin(m_Value(), m_Value())))
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
RD::MRK_SIntMin);
if (match(I, m_UMax(m_Value(), m_Value())))
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
RD::MRK_UIntMax);
if (match(I, m_UMin(m_Value(), m_Value())))
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
RD::MRK_UIntMin);
if (auto *Select = dyn_cast<SelectInst>(I)) {
// Try harder: look for min/max pattern based on instructions producing
@ -6814,28 +6820,23 @@ class HorizontalReduction {
switch (Pred) {
default:
return OperationData(*I);
case CmpInst::ICMP_ULT:
case CmpInst::ICMP_ULE:
RdxFlags.IsMaxOp = false;
RdxFlags.IsSigned = false;
break;
case CmpInst::ICMP_SLT:
case CmpInst::ICMP_SLE:
RdxFlags.IsMaxOp = false;
RdxFlags.IsSigned = true;
break;
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_UGE:
RdxFlags.IsMaxOp = true;
RdxFlags.IsSigned = false;
break;
case CmpInst::ICMP_SGT:
case CmpInst::ICMP_SGE:
RdxFlags.IsMaxOp = true;
RdxFlags.IsSigned = true;
break;
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
RD::MRK_SIntMax);
case CmpInst::ICMP_SLT:
case CmpInst::ICMP_SLE:
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
RD::MRK_SIntMin);
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_UGE:
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
RD::MRK_UIntMax);
case CmpInst::ICMP_ULT:
case CmpInst::ICMP_ULE:
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax,
RD::MRK_UIntMin);
}
return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax, RdxFlags);
}
return OperationData(*I);
}
@ -7186,8 +7187,8 @@ private:
break;
case RD::RK_IntegerMinMax: {
auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VecTy));
bool IsUnsigned = !RdxTreeInst.getFlags().IsSigned;
RD::MinMaxRecurrenceKind MMKind = RdxTreeInst.getMinMaxKind();
bool IsUnsigned = MMKind == RD::MRK_UIntMax || MMKind == RD::MRK_UIntMin;
PairwiseRdxCost =
TTI->getMinMaxReductionCost(VecTy, VecCondTy,
/*IsPairwiseForm=*/true, IsUnsigned);
@ -7248,7 +7249,7 @@ private:
assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF");
return createSimpleTargetReduction(
Builder, TTI, RdxTreeInst.getOpcode(), VectorizedValue,
RdxTreeInst.getFlags(), ReductionOps.back());
RdxTreeInst.getMinMaxKind(), ReductionOps.back());
}
Value *TmpVec = VectorizedValue;