[MLIR][Presburger] Support computing a representation of a set that only has locals that are divs

This paves the way for integer-exact projection, and for supporting
non-division locals in subtraction, complement, and equality checks.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D127463
This commit is contained in:
Arjun P 2022-06-21 06:30:11 +02:00
parent 1452e2e5cb
commit 8a7ead691b
6 changed files with 199 additions and 4 deletions

View File

@ -26,6 +26,7 @@ namespace presburger {
class IntegerRelation;
class IntegerPolyhedron;
class PresburgerSet;
/// An IntegerRelation represents the set of points from a PresburgerSpace that
/// satisfy a list of affine constraints. Affine constraints can be inequalities
@ -93,6 +94,17 @@ public:
/// Returns a reference to the underlying space.
const PresburgerSpace &getSpace() const { return space; }
/// Set the space to `oSpace`, which should have the same number of ids as
/// the current space.
void setSpace(const PresburgerSpace &oSpace);
/// Set the space to `oSpace`, which should not have any local ids.
/// `oSpace` can have fewer ids than the current space; in that case, the
/// the extra ids in `this` that are not accounted for by `oSpace` will be
/// considered as local ids. `oSpace` should not have more ids than the
/// current space; this will result in an assert failure.
void setSpaceExceptLocals(const PresburgerSpace &oSpace);
/// Returns a copy of the space without locals.
PresburgerSpace getSpaceWithoutLocals() const {
return PresburgerSpace::getRelationSpace(space.getNumDomainIds(),
@ -497,6 +509,9 @@ public:
/// locals that have been added to `this`.
unsigned mergeLocalIds(IntegerRelation &other);
/// Check whether all local ids have a division representation.
bool hasOnlyDivLocals() const;
/// Changes the partition between dimensions and symbols. Depending on the new
/// symbol count, either a chunk of dimensional identifiers immediately before
/// the split become symbols, or some of the symbols immediately after the
@ -739,6 +754,12 @@ public:
/// first added identifier.
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
/// Compute an equivalent representation of the same set, such that all local
/// ids have division representations. This representation may involve
/// local ids that correspond to divisions, and may also be a union of convex
/// disjuncts.
PresburgerSet computeReprWithOnlyDivLocals() const;
/// Compute the symbolic integer lexmin of the polyhedron.
/// This finds, for every assignment to the symbols, the lexicographically
/// minimum value attained by the dimensions. For example, the symbolic lexmin

View File

@ -55,6 +55,14 @@ public:
const PresburgerSpace &getSpace() const { return space; }
/// Set the space to `oSpace`. `oSpace` should not contain any local ids.
/// `oSpace` need not have the same number of ids as the current space;
/// it could have more or less. If it has less, the extra ids become
/// locals of the disjuncts. It can also have more, in which case the
/// disjuncts will have fewer locals. If its total number of ids
/// exceeds that of some disjunct, an assert failure will occur.
void setSpace(const PresburgerSpace &oSpace);
/// Return a reference to the list of disjuncts.
ArrayRef<IntegerRelation> getAllDisjuncts() const;
@ -117,6 +125,9 @@ public:
/// disjuncts in the union.
PresburgerRelation coalesce() const;
/// Check whether all local ids in all disjuncts have a div representation.
bool hasOnlyDivLocals() const;
/// Print the set's internal state.
void print(raw_ostream &os) const;
void dump() const;

View File

@ -572,10 +572,28 @@ public:
/// `constraints`, and no other ids.
SymbolicLexSimplex(const IntegerPolyhedron &constraints,
const IntegerPolyhedron &symbolDomain)
: LexSimplexBase(constraints), domainPoly(symbolDomain),
domainSimplex(symbolDomain) {
assert(domainPoly.getNumIds() == constraints.getNumSymbolIds());
assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds());
: SymbolicLexSimplex(constraints,
constraints.getIdKindOffset(IdKind::Symbol),
symbolDomain) {
assert(constraints.getNumSymbolIds() == symbolDomain.getNumIds());
}
/// An overload to select some other subrange of ids as symbols for lexmin.
/// The symbol ids are the range of ids with absolute index
/// [symbolOffset, symbolOffset + symbolDomain.getNumIds())
/// symbolDomain should only have dim ids.
SymbolicLexSimplex(const IntegerPolyhedron &constraints,
unsigned symbolOffset,
const IntegerPolyhedron &symbolDomain)
: LexSimplexBase(/*nVar=*/constraints.getNumIds(), symbolOffset,
symbolDomain.getNumIds()),
domainPoly(symbolDomain), domainSimplex(symbolDomain) {
// TODO consider supporting this case. It amounts
// to just returning the input constraints.
assert(domainPoly.getNumIds() > 0 &&
"there must be some non-symbols to optimize!");
assert(domainPoly.getNumIds() == domainPoly.getNumDimIds());
intersectIntegerRelation(constraints);
}
/// The lexmin will be stored as a function `lexmin` from symbols to
@ -583,6 +601,9 @@ public:
///
/// For some values of the symbols, the lexmin may be unbounded.
/// These parts of the symbol domain will be stored in `unboundedDomain`.
///
/// The spaces of the sets in the result are compatible with the symbolDomain
/// passed in the SymbolicLexSimplex constructor.
SymbolicLexMin computeSymbolicIntegerLexMin();
private:

View File

@ -38,6 +38,19 @@ std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
return std::make_unique<IntegerPolyhedron>(*this);
}
void IntegerRelation::setSpace(const PresburgerSpace &oSpace) {
assert(space.getNumIds() == oSpace.getNumIds() && "invalid space!");
space = oSpace;
}
void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) {
assert(oSpace.getNumLocalIds() == 0 && "no locals should be present!");
assert(oSpace.getNumIds() <= getNumIds() && "invalid space!");
unsigned newNumLocals = getNumIds() - oSpace.getNumIds();
space = oSpace;
space.insertId(IdKind::Local, 0, newNumLocals);
}
void IntegerRelation::append(const IntegerRelation &other) {
assert(space.isEqual(other.getSpace()) && "Spaces must be equal.");
@ -152,6 +165,67 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
removeEqualityRange(counts.getNumEqs(), getNumEqualities());
}
PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const {
// If there are no locals, we're done.
if (getNumLocalIds() == 0)
return PresburgerSet(*this);
// Move all the non-div locals to the end, as the current API to
// SymbolicLexMin requires these to form a contiguous range.
//
// Take a copy so we can perform mutations.
IntegerPolyhedron copy = *this;
std::vector<MaybeLocalRepr> reprs;
copy.getLocalReprs(reprs);
// Iterate through all the locals. The last `numNonDivLocals` are the locals
// that have been scanned already and do not have division representations.
unsigned numNonDivLocals = 0;
unsigned offset = copy.getIdKindOffset(IdKind::Local);
for (unsigned i = 0, e = copy.getNumLocalIds(); i < e - numNonDivLocals;) {
if (!reprs[i]) {
// Whenever we come across a local that does not have a division
// representation, we swap it to the `numNonDivLocals`-th last position
// and increment `numNonDivLocal`s. `reprs` also needs to be swapped.
copy.swapId(offset + i, offset + e - numNonDivLocals - 1);
std::swap(reprs[i], reprs[e - numNonDivLocals - 1]);
++numNonDivLocals;
continue;
}
++i;
}
// If there are no non-div locals, we're done.
if (numNonDivLocals == 0)
return PresburgerSet(*this);
// We computeSymbolicIntegerLexMin by considering the non-div locals as
// "non-symbols" and considering everything else as "symbols". This will
// compute a function mapping assignments to "symbols" to the
// lexicographically minimal valid assignment of "non-symbols", when a
// satisfying assignment exists. It separately returns the set of assignments
// to the "symbols" such that a satisfying assignment to the "non-symbols"
// exists but the lexmin is unbounded. We basically want to find the set of
// values of the "symbols" such that an assignment to the "non-symbols"
// exists, which is the union of the domain of the returned lexmin function
// and the returned set of assignments to the "symbols" that makes the lexmin
// unbounded.
SymbolicLexMin lexminResult =
SymbolicLexSimplex(copy, /*symbolOffset*/ 0,
IntegerPolyhedron(PresburgerSpace::getSetSpace(
/*numDims=*/copy.getNumIds() - numNonDivLocals)))
.computeSymbolicIntegerLexMin();
PresburgerSet result =
lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain);
// The result set might lie in the wrong space -- all its ids are dims.
// Set it to the desired space and return.
PresburgerSpace space = getSpace();
space.removeIdRange(IdKind::Local, 0, getNumLocalIds());
result.setSpace(space);
return result;
}
SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
// Compute the symbolic lexmin of the dims and locals, with the symbols being
// the actual symbols of this set.
@ -1120,6 +1194,13 @@ unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) {
return relA.getNumLocalIds() - oldALocals;
}
bool IntegerRelation::hasOnlyDivLocals() const {
std::vector<MaybeLocalRepr> reprs;
getLocalReprs(reprs);
return llvm::all_of(reprs,
[](const MaybeLocalRepr &repr) { return bool(repr); });
}
void IntegerRelation::removeDuplicateDivs() {
std::vector<SmallVector<int64_t, 8>> divs;
SmallVector<unsigned, 4> denoms;

View File

@ -21,6 +21,13 @@ PresburgerRelation::PresburgerRelation(const IntegerRelation &disjunct)
unionInPlace(disjunct);
}
void PresburgerRelation::setSpace(const PresburgerSpace &oSpace) {
assert(space.getNumLocalIds() == 0 && "no locals should be present");
space = oSpace;
for (IntegerRelation &disjunct : disjuncts)
disjunct.setSpaceExceptLocals(space);
}
unsigned PresburgerRelation::getNumDisjuncts() const {
return disjuncts.size();
}
@ -770,6 +777,12 @@ PresburgerRelation PresburgerRelation::coalesce() const {
return SetCoalescer(*this).coalesce();
}
bool PresburgerRelation::hasOnlyDivLocals() const {
return llvm::all_of(disjuncts, [](const IntegerRelation &rel) {
return rel.hasOnlyDivLocals();
});
}
void PresburgerRelation::print(raw_ostream &os) const {
os << "Number of Disjuncts: " << getNumDisjuncts() << "\n";
for (const IntegerRelation &disjunct : disjuncts) {

View File

@ -751,6 +751,54 @@ TEST(SetTest, computeVolume) {
/*resultBound=*/{});
}
// The last `numToProject` dims will be projected out, i.e., converted to
// locals.
void testComputeReprAtPoints(IntegerPolyhedron poly,
ArrayRef<SmallVector<int64_t, 4>> points,
unsigned numToProject) {
poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject,
poly.getNumDimIds(), IdKind::Local);
PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
EXPECT_TRUE(repr.hasOnlyDivLocals());
EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
for (const SmallVector<int64_t, 4> &point : points) {
EXPECT_EQ(poly.containsPointNoLocal(point).hasValue(),
repr.containsPoint(point));
}
}
void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected,
unsigned numToProject) {
poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject,
poly.getNumDimIds(), IdKind::Local);
PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
EXPECT_TRUE(repr.hasOnlyDivLocals());
EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
EXPECT_TRUE(repr.isEqual(expected));
}
TEST(SetTest, computeReprWithOnlyDivLocals) {
testComputeReprAtPoints(parsePoly("(x, y) : (x - 2*y == 0)"),
{{1, 0}, {2, 1}, {3, 0}, {4, 2}, {5, 3}},
/*numToProject=*/0);
testComputeReprAtPoints(parsePoly("(x, e) : (x - 2*e == 0)"),
{{1}, {2}, {3}, {4}, {5}}, /*numToProject=*/1);
// Tests to check that the space is preserved.
testComputeReprAtPoints(parsePoly("(x, y)[z, w] : ()"), {},
/*numToProject=*/1);
testComputeReprAtPoints(parsePoly("(x, y)[z, w] : (z - (w floordiv 2) == 0)"),
{},
/*numToProject=*/1);
// Bezout's lemma: if a, b are constants,
// the set of values that ax + by can take is all multiples of gcd(a, b).
testComputeRepr(
parsePoly("(x, e, f) : (x - 15*e - 21*f == 0)"),
PresburgerSet(parsePoly({"(x) : (x - 3*(x floordiv 3) == 0)"})),
/*numToProject=*/2);
}
TEST(SetTest, subtractOutputSizeRegression) {
PresburgerSet set1 =
parsePresburgerSetFromPolyStrings(1, {"(i) : (i >= 0, 10 - i >= 0)"});