[SCEV] `getSequentialMinMaxExpr()`: rewrite deduplication to be fully recursive

Since we don't merge/expand non-sequential umin exprs into umin_seq exprs,
we may have umin_seq(umin(umin_seq())) chain, and the innermost umin_seq
can have duplicate operands still.
This commit is contained in:
Roman Lebedev 2022-01-14 15:11:20 +03:00
parent cd3ab156a7
commit c86a982d7d
No known key found for this signature in database
GPG Key ID: 083C3EBB4A1689E0
3 changed files with 141 additions and 43 deletions

View File

@ -531,6 +531,20 @@ protected:
public:
Type *getType() const { return getOperand(0)->getType(); }
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
assert(isSequentialMinMaxType(Ty));
switch (Ty) {
case scSequentialUMinExpr:
return scUMinExpr;
default:
llvm_unreachable("Not a sequential min/max type.");
}
}
SCEVTypes getEquivalentNonSequentialSCEVType() const {
return getEquivalentNonSequentialSCEVType(getSCEVType());
}
static bool classof(const SCEV *S) {
return isSequentialMinMaxType(S->getSCEVType());
}

View File

@ -3865,6 +3865,127 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
return S;
}
namespace {
class SCEVSequentialMinMaxDeduplicatingVisitor final
: public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
Optional<const SCEV *>> {
using RetVal = Optional<const SCEV *>;
using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>;
ScalarEvolution &SE;
const SCEVTypes RootKind; // Must be a sequential min/max expression.
const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
SmallPtrSet<const SCEV *, 16> SeenOps;
bool canRecurseInto(SCEVTypes Kind) const {
// We can only recurse into the SCEV expression of the same effective type
// as the type of our root SCEV expression.
return RootKind == Kind || NonSequentialRootKind == Kind;
};
RetVal visitAnyMinMaxExpr(const SCEV *S) {
assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
"Only for min/max expressions.");
SCEVTypes Kind = S->getSCEVType();
if (!canRecurseInto(Kind))
return S;
auto *NAry = cast<SCEVNAryExpr>(S);
SmallVector<const SCEV *> NewOps;
bool Changed =
visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
if (!Changed)
return S;
if (NewOps.empty())
return None;
return isa<SCEVSequentialMinMaxExpr>(S)
? SE.getSequentialMinMaxExpr(Kind, NewOps)
: SE.getMinMaxExpr(Kind, NewOps);
}
RetVal visit(const SCEV *S) {
// Has the whole operand been seen already?
if (!SeenOps.insert(S).second)
return None;
return Base::visit(S);
}
public:
SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
SCEVTypes RootKind)
: SE(SE), RootKind(RootKind),
NonSequentialRootKind(
SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
RootKind)) {}
bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
SmallVectorImpl<const SCEV *> &NewOps) {
bool Changed = false;
SmallVector<const SCEV *> Ops;
Ops.reserve(OrigOps.size());
for (const SCEV *Op : OrigOps) {
RetVal NewOp = visit(Op);
if (NewOp != Op)
Changed = true;
if (NewOp)
Ops.emplace_back(*NewOp);
}
if (Changed)
NewOps = std::move(Ops);
return Changed;
}
RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
};
} // namespace
const SCEV *
ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
SmallVectorImpl<const SCEV *> &Ops) {
@ -3895,45 +4016,8 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
// Keep only the first instance of an operand.
{
SmallPtrSet<const SCEV *, 16> SeenOps;
unsigned Idx = 0;
bool Changed = false;
while (Idx < Ops.size()) {
// Has the whole operand been seen already?
if (!SeenOps.insert(Ops[Idx]).second) {
Ops.erase(Ops.begin() + Idx);
Changed = true;
continue; // Look at operand under this index again.
}
// Look into non-sequential same-typed min/max expressions,
// drop any of it's operands that we have already seen.
// FIXME: once there are other sequential min/max types, generalize.
if (const auto *CommUMinExpr = dyn_cast<SCEVUMinExpr>(Ops[Idx])) {
SmallVector<const SCEV *> InnerOps;
InnerOps.reserve(CommUMinExpr->getNumOperands());
for (const SCEV *InnerOp : CommUMinExpr->operands()) {
if (SeenOps.insert(InnerOp).second) // Operand not seen before?
InnerOps.emplace_back(InnerOp); // Keep this inner operand.
}
// Were any operands of this 'umin' themselves redundant?
if (InnerOps.size() != CommUMinExpr->getNumOperands()) {
Changed = true;
// Was the whole operand effectively redundant? Note that it can
// happen even when the operand itself wasn't redundant as a whole.
if (InnerOps.empty()) {
Ops.erase(Ops.begin() + Idx);
continue; // Look at operand under this index again.
}
// Recreate our operand.
Ops[Idx] = getMinMaxExpr(Ops[Idx]->getSCEVType(), InnerOps);
}
}
// Ok, can't do anything else about this operand, move onto the next one.
++Idx;
}
SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
bool Changed = Deduplicator.visit(Kind, Ops, Ops);
if (Changed)
return getSequentialMinMaxExpr(Kind, Ops);
}

View File

@ -359,9 +359,9 @@ define i32 @logical_or_5ops_redundant_opearand_of_inner_uminseq(i32 %a, i32 %b,
; CHECK-NEXT: %cond_p4 = select i1 %cond_p3, i1 true, i1 %cond_p2
; CHECK-NEXT: --> %cond_p4 U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %first.loop: Variant }
; CHECK-NEXT: %i = phi i32 [ 0, %first.loop.exit ], [ %i.next, %loop ]
; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d)) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c)) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %i.next = add i32 %i, 1
; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %umin = call i32 @llvm.umin.i32(i32 %c, i32 %d)
; CHECK-NEXT: --> (%c umin %d) U: full-set S: full-set Exits: (%c umin %d) LoopDispositions: { %loop: Invariant }
; CHECK-NEXT: %umin2 = call i32 @llvm.umin.i32(i32 %umin, i32 %first.i)
@ -371,9 +371,9 @@ define i32 @logical_or_5ops_redundant_opearand_of_inner_uminseq(i32 %a, i32 %b,
; CHECK-NEXT: %cond = select i1 %cond_p8, i1 true, i1 %cond_p7
; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Variant }
; CHECK-NEXT: Determining loop execution counts for: @logical_or_5ops_redundant_opearand_of_inner_uminseq
; CHECK-NEXT: Loop %loop: backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))
; CHECK-NEXT: Loop %loop: backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))
; CHECK-NEXT: Loop %loop: max backedge-taken count is -1
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
; CHECK-NEXT: Loop %first.loop: backedge-taken count is (%e umin_seq %d umin_seq %a)