[MLIR][Presburger] Add applyDomain/Range to IntegerRelation

This patch adds support for applying a relation on domain/range of a relation.

Reviewed By: arjunp, ftynse

Differential Revision: https://reviews.llvm.org/D126339
This commit is contained in:
Groverkss 2022-05-29 02:06:11 +05:30
parent 6abce17fc2
commit dac27da7b9
3 changed files with 103 additions and 6 deletions

View File

@ -459,10 +459,16 @@ public:
void removeDuplicateDivs();
/// Converts identifiers of kind srcKind in the range [idStart, idLimit) to
/// variables of kind dstKind and placed after all the other variables of kind
/// dstKind. The internal ordering among the moved variables is preserved.
/// variables of kind dstKind. If `pos` is given, the variables are placed at
/// position `pos` of dstKind, otherwise they are placed after all the other
/// variables of kind dstKind. The internal ordering among the moved variables
/// is preserved.
void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit,
IdKind dstKind);
IdKind dstKind, unsigned pos);
void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit,
IdKind dstKind) {
convertIdKind(srcKind, idStart, idLimit, dstKind, getNumIdKind(dstKind));
}
void convertToLocal(IdKind kind, unsigned idStart, unsigned idLimit) {
convertIdKind(kind, idStart, idLimit, IdKind::Local);
}
@ -523,6 +529,32 @@ public:
/// modifies R to be B -> A.
void inverse();
/// Let the relation `this` be R1, and the relation `rel` be R2. Modifies R1
/// to be the composition of R1 and R2: R1;R2.
///
/// Formally, if R1: A -> B, and R2: B -> C, then this function returns a
/// relation R3: A -> C such that a point (a, c) belongs to R3 iff there
/// exists b such that (a, b) is in R1 and, (b, c) is in R2.
void compose(const IntegerRelation &rel);
/// Given a relation `rel`, apply the relation to the domain of this relation.
///
/// R1: i -> j : (0 <= i < 2, j = i)
/// R2: i -> k : (k = i floordiv 2)
/// R3: k -> j : (0 <= k < 1, 2k <= j <= 2k + 1)
///
/// R1 = {(0, 0), (1, 1)}. R2 maps both 0 and 1 to 0.
/// So R3 = {(0, 0), (0, 1)}.
///
/// Formally, R1.applyDomain(R2) = R2.inverse().compose(R1).
void applyDomain(const IntegerRelation &rel);
/// Given a relation `rel`, apply the relation to the range of this relation.
///
/// Formally, R1.applyRange(R2) is the same as R1.compose(R2) but we provide
/// this for uniformity with `applyDomain`.
void applyRange(const IntegerRelation &rel);
void print(raw_ostream &os) const;
void dump() const;

View File

@ -1184,16 +1184,16 @@ void IntegerRelation::removeRedundantLocalVars() {
}
void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart,
unsigned idLimit, IdKind dstKind) {
unsigned idLimit, IdKind dstKind,
unsigned pos) {
assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range");
if (idStart >= idLimit)
return;
// Append new local variables corresponding to the dimensions to be converted.
unsigned newIdsBegin = getIdKindEnd(dstKind);
unsigned convertCount = idLimit - idStart;
appendId(dstKind, convertCount);
unsigned newIdsBegin = insertId(dstKind, pos, convertCount);
// Swap the new local variables with dimensions.
//
@ -2137,6 +2137,40 @@ void IntegerRelation::inverse() {
convertIdKind(IdKind::Range, 0, numRangeIds, IdKind::Domain);
}
void IntegerRelation::compose(const IntegerRelation &rel) {
assert(getRangeSet().getSpace().isCompatible(rel.getDomainSet().getSpace()) &&
"Range of `this` should be compatible with Domain of `rel`");
IntegerRelation copyRel = rel;
// Let relation `this` be R1: A -> B, and `rel` be R2: B -> C.
// We convert R1 to A -> (B X C), and R2 to B X C then intersect the range of
// R1 with R2. After this, we get R1: A -> C, by projecting out B.
// TODO: Using nested spaces here would help, since we could directly
// intersect the range with another relation.
unsigned numBIds = getNumRangeIds();
// Convert R1 from A -> B to A -> (B X C).
appendId(IdKind::Range, copyRel.getNumRangeIds());
// Convert R2 to B X C.
copyRel.convertIdKind(IdKind::Domain, 0, numBIds, IdKind::Range, 0);
// Intersect R2 to range of R1.
intersectRange(IntegerPolyhedron(copyRel));
// Project out B in R1.
convertIdKind(IdKind::Range, 0, numBIds, IdKind::Local);
}
void IntegerRelation::applyDomain(const IntegerRelation &rel) {
inverse();
compose(rel);
inverse();
}
void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
void IntegerRelation::printSpace(raw_ostream &os) const {
space.print(os);
os << getNumConstraints() << " constraints\n";

View File

@ -91,3 +91,34 @@ TEST(IntegerRelationTest, intersectDomainAndRange) {
EXPECT_TRUE(copyRel.isEqual(expectedRel));
}
}
TEST(IntegerRelationTest, applyDomainAndRange) {
{
IntegerRelation map1 = parseRelationFromSet(
"(x, y, a, b)[N] : (a - x - N == 0, b - y + N == 0)", 2);
IntegerRelation map2 =
parseRelationFromSet("(x, y, a)[N] : (a - x - y == 0)", 2);
map1.applyRange(map2);
IntegerRelation map3 =
parseRelationFromSet("(x, y, a)[N] : (a - x - y == 0)", 2);
EXPECT_TRUE(map1.isEqual(map3));
}
{
IntegerRelation map1 = parseRelationFromSet(
"(x, y, a, b)[N] : (a - x + N == 0, b - y - N == 0)", 2);
IntegerRelation map2 =
parseRelationFromSet("(x, y, a, b)[N] : (a - N == 0, b - N == 0)", 2);
IntegerRelation map3 =
parseRelationFromSet("(x, y, a, b)[N] : (x - N == 0, y - N == 0)", 2);
map1.applyDomain(map2);
EXPECT_TRUE(map1.isEqual(map3));
}
}