forked from OSchip/llvm-project
[MLIR][Presburger] Support computing a representation of a set that only has locals that are divs
This paves the way for integer-exact projection, and for supporting non-division locals in subtraction, complement, and equality checks. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D127463
This commit is contained in:
parent
1452e2e5cb
commit
8a7ead691b
|
@ -26,6 +26,7 @@ namespace presburger {
|
|||
|
||||
class IntegerRelation;
|
||||
class IntegerPolyhedron;
|
||||
class PresburgerSet;
|
||||
|
||||
/// An IntegerRelation represents the set of points from a PresburgerSpace that
|
||||
/// satisfy a list of affine constraints. Affine constraints can be inequalities
|
||||
|
@ -93,6 +94,17 @@ public:
|
|||
/// Returns a reference to the underlying space.
|
||||
const PresburgerSpace &getSpace() const { return space; }
|
||||
|
||||
/// Set the space to `oSpace`, which should have the same number of ids as
|
||||
/// the current space.
|
||||
void setSpace(const PresburgerSpace &oSpace);
|
||||
|
||||
/// Set the space to `oSpace`, which should not have any local ids.
|
||||
/// `oSpace` can have fewer ids than the current space; in that case, the
|
||||
/// the extra ids in `this` that are not accounted for by `oSpace` will be
|
||||
/// considered as local ids. `oSpace` should not have more ids than the
|
||||
/// current space; this will result in an assert failure.
|
||||
void setSpaceExceptLocals(const PresburgerSpace &oSpace);
|
||||
|
||||
/// Returns a copy of the space without locals.
|
||||
PresburgerSpace getSpaceWithoutLocals() const {
|
||||
return PresburgerSpace::getRelationSpace(space.getNumDomainIds(),
|
||||
|
@ -497,6 +509,9 @@ public:
|
|||
/// locals that have been added to `this`.
|
||||
unsigned mergeLocalIds(IntegerRelation &other);
|
||||
|
||||
/// Check whether all local ids have a division representation.
|
||||
bool hasOnlyDivLocals() const;
|
||||
|
||||
/// Changes the partition between dimensions and symbols. Depending on the new
|
||||
/// symbol count, either a chunk of dimensional identifiers immediately before
|
||||
/// the split become symbols, or some of the symbols immediately after the
|
||||
|
@ -739,6 +754,12 @@ public:
|
|||
/// first added identifier.
|
||||
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
|
||||
|
||||
/// Compute an equivalent representation of the same set, such that all local
|
||||
/// ids have division representations. This representation may involve
|
||||
/// local ids that correspond to divisions, and may also be a union of convex
|
||||
/// disjuncts.
|
||||
PresburgerSet computeReprWithOnlyDivLocals() const;
|
||||
|
||||
/// Compute the symbolic integer lexmin of the polyhedron.
|
||||
/// This finds, for every assignment to the symbols, the lexicographically
|
||||
/// minimum value attained by the dimensions. For example, the symbolic lexmin
|
||||
|
|
|
@ -55,6 +55,14 @@ public:
|
|||
|
||||
const PresburgerSpace &getSpace() const { return space; }
|
||||
|
||||
/// Set the space to `oSpace`. `oSpace` should not contain any local ids.
|
||||
/// `oSpace` need not have the same number of ids as the current space;
|
||||
/// it could have more or less. If it has less, the extra ids become
|
||||
/// locals of the disjuncts. It can also have more, in which case the
|
||||
/// disjuncts will have fewer locals. If its total number of ids
|
||||
/// exceeds that of some disjunct, an assert failure will occur.
|
||||
void setSpace(const PresburgerSpace &oSpace);
|
||||
|
||||
/// Return a reference to the list of disjuncts.
|
||||
ArrayRef<IntegerRelation> getAllDisjuncts() const;
|
||||
|
||||
|
@ -117,6 +125,9 @@ public:
|
|||
/// disjuncts in the union.
|
||||
PresburgerRelation coalesce() const;
|
||||
|
||||
/// Check whether all local ids in all disjuncts have a div representation.
|
||||
bool hasOnlyDivLocals() const;
|
||||
|
||||
/// Print the set's internal state.
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
|
|
@ -572,10 +572,28 @@ public:
|
|||
/// `constraints`, and no other ids.
|
||||
SymbolicLexSimplex(const IntegerPolyhedron &constraints,
|
||||
const IntegerPolyhedron &symbolDomain)
|
||||
: LexSimplexBase(constraints), domainPoly(symbolDomain),
|
||||
domainSimplex(symbolDomain) {
|
||||
assert(domainPoly.getNumIds() == constraints.getNumSymbolIds());
|
||||
assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds());
|
||||
: SymbolicLexSimplex(constraints,
|
||||
constraints.getIdKindOffset(IdKind::Symbol),
|
||||
symbolDomain) {
|
||||
assert(constraints.getNumSymbolIds() == symbolDomain.getNumIds());
|
||||
}
|
||||
|
||||
/// An overload to select some other subrange of ids as symbols for lexmin.
|
||||
/// The symbol ids are the range of ids with absolute index
|
||||
/// [symbolOffset, symbolOffset + symbolDomain.getNumIds())
|
||||
/// symbolDomain should only have dim ids.
|
||||
SymbolicLexSimplex(const IntegerPolyhedron &constraints,
|
||||
unsigned symbolOffset,
|
||||
const IntegerPolyhedron &symbolDomain)
|
||||
: LexSimplexBase(/*nVar=*/constraints.getNumIds(), symbolOffset,
|
||||
symbolDomain.getNumIds()),
|
||||
domainPoly(symbolDomain), domainSimplex(symbolDomain) {
|
||||
// TODO consider supporting this case. It amounts
|
||||
// to just returning the input constraints.
|
||||
assert(domainPoly.getNumIds() > 0 &&
|
||||
"there must be some non-symbols to optimize!");
|
||||
assert(domainPoly.getNumIds() == domainPoly.getNumDimIds());
|
||||
intersectIntegerRelation(constraints);
|
||||
}
|
||||
|
||||
/// The lexmin will be stored as a function `lexmin` from symbols to
|
||||
|
@ -583,6 +601,9 @@ public:
|
|||
///
|
||||
/// For some values of the symbols, the lexmin may be unbounded.
|
||||
/// These parts of the symbol domain will be stored in `unboundedDomain`.
|
||||
///
|
||||
/// The spaces of the sets in the result are compatible with the symbolDomain
|
||||
/// passed in the SymbolicLexSimplex constructor.
|
||||
SymbolicLexMin computeSymbolicIntegerLexMin();
|
||||
|
||||
private:
|
||||
|
|
|
@ -38,6 +38,19 @@ std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
|
|||
return std::make_unique<IntegerPolyhedron>(*this);
|
||||
}
|
||||
|
||||
void IntegerRelation::setSpace(const PresburgerSpace &oSpace) {
|
||||
assert(space.getNumIds() == oSpace.getNumIds() && "invalid space!");
|
||||
space = oSpace;
|
||||
}
|
||||
|
||||
void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) {
|
||||
assert(oSpace.getNumLocalIds() == 0 && "no locals should be present!");
|
||||
assert(oSpace.getNumIds() <= getNumIds() && "invalid space!");
|
||||
unsigned newNumLocals = getNumIds() - oSpace.getNumIds();
|
||||
space = oSpace;
|
||||
space.insertId(IdKind::Local, 0, newNumLocals);
|
||||
}
|
||||
|
||||
void IntegerRelation::append(const IntegerRelation &other) {
|
||||
assert(space.isEqual(other.getSpace()) && "Spaces must be equal.");
|
||||
|
||||
|
@ -152,6 +165,67 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
|
|||
removeEqualityRange(counts.getNumEqs(), getNumEqualities());
|
||||
}
|
||||
|
||||
PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const {
|
||||
// If there are no locals, we're done.
|
||||
if (getNumLocalIds() == 0)
|
||||
return PresburgerSet(*this);
|
||||
|
||||
// Move all the non-div locals to the end, as the current API to
|
||||
// SymbolicLexMin requires these to form a contiguous range.
|
||||
//
|
||||
// Take a copy so we can perform mutations.
|
||||
IntegerPolyhedron copy = *this;
|
||||
std::vector<MaybeLocalRepr> reprs;
|
||||
copy.getLocalReprs(reprs);
|
||||
|
||||
// Iterate through all the locals. The last `numNonDivLocals` are the locals
|
||||
// that have been scanned already and do not have division representations.
|
||||
unsigned numNonDivLocals = 0;
|
||||
unsigned offset = copy.getIdKindOffset(IdKind::Local);
|
||||
for (unsigned i = 0, e = copy.getNumLocalIds(); i < e - numNonDivLocals;) {
|
||||
if (!reprs[i]) {
|
||||
// Whenever we come across a local that does not have a division
|
||||
// representation, we swap it to the `numNonDivLocals`-th last position
|
||||
// and increment `numNonDivLocal`s. `reprs` also needs to be swapped.
|
||||
copy.swapId(offset + i, offset + e - numNonDivLocals - 1);
|
||||
std::swap(reprs[i], reprs[e - numNonDivLocals - 1]);
|
||||
++numNonDivLocals;
|
||||
continue;
|
||||
}
|
||||
++i;
|
||||
}
|
||||
|
||||
// If there are no non-div locals, we're done.
|
||||
if (numNonDivLocals == 0)
|
||||
return PresburgerSet(*this);
|
||||
|
||||
// We computeSymbolicIntegerLexMin by considering the non-div locals as
|
||||
// "non-symbols" and considering everything else as "symbols". This will
|
||||
// compute a function mapping assignments to "symbols" to the
|
||||
// lexicographically minimal valid assignment of "non-symbols", when a
|
||||
// satisfying assignment exists. It separately returns the set of assignments
|
||||
// to the "symbols" such that a satisfying assignment to the "non-symbols"
|
||||
// exists but the lexmin is unbounded. We basically want to find the set of
|
||||
// values of the "symbols" such that an assignment to the "non-symbols"
|
||||
// exists, which is the union of the domain of the returned lexmin function
|
||||
// and the returned set of assignments to the "symbols" that makes the lexmin
|
||||
// unbounded.
|
||||
SymbolicLexMin lexminResult =
|
||||
SymbolicLexSimplex(copy, /*symbolOffset*/ 0,
|
||||
IntegerPolyhedron(PresburgerSpace::getSetSpace(
|
||||
/*numDims=*/copy.getNumIds() - numNonDivLocals)))
|
||||
.computeSymbolicIntegerLexMin();
|
||||
PresburgerSet result =
|
||||
lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain);
|
||||
|
||||
// The result set might lie in the wrong space -- all its ids are dims.
|
||||
// Set it to the desired space and return.
|
||||
PresburgerSpace space = getSpace();
|
||||
space.removeIdRange(IdKind::Local, 0, getNumLocalIds());
|
||||
result.setSpace(space);
|
||||
return result;
|
||||
}
|
||||
|
||||
SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
|
||||
// Compute the symbolic lexmin of the dims and locals, with the symbols being
|
||||
// the actual symbols of this set.
|
||||
|
@ -1120,6 +1194,13 @@ unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) {
|
|||
return relA.getNumLocalIds() - oldALocals;
|
||||
}
|
||||
|
||||
bool IntegerRelation::hasOnlyDivLocals() const {
|
||||
std::vector<MaybeLocalRepr> reprs;
|
||||
getLocalReprs(reprs);
|
||||
return llvm::all_of(reprs,
|
||||
[](const MaybeLocalRepr &repr) { return bool(repr); });
|
||||
}
|
||||
|
||||
void IntegerRelation::removeDuplicateDivs() {
|
||||
std::vector<SmallVector<int64_t, 8>> divs;
|
||||
SmallVector<unsigned, 4> denoms;
|
||||
|
|
|
@ -21,6 +21,13 @@ PresburgerRelation::PresburgerRelation(const IntegerRelation &disjunct)
|
|||
unionInPlace(disjunct);
|
||||
}
|
||||
|
||||
void PresburgerRelation::setSpace(const PresburgerSpace &oSpace) {
|
||||
assert(space.getNumLocalIds() == 0 && "no locals should be present");
|
||||
space = oSpace;
|
||||
for (IntegerRelation &disjunct : disjuncts)
|
||||
disjunct.setSpaceExceptLocals(space);
|
||||
}
|
||||
|
||||
unsigned PresburgerRelation::getNumDisjuncts() const {
|
||||
return disjuncts.size();
|
||||
}
|
||||
|
@ -770,6 +777,12 @@ PresburgerRelation PresburgerRelation::coalesce() const {
|
|||
return SetCoalescer(*this).coalesce();
|
||||
}
|
||||
|
||||
bool PresburgerRelation::hasOnlyDivLocals() const {
|
||||
return llvm::all_of(disjuncts, [](const IntegerRelation &rel) {
|
||||
return rel.hasOnlyDivLocals();
|
||||
});
|
||||
}
|
||||
|
||||
void PresburgerRelation::print(raw_ostream &os) const {
|
||||
os << "Number of Disjuncts: " << getNumDisjuncts() << "\n";
|
||||
for (const IntegerRelation &disjunct : disjuncts) {
|
||||
|
|
|
@ -751,6 +751,54 @@ TEST(SetTest, computeVolume) {
|
|||
/*resultBound=*/{});
|
||||
}
|
||||
|
||||
// The last `numToProject` dims will be projected out, i.e., converted to
|
||||
// locals.
|
||||
void testComputeReprAtPoints(IntegerPolyhedron poly,
|
||||
ArrayRef<SmallVector<int64_t, 4>> points,
|
||||
unsigned numToProject) {
|
||||
poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject,
|
||||
poly.getNumDimIds(), IdKind::Local);
|
||||
PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
|
||||
EXPECT_TRUE(repr.hasOnlyDivLocals());
|
||||
EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
|
||||
for (const SmallVector<int64_t, 4> &point : points) {
|
||||
EXPECT_EQ(poly.containsPointNoLocal(point).hasValue(),
|
||||
repr.containsPoint(point));
|
||||
}
|
||||
}
|
||||
|
||||
void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected,
|
||||
unsigned numToProject) {
|
||||
poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject,
|
||||
poly.getNumDimIds(), IdKind::Local);
|
||||
PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
|
||||
EXPECT_TRUE(repr.hasOnlyDivLocals());
|
||||
EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
|
||||
EXPECT_TRUE(repr.isEqual(expected));
|
||||
}
|
||||
|
||||
TEST(SetTest, computeReprWithOnlyDivLocals) {
|
||||
testComputeReprAtPoints(parsePoly("(x, y) : (x - 2*y == 0)"),
|
||||
{{1, 0}, {2, 1}, {3, 0}, {4, 2}, {5, 3}},
|
||||
/*numToProject=*/0);
|
||||
testComputeReprAtPoints(parsePoly("(x, e) : (x - 2*e == 0)"),
|
||||
{{1}, {2}, {3}, {4}, {5}}, /*numToProject=*/1);
|
||||
|
||||
// Tests to check that the space is preserved.
|
||||
testComputeReprAtPoints(parsePoly("(x, y)[z, w] : ()"), {},
|
||||
/*numToProject=*/1);
|
||||
testComputeReprAtPoints(parsePoly("(x, y)[z, w] : (z - (w floordiv 2) == 0)"),
|
||||
{},
|
||||
/*numToProject=*/1);
|
||||
|
||||
// Bezout's lemma: if a, b are constants,
|
||||
// the set of values that ax + by can take is all multiples of gcd(a, b).
|
||||
testComputeRepr(
|
||||
parsePoly("(x, e, f) : (x - 15*e - 21*f == 0)"),
|
||||
PresburgerSet(parsePoly({"(x) : (x - 3*(x floordiv 3) == 0)"})),
|
||||
/*numToProject=*/2);
|
||||
}
|
||||
|
||||
TEST(SetTest, subtractOutputSizeRegression) {
|
||||
PresburgerSet set1 =
|
||||
parsePresburgerSetFromPolyStrings(1, {"(i) : (i >= 0, 10 - i >= 0)"});
|
||||
|
|
Loading…
Reference in New Issue