forked from OSchip/llvm-project
[Refactor] Use nicer print callback function in IslAst
llvm-svn: 214447
This commit is contained in:
parent
ef940aaf07
commit
0eefb0258f
|
@ -46,12 +46,16 @@ public:
|
|||
struct IslAstUserPayload {
|
||||
/// @brief Construct and initialize the payload.
|
||||
IslAstUserPayload()
|
||||
: IsInnermostParallel(false), IsOutermostParallel(false),
|
||||
IsReductionParallel(false), Build(nullptr) {}
|
||||
: IsInnermost(false), IsInnermostParallel(false),
|
||||
IsOutermostParallel(false), IsReductionParallel(false),
|
||||
Build(nullptr) {}
|
||||
|
||||
/// @brief Cleanup all isl structs on destruction.
|
||||
~IslAstUserPayload();
|
||||
|
||||
/// @brief Flag to mark innermost loops.
|
||||
bool IsInnermost;
|
||||
|
||||
/// @brief Flag to mark innermost parallel loops.
|
||||
bool IsInnermostParallel;
|
||||
|
||||
|
@ -97,6 +101,9 @@ public:
|
|||
/// @brief Get the complete payload attached to @p Node.
|
||||
static IslAstUserPayload *getNodePayload(__isl_keep isl_ast_node *Node);
|
||||
|
||||
/// @brief Is this loop an innermost loop?
|
||||
static bool isInnermost(__isl_keep isl_ast_node *Node);
|
||||
|
||||
/// @brief Is this loop a parallel loop?
|
||||
static bool isParallel(__isl_keep isl_ast_node *Node);
|
||||
|
||||
|
|
|
@ -100,43 +100,31 @@ struct AstBuildUserInfo {
|
|||
isl_id *LastForNodeId;
|
||||
};
|
||||
|
||||
// Print a loop annotated with OpenMP or vector pragmas.
|
||||
static __isl_give isl_printer *
|
||||
printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer,
|
||||
__isl_take isl_ast_print_options *PrintOptions,
|
||||
IslAstUserPayload *Info) {
|
||||
if (Info) {
|
||||
if (Info->IsInnermostParallel) {
|
||||
Printer = isl_printer_start_line(Printer);
|
||||
Printer = isl_printer_print_str(Printer, "#pragma simd");
|
||||
if (Info->IsReductionParallel)
|
||||
Printer = isl_printer_print_str(Printer, " reduction");
|
||||
Printer = isl_printer_end_line(Printer);
|
||||
}
|
||||
if (Info->IsOutermostParallel) {
|
||||
Printer = isl_printer_start_line(Printer);
|
||||
Printer = isl_printer_print_str(Printer, "#pragma omp parallel for");
|
||||
if (Info->IsReductionParallel)
|
||||
Printer = isl_printer_print_str(Printer, " reduction");
|
||||
Printer = isl_printer_end_line(Printer);
|
||||
}
|
||||
}
|
||||
return isl_ast_node_for_print(Node, Printer, PrintOptions);
|
||||
/// @brief Print a string @p str in a single line using @p Printer.
|
||||
static isl_printer *printLine(__isl_take isl_printer *Printer,
|
||||
const std::string &str) {
|
||||
Printer = isl_printer_start_line(Printer);
|
||||
Printer = isl_printer_print_str(Printer, str.c_str());
|
||||
return isl_printer_end_line(Printer);
|
||||
}
|
||||
|
||||
// Print an isl_ast_for.
|
||||
static __isl_give isl_printer *
|
||||
printFor(__isl_take isl_printer *Printer,
|
||||
__isl_take isl_ast_print_options *PrintOptions,
|
||||
__isl_keep isl_ast_node *Node, void *User) {
|
||||
isl_id *Id = isl_ast_node_get_annotation(Node);
|
||||
if (!Id)
|
||||
return isl_ast_node_for_print(Node, Printer, PrintOptions);
|
||||
/// @brief Callback executed for each for node in the ast in order to print it.
|
||||
static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
|
||||
__isl_take isl_ast_print_options *Options,
|
||||
__isl_keep isl_ast_node *Node, void *) {
|
||||
if (IslAstInfo::isInnermostParallel(Node))
|
||||
Printer = printLine(Printer, "#pragma simd");
|
||||
|
||||
IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
|
||||
Printer = printParallelFor(Node, Printer, PrintOptions, Info);
|
||||
isl_id_free(Id);
|
||||
return Printer;
|
||||
if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
|
||||
Printer = printLine(Printer, "#pragma simd reduction");
|
||||
|
||||
if (IslAstInfo::isOuterParallel(Node))
|
||||
Printer = printLine(Printer, "#pragma omp parallel for");
|
||||
|
||||
if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
|
||||
Printer = printLine(Printer, "#pragma omp parallel for reduction");
|
||||
|
||||
return isl_ast_node_for_print(Node, Printer, Options);
|
||||
}
|
||||
|
||||
/// @brief Check if the current scheduling dimension is parallel
|
||||
|
@ -219,18 +207,16 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
|
|||
IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
|
||||
AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
|
||||
|
||||
bool IsInnermost = (Id == BuildInfo->LastForNodeId);
|
||||
Info->IsInnermost = (Id == BuildInfo->LastForNodeId);
|
||||
|
||||
if (Info) {
|
||||
if (Info->IsOutermostParallel)
|
||||
BuildInfo->InParallelFor = 0;
|
||||
if (IsInnermost)
|
||||
if (astScheduleDimIsParallel(Build, BuildInfo->Deps,
|
||||
Info->IsReductionParallel))
|
||||
Info->IsInnermostParallel = 1;
|
||||
if (!Info->Build)
|
||||
Info->Build = isl_ast_build_copy(Build);
|
||||
}
|
||||
if (Info->IsOutermostParallel)
|
||||
BuildInfo->InParallelFor = 0;
|
||||
if (Info->IsInnermost)
|
||||
if (astScheduleDimIsParallel(Build, BuildInfo->Deps,
|
||||
Info->IsReductionParallel))
|
||||
Info->IsInnermostParallel = 1;
|
||||
if (!Info->Build)
|
||||
Info->Build = isl_ast_build_copy(Build);
|
||||
|
||||
isl_id_free(Id);
|
||||
return Node;
|
||||
|
@ -356,6 +342,11 @@ IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) {
|
|||
return Payload;
|
||||
}
|
||||
|
||||
bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) {
|
||||
IslAstUserPayload *Payload = getNodePayload(Node);
|
||||
return Payload && Payload->IsInnermost;
|
||||
}
|
||||
|
||||
bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) {
|
||||
return (isInnermostParallel(Node) || isOuterParallel(Node)) &&
|
||||
!isReductionParallel(Node);
|
||||
|
@ -391,7 +382,7 @@ void IslAstInfo::printScop(raw_ostream &OS) const {
|
|||
|
||||
Scop &S = getCurScop();
|
||||
Options = isl_ast_print_options_alloc(S.getIslCtx());
|
||||
Options = isl_ast_print_options_set_print_for(Options, printFor, nullptr);
|
||||
Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr);
|
||||
|
||||
isl_printer *P = isl_printer_to_str(S.getIslCtx());
|
||||
P = isl_printer_print_ast_expr(P, RunCondition);
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
; RUN: opt %loadPolly -polly-ast -polly-ast-detect-parallel -analyze < %s | FileCheck %s
|
||||
;
|
||||
; CHECK: pragma simd reduction
|
||||
; CHECK: pragma omp parallel for reduction
|
||||
;
|
||||
; int prod;
|
||||
; void f() {
|
||||
|
|
Loading…
Reference in New Issue