[PSE] Remove assumption that top level predicate is union from public interface [NFC*]

Note that this doesn't actually cause the top level predicate to become a non-union just yet.

The * above comes from a case in the LoopVectorizer where a predicate which is later proven no longer blocks vectorization due to a change from checking if predicates exists to whether the predicate is possibly false.
This commit is contained in:
Philip Reames 2022-02-10 15:52:13 -08:00
parent ecbcefd693
commit 5ba115031d
9 changed files with 18 additions and 18 deletions

View File

@ -2199,7 +2199,7 @@ class PredicatedScalarEvolution {
public: public:
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L); PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L);
const SCEVUnionPredicate &getUnionPredicate() const; const SCEVPredicate &getPredicate() const;
/// Returns the SCEV expression of V, in the context of the current SCEV /// Returns the SCEV expression of V, in the context of the current SCEV
/// predicate. The order of transformations applied on the expression of V /// predicate. The order of transformations applied on the expression of V

View File

@ -123,7 +123,7 @@ private:
SmallVector<RuntimePointerCheck, 4> AliasChecks; SmallVector<RuntimePointerCheck, 4> AliasChecks;
/// The set of SCEV checks that we are versioning for. /// The set of SCEV checks that we are versioning for.
const SCEVUnionPredicate &Preds; const SCEVPredicate &Preds;
/// Maps a pointer to the pointer checking group that the pointer /// Maps a pointer to the pointer checking group that the pointer
/// belongs to. /// belongs to.

View File

@ -2342,7 +2342,7 @@ void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const {
<< "found in loop.\n"; << "found in loop.\n";
OS.indent(Depth) << "SCEV assumptions:\n"; OS.indent(Depth) << "SCEV assumptions:\n";
PSE->getUnionPredicate().print(OS, Depth); PSE->getPredicate().print(OS, Depth);
OS << "\n"; OS << "\n";

View File

@ -13975,7 +13975,7 @@ void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
updateGeneration(); updateGeneration();
} }
const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { const SCEVPredicate &PredicatedScalarEvolution::getPredicate() const {
return *Preds; return *Preds;
} }

View File

@ -770,7 +770,7 @@ public:
// Don't distribute the loop if we need too many SCEV run-time checks, or // Don't distribute the loop if we need too many SCEV run-time checks, or
// any if it's illegal. // any if it's illegal.
const SCEVUnionPredicate &Pred = LAI->getPSE().getUnionPredicate(); const SCEVPredicate &Pred = LAI->getPSE().getPredicate();
if (LAI->hasConvergentOp() && !Pred.isAlwaysTrue()) { if (LAI->hasConvergentOp() && !Pred.isAlwaysTrue()) {
return fail("RuntimeCheckWithConvergent", return fail("RuntimeCheckWithConvergent",
"may not insert runtime check with convergent operation"); "may not insert runtime check with convergent operation");

View File

@ -529,7 +529,7 @@ public:
return false; return false;
} }
if (LAI.getPSE().getUnionPredicate().getComplexity() > if (LAI.getPSE().getPredicate().getComplexity() >
LoadElimSCEVCheckThreshold) { LoadElimSCEVCheckThreshold) {
LLVM_DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); LLVM_DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
return false; return false;
@ -540,7 +540,7 @@ public:
return false; return false;
} }
if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { if (!Checks.empty() || !LAI.getPSE().getPredicate().isAlwaysTrue()) {
if (LAI.hasConvergentOp()) { if (LAI.hasConvergentOp()) {
LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with " LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with "
"convergent calls\n"); "convergent calls\n");

View File

@ -42,7 +42,7 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
LoopInfo *LI, DominatorTree *DT, LoopInfo *LI, DominatorTree *DT,
ScalarEvolution *SE) ScalarEvolution *SE)
: VersionedLoop(L), AliasChecks(Checks.begin(), Checks.end()), : VersionedLoop(L), AliasChecks(Checks.begin(), Checks.end()),
Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT), Preds(LAI.getPSE().getPredicate()), LAI(LAI), LI(LI), DT(DT),
SE(SE) { SE(SE) {
} }
@ -276,7 +276,7 @@ bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA,
const LoopAccessInfo &LAI = GetLAA(*L); const LoopAccessInfo &LAI = GetLAA(*L);
if (!LAI.hasConvergentOp() && if (!LAI.hasConvergentOp() &&
(LAI.getNumRuntimePointerChecks() || (LAI.getNumRuntimePointerChecks() ||
!LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { !LAI.getPSE().getPredicate().isAlwaysTrue())) {
LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L, LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
LI, DT, SE); LI, DT, SE);
LVer.versionLoop(); LVer.versionLoop();

View File

@ -572,7 +572,7 @@ void LoopVectorizationLegality::addInductionPhi(
// on predicates that only hold within the loop, since allowing the exit // on predicates that only hold within the loop, since allowing the exit
// currently means re-using this SCEV outside the loop (see PR33706 for more // currently means re-using this SCEV outside the loop (see PR33706 for more
// details). // details).
if (PSE.getUnionPredicate().isAlwaysTrue()) { if (PSE.getPredicate().isAlwaysTrue()) {
AllowedExit.insert(Phi); AllowedExit.insert(Phi);
AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch()));
} }
@ -849,7 +849,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
// used outside the loop only if the SCEV predicates within the loop is // used outside the loop only if the SCEV predicates within the loop is
// same as outside the loop. Allowing the exit means reusing the SCEV // same as outside the loop. Allowing the exit means reusing the SCEV
// outside the loop. // outside the loop.
if (PSE.getUnionPredicate().isAlwaysTrue()) { if (PSE.getPredicate().isAlwaysTrue()) {
AllowedExit.insert(&I); AllowedExit.insert(&I);
continue; continue;
} }
@ -919,7 +919,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
} }
Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
PSE.addPredicate(LAI->getPSE().getUnionPredicate()); PSE.addPredicate(LAI->getPSE().getPredicate());
return true; return true;
} }
@ -1266,7 +1266,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) {
if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { if (PSE.getPredicate().getComplexity() > SCEVThreshold) {
reportVectorizationFailure("Too many SCEV checks needed", reportVectorizationFailure("Too many SCEV checks needed",
"Too many SCEV assumptions need to be made and checked at runtime", "Too many SCEV assumptions need to be made and checked at runtime",
"TooManySCEVRunTimeChecks", ORE, TheLoop); "TooManySCEVRunTimeChecks", ORE, TheLoop);

View File

@ -1998,7 +1998,7 @@ public:
/// there is no vector code generation, the check blocks are removed /// there is no vector code generation, the check blocks are removed
/// completely. /// completely.
void Create(Loop *L, const LoopAccessInfo &LAI, void Create(Loop *L, const LoopAccessInfo &LAI,
const SCEVUnionPredicate &UnionPred) { const SCEVPredicate &Pred) {
BasicBlock *LoopHeader = L->getHeader(); BasicBlock *LoopHeader = L->getHeader();
BasicBlock *Preheader = L->getLoopPreheader(); BasicBlock *Preheader = L->getLoopPreheader();
@ -2007,12 +2007,12 @@ public:
// ensure the blocks are properly added to LoopInfo & DominatorTree. Those // ensure the blocks are properly added to LoopInfo & DominatorTree. Those
// may be used by SCEVExpander. The blocks will be un-linked from their // may be used by SCEVExpander. The blocks will be un-linked from their
// predecessors and removed from LI & DT at the end of the function. // predecessors and removed from LI & DT at the end of the function.
if (!UnionPred.isAlwaysTrue()) { if (!Pred.isAlwaysTrue()) {
SCEVCheckBlock = SplitBlock(Preheader, Preheader->getTerminator(), DT, LI, SCEVCheckBlock = SplitBlock(Preheader, Preheader->getTerminator(), DT, LI,
nullptr, "vector.scevcheck"); nullptr, "vector.scevcheck");
SCEVCheckCond = SCEVExp.expandCodeForPredicate( SCEVCheckCond = SCEVExp.expandCodeForPredicate(
&UnionPred, SCEVCheckBlock->getTerminator()); &Pred, SCEVCheckBlock->getTerminator());
} }
const auto &RtPtrChecking = *LAI.getRuntimePointerChecking(); const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
@ -5161,7 +5161,7 @@ bool LoopVectorizationCostModel::runtimeChecksRequired() {
return true; return true;
} }
if (!PSE.getUnionPredicate().getPredicates().empty()) { if (!PSE.getPredicate().isAlwaysTrue()) {
reportVectorizationFailure("Runtime SCEV check is required with -Os/-Oz", reportVectorizationFailure("Runtime SCEV check is required with -Os/-Oz",
"runtime SCEV checks needed. Enable vectorization of this " "runtime SCEV checks needed. Enable vectorization of this "
"loop with '#pragma clang loop vectorize(enable)' when " "loop with '#pragma clang loop vectorize(enable)' when "
@ -10557,7 +10557,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, GeneratedRTChecks Checks(*PSE.getSE(), DT, LI,
F->getParent()->getDataLayout()); F->getParent()->getDataLayout());
if (!VF.Width.isScalar() || IC > 1) if (!VF.Width.isScalar() || IC > 1)
Checks.Create(L, *LVL.getLAI(), PSE.getUnionPredicate()); Checks.Create(L, *LVL.getLAI(), PSE.getPredicate());
using namespace ore; using namespace ore;
if (!VectorizeLoop) { if (!VectorizeLoop) {