Detect Parameters directly on the SCEV.

Instead of using TempScop to find parameters, we detect them directly
on the SCEV. This allows us to remove the TempScop parameter detection
in a subsequent commit.

This fixes a bug reported by Marcello Maggioni <hayarms@gmail.com>

llvm-svn: 144087
This commit is contained in:
Tobias Grosser 2011-11-08 15:41:28 +00:00
parent 65b0058b56
commit 60b54f19e6
7 changed files with 152 additions and 27 deletions

View File

@ -266,7 +266,7 @@ class ScopStmt {
TempScop &tempScop,
const Region &CurRegion);
__isl_give isl_set *addLoopBoundsToDomain(__isl_take isl_set *Domain,
TempScop &tempScop) const;
TempScop &tempScop);
__isl_give isl_set *buildDomain(TempScop &tempScop,
const Region &CurRegion);
void buildScattering(SmallVectorImpl<unsigned> &Scatter);
@ -433,9 +433,6 @@ class Scop {
/// @return True if the basic block is trivial, otherwise false.
static bool isTrivialBB(BasicBlock *BB, TempScop &tempScop);
/// @brief Add the parameters to the internal parameter set.
void initializeParameters(ParamSetType *ParamSet);
/// @brief Build the Context of the Scop.
void buildContext();
@ -470,6 +467,8 @@ public:
/// @return The set containing the parameters used in this Scop.
inline const ParamVecType &getParams() const { return Parameters; }
/// @brief Take a list of parameters and add the new ones to the scop.
void addParams(std::vector<const SCEV*> NewParameters);
/// @brief Return the isl_id that represents a certain parameter.
///

View File

@ -12,6 +12,8 @@
#ifndef POLLY_SCEV_VALIDATOR_H
#define POLLY_SCEV_VALIDATOR_H
#include <vector>
namespace llvm {
class Region;
class SCEV;
@ -22,6 +24,12 @@ namespace llvm {
namespace polly {
bool isAffineExpr(const llvm::Region *R, const llvm::SCEV *Expression,
llvm::ScalarEvolution &SE, llvm::Value **BaseAddress = 0);
std::vector<const llvm::SCEV*> getParamsInAffineExpr(
const llvm::Region *R,
const llvm::SCEV *Expression,
llvm::ScalarEvolution &SE,
llvm::Value **BaseAddress = 0);
}
#endif

View File

@ -61,6 +61,8 @@
#define DEBUG_TYPE "polly-detect"
#include "llvm/Support/Debug.h"
#include <set>
using namespace llvm;
using namespace polly;

View File

@ -23,6 +23,7 @@
#include "polly/LinkAllPasses.h"
#include "polly/Support/GICHelper.h"
#include "polly/Support/ScopHelper.h"
#include "polly/Support/SCEVValidator.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
@ -80,8 +81,18 @@ private:
const Value *baseAddress;
public:
static isl_pw_aff *getPwAff(const ScopStmt *stmt, const SCEV *scev,
static isl_pw_aff *getPwAff(ScopStmt *stmt, const SCEV *scev,
const Value *baseAddress = 0) {
Scop *S = stmt->getParent();
const Region *Reg = &S->getRegion();
if (baseAddress) {
Value *Base;
S->addParams(getParamsInAffineExpr(Reg, scev, *S->getSE(), &Base));
} else {
S->addParams(getParamsInAffineExpr(Reg, scev, *S->getSE()));
}
SCEVAffinator Affinator(stmt, baseAddress);
return Affinator.visit(scev);
}
@ -598,7 +609,6 @@ void ScopStmt::realignParams() {
}
__isl_give isl_set *ScopStmt::buildConditionSet(const Comparison &Comp) {
isl_pw_aff *L = SCEVAffinator::getPwAff(this, Comp.getLHS()->OriginalSCEV, 0);
isl_pw_aff *R = SCEVAffinator::getPwAff(this, Comp.getRHS()->OriginalSCEV, 0);
@ -626,7 +636,7 @@ __isl_give isl_set *ScopStmt::buildConditionSet(const Comparison &Comp) {
}
__isl_give isl_set *ScopStmt::addLoopBoundsToDomain(__isl_take isl_set *Domain,
TempScop &tempScop) const {
TempScop &tempScop) {
isl_space *Space;
isl_local_space *LocalSpace;
@ -852,6 +862,22 @@ void ScopStmt::dump() const { print(dbgs()); }
//===----------------------------------------------------------------------===//
/// Scop class implement
void Scop::addParams(std::vector<const SCEV*> NewParameters) {
for (std::vector<const SCEV*>::iterator PI = NewParameters.begin(),
PE = NewParameters.end(); PI != PE; ++PI) {
const SCEV *Parameter = *PI;
if (ParameterIds.find(Parameter) != ParameterIds.end())
continue;
int dimension = Parameters.size();
Parameters.push_back(Parameter);
ParameterIds[Parameter] = dimension;
}
}
__isl_give isl_id *Scop::getIdForParam(const SCEV *Parameter) const {
ParamIdType::const_iterator IdIter = ParameterIds.find(Parameter);
@ -862,17 +888,6 @@ __isl_give isl_id *Scop::getIdForParam(const SCEV *Parameter) const {
return isl_id_alloc(getIslCtx(), ParameterName.c_str(), (void *) Parameter);
}
void Scop::initializeParameters(ParamSetType *ParamSet) {
int i = 0;
for (ParamSetType::iterator PI = ParamSet->begin(), PE = ParamSet->end();
PI != PE; ++PI) {
const SCEV *Parameter = *PI;
Parameters.push_back(Parameter);
ParameterIds.insert(std::pair<const SCEV*, int>(Parameter, i));
i++;
}
}
void Scop::buildContext() {
isl_space *Space = isl_space_params_alloc(IslCtx, 0);
Context = isl_set_universe (Space);
@ -880,7 +895,7 @@ void Scop::buildContext() {
void Scop::realignParams() {
// Add all parameters into a common model.
isl_space *Space = isl_space_params_alloc(IslCtx, Parameters.size());
isl_space *Space = isl_space_params_alloc(IslCtx, ParameterIds.size());
for (ParamIdType::iterator PI = ParameterIds.begin(), PE = ParameterIds.end();
PI != PE; ++PI) {
@ -901,7 +916,6 @@ Scop::Scop(TempScop &tempScop, LoopInfo &LI, ScalarEvolution &ScalarEvolution,
: SE(&ScalarEvolution), R(tempScop.getMaxRegion()),
MaxLoopDepth(tempScop.getMaxLoopDepth()) {
IslCtx = Context;
initializeParameters(&tempScop.getParamSet());
buildContext();
SmallVector<Loop*, 8> NestLoops;
@ -964,6 +978,14 @@ void Scop::printContext(raw_ostream &OS) const {
}
OS.indent(4) << getContextStr() << "\n";
for (ParamVecType::const_iterator PI = Parameters.begin(),
PE = Parameters.end(); PI != PE; ++PI) {
const SCEV *Parameter = *PI;
int Dim = ParameterIds.find(Parameter)->second;
OS.indent(4) << "p" << Dim << ": " << *Parameter << "\n";
}
}
void Scop::printStatements(raw_ostream &OS) const {

View File

@ -19,6 +19,7 @@
#include "polly/Support/AffineSCEVIterator.h"
#include "polly/Support/GICHelper.h"
#include "polly/Support/ScopHelper.h"
#include "polly/Support/SCEVValidator.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/RegionIterator.h"

View File

@ -5,6 +5,8 @@
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/RegionInfo.h"
#include <vector>
using namespace llvm;
namespace SCEVType {
@ -13,14 +15,19 @@ namespace SCEVType {
struct ValidatorResult {
SCEVType::TYPE type;
std::vector<const SCEV*> Parameters;
ValidatorResult() : type(SCEVType::INVALID) {};
ValidatorResult(const ValidatorResult &vres) {
type = vres.type;
Parameters = vres.Parameters;
};
ValidatorResult(SCEVType::TYPE type) : type(type) {};
ValidatorResult(SCEVType::TYPE type, const SCEV* Expr) : type(type) {
Parameters.push_back(Expr);
};
bool isConstant() {
return type == SCEVType::INT || type == SCEVType::PARAM;
@ -37,6 +44,11 @@ struct ValidatorResult {
bool isINT() {
return type == SCEVType::INT;
}
void addParamsFrom(struct ValidatorResult &Source) {
Parameters.insert(Parameters.end(), Source.Parameters.begin(),
Source.Parameters.end());
}
};
/// Check if a SCEV is valid in a SCoP.
@ -63,7 +75,7 @@ public:
// expression. If it is constant during Scop execution, we treat it as a
// parameter, otherwise we bail out.
if (Op.isConstant())
return ValidatorResult(SCEVType::PARAM);
return ValidatorResult(SCEVType::PARAM, Expr);
return ValidatorResult (SCEVType::INVALID);
}
@ -75,7 +87,7 @@ public:
// expression. If it is constant during Scop execution, we treat it as a
// parameter, otherwise we bail out.
if (Op.isConstant())
return ValidatorResult (SCEVType::PARAM);
return ValidatorResult (SCEVType::PARAM, Expr);
return ValidatorResult(SCEVType::INVALID);
}
@ -98,6 +110,7 @@ public:
return ValidatorResult(SCEVType::INVALID);
Return.type = std::max(Return.type, Op.type);
Return.addParamsFrom(Op);
}
// TODO: Check for NSW and NUW.
@ -117,6 +130,7 @@ public:
return ValidatorResult(SCEVType::INVALID);
Return.type = Op.type;
Return.addParamsFrom(Op);
}
// TODO: Check for NSW and NUW.
@ -131,7 +145,7 @@ public:
// expression. If the division is constant during Scop execution we treat it
// as a parameter, otherwise we bail out.
if (LHS.isConstant() && RHS.isConstant())
return ValidatorResult(SCEVType::PARAM);
return ValidatorResult(SCEVType::PARAM, Expr);
return ValidatorResult(SCEVType::INVALID);
}
@ -151,13 +165,15 @@ public:
if (Start.isIV())
return ValidatorResult(SCEVType::INVALID);
else
return ValidatorResult(SCEVType::PARAM);
return ValidatorResult(SCEVType::PARAM, Expr);
}
if (!Recurrence.isINT())
return ValidatorResult(SCEVType::INVALID);
return ValidatorResult(SCEVType::IV);
ValidatorResult Result(SCEVType::IV);
Result.addParamsFrom(Start);
return Result;
}
struct ValidatorResult visitSMaxExpr(const SCEVSMaxExpr* Expr) {
@ -170,12 +186,15 @@ public:
return ValidatorResult(SCEVType::INVALID);
Return.type = std::max(Return.type, Op.type);
Return.addParamsFrom(Op);
}
return Return;
}
struct ValidatorResult visitUMaxExpr(const SCEVUMaxExpr* Expr) {
ValidatorResult Return(SCEVType::PARAM);
// We do not support unsigned operations. If 'Expr' is constant during Scop
// execution we treat this as a parameter, otherwise we bail out.
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
@ -183,9 +202,11 @@ public:
if (!Op.isConstant())
return ValidatorResult(SCEVType::INVALID);
Return.addParamsFrom(Op);
}
return ValidatorResult(SCEVType::PARAM);
return Return;
}
ValidatorResult visitUnknown(const SCEVUnknown* Expr) {
@ -208,7 +229,7 @@ public:
if (BaseAddress)
return ValidatorResult(SCEVType::PARAM);
else
return ValidatorResult(SCEVType::PARAM);
return ValidatorResult(SCEVType::PARAM, Expr);
}
};
@ -226,6 +247,22 @@ namespace polly {
return Result.isValid();
}
std::vector<const SCEV*> getParamsInAffineExpr(const Region *R,
const SCEV *Expr,
ScalarEvolution &SE,
Value **BaseAddress) {
if (isa<SCEVCouldNotCompute>(Expr))
return std::vector<const SCEV*>();
if (BaseAddress)
*BaseAddress = NULL;
SCEVValidator Validator(R, SE, BaseAddress);
ValidatorResult Result = Validator.visit(Expr);
return Result.Parameters;
}
}

View File

@ -0,0 +1,56 @@
; RUN: opt %loadPolly %defaultOpts -polly-scops -analyze < %s | FileCheck %s
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
declare void @foo()
define i32 @main(i8* %A) nounwind uwtable {
entry:
br label %for.cond
for.cond: ; preds = %for.inc5, %entry
%indvar_out = phi i64 [ %indvar_out.next, %for.inc5 ], [ 0, %entry ]
call void @foo()
%tmp = add i64 %indvar_out, 2
%exitcond5 = icmp ne i64 %indvar_out, 1023
br i1 %exitcond5, label %for.body, label %for.end7
for.body: ; preds = %for.cond
br label %for.cond1
for.cond1: ; preds = %for.inc, %for.body
%indvar = phi i64 [ %indvar.next, %for.inc ], [ 0, %for.body ]
%exitcond = icmp ne i64 %indvar, 1023
br i1 %exitcond, label %for.body3, label %for.end
for.body3: ; preds = %for.cond1
%tmp1 = add i64 %tmp, %indvar
%cmp4 = icmp sgt i64 %tmp1, 1000
br i1 %cmp4, label %if.then, label %if.end
if.then: ; preds = %for.body3
%arrayidx = getelementptr i8* %A, i64 %indvar
store i8 5, i8* %arrayidx
br label %if.end
if.end: ; preds = %if.end.single_exit
br label %for.inc
for.inc: ; preds = %if.end
%indvar.next = add i64 %indvar, 1
br label %for.cond1
for.end: ; preds = %for.cond1
br label %for.inc5
for.inc5: ; preds = %for.end
%indvar_out.next = add i64 %indvar_out, 1
br label %for.cond
for.end7: ; preds = %for.cond
ret i32 0
}
; CHECK: Domain :=
; CHECK: [p0] -> { Stmt_if_then[i0] : i0 >= 0 and i0 <= 1022 and i0 >= 1001 - p0 };