llvm-project/polly/lib/Transform/ScheduleTreeTransform.cpp

634 lines
24 KiB
C++

//===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Make changes to isl's schedule tree data structure.
//
//===----------------------------------------------------------------------===//
#include "polly/ScheduleTreeTransform.h"
#include "polly/Support/ISLTools.h"
#include "polly/Support/ScopHelper.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
using namespace polly;
using namespace llvm;
namespace {
/// Recursively visit all nodes of a schedule tree while allowing changes.
///
/// The visit methods return an isl::schedule_node that is used to continue
/// visiting the tree. Structural changes such as returning a different node
/// will confuse the visitor.
template <typename Derived, typename... Args>
struct ScheduleNodeRewriter
: public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node,
Args...> {
Derived &getDerived() { return *static_cast<Derived *>(this); }
const Derived &getDerived() const {
return *static_cast<const Derived *>(this);
}
isl::schedule_node visitNode(const isl::schedule_node &Node, Args... args) {
if (!Node.has_children())
return Node;
isl::schedule_node It = Node.first_child();
while (true) {
It = getDerived().visit(It, std::forward<Args>(args)...);
if (!It.has_next_sibling())
break;
It = It.next_sibling();
}
return It.parent();
}
};
/// Rewrite a schedule tree by reconstructing it bottom-up.
///
/// By default, the original schedule tree is reconstructed. To build a
/// different tree, redefine visitor methods in a derived class (CRTP).
///
/// Note that AST build options are not applied; Setting the isolate[] option
/// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence,
/// AST build options must be set after the tree has been constructed.
template <typename Derived, typename... Args>
struct ScheduleTreeRewriter
: public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> {
Derived &getDerived() { return *static_cast<Derived *>(this); }
const Derived &getDerived() const {
return *static_cast<const Derived *>(this);
}
isl::schedule visitDomain(const isl::schedule_node &Node, Args... args) {
// Every schedule_tree already has a domain node, no need to add one.
return getDerived().visit(Node.first_child(), std::forward<Args>(args)...);
}
isl::schedule visitBand(const isl::schedule_node &Band, Args... args) {
isl::multi_union_pw_aff PartialSched =
isl::manage(isl_schedule_node_band_get_partial_schedule(Band.get()));
isl::schedule NewChild =
getDerived().visit(Band.child(0), std::forward<Args>(args)...);
isl::schedule_node NewNode =
NewChild.insert_partial_schedule(PartialSched).get_root().get_child(0);
// Reapply permutability and coincidence attributes.
NewNode = isl::manage(isl_schedule_node_band_set_permutable(
NewNode.release(), isl_schedule_node_band_get_permutable(Band.get())));
unsigned BandDims = isl_schedule_node_band_n_member(Band.get());
for (unsigned i = 0; i < BandDims; i += 1)
NewNode = isl::manage(isl_schedule_node_band_member_set_coincident(
NewNode.release(), i,
isl_schedule_node_band_member_get_coincident(Band.get(), i)));
return NewNode.get_schedule();
}
isl::schedule visitSequence(const isl::schedule_node &Sequence,
Args... args) {
int NumChildren = isl_schedule_node_n_children(Sequence.get());
isl::schedule Result =
getDerived().visit(Sequence.child(0), std::forward<Args>(args)...);
for (int i = 1; i < NumChildren; i += 1)
Result = Result.sequence(
getDerived().visit(Sequence.child(i), std::forward<Args>(args)...));
return Result;
}
isl::schedule visitSet(const isl::schedule_node &Set, Args... args) {
int NumChildren = isl_schedule_node_n_children(Set.get());
isl::schedule Result =
getDerived().visit(Set.child(0), std::forward<Args>(args)...);
for (int i = 1; i < NumChildren; i += 1)
Result = isl::manage(
isl_schedule_set(Result.release(),
getDerived()
.visit(Set.child(i), std::forward<Args>(args)...)
.release()));
return Result;
}
isl::schedule visitLeaf(const isl::schedule_node &Leaf, Args... args) {
return isl::schedule::from_domain(Leaf.get_domain());
}
isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) {
isl::id TheMark = Mark.mark_get_id();
isl::schedule_node NewChild =
getDerived()
.visit(Mark.first_child(), std::forward<Args>(args)...)
.get_root()
.first_child();
return NewChild.insert_mark(TheMark).get_schedule();
}
isl::schedule visitExtension(const isl::schedule_node &Extension,
Args... args) {
isl::union_map TheExtension = Extension.extension_get_extension();
isl::schedule_node NewChild = getDerived()
.visit(Extension.child(0), args...)
.get_root()
.first_child();
isl::schedule_node NewExtension =
isl::schedule_node::from_extension(TheExtension);
return NewChild.graft_before(NewExtension).get_schedule();
}
isl::schedule visitFilter(const isl::schedule_node &Filter, Args... args) {
isl::union_set FilterDomain = Filter.filter_get_filter();
isl::schedule NewSchedule =
getDerived().visit(Filter.child(0), std::forward<Args>(args)...);
return NewSchedule.intersect_domain(FilterDomain);
}
isl::schedule visitNode(const isl::schedule_node &Node, Args... args) {
llvm_unreachable("Not implemented");
}
};
/// Rewrite a schedule tree to an equivalent one without extension nodes.
///
/// Each visit method takes two additional arguments:
///
/// * The new domain the node, which is the inherited domain plus any domains
/// added by extension nodes.
///
/// * A map of extension domains of all children is returned; it is required by
/// band nodes to schedule the additional domains at the same position as the
/// extension node would.
///
struct ExtensionNodeRewriter
: public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
isl::union_map &> {
using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
const isl::union_set &, isl::union_map &>;
BaseTy &getBase() { return *this; }
const BaseTy &getBase() const { return *this; }
isl::schedule visitSchedule(const isl::schedule &Schedule) {
isl::union_map Extensions;
isl::schedule Result =
visit(Schedule.get_root(), Schedule.get_domain(), Extensions);
assert(Extensions && Extensions.is_empty());
return Result;
}
isl::schedule visitSequence(const isl::schedule_node &Sequence,
const isl::union_set &Domain,
isl::union_map &Extensions) {
int NumChildren = isl_schedule_node_n_children(Sequence.get());
isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions);
for (int i = 1; i < NumChildren; i += 1) {
isl::schedule_node OldChild = Sequence.child(i);
isl::union_map NewChildExtensions;
isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
NewNode = NewNode.sequence(NewChildNode);
Extensions = Extensions.unite(NewChildExtensions);
}
return NewNode;
}
isl::schedule visitSet(const isl::schedule_node &Set,
const isl::union_set &Domain,
isl::union_map &Extensions) {
int NumChildren = isl_schedule_node_n_children(Set.get());
isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions);
for (int i = 1; i < NumChildren; i += 1) {
isl::schedule_node OldChild = Set.child(i);
isl::union_map NewChildExtensions;
isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
NewNode = isl::manage(
isl_schedule_set(NewNode.release(), NewChildNode.release()));
Extensions = Extensions.unite(NewChildExtensions);
}
return NewNode;
}
isl::schedule visitLeaf(const isl::schedule_node &Leaf,
const isl::union_set &Domain,
isl::union_map &Extensions) {
isl::ctx Ctx = Leaf.get_ctx();
Extensions = isl::union_map::empty(isl::space::params_alloc(Ctx, 0));
return isl::schedule::from_domain(Domain);
}
isl::schedule visitBand(const isl::schedule_node &OldNode,
const isl::union_set &Domain,
isl::union_map &OuterExtensions) {
isl::schedule_node OldChild = OldNode.first_child();
isl::multi_union_pw_aff PartialSched =
isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get()));
isl::union_map NewChildExtensions;
isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions);
// Add the extensions to the partial schedule.
OuterExtensions = isl::union_map::empty(NewChildExtensions.get_space());
isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched);
unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get());
for (isl::map Ext : NewChildExtensions.get_map_list()) {
unsigned ExtDims = Ext.dim(isl::dim::in);
assert(ExtDims >= BandDims);
unsigned OuterDims = ExtDims - BandDims;
isl::map BandSched =
Ext.project_out(isl::dim::in, 0, OuterDims).reverse();
NewPartialSchedMap = NewPartialSchedMap.unite(BandSched);
// There might be more outer bands that have to schedule the extensions.
if (OuterDims > 0) {
isl::map OuterSched =
Ext.project_out(isl::dim::in, OuterDims, BandDims);
OuterExtensions = OuterExtensions.add_map(OuterSched);
}
}
isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff =
isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap);
isl::schedule_node NewNode =
NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff)
.get_root()
.get_child(0);
// Reapply permutability and coincidence attributes.
NewNode = isl::manage(isl_schedule_node_band_set_permutable(
NewNode.release(),
isl_schedule_node_band_get_permutable(OldNode.get())));
for (unsigned i = 0; i < BandDims; i += 1) {
NewNode = isl::manage(isl_schedule_node_band_member_set_coincident(
NewNode.release(), i,
isl_schedule_node_band_member_get_coincident(OldNode.get(), i)));
}
return NewNode.get_schedule();
}
isl::schedule visitFilter(const isl::schedule_node &Filter,
const isl::union_set &Domain,
isl::union_map &Extensions) {
isl::union_set FilterDomain = Filter.filter_get_filter();
isl::union_set NewDomain = Domain.intersect(FilterDomain);
// A filter is added implicitly if necessary when joining schedule trees.
return visit(Filter.first_child(), NewDomain, Extensions);
}
isl::schedule visitExtension(const isl::schedule_node &Extension,
const isl::union_set &Domain,
isl::union_map &Extensions) {
isl::union_map ExtDomain = Extension.extension_get_extension();
isl::union_set NewDomain = Domain.unite(ExtDomain.range());
isl::union_map ChildExtensions;
isl::schedule NewChild =
visit(Extension.first_child(), NewDomain, ChildExtensions);
Extensions = ChildExtensions.unite(ExtDomain);
return NewChild;
}
};
/// Collect all AST build options in any schedule tree band.
///
/// ScheduleTreeRewriter cannot apply the schedule tree options. This class
/// collects these options to apply them later.
struct CollectASTBuildOptions
: public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> {
using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>;
BaseTy &getBase() { return *this; }
const BaseTy &getBase() const { return *this; }
llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
void visitBand(const isl::schedule_node &Band) {
ASTBuildOptions.push_back(
isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get())));
return getBase().visitBand(Band);
}
};
/// Apply AST build options to the bands in a schedule tree.
///
/// This rewrites a schedule tree with the AST build options applied. We assume
/// that the band nodes are visited in the same order as they were when the
/// build options were collected, typically by CollectASTBuildOptions.
struct ApplyASTBuildOptions
: public ScheduleNodeRewriter<ApplyASTBuildOptions> {
using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>;
BaseTy &getBase() { return *this; }
const BaseTy &getBase() const { return *this; }
size_t Pos;
llvm::ArrayRef<isl::union_set> ASTBuildOptions;
ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
: ASTBuildOptions(ASTBuildOptions) {}
isl::schedule visitSchedule(const isl::schedule &Schedule) {
Pos = 0;
isl::schedule Result = visit(Schedule).get_schedule();
assert(Pos == ASTBuildOptions.size() &&
"AST build options must match to band nodes");
return Result;
}
isl::schedule_node visitBand(const isl::schedule_node &Band) {
isl::schedule_node Result =
Band.band_set_ast_build_options(ASTBuildOptions[Pos]);
Pos += 1;
return getBase().visitBand(Result);
}
};
/// Return whether the schedule contains an extension node.
static bool containsExtensionNode(isl::schedule Schedule) {
assert(!Schedule.is_null());
auto Callback = [](__isl_keep isl_schedule_node *Node,
void *User) -> isl_bool {
if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) {
// Stop walking the schedule tree.
return isl_bool_error;
}
// Continue searching the subtree.
return isl_bool_true;
};
isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down(
Schedule.get(), Callback, nullptr);
// We assume that the traversal itself does not fail, i.e. the only reason to
// return isl_stat_error is that an extension node was found.
return RetVal == isl_stat_error;
}
/// Find a named MDNode property in a LoopID.
static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) {
return dyn_cast_or_null<MDNode>(
findMetadataOperand(LoopMD, Name).getValueOr(nullptr));
}
/// Is this node of type mark?
static bool isMark(const isl::schedule_node &Node) {
return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark;
}
#ifndef NDEBUG
/// Is this node of type band?
static bool isBand(const isl::schedule_node &Node) {
return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band;
}
/// Is this node a band of a single dimension (i.e. could represent a loop)?
static bool isBandWithSingleLoop(const isl::schedule_node &Node) {
return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1;
}
#endif
/// Create an isl::id representing the output loop after a transformation.
static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) {
// Don't need to id the followup.
// TODO: Append llvm.loop.disable_heustistics metadata unless overridden by
// user followup-MD
if (!FollowupLoopMD)
return {};
BandAttr *Attr = new BandAttr();
Attr->Metadata = FollowupLoopMD;
return getIslLoopAttr(Ctx, Attr);
}
/// A loop consists of a band and an optional marker that wraps it. Return the
/// outermost of the two.
/// That is, either the mark or, if there is not mark, the loop itself. Can
/// start with either the mark or the band.
static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) {
if (isBandMark(BandOrMark)) {
assert(isBandWithSingleLoop(BandOrMark.get_child(0)));
return BandOrMark;
}
assert(isBandWithSingleLoop(BandOrMark));
isl::schedule_node Mark = BandOrMark.parent();
if (isBandMark(Mark))
return Mark;
// Band has no loop marker.
return BandOrMark;
}
static isl::schedule_node removeMark(isl::schedule_node MarkOrBand,
BandAttr *&Attr) {
MarkOrBand = moveToBandMark(MarkOrBand);
isl::schedule_node Band;
if (isMark(MarkOrBand)) {
Attr = getLoopAttr(MarkOrBand.mark_get_id());
Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release()));
} else {
Attr = nullptr;
Band = MarkOrBand;
}
assert(isBandWithSingleLoop(Band));
return Band;
}
/// Remove the mark that wraps a loop. Return the band representing the loop.
static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) {
BandAttr *Attr;
return removeMark(MarkOrBand, Attr);
}
static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) {
assert(isBand(Band));
assert(moveToBandMark(Band).is_equal(Band) &&
"Don't add a two marks for a band");
return Band.insert_mark(Mark).get_child(0);
}
/// Return the (one-dimensional) set of numbers that are divisible by @p Factor
/// with remainder @p Offset.
///
/// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 }
/// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 }
///
static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor,
long Offset) {
isl::val ValFactor{Ctx, Factor};
isl::val ValOffset{Ctx, Offset};
isl::space Unispace{Ctx, 0, 1};
isl::local_space LUnispace{Unispace};
isl::aff AffFactor{LUnispace, ValFactor};
isl::aff AffOffset{LUnispace, ValOffset};
isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0);
isl::aff DivMul = Id.mod(ValFactor);
isl::basic_map Divisible = isl::basic_map::from_aff(DivMul);
isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset);
return Modulo.domain();
}
} // namespace
bool polly::isBandMark(const isl::schedule_node &Node) {
return isMark(Node) && isLoopAttr(Node.mark_get_id());
}
BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) {
MarkOrBand = moveToBandMark(MarkOrBand);
if (!isMark(MarkOrBand))
return nullptr;
return getLoopAttr(MarkOrBand.mark_get_id());
}
isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) {
// If there is no extension node in the first place, return the original
// schedule tree.
if (!containsExtensionNode(Sched))
return Sched;
// Build options can anchor schedule nodes, such that the schedule tree cannot
// be modified anymore. Therefore, apply build options after the tree has been
// created.
CollectASTBuildOptions Collector;
Collector.visit(Sched);
// Rewrite the schedule tree without extension nodes.
ExtensionNodeRewriter Rewriter;
isl::schedule NewSched = Rewriter.visitSchedule(Sched);
// Reapply the AST build options. The rewriter must not change the iteration
// order of bands. Any other node type is ignored.
ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
NewSched = Applicator.visitSchedule(NewSched);
return NewSched;
}
isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
isl::ctx Ctx = BandToUnroll.get_ctx();
// Remove the loop's mark, the loop will disappear anyway.
BandToUnroll = removeMark(BandToUnroll);
assert(isBandWithSingleLoop(BandToUnroll));
isl::multi_union_pw_aff PartialSched = isl::manage(
isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
assert(PartialSched.dim(isl::dim::out) == 1 &&
"Can only unroll a single dimension");
isl::union_pw_aff PartialSchedUAff = PartialSched.get_union_pw_aff(0);
isl::union_set Domain = BandToUnroll.get_domain();
PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain);
isl::union_map PartialSchedUMap = isl::union_map(PartialSchedUAff);
// Enumerator only the scatter elements.
isl::union_set ScatterList = PartialSchedUMap.range();
// Enumerate all loop iterations.
// TODO: Diagnose if not enumerable or depends on a parameter.
SmallVector<isl::point, 16> Elts;
ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat {
Elts.push_back(P);
return isl::stat::ok();
});
// Don't assume that foreach_point returns in execution order.
llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool {
isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0);
isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0);
return C1.lt(C2);
});
// Convert the points to a sequence of filters.
isl::union_set_list List = isl::union_set_list::alloc(Ctx, Elts.size());
for (isl::point P : Elts) {
// Determine the domains that map this scatter element.
isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain();
List = List.add(DomainFilter);
}
// Replace original band with unrolled sequence.
isl::schedule_node Body =
isl::manage(isl_schedule_node_delete(BandToUnroll.release()));
Body = Body.insert_sequence(List);
return Body.get_schedule();
}
isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll,
int Factor) {
assert(Factor > 0 && "Positive unroll factor required");
isl::ctx Ctx = BandToUnroll.get_ctx();
// Remove the mark, save the attribute for later use.
BandAttr *Attr;
BandToUnroll = removeMark(BandToUnroll, Attr);
assert(isBandWithSingleLoop(BandToUnroll));
isl::multi_union_pw_aff PartialSched = isl::manage(
isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
// { Stmt[] -> [x] }
isl::union_pw_aff PartialSchedUAff = PartialSched.get_union_pw_aff(0);
// Here we assume the schedule stride is one and starts with 0, which is not
// necessarily the case.
isl::union_pw_aff StridedPartialSchedUAff =
isl::union_pw_aff::empty(PartialSchedUAff.get_space());
isl::val ValFactor{Ctx, Factor};
PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff,
&ValFactor](isl::pw_aff PwAff) -> isl::stat {
isl::space Space = PwAff.get_space();
isl::set Universe = isl::set::universe(Space.domain());
isl::pw_aff AffFactor{Universe, ValFactor};
isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor);
StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff);
return isl::stat::ok();
});
isl::union_set_list List = isl::union_set_list::alloc(Ctx, Factor);
for (auto i : seq<int>(0, Factor)) {
// { Stmt[] -> [x] }
isl::union_map UMap{PartialSchedUAff};
// { [x] }
isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i);
// { Stmt[] }
isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain();
List = List.add(UnrolledDomain);
}
isl::schedule_node Body =
isl::manage(isl_schedule_node_delete(BandToUnroll.copy()));
Body = Body.insert_sequence(List);
isl::schedule_node NewLoop =
Body.insert_partial_schedule(StridedPartialSchedUAff);
MDNode *FollowupMD = nullptr;
if (Attr && Attr->Metadata)
FollowupMD =
findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled);
isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD);
if (NewBandId)
NewLoop = insertMark(NewLoop, NewBandId);
return NewLoop.get_schedule();
}