forked from OSchip/llvm-project
[MLIR][FlatAffineConstraints] Remove duplicate divisions while merging local ids
This patch implements detecting duplicate local identifiers by extracting their division representation while merging local identifiers. For example, given the FACs A, B: ``` A: (x, y)[s0] : (exists d0 = [x / 4], d1 = [y / 4]: d0 <= s0, d1 <= s0, x + y >= 2) B: (x, y)[s0] : (exists d0 = [x / 4], d1 = [y / 4]: d0 <= s0, d1 <= s0, x + y >= 5) ``` The intersection of A and B without this patch would lead to the following FAC: ``` (x, y)[s0] : (exists d0 = [x / 4], d1 = [y / 4], d2 = [x / 4], d3 = [x / 4]: d0 <= s0, d1 <= s0, d2 <= s0, d3 <= s0, x + y >= 2, x + y >= 5) ``` after this patch, merging of local ids will detect that `d0 = d2` and `d1 = d3`, and the intersection of these two FACs will be (after removing duplicate constraints): ``` (x, y)[s0] : (exists d0 = [x / 4], d1 = [y / 4] : d0 <= s0, d1 <= s0, x + y >= 2, x + y >= 5) ``` This reduces the number of constraints by 2 (constraints) + 4 (2 constraints for each extra division) for this case. This is used to reduce the output size representation of operations like PresburgerSet::subtract, PresburgerSet::intersect which require merging local variables. Reviewed By: arjunp, bondhugula Differential Revision: https://reviews.llvm.org/D112867
This commit is contained in:
parent
cff427ee20
commit
d257f7c1bf
|
@ -441,10 +441,16 @@ public:
|
|||
/// variables.
|
||||
void convertDimToLocal(unsigned dimStart, unsigned dimLimit);
|
||||
|
||||
/// Merge local ids of `this` and `other`. This is done by appending local ids
|
||||
/// of `other` to `this` and inserting local ids of `this` to `other` at start
|
||||
/// of its local ids. Number of dimension and symbol ids should match in
|
||||
/// `this` and `other`.
|
||||
/// Adds additional local ids to the sets such that they both have the union
|
||||
/// of the local ids in each set, without changing the set of points that
|
||||
/// lie in `this` and `other`. The ordering of the local ids in the
|
||||
/// sets may also be changed. After merging, if the `i^th` local variable in
|
||||
/// one set has a known division representation, then the `i^th` local
|
||||
/// variable in the other set either has the same division representation or
|
||||
/// no known division representation.
|
||||
///
|
||||
/// The number of dimensions and symbol ids in `this` and `other` should
|
||||
/// match.
|
||||
void mergeLocalIds(FlatAffineConstraints &other);
|
||||
|
||||
/// Removes all equalities and inequalities.
|
||||
|
@ -819,8 +825,8 @@ public:
|
|||
/// constraint systems are updated so that they have the union of all
|
||||
/// identifiers, with `this`'s original identifiers appearing first followed
|
||||
/// by any of `other`'s identifiers that didn't appear in `this`. Local
|
||||
/// identifiers of each system are by design separate/local and are placed
|
||||
/// one after other (`this`'s followed by `other`'s).
|
||||
/// identifiers in `other` that have the same division representation as local
|
||||
/// identifiers in `this` are merged into one.
|
||||
// E.g.: Input: `this` has (%i, %j) [%M, %N]
|
||||
// `other` has (%k, %j) [%P, %N, %M]
|
||||
// Output: both `this`, `other` have (%i, %j, %k) [%M, %N, %P]
|
||||
|
|
|
@ -493,8 +493,8 @@ static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(
|
|||
/// dimension-wise and symbol-wise unique; both constraint systems are updated
|
||||
/// so that they have the union of all identifiers, with A's original
|
||||
/// identifiers appearing first followed by any of B's identifiers that didn't
|
||||
/// appear in A. Local identifiers of each system are by design separate/local
|
||||
/// and are placed one after other (A's followed by B's).
|
||||
/// appear in A. Local identifiers in B that have the same division
|
||||
/// representation as local identifiers in A are merged into one.
|
||||
// E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
|
||||
// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
|
||||
static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a,
|
||||
|
@ -1918,18 +1918,108 @@ void FlatAffineConstraints::removeRedundantConstraints() {
|
|||
equalities.resizeVertically(pos);
|
||||
}
|
||||
|
||||
/// Merge local ids of `this` and `other`. This is done by appending local ids
|
||||
/// of `other` to `this` and inserting local ids of `this` to `other` at start
|
||||
/// of its local ids. Number of dimension and symbol ids should match in
|
||||
/// `this` and `other`.
|
||||
/// Eliminate `pos2^th` local identifier, replacing its every instance with
|
||||
/// `pos1^th` local identifier. This function is intended to be used to remove
|
||||
/// redundancy when local variables at position `pos1` and `pos2` are restricted
|
||||
/// to have the same value.
|
||||
static void eliminateRedundantLocalId(FlatAffineConstraints &fac, unsigned pos1,
|
||||
unsigned pos2) {
|
||||
|
||||
assert(pos1 < fac.getNumLocalIds() && "Invalid local id position");
|
||||
assert(pos2 < fac.getNumLocalIds() && "Invalid local id position");
|
||||
|
||||
unsigned localOffset = fac.getNumDimAndSymbolIds();
|
||||
pos1 += localOffset;
|
||||
pos2 += localOffset;
|
||||
for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i)
|
||||
fac.atIneq(i, pos1) += fac.atIneq(i, pos2);
|
||||
for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i)
|
||||
fac.atEq(i, pos1) += fac.atEq(i, pos2);
|
||||
fac.removeId(pos2);
|
||||
}
|
||||
|
||||
/// Adds additional local ids to the sets such that they both have the union
|
||||
/// of the local ids in each set, without changing the set of points that
|
||||
/// lie in `this` and `other`.
|
||||
///
|
||||
/// To detect local ids that always take the same in both sets, each local id is
|
||||
/// represented as a floordiv with constant denominator in terms of other ids.
|
||||
/// After extracting these divisions, local ids with the same division
|
||||
/// representation are considered duplicate and are merged. It is possible that
|
||||
/// division representation for some local id cannot be obtained, and thus these
|
||||
/// local ids are not considered for detecting duplicates.
|
||||
void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) {
|
||||
assert(getNumDimIds() == other.getNumDimIds() &&
|
||||
"Number of dimension ids should match");
|
||||
assert(getNumSymbolIds() == other.getNumSymbolIds() &&
|
||||
"Number of symbol ids should match");
|
||||
unsigned initLocals = getNumLocalIds();
|
||||
insertLocalId(getNumLocalIds(), other.getNumLocalIds());
|
||||
other.insertLocalId(0, initLocals);
|
||||
|
||||
FlatAffineConstraints &facA = *this;
|
||||
FlatAffineConstraints &facB = other;
|
||||
|
||||
// Merge local ids of facA and facB without using division information,
|
||||
// i.e. append local ids of `facB` to `facA` and insert local ids of `facA`
|
||||
// to `facB` at start of its local ids.
|
||||
unsigned initLocals = facA.getNumLocalIds();
|
||||
insertLocalId(facA.getNumLocalIds(), facB.getNumLocalIds());
|
||||
facB.insertLocalId(0, initLocals);
|
||||
|
||||
// Get division representations from each FAC.
|
||||
std::vector<SmallVector<int64_t, 8>> divsA, divsB;
|
||||
SmallVector<unsigned, 4> denomsA, denomsB;
|
||||
facA.getLocalReprs(divsA, denomsA);
|
||||
facB.getLocalReprs(divsB, denomsB);
|
||||
|
||||
// Copy division information for facB into `divsA` and `denomsA`, so that
|
||||
// these have the combined division information of both FACs. Since newly
|
||||
// added local variables in facA and facB have no constraints, they will not
|
||||
// have any division representation.
|
||||
std::copy(divsB.begin() + initLocals, divsB.end(),
|
||||
divsA.begin() + initLocals);
|
||||
std::copy(denomsB.begin() + initLocals, denomsB.end(),
|
||||
denomsA.begin() + initLocals);
|
||||
|
||||
// Find and merge duplicate divisions.
|
||||
// TODO: Add division normalization to support divisions that differ by
|
||||
// a constant.
|
||||
// TODO: Add division ordering such that a division representation for local
|
||||
// identifier at position `i` only depends on local identifiers at position <
|
||||
// `i`. This would make sure that all divisions depending on other local
|
||||
// variables that can be merged, are merged.
|
||||
unsigned localOffset = getIdKindOffset(IdKind::Local);
|
||||
for (unsigned i = 0; i < divsA.size(); ++i) {
|
||||
// Check if a division representation exists for the `i^th` local id.
|
||||
if (denomsA[i] == 0)
|
||||
continue;
|
||||
// Check if a division exists which is a duplicate of the division at `i`.
|
||||
for (unsigned j = i + 1; j < divsA.size(); ++j) {
|
||||
// Check if a division representation exists for the `j^th` local id.
|
||||
if (denomsA[j] == 0)
|
||||
continue;
|
||||
// Check if the denominators match.
|
||||
if (denomsA[i] != denomsA[j])
|
||||
continue;
|
||||
// Check if the representations are equal.
|
||||
if (divsA[i] != divsA[j])
|
||||
continue;
|
||||
|
||||
// Merge divisions at position `j` into division at position `i`.
|
||||
eliminateRedundantLocalId(facA, i, j);
|
||||
eliminateRedundantLocalId(facB, i, j);
|
||||
for (unsigned k = 0, g = divsA.size(); k < g; ++k) {
|
||||
SmallVector<int64_t, 8> &div = divsA[k];
|
||||
if (denomsA[k] != 0) {
|
||||
div[localOffset + i] += div[localOffset + j];
|
||||
div.erase(div.begin() + localOffset + j);
|
||||
}
|
||||
}
|
||||
|
||||
divsA.erase(divsA.begin() + j);
|
||||
denomsA.erase(denomsA.begin() + j);
|
||||
// Since `j` can never be zero, we do not need to worry about overflows.
|
||||
--j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes local variables using equalities. Each equality is checked if it
|
||||
|
|
|
@ -809,4 +809,127 @@ TEST(FlatAffineConstraintsTest, simplifyLocalsTest) {
|
|||
EXPECT_TRUE(fac3.isEmpty());
|
||||
}
|
||||
|
||||
TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) {
|
||||
{
|
||||
// (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0).
|
||||
FlatAffineConstraints fac1(1, 0, 1);
|
||||
fac1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2].
|
||||
fac1.addEquality({1, 0, -3, 0}); // x = 3y.
|
||||
fac1.addInequality({1, 1, 0, 1}); // x + z + 1 >= 0.
|
||||
|
||||
// (x) : (exists y = [x / 2], z : x = 5y).
|
||||
FlatAffineConstraints fac2(1);
|
||||
fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
|
||||
fac2.addEquality({1, -5, 0}); // x = 5y.
|
||||
fac2.appendLocalId(); // Add local id z.
|
||||
|
||||
fac1.mergeLocalIds(fac2);
|
||||
|
||||
// Local space should be same.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
|
||||
|
||||
// 1 division should be matched + 2 unmatched local ids.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), 3u);
|
||||
EXPECT_EQ(fac2.getNumLocalIds(), 3u);
|
||||
}
|
||||
|
||||
{
|
||||
// (x) : (exists z = [x / 5], y = [x / 2] : x = 3y).
|
||||
FlatAffineConstraints fac1(1);
|
||||
fac1.addLocalFloorDiv({1, 0}, 5); // z = [x / 5].
|
||||
fac1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2].
|
||||
fac1.addEquality({1, 0, -3, 0}); // x = 3y.
|
||||
|
||||
// (x) : (exists y = [x / 2], z = [x / 5]: x = 5z).
|
||||
FlatAffineConstraints fac2(1);
|
||||
fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
|
||||
fac2.addLocalFloorDiv({1, 0, 0}, 5); // z = [x / 5].
|
||||
fac2.addEquality({1, 0, -5, 0}); // x = 5z.
|
||||
|
||||
fac1.mergeLocalIds(fac2);
|
||||
|
||||
// Local space should be same.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
|
||||
|
||||
// 2 divisions should be matched.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), 2u);
|
||||
EXPECT_EQ(fac2.getNumLocalIds(), 2u);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) {
|
||||
{
|
||||
// (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x).
|
||||
FlatAffineConstraints fac1(1);
|
||||
fac1.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
|
||||
fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
|
||||
fac1.addInequality({-1, 1, 1, 0}); // y + z >= x.
|
||||
|
||||
// (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x).
|
||||
FlatAffineConstraints fac2(1);
|
||||
fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
|
||||
fac2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
|
||||
fac2.addInequality({1, -1, -1, 0}); // y + z <= x.
|
||||
|
||||
fac1.mergeLocalIds(fac2);
|
||||
|
||||
// Local space should be same.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
|
||||
|
||||
// 2 divisions should be matched.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), 2u);
|
||||
EXPECT_EQ(fac2.getNumLocalIds(), 2u);
|
||||
}
|
||||
|
||||
{
|
||||
// (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z >= x).
|
||||
FlatAffineConstraints fac1(1);
|
||||
fac1.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
|
||||
fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
|
||||
fac1.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5].
|
||||
fac1.addInequality({-1, 1, 1, 0, 0}); // y + z >= x.
|
||||
|
||||
// (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z <= x).
|
||||
FlatAffineConstraints fac2(1);
|
||||
fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
|
||||
fac2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
|
||||
fac2.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5].
|
||||
fac2.addInequality({1, -1, -1, 0, 0}); // y + z <= x.
|
||||
|
||||
fac1.mergeLocalIds(fac2);
|
||||
|
||||
// Local space should be same.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
|
||||
|
||||
// 3 divisions should be matched.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), 3u);
|
||||
EXPECT_EQ(fac2.getNumLocalIds(), 3u);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) {
|
||||
{
|
||||
// (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x).
|
||||
FlatAffineConstraints fac1(1);
|
||||
fac1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2].
|
||||
fac1.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3].
|
||||
fac1.addInequality({-1, 1, 1, 0}); // y + z >= x.
|
||||
|
||||
// (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x).
|
||||
FlatAffineConstraints fac2(1);
|
||||
fac2.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2].
|
||||
fac2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3].
|
||||
fac2.addInequality({1, -1, -1, 0}); // y + z <= x.
|
||||
|
||||
fac1.mergeLocalIds(fac2);
|
||||
|
||||
// Local space should be same.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
|
||||
|
||||
// 2 divisions should be matched.
|
||||
EXPECT_EQ(fac1.getNumLocalIds(), 2u);
|
||||
EXPECT_EQ(fac2.getNumLocalIds(), 2u);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue