[AssumeBundles] adapt Assumption cache to assume bundles

Summary: change assumption cache to store an assume along with an index to the operand bundle containing the knowledge.

Reviewers: jdoerfert, hfinkel

Reviewed By: jdoerfert

Subscribers: hiraditya, mgrang, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77402
This commit is contained in:
Tyker 2020-04-13 11:27:27 +02:00
parent 29bb046fe9
commit 813f438baa
7 changed files with 164 additions and 39 deletions

View File

@ -39,6 +39,21 @@ class Value;
/// register any new \@llvm.assume calls that they create. Deletions of
/// \@llvm.assume calls do not require special handling.
class AssumptionCache {
public:
/// Value of ResultElem::Index indicating that the argument to the call of the
/// llvm.assume.
enum : unsigned { ExprResultIdx = std::numeric_limits<unsigned>::max() };
struct ResultElem {
WeakTrackingVH Assume;
/// contains either ExprResultIdx or the index of the operand bundle
/// containing the knowledge.
unsigned Index;
operator Value *() const { return Assume; }
};
private:
/// The function for which this cache is handling assumptions.
///
/// We track this to lazily populate our assumptions.
@ -46,7 +61,7 @@ class AssumptionCache {
/// Vector of weak value handles to calls of the \@llvm.assume
/// intrinsic.
SmallVector<WeakTrackingVH, 4> AssumeHandles;
SmallVector<ResultElem, 4> AssumeHandles;
class AffectedValueCallbackVH final : public CallbackVH {
AssumptionCache *AC;
@ -66,12 +81,12 @@ class AssumptionCache {
/// A map of values about which an assumption might be providing
/// information to the relevant set of assumptions.
using AffectedValuesMap =
DenseMap<AffectedValueCallbackVH, SmallVector<WeakTrackingVH, 1>,
DenseMap<AffectedValueCallbackVH, SmallVector<ResultElem, 1>,
AffectedValueCallbackVH::DMI>;
AffectedValuesMap AffectedValues;
/// Get the vector of assumptions which affect a value from the cache.
SmallVector<WeakTrackingVH, 1> &getOrInsertAffectedValues(Value *V);
SmallVector<ResultElem, 1> &getOrInsertAffectedValues(Value *V);
/// Move affected values in the cache for OV to be affected values for NV.
void transferAffectedValuesInCache(Value *OV, Value *NV);
@ -128,20 +143,20 @@ public:
/// FIXME: We should replace this with pointee_iterator<filter_iterator<...>>
/// when we can write that to filter out the null values. Then caller code
/// will become simpler.
MutableArrayRef<WeakTrackingVH> assumptions() {
MutableArrayRef<ResultElem> assumptions() {
if (!Scanned)
scanFunction();
return AssumeHandles;
}
/// Access the list of assumptions which affect this value.
MutableArrayRef<WeakTrackingVH> assumptionsFor(const Value *V) {
MutableArrayRef<ResultElem> assumptionsFor(const Value *V) {
if (!Scanned)
scanFunction();
auto AVI = AffectedValues.find_as(const_cast<Value *>(V));
if (AVI == AffectedValues.end())
return MutableArrayRef<WeakTrackingVH>();
return MutableArrayRef<ResultElem>();
return AVI->second;
}
@ -234,6 +249,21 @@ public:
static char ID; // Pass identification, replacement for typeid
};
template<> struct simplify_type<AssumptionCache::ResultElem> {
using SimpleType = Value *;
static SimpleType getSimplifiedValue(AssumptionCache::ResultElem &Val) {
return Val;
}
};
template<> struct simplify_type<const AssumptionCache::ResultElem> {
using SimpleType = /*const*/ Value *;
static SimpleType getSimplifiedValue(const AssumptionCache::ResultElem &Val) {
return Val;
}
};
} // end namespace llvm
#endif // LLVM_ANALYSIS_ASSUMPTIONCACHE_H

View File

@ -22,6 +22,7 @@
namespace llvm {
class IntrinsicInst;
class AssumptionCache;
/// Build a call to llvm.assume to preserve informations that can be derived
/// from the given instruction.
@ -32,7 +33,7 @@ IntrinsicInst *buildAssumeFromInst(Instruction *I);
/// Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert
/// if before I. This is usually what need to be done to salvage the knowledge
/// contained in the instruction I.
void salvageKnowledge(Instruction *I);
void salvageKnowledge(Instruction *I, AssumptionCache *AC = nullptr);
/// This pass will try to build an llvm.assume for every instruction in the
/// function. Its main purpose is testing.

View File

@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
@ -41,7 +42,7 @@ static cl::opt<bool>
cl::desc("Enable verification of assumption cache"),
cl::init(false));
SmallVector<WeakTrackingVH, 1> &
SmallVector<AssumptionCache::ResultElem, 1> &
AssumptionCache::getOrInsertAffectedValues(Value *V) {
// Try using find_as first to avoid creating extra value handles just for the
// purpose of doing the lookup.
@ -50,32 +51,39 @@ AssumptionCache::getOrInsertAffectedValues(Value *V) {
return AVI->second;
auto AVIP = AffectedValues.insert(
{AffectedValueCallbackVH(V, this), SmallVector<WeakTrackingVH, 1>()});
{AffectedValueCallbackVH(V, this), SmallVector<ResultElem, 1>()});
return AVIP.first->second;
}
static void findAffectedValues(CallInst *CI,
SmallVectorImpl<Value *> &Affected) {
static void
findAffectedValues(CallInst *CI,
SmallVectorImpl<AssumptionCache::ResultElem> &Affected) {
// Note: This code must be kept in-sync with the code in
// computeKnownBitsFromAssume in ValueTracking.
auto AddAffected = [&Affected](Value *V) {
auto AddAffected = [&Affected](Value *V, unsigned Idx =
AssumptionCache::ExprResultIdx) {
if (isa<Argument>(V)) {
Affected.push_back(V);
Affected.push_back({V, Idx});
} else if (auto *I = dyn_cast<Instruction>(V)) {
Affected.push_back(I);
Affected.push_back({I, Idx});
// Peek through unary operators to find the source of the condition.
Value *Op;
if (match(I, m_BitCast(m_Value(Op))) ||
match(I, m_PtrToInt(m_Value(Op))) ||
match(I, m_Not(m_Value(Op)))) {
match(I, m_PtrToInt(m_Value(Op))) || match(I, m_Not(m_Value(Op)))) {
if (isa<Instruction>(Op) || isa<Argument>(Op))
Affected.push_back(Op);
Affected.push_back({Op, Idx});
}
}
};
for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++) {
if (CI->getOperandBundleAt(Idx).Inputs.size() > ABA_WasOn &&
CI->getOperandBundleAt(Idx).getTagName() != "ignore")
AddAffected(CI->getOperandBundleAt(Idx).Inputs[ABA_WasOn], Idx);
}
Value *Cond = CI->getArgOperand(0), *A, *B;
AddAffected(Cond);
@ -112,28 +120,44 @@ static void findAffectedValues(CallInst *CI,
}
void AssumptionCache::updateAffectedValues(CallInst *CI) {
SmallVector<Value *, 16> Affected;
SmallVector<AssumptionCache::ResultElem, 16> Affected;
findAffectedValues(CI, Affected);
for (auto &AV : Affected) {
auto &AVV = getOrInsertAffectedValues(AV);
if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end())
AVV.push_back(CI);
auto &AVV = getOrInsertAffectedValues(AV.Assume);
if (std::find_if(AVV.begin(), AVV.end(), [&](ResultElem &Elem) {
return Elem.Assume == CI && Elem.Index == AV.Index;
}) == AVV.end())
AVV.push_back({CI, AV.Index});
}
}
void AssumptionCache::unregisterAssumption(CallInst *CI) {
SmallVector<Value *, 16> Affected;
SmallVector<AssumptionCache::ResultElem, 16> Affected;
findAffectedValues(CI, Affected);
for (auto &AV : Affected) {
auto AVI = AffectedValues.find_as(AV);
if (AVI != AffectedValues.end())
auto AVI = AffectedValues.find_as(AV.Assume);
if (AVI == AffectedValues.end())
continue;
bool Found = false;
bool HasNonnull = false;
for (ResultElem &Elem : AVI->second) {
if (Elem.Assume == CI) {
Found = true;
Elem.Assume = nullptr;
}
HasNonnull |= !!Elem.Assume;
if (HasNonnull && Found)
break;
}
assert(Found && "already unregistered or incorrect cache state");
if (!HasNonnull)
AffectedValues.erase(AVI);
}
AssumeHandles.erase(
remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; }),
remove_if(AssumeHandles, [CI](ResultElem &RE) { return CI == RE; }),
AssumeHandles.end());
}
@ -177,7 +201,7 @@ void AssumptionCache::scanFunction() {
for (BasicBlock &B : F)
for (Instruction &II : B)
if (match(&II, m_Intrinsic<Intrinsic::assume>()))
AssumeHandles.push_back(&II);
AssumeHandles.push_back({&II, ExprResultIdx});
// Mark the scan as complete.
Scanned = true;
@ -196,7 +220,7 @@ void AssumptionCache::registerAssumption(CallInst *CI) {
if (!Scanned)
return;
AssumeHandles.push_back(CI);
AssumeHandles.push_back({CI, ExprResultIdx});
#ifndef NDEBUG
assert(CI->getParent() &&

View File

@ -948,7 +948,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
continue;
}
salvageKnowledge(&Inst);
salvageKnowledge(&Inst, &AC);
salvageDebugInfoOrMarkUndef(Inst);
removeMSSA(Inst);
Inst.eraseFromParent();
@ -1015,7 +1015,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
cast<ConstantInt>(KnownCond)->isOne()) {
LLVM_DEBUG(dbgs()
<< "EarlyCSE removing guard: " << Inst << '\n');
salvageKnowledge(&Inst);
salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@ -1051,7 +1051,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
Changed = true;
}
if (isInstructionTriviallyDead(&Inst, &TLI)) {
salvageKnowledge(&Inst);
salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@ -1077,7 +1077,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
if (auto *I = dyn_cast<Instruction>(V))
I->andIRFlags(&Inst);
Inst.replaceAllUsesWith(V);
salvageKnowledge(&Inst);
salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@ -1138,7 +1138,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
}
if (!Inst.use_empty())
Inst.replaceAllUsesWith(Op);
salvageKnowledge(&Inst);
salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@ -1182,7 +1182,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
}
if (!Inst.use_empty())
Inst.replaceAllUsesWith(InVal.first);
salvageKnowledge(&Inst);
salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@ -1235,7 +1235,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
salvageKnowledge(&Inst);
salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@ -1271,7 +1271,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
} else {
salvageKnowledge(&Inst);
salvageKnowledge(&Inst, &AC);
removeMSSA(*LastStore);
LastStore->eraseFromParent();
Changed = true;

View File

@ -8,6 +8,7 @@
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
@ -222,9 +223,12 @@ IntrinsicInst *llvm::buildAssumeFromInst(Instruction *I) {
return Builder.build();
}
void llvm::salvageKnowledge(Instruction *I) {
if (Instruction *Intr = buildAssumeFromInst(I))
void llvm::salvageKnowledge(Instruction *I, AssumptionCache *AC) {
if (IntrinsicInst *Intr = buildAssumeFromInst(I)) {
Intr->insertBefore(I);
if (AC)
AC->registerAssumption(Intr);
}
}
PreservedAnalyses AssumeBuilderPass::run(Function &F,

View File

@ -1837,9 +1837,11 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
// check what will be known at the start of the inlined code.
AddAlignmentAssumptions(CS, IFI);
AssumptionCache *AC =
IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr;
/// Preserve all attributes on of the call and its parameters.
if (Instruction *Assume = buildAssumeFromInst(CS.getInstruction()))
Assume->insertBefore(CS.getInstruction());
salvageKnowledge(CS.getInstruction(), AC);
// We want the inliner to prune the code as it copies. We would LOVE to
// have no dead or constant instructions leftover after inlining occurs

View File

@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/CallSite.h"
@ -510,3 +511,66 @@ TEST(AssumeQueryAPI, getKnowledgeFromUseInAssume) {
// large.
RunRandTest(9876789, 100000, -0, 7, 100);
}
TEST(AssumeQueryAPI, AssumptionCache) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> Mod = parseAssemblyString(
"declare void @llvm.assume(i1)\n"
"define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3, i1 %B) {\n"
"call void @llvm.assume(i1 true) [\"nonnull\"(i32* %P), \"align\"(i32* "
"%P2, i32 4), \"align\"(i32* %P, i32 8)]\n"
"call void @llvm.assume(i1 %B) [\"test\"(i32* %P1), "
"\"dereferenceable\"(i32* %P, i32 4)]\n"
"ret void\n}\n",
Err, C);
if (!Mod)
Err.print("AssumeQueryAPI", errs());
Function *F = Mod->getFunction("test");
BasicBlock::iterator First = F->begin()->begin();
BasicBlock::iterator Second = F->begin()->begin();
Second++;
AssumptionCacheTracker ACT;
AssumptionCache &AC = ACT.getAssumptionCache(*F);
auto AR = AC.assumptionsFor(F->getArg(3));
ASSERT_EQ(AR.size(), 0u);
AR = AC.assumptionsFor(F->getArg(1));
ASSERT_EQ(AR.size(), 1u);
ASSERT_EQ(AR[0].Index, 0u);
ASSERT_EQ(AR[0].Assume, &*Second);
AR = AC.assumptionsFor(F->getArg(2));
ASSERT_EQ(AR.size(), 1u);
ASSERT_EQ(AR[0].Index, 1u);
ASSERT_EQ(AR[0].Assume, &*First);
AR = AC.assumptionsFor(F->getArg(0));
ASSERT_EQ(AR.size(), 3u);
llvm::sort(AR,
[](const auto &L, const auto &R) { return L.Index < R.Index; });
ASSERT_EQ(AR[0].Assume, &*First);
ASSERT_EQ(AR[0].Index, 0u);
ASSERT_EQ(AR[1].Assume, &*Second);
ASSERT_EQ(AR[1].Index, 1u);
ASSERT_EQ(AR[2].Assume, &*First);
ASSERT_EQ(AR[2].Index, 2u);
AR = AC.assumptionsFor(F->getArg(4));
ASSERT_EQ(AR.size(), 1u);
ASSERT_EQ(AR[0].Assume, &*Second);
ASSERT_EQ(AR[0].Index, AssumptionCache::ExprResultIdx);
AC.unregisterAssumption(cast<CallInst>(&*Second));
AR = AC.assumptionsFor(F->getArg(1));
ASSERT_EQ(AR.size(), 0u);
AR = AC.assumptionsFor(F->getArg(0));
ASSERT_EQ(AR.size(), 3u);
llvm::sort(AR,
[](const auto &L, const auto &R) { return L.Index < R.Index; });
ASSERT_EQ(AR[0].Assume, &*First);
ASSERT_EQ(AR[0].Index, 0u);
ASSERT_EQ(AR[1].Assume, nullptr);
ASSERT_EQ(AR[1].Index, 1u);
ASSERT_EQ(AR[2].Assume, &*First);
ASSERT_EQ(AR[2].Index, 2u);
AR = AC.assumptionsFor(F->getArg(2));
ASSERT_EQ(AR.size(), 1u);
ASSERT_EQ(AR[0].Index, 1u);
ASSERT_EQ(AR[0].Assume, &*First);
}