[SCEV] Generalize SCEVParameterRewriter to accept SCEV expression as target.

This patch extends SCEVParameterRewriter to support rewriting unknown
epxressions to arbitrary SCEV expressions. It will be used by further
patches.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D67176
This commit is contained in:
Florian Hahn 2020-09-18 09:50:01 +01:00
parent c10200536f
commit 4635f6050b
3 changed files with 61 additions and 23 deletions

View File

@ -810,35 +810,30 @@ class Type;
};
using ValueToValueMap = DenseMap<const Value *, Value *>;
using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
/// The SCEVParameterRewriter takes a scalar evolution expression and updates
/// the SCEVUnknown components following the Map (Value -> Value).
/// the SCEVUnknown components following the Map (Value -> SCEV).
class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
public:
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
ValueToValueMap &Map,
bool InterpretConsts = false) {
SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts);
ValueToSCEVMapTy &Map) {
SCEVParameterRewriter Rewriter(SE, Map);
return Rewriter.visit(Scev);
}
SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C)
: SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {}
SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
: SCEVRewriteVisitor(SE), Map(M) {}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
Value *V = Expr->getValue();
if (Map.count(V)) {
Value *NV = Map[V];
if (InterpretConsts && isa<ConstantInt>(NV))
return SE.getConstant(cast<ConstantInt>(NV));
return SE.getUnknown(NV);
}
return Expr;
auto I = Map.find(Expr->getValue());
if (I == Map.end())
return Expr;
return I->second;
}
private:
ValueToValueMap &Map;
bool InterpretConsts;
ValueToSCEVMapTy &Map;
};
using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;

View File

@ -215,16 +215,14 @@ void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
return cannotDivide(Numerator);
// The Remainder is obtained by replacing Denominator by 0 in Numerator.
ValueToValueMap RewriteMap;
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
cast<SCEVConstant>(Zero)->getValue();
Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
ValueToSCEVMapTy RewriteMap;
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
if (Remainder->isZero()) {
// The Quotient is obtained by replacing Denominator by 1 in Numerator.
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
cast<SCEVConstant>(One)->getValue();
Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
return;
}

View File

@ -149,6 +149,12 @@ static Instruction *getInstructionByName(Function &F, StringRef Name) {
llvm_unreachable("Expected to find instruction!");
}
static Value *getArgByName(Function &F, StringRef Name) {
for (auto &Arg : F.args())
if (Arg.getName() == Name)
return &Arg;
llvm_unreachable("Expected to find instruction!");
}
TEST_F(ScalarEvolutionsTest, CommutativeExprOperandOrder) {
LLVMContext C;
SMDiagnostic Err;
@ -1120,4 +1126,43 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
});
}
TEST_F(ScalarEvolutionsTest, SCEVrewriteUnknowns) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
"define void @foo(i32 %i) { "
"entry: "
" %cmp3 = icmp ult i32 %i, 16 "
" br i1 %cmp3, label %loop.body, label %exit "
"loop.body: "
" %iv = phi i32 [ %iv.next, %loop.body ], [ %i, %entry ] "
" %iv.next = add nsw i32 %iv, 1 "
" %cmp = icmp eq i32 %iv.next, 16 "
" br i1 %cmp, label %exit, label %loop.body "
"exit: "
" ret void "
"} ",
Err, C);
ASSERT_TRUE(M && "Could not parse module?");
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
auto *ScevI = SE.getSCEV(getArgByName(F, "i")); // {0,+,1}
ValueToSCEVMapTy RewriteMap;
RewriteMap[cast<SCEVUnknown>(ScevI)->getValue()] =
SE.getUMinExpr(ScevI, SE.getConstant(ScevI->getType(), 17));
auto *WithUMin = SCEVParameterRewriter::rewrite(ScevIV, SE, RewriteMap);
EXPECT_NE(WithUMin, ScevIV);
auto *AR = dyn_cast<SCEVAddRecExpr>(WithUMin);
EXPECT_TRUE(AR);
EXPECT_EQ(AR->getStart(),
SE.getUMinExpr(ScevI, SE.getConstant(ScevI->getType(), 17)));
EXPECT_EQ(AR->getStepRecurrence(SE),
cast<SCEVAddRecExpr>(ScevIV)->getStepRecurrence(SE));
});
}
} // end namespace llvm