[Refactor] Use nicer print callback function in IslAst

llvm-svn: 214447
This commit is contained in:
Johannes Doerfert 2014-07-31 21:33:49 +00:00
parent ef940aaf07
commit 0eefb0258f
3 changed files with 46 additions and 49 deletions

View File

@ -46,12 +46,16 @@ public:
struct IslAstUserPayload {
/// @brief Construct and initialize the payload.
IslAstUserPayload()
: IsInnermostParallel(false), IsOutermostParallel(false),
IsReductionParallel(false), Build(nullptr) {}
: IsInnermost(false), IsInnermostParallel(false),
IsOutermostParallel(false), IsReductionParallel(false),
Build(nullptr) {}
/// @brief Cleanup all isl structs on destruction.
~IslAstUserPayload();
/// @brief Flag to mark innermost loops.
bool IsInnermost;
/// @brief Flag to mark innermost parallel loops.
bool IsInnermostParallel;
@ -97,6 +101,9 @@ public:
/// @brief Get the complete payload attached to @p Node.
static IslAstUserPayload *getNodePayload(__isl_keep isl_ast_node *Node);
/// @brief Is this loop an innermost loop?
static bool isInnermost(__isl_keep isl_ast_node *Node);
/// @brief Is this loop a parallel loop?
static bool isParallel(__isl_keep isl_ast_node *Node);

View File

@ -100,43 +100,31 @@ struct AstBuildUserInfo {
isl_id *LastForNodeId;
};
// Print a loop annotated with OpenMP or vector pragmas.
static __isl_give isl_printer *
printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer,
__isl_take isl_ast_print_options *PrintOptions,
IslAstUserPayload *Info) {
if (Info) {
if (Info->IsInnermostParallel) {
Printer = isl_printer_start_line(Printer);
Printer = isl_printer_print_str(Printer, "#pragma simd");
if (Info->IsReductionParallel)
Printer = isl_printer_print_str(Printer, " reduction");
Printer = isl_printer_end_line(Printer);
}
if (Info->IsOutermostParallel) {
Printer = isl_printer_start_line(Printer);
Printer = isl_printer_print_str(Printer, "#pragma omp parallel for");
if (Info->IsReductionParallel)
Printer = isl_printer_print_str(Printer, " reduction");
Printer = isl_printer_end_line(Printer);
}
}
return isl_ast_node_for_print(Node, Printer, PrintOptions);
/// @brief Print a string @p str in a single line using @p Printer.
static isl_printer *printLine(__isl_take isl_printer *Printer,
const std::string &str) {
Printer = isl_printer_start_line(Printer);
Printer = isl_printer_print_str(Printer, str.c_str());
return isl_printer_end_line(Printer);
}
// Print an isl_ast_for.
static __isl_give isl_printer *
printFor(__isl_take isl_printer *Printer,
__isl_take isl_ast_print_options *PrintOptions,
__isl_keep isl_ast_node *Node, void *User) {
isl_id *Id = isl_ast_node_get_annotation(Node);
if (!Id)
return isl_ast_node_for_print(Node, Printer, PrintOptions);
/// @brief Callback executed for each for node in the ast in order to print it.
static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
__isl_take isl_ast_print_options *Options,
__isl_keep isl_ast_node *Node, void *) {
if (IslAstInfo::isInnermostParallel(Node))
Printer = printLine(Printer, "#pragma simd");
IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
Printer = printParallelFor(Node, Printer, PrintOptions, Info);
isl_id_free(Id);
return Printer;
if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
Printer = printLine(Printer, "#pragma simd reduction");
if (IslAstInfo::isOuterParallel(Node))
Printer = printLine(Printer, "#pragma omp parallel for");
if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
Printer = printLine(Printer, "#pragma omp parallel for reduction");
return isl_ast_node_for_print(Node, Printer, Options);
}
/// @brief Check if the current scheduling dimension is parallel
@ -219,18 +207,16 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
bool IsInnermost = (Id == BuildInfo->LastForNodeId);
Info->IsInnermost = (Id == BuildInfo->LastForNodeId);
if (Info) {
if (Info->IsOutermostParallel)
BuildInfo->InParallelFor = 0;
if (IsInnermost)
if (astScheduleDimIsParallel(Build, BuildInfo->Deps,
Info->IsReductionParallel))
Info->IsInnermostParallel = 1;
if (!Info->Build)
Info->Build = isl_ast_build_copy(Build);
}
if (Info->IsOutermostParallel)
BuildInfo->InParallelFor = 0;
if (Info->IsInnermost)
if (astScheduleDimIsParallel(Build, BuildInfo->Deps,
Info->IsReductionParallel))
Info->IsInnermostParallel = 1;
if (!Info->Build)
Info->Build = isl_ast_build_copy(Build);
isl_id_free(Id);
return Node;
@ -356,6 +342,11 @@ IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) {
return Payload;
}
bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) {
IslAstUserPayload *Payload = getNodePayload(Node);
return Payload && Payload->IsInnermost;
}
bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) {
return (isInnermostParallel(Node) || isOuterParallel(Node)) &&
!isReductionParallel(Node);
@ -391,7 +382,7 @@ void IslAstInfo::printScop(raw_ostream &OS) const {
Scop &S = getCurScop();
Options = isl_ast_print_options_alloc(S.getIslCtx());
Options = isl_ast_print_options_set_print_for(Options, printFor, nullptr);
Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr);
isl_printer *P = isl_printer_to_str(S.getIslCtx());
P = isl_printer_print_ast_expr(P, RunCondition);

View File

@ -1,7 +1,6 @@
; RUN: opt %loadPolly -polly-ast -polly-ast-detect-parallel -analyze < %s | FileCheck %s
;
; CHECK: pragma simd reduction
; CHECK: pragma omp parallel for reduction
;
; int prod;
; void f() {