[ScopDetect] Use SCEVRewriteVisitor to simplify SCEVRemoveSMax rewriter

ScalarEvolution got at some pointer a SCEVRewriteVisitor. Use it to simplify
our SCEVRemoveSMax visitor.

llvm-svn: 285491
This commit is contained in:
Tobias Grosser 2016-10-29 06:19:34 +00:00
parent c88ba36eab
commit ebb626e4b7
1 changed files with 8 additions and 56 deletions

View File

@ -587,29 +587,16 @@ bool ScopDetection::isInvariant(const Value &Val, const Region &Reg) const {
/// always add and verify the assumption that for all subscript expressions
/// 'exp' the inequality 0 <= exp < size holds. Hence, we will also verify
/// that 0 <= size, which means smax(0, size) == size.
struct SCEVRemoveMax : public SCEVVisitor<SCEVRemoveMax, const SCEV *> {
class SCEVRemoveMax : public SCEVRewriteVisitor<SCEVRemoveMax> {
public:
static const SCEV *remove(ScalarEvolution &SE, const SCEV *Expr,
std::vector<const SCEV *> *Terms = nullptr) {
SCEVRemoveMax D(SE, Terms);
return D.visit(Expr);
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
std::vector<const SCEV *> *Terms = nullptr) {
SCEVRemoveMax Rewriter(SE, Terms);
return Rewriter.visit(Scev);
}
SCEVRemoveMax(ScalarEvolution &SE, std::vector<const SCEV *> *Terms)
: SE(SE), Terms(Terms) {}
const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
return Expr;
}
const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
return SE.getSignExtendExpr(visit(Expr->getOperand()), Expr->getType());
}
const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
: SCEVRewriteVisitor(SE), Terms(Terms) {}
const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
if ((Expr->getNumOperands() == 2) && Expr->getOperand(0)->isZero()) {
@ -622,42 +609,7 @@ public:
return Expr;
}
const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { return Expr; }
const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
return Expr;
}
const SCEV *visitConstant(const SCEVConstant *Expr) { return Expr; }
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
SmallVector<const SCEV *, 5> NewOps;
for (const SCEV *Op : Expr->operands())
NewOps.push_back(visit(Op));
return SE.getAddRecExpr(NewOps, Expr->getLoop(), Expr->getNoWrapFlags());
}
const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
SmallVector<const SCEV *, 5> NewOps;
for (const SCEV *Op : Expr->operands())
NewOps.push_back(visit(Op));
return SE.getAddExpr(NewOps);
}
const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
SmallVector<const SCEV *, 5> NewOps;
for (const SCEV *Op : Expr->operands())
NewOps.push_back(visit(Op));
return SE.getMulExpr(NewOps);
}
private:
ScalarEvolution &SE;
std::vector<const SCEV *> *Terms;
};
@ -667,7 +619,7 @@ ScopDetection::getDelinearizationTerms(DetectionContext &Context,
SmallVector<const SCEV *, 4> Terms;
for (const auto &Pair : Context.Accesses[BasePointer]) {
std::vector<const SCEV *> MaxTerms;
SCEVRemoveMax::remove(*SE, Pair.second, &MaxTerms);
SCEVRemoveMax::rewrite(Pair.second, *SE, &MaxTerms);
if (MaxTerms.size() > 0) {
Terms.insert(Terms.begin(), MaxTerms.begin(), MaxTerms.end());
continue;
@ -773,7 +725,7 @@ bool ScopDetection::computeAccessFunctions(
for (const auto &Pair : Context.Accesses[BasePointer]) {
const Instruction *Insn = Pair.first;
auto *AF = Pair.second;
AF = SCEVRemoveMax::remove(*SE, AF);
AF = SCEVRemoveMax::rewrite(AF, *SE);
bool IsNonAffine = false;
TempMemoryAccesses.insert(std::make_pair(Insn, MemAcc(Insn, Shape)));
MemAcc *Acc = &TempMemoryAccesses.find(Insn)->second;