[Refactor] IslAst and payload struct

+ Renamed context into build when it's the isl_ast_build
  + Use the IslAstInfo functions to extract the schedule of a node
  + Use the IslAstInfo functions to extract the build/context of a node
  + Move the payload struct into the IslAstInfo class
  + Use a constructor and destructor (also new and delete) to
    allocate/initialize the payload struct

llvm-svn: 213792
This commit is contained in:
Johannes Doerfert 2014-07-23 20:17:28 +00:00
parent 16ea3268b2
commit c4968e508b
3 changed files with 79 additions and 89 deletions

View File

@ -40,20 +40,32 @@ namespace polly {
class Scop;
class IslAst;
// Information about an ast node.
struct IslAstUserPayload {
struct isl_ast_build *Context;
// The node is the outermost parallel loop.
int IsOutermostParallel;
// The node is the innermost parallel loop.
int IsInnermostParallel;
// The node is only parallel because of reductions
bool IsReductionParallel;
};
class IslAstInfo : public ScopPass {
public:
/// @brief Payload information used to annoate an ast node.
struct IslAstUserPayload {
/// @brief Construct and initialize the payload.
IslAstUserPayload()
: IsInnermostParallel(false), IsOutermostParallel(false),
IsReductionParallel(false), Build(nullptr) {}
/// @brief Cleanup all isl structs on destruction.
~IslAstUserPayload();
/// @brief Flag to mark innermost parallel loops.
bool IsInnermostParallel;
/// @brief Flag to mark outermost parallel loops.
bool IsOutermostParallel;
/// @brief Flag to mark reduction parallel loops.
bool IsReductionParallel;
/// @brief The build environment at the time this node was constructed.
isl_ast_build *Build;
};
private:
Scop *S;
IslAst *Ast;
@ -97,6 +109,9 @@ public:
/// @brief Is this loop a reduction parallel loop?
static bool isReductionParallel(__isl_keep isl_ast_node *Node);
/// @brief Get the nodes schedule or a nullptr if not available.
static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node);
///}
virtual void getAnalysisUsage(AnalysisUsage &AU) const;

View File

@ -34,10 +34,12 @@
#include "isl/map.h"
#include "isl/aff.h"
#define DEBUG_TYPE "polly-ast"
using namespace llvm;
using namespace polly;
#define DEBUG_TYPE "polly-ast"
using IslAstUserPayload = IslAstInfo::IslAstUserPayload;
static cl::opt<bool> UseContext("polly-ast-use-context",
cl::desc("Use context"), cl::Hidden,
@ -69,10 +71,19 @@ private:
isl_ast_node *Root;
isl_ast_expr *RunCondition;
void buildRunCondition(__isl_keep isl_ast_build *Context);
void buildRunCondition(__isl_keep isl_ast_build *Build);
};
} // End namespace polly.
/// @brief Free an IslAstUserPayload object pointed to by @p Ptr
static void freeIslAstUserPayload(void *Ptr) {
delete ((IslAstInfo::IslAstUserPayload *)Ptr);
}
IslAstInfo::IslAstUserPayload::~IslAstUserPayload() {
isl_ast_build_free(Build);
}
// Temporary information used when building the ast.
struct AstBuildUserInfo {
// The dependence information.
@ -115,32 +126,12 @@ printFor(__isl_take isl_printer *Printer,
if (!Id)
return isl_ast_node_for_print(Node, Printer, PrintOptions);
struct IslAstUserPayload *Info =
(struct IslAstUserPayload *)isl_id_get_user(Id);
IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
Printer = printParallelFor(Node, Printer, PrintOptions, Info);
isl_id_free(Id);
return Printer;
}
// Allocate an AstNodeInfo structure and initialize it with default values.
static struct IslAstUserPayload *allocateIslAstUser() {
struct IslAstUserPayload *NodeInfo;
NodeInfo =
(struct IslAstUserPayload *)malloc(sizeof(struct IslAstUserPayload));
NodeInfo->Context = 0;
NodeInfo->IsOutermostParallel = 0;
NodeInfo->IsInnermostParallel = 0;
NodeInfo->IsReductionParallel = false;
return NodeInfo;
}
// Free the AstNodeInfo structure.
static void freeIslAstUser(void *Ptr) {
struct IslAstUserPayload *UserStruct = (struct IslAstUserPayload *)Ptr;
isl_ast_build_free(UserStruct->Context);
free(UserStruct);
}
// Check if the current scheduling dimension is parallel.
//
// We check for parallelism by verifying that the loop does not carry any
@ -221,8 +212,8 @@ static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
// Mark a for node openmp parallel, if it is the outermost parallel for node.
static void markOpenmpParallel(__isl_keep isl_ast_build *Build,
struct AstBuildUserInfo *BuildInfo,
struct IslAstUserPayload *NodeInfo) {
AstBuildUserInfo *BuildInfo,
IslAstUserPayload *NodeInfo) {
if (BuildInfo->InParallelFor)
return;
@ -242,10 +233,10 @@ static void markOpenmpParallel(__isl_keep isl_ast_build *Build,
//
static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build,
void *User) {
struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User;
struct IslAstUserPayload *NodeInfo = allocateIslAstUser();
AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
IslAstUserPayload *NodeInfo = new IslAstUserPayload();
isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo);
Id = isl_id_set_free_user(Id, freeIslAstUser);
Id = isl_id_set_free_user(Id, freeIslAstUserPayload);
markOpenmpParallel(Build, BuildInfo, NodeInfo);
@ -305,9 +296,8 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
isl_id *Id = isl_ast_node_get_annotation(Node);
if (!Id)
return Node;
struct IslAstUserPayload *Info =
(struct IslAstUserPayload *)isl_id_get_user(Id);
struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User;
IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
if (Info) {
if (Info->IsOutermostParallel)
@ -316,8 +306,8 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
if (astScheduleDimIsParallel(Build, BuildInfo->Deps,
Info->IsReductionParallel))
Info->IsInnermostParallel = 1;
if (!Info->Context)
Info->Context = isl_ast_build_copy(Build);
if (!Info->Build)
Info->Build = isl_ast_build_copy(Build);
}
isl_id_free(Id);
@ -325,29 +315,29 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
}
static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node,
__isl_keep isl_ast_build *Context,
__isl_keep isl_ast_build *Build,
void *User) {
struct IslAstUserPayload *Info = nullptr;
IslAstUserPayload *Info = nullptr;
isl_id *Id = isl_ast_node_get_annotation(Node);
if (Id)
Info = (struct IslAstUserPayload *)isl_id_get_user(Id);
Info = (IslAstUserPayload *)isl_id_get_user(Id);
if (!Info) {
// Allocate annotations once: parallel for detection might have already
// allocated the annotations for this node.
Info = allocateIslAstUser();
Info = new IslAstUserPayload();
Id = isl_id_alloc(isl_ast_node_get_ctx(Node), nullptr, Info);
Id = isl_id_set_free_user(Id, &freeIslAstUser);
Id = isl_id_set_free_user(Id, freeIslAstUserPayload);
}
if (!Info->Context)
Info->Context = isl_ast_build_copy(Context);
if (!Info->Build)
Info->Build = isl_ast_build_copy(Build);
return isl_ast_node_set_annotation(Node, Id);
}
void IslAst::buildRunCondition(__isl_keep isl_ast_build *Context) {
void IslAst::buildRunCondition(__isl_keep isl_ast_build *Build) {
// The conditions that need to be checked at run-time for this scop are
// available as an isl_set in the AssumedContext. We generate code for this
// check as follows. First, we generate an isl_pw_aff that is 1, if a certain
@ -373,21 +363,21 @@ void IslAst::buildRunCondition(__isl_keep isl_ast_build *Context) {
isl_pw_aff *Cond = isl_pw_aff_union_max(PwOne, PwZero);
RunCondition = isl_ast_build_expr_from_pw_aff(Context, Cond);
RunCondition = isl_ast_build_expr_from_pw_aff(Build, Cond);
}
IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) {
isl_ctx *Ctx = S->getIslCtx();
isl_options_set_ast_build_atomic_upper_bound(Ctx, true);
isl_ast_build *Context;
struct AstBuildUserInfo BuildInfo;
isl_ast_build *Build;
AstBuildUserInfo BuildInfo;
if (UseContext)
Context = isl_ast_build_from_context(S->getContext());
Build = isl_ast_build_from_context(S->getContext());
else
Context = isl_ast_build_from_context(isl_set_universe(S->getParamSpace()));
Build = isl_ast_build_from_context(isl_set_universe(S->getParamSpace()));
Context = isl_ast_build_set_at_each_domain(Context, AtEachDomain, nullptr);
Build = isl_ast_build_set_at_each_domain(Build, AtEachDomain, nullptr);
isl_union_map *Schedule =
isl_union_map_intersect_domain(S->getSchedule(), S->getDomains());
@ -396,17 +386,17 @@ IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) {
BuildInfo.Deps = &D;
BuildInfo.InParallelFor = 0;
Context = isl_ast_build_set_before_each_for(Context, &astBuildBeforeFor,
&BuildInfo);
Context = isl_ast_build_set_after_each_for(Context, &astBuildAfterFor,
&BuildInfo);
Build = isl_ast_build_set_before_each_for(Build, &astBuildBeforeFor,
&BuildInfo);
Build =
isl_ast_build_set_after_each_for(Build, &astBuildAfterFor, &BuildInfo);
}
buildRunCondition(Context);
buildRunCondition(Build);
Root = isl_ast_build_ast_from_schedule(Context, Schedule);
Root = isl_ast_build_ast_from_schedule(Build, Schedule);
isl_ast_build_free(Context);
isl_ast_build_free(Build);
}
IslAst::~IslAst() {
@ -476,6 +466,11 @@ bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) {
return Payload && Payload->IsReductionParallel;
}
isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) {
IslAstUserPayload *Payload = getNodePayload(Node);
return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr;
}
void IslAstInfo::printScop(raw_ostream &OS) const {
isl_ast_print_options *Options;
isl_ast_node *RootNode = getAst();

View File

@ -778,20 +778,8 @@ IslNodeBuilder::getUpperBound(__isl_keep isl_ast_node *For,
}
unsigned IslNodeBuilder::getNumberOfIterations(__isl_keep isl_ast_node *For) {
isl_id *Annotation = isl_ast_node_get_annotation(For);
if (!Annotation)
return -1;
struct IslAstUserPayload *Info =
(struct IslAstUserPayload *)isl_id_get_user(Annotation);
if (!Info) {
isl_id_free(Annotation);
return -1;
}
isl_union_map *Schedule = isl_ast_build_get_schedule(Info->Context);
isl_union_map *Schedule = IslAstInfo::getSchedule(Build);
isl_set *LoopDomain = isl_set_from_union_set(isl_union_map_range(Schedule));
isl_id_free(Annotation);
int NumberOfIterations = polly::getNumberOfIterations(LoopDomain);
if (NumberOfIterations == -1)
return -1;
@ -848,14 +836,7 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
for (int i = 1; i < VectorWidth; i++)
IVS[i] = Builder.CreateAdd(IVS[i - 1], ValueInc, "p_vector_iv");
isl_id *Annotation = isl_ast_node_get_annotation(For);
assert(Annotation && "For statement is not annotated");
struct IslAstUserPayload *Info =
(struct IslAstUserPayload *)isl_id_get_user(Annotation);
assert(Info && "For statement annotation does not contain info");
isl_union_map *Schedule = isl_ast_build_get_schedule(Info->Context);
isl_union_map *Schedule = IslAstInfo::getSchedule(Build);
assert(Schedule && "For statement annotation does not contain its schedule");
IDToValue[IteratorID] = ValueLB;
@ -883,7 +864,6 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
IDToValue.erase(IteratorID);
isl_id_free(IteratorID);
isl_id_free(Annotation);
isl_union_map_free(Schedule);
isl_ast_node_free(For);