From 8a7ead691bad29b86017d9e42fa63a57c8c0d629 Mon Sep 17 00:00:00 2001 From: Arjun P Date: Tue, 21 Jun 2022 06:30:11 +0200 Subject: [PATCH] [MLIR][Presburger] Support computing a representation of a set that only has locals that are divs This paves the way for integer-exact projection, and for supporting non-division locals in subtraction, complement, and equality checks. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D127463 --- .../Analysis/Presburger/IntegerRelation.h | 21 +++++ .../Analysis/Presburger/PresburgerRelation.h | 11 +++ .../mlir/Analysis/Presburger/Simplex.h | 29 ++++++- .../Analysis/Presburger/IntegerRelation.cpp | 81 +++++++++++++++++++ .../Presburger/PresburgerRelation.cpp | 13 +++ .../Analysis/Presburger/PresburgerSetTest.cpp | 48 +++++++++++ 6 files changed, 199 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index 4a866c17dd3b..935307d4bb88 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -26,6 +26,7 @@ namespace presburger { class IntegerRelation; class IntegerPolyhedron; +class PresburgerSet; /// An IntegerRelation represents the set of points from a PresburgerSpace that /// satisfy a list of affine constraints. Affine constraints can be inequalities @@ -93,6 +94,17 @@ public: /// Returns a reference to the underlying space. const PresburgerSpace &getSpace() const { return space; } + /// Set the space to `oSpace`, which should have the same number of ids as + /// the current space. + void setSpace(const PresburgerSpace &oSpace); + + /// Set the space to `oSpace`, which should not have any local ids. + /// `oSpace` can have fewer ids than the current space; in that case, the + /// the extra ids in `this` that are not accounted for by `oSpace` will be + /// considered as local ids. `oSpace` should not have more ids than the + /// current space; this will result in an assert failure. + void setSpaceExceptLocals(const PresburgerSpace &oSpace); + /// Returns a copy of the space without locals. PresburgerSpace getSpaceWithoutLocals() const { return PresburgerSpace::getRelationSpace(space.getNumDomainIds(), @@ -497,6 +509,9 @@ public: /// locals that have been added to `this`. unsigned mergeLocalIds(IntegerRelation &other); + /// Check whether all local ids have a division representation. + bool hasOnlyDivLocals() const; + /// Changes the partition between dimensions and symbols. Depending on the new /// symbol count, either a chunk of dimensional identifiers immediately before /// the split become symbols, or some of the symbols immediately after the @@ -739,6 +754,12 @@ public: /// first added identifier. unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; + /// Compute an equivalent representation of the same set, such that all local + /// ids have division representations. This representation may involve + /// local ids that correspond to divisions, and may also be a union of convex + /// disjuncts. + PresburgerSet computeReprWithOnlyDivLocals() const; + /// Compute the symbolic integer lexmin of the polyhedron. /// This finds, for every assignment to the symbols, the lexicographically /// minimum value attained by the dimensions. For example, the symbolic lexmin diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h index e4aa36599537..89a3deb30e68 100644 --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -55,6 +55,14 @@ public: const PresburgerSpace &getSpace() const { return space; } + /// Set the space to `oSpace`. `oSpace` should not contain any local ids. + /// `oSpace` need not have the same number of ids as the current space; + /// it could have more or less. If it has less, the extra ids become + /// locals of the disjuncts. It can also have more, in which case the + /// disjuncts will have fewer locals. If its total number of ids + /// exceeds that of some disjunct, an assert failure will occur. + void setSpace(const PresburgerSpace &oSpace); + /// Return a reference to the list of disjuncts. ArrayRef getAllDisjuncts() const; @@ -117,6 +125,9 @@ public: /// disjuncts in the union. PresburgerRelation coalesce() const; + /// Check whether all local ids in all disjuncts have a div representation. + bool hasOnlyDivLocals() const; + /// Print the set's internal state. void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h index 4caaf78bd471..f583c59ef24b 100644 --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -572,10 +572,28 @@ public: /// `constraints`, and no other ids. SymbolicLexSimplex(const IntegerPolyhedron &constraints, const IntegerPolyhedron &symbolDomain) - : LexSimplexBase(constraints), domainPoly(symbolDomain), - domainSimplex(symbolDomain) { - assert(domainPoly.getNumIds() == constraints.getNumSymbolIds()); - assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds()); + : SymbolicLexSimplex(constraints, + constraints.getIdKindOffset(IdKind::Symbol), + symbolDomain) { + assert(constraints.getNumSymbolIds() == symbolDomain.getNumIds()); + } + + /// An overload to select some other subrange of ids as symbols for lexmin. + /// The symbol ids are the range of ids with absolute index + /// [symbolOffset, symbolOffset + symbolDomain.getNumIds()) + /// symbolDomain should only have dim ids. + SymbolicLexSimplex(const IntegerPolyhedron &constraints, + unsigned symbolOffset, + const IntegerPolyhedron &symbolDomain) + : LexSimplexBase(/*nVar=*/constraints.getNumIds(), symbolOffset, + symbolDomain.getNumIds()), + domainPoly(symbolDomain), domainSimplex(symbolDomain) { + // TODO consider supporting this case. It amounts + // to just returning the input constraints. + assert(domainPoly.getNumIds() > 0 && + "there must be some non-symbols to optimize!"); + assert(domainPoly.getNumIds() == domainPoly.getNumDimIds()); + intersectIntegerRelation(constraints); } /// The lexmin will be stored as a function `lexmin` from symbols to @@ -583,6 +601,9 @@ public: /// /// For some values of the symbols, the lexmin may be unbounded. /// These parts of the symbol domain will be stored in `unboundedDomain`. + /// + /// The spaces of the sets in the result are compatible with the symbolDomain + /// passed in the SymbolicLexSimplex constructor. SymbolicLexMin computeSymbolicIntegerLexMin(); private: diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 25d89f93d93d..7376747663e6 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -38,6 +38,19 @@ std::unique_ptr IntegerPolyhedron::clone() const { return std::make_unique(*this); } +void IntegerRelation::setSpace(const PresburgerSpace &oSpace) { + assert(space.getNumIds() == oSpace.getNumIds() && "invalid space!"); + space = oSpace; +} + +void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) { + assert(oSpace.getNumLocalIds() == 0 && "no locals should be present!"); + assert(oSpace.getNumIds() <= getNumIds() && "invalid space!"); + unsigned newNumLocals = getNumIds() - oSpace.getNumIds(); + space = oSpace; + space.insertId(IdKind::Local, 0, newNumLocals); +} + void IntegerRelation::append(const IntegerRelation &other) { assert(space.isEqual(other.getSpace()) && "Spaces must be equal."); @@ -152,6 +165,67 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) { removeEqualityRange(counts.getNumEqs(), getNumEqualities()); } +PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const { + // If there are no locals, we're done. + if (getNumLocalIds() == 0) + return PresburgerSet(*this); + + // Move all the non-div locals to the end, as the current API to + // SymbolicLexMin requires these to form a contiguous range. + // + // Take a copy so we can perform mutations. + IntegerPolyhedron copy = *this; + std::vector reprs; + copy.getLocalReprs(reprs); + + // Iterate through all the locals. The last `numNonDivLocals` are the locals + // that have been scanned already and do not have division representations. + unsigned numNonDivLocals = 0; + unsigned offset = copy.getIdKindOffset(IdKind::Local); + for (unsigned i = 0, e = copy.getNumLocalIds(); i < e - numNonDivLocals;) { + if (!reprs[i]) { + // Whenever we come across a local that does not have a division + // representation, we swap it to the `numNonDivLocals`-th last position + // and increment `numNonDivLocal`s. `reprs` also needs to be swapped. + copy.swapId(offset + i, offset + e - numNonDivLocals - 1); + std::swap(reprs[i], reprs[e - numNonDivLocals - 1]); + ++numNonDivLocals; + continue; + } + ++i; + } + + // If there are no non-div locals, we're done. + if (numNonDivLocals == 0) + return PresburgerSet(*this); + + // We computeSymbolicIntegerLexMin by considering the non-div locals as + // "non-symbols" and considering everything else as "symbols". This will + // compute a function mapping assignments to "symbols" to the + // lexicographically minimal valid assignment of "non-symbols", when a + // satisfying assignment exists. It separately returns the set of assignments + // to the "symbols" such that a satisfying assignment to the "non-symbols" + // exists but the lexmin is unbounded. We basically want to find the set of + // values of the "symbols" such that an assignment to the "non-symbols" + // exists, which is the union of the domain of the returned lexmin function + // and the returned set of assignments to the "symbols" that makes the lexmin + // unbounded. + SymbolicLexMin lexminResult = + SymbolicLexSimplex(copy, /*symbolOffset*/ 0, + IntegerPolyhedron(PresburgerSpace::getSetSpace( + /*numDims=*/copy.getNumIds() - numNonDivLocals))) + .computeSymbolicIntegerLexMin(); + PresburgerSet result = + lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain); + + // The result set might lie in the wrong space -- all its ids are dims. + // Set it to the desired space and return. + PresburgerSpace space = getSpace(); + space.removeIdRange(IdKind::Local, 0, getNumLocalIds()); + result.setSpace(space); + return result; +} + SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const { // Compute the symbolic lexmin of the dims and locals, with the symbols being // the actual symbols of this set. @@ -1120,6 +1194,13 @@ unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) { return relA.getNumLocalIds() - oldALocals; } +bool IntegerRelation::hasOnlyDivLocals() const { + std::vector reprs; + getLocalReprs(reprs); + return llvm::all_of(reprs, + [](const MaybeLocalRepr &repr) { return bool(repr); }); +} + void IntegerRelation::removeDuplicateDivs() { std::vector> divs; SmallVector denoms; diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp index 9ce59d769d43..1b5d48b9cf89 100644 --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -21,6 +21,13 @@ PresburgerRelation::PresburgerRelation(const IntegerRelation &disjunct) unionInPlace(disjunct); } +void PresburgerRelation::setSpace(const PresburgerSpace &oSpace) { + assert(space.getNumLocalIds() == 0 && "no locals should be present"); + space = oSpace; + for (IntegerRelation &disjunct : disjuncts) + disjunct.setSpaceExceptLocals(space); +} + unsigned PresburgerRelation::getNumDisjuncts() const { return disjuncts.size(); } @@ -770,6 +777,12 @@ PresburgerRelation PresburgerRelation::coalesce() const { return SetCoalescer(*this).coalesce(); } +bool PresburgerRelation::hasOnlyDivLocals() const { + return llvm::all_of(disjuncts, [](const IntegerRelation &rel) { + return rel.hasOnlyDivLocals(); + }); +} + void PresburgerRelation::print(raw_ostream &os) const { os << "Number of Disjuncts: " << getNumDisjuncts() << "\n"; for (const IntegerRelation &disjunct : disjuncts) { diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp index ba3a0024f973..0c98f488ea07 100644 --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -751,6 +751,54 @@ TEST(SetTest, computeVolume) { /*resultBound=*/{}); } +// The last `numToProject` dims will be projected out, i.e., converted to +// locals. +void testComputeReprAtPoints(IntegerPolyhedron poly, + ArrayRef> points, + unsigned numToProject) { + poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject, + poly.getNumDimIds(), IdKind::Local); + PresburgerSet repr = poly.computeReprWithOnlyDivLocals(); + EXPECT_TRUE(repr.hasOnlyDivLocals()); + EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace())); + for (const SmallVector &point : points) { + EXPECT_EQ(poly.containsPointNoLocal(point).hasValue(), + repr.containsPoint(point)); + } +} + +void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected, + unsigned numToProject) { + poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject, + poly.getNumDimIds(), IdKind::Local); + PresburgerSet repr = poly.computeReprWithOnlyDivLocals(); + EXPECT_TRUE(repr.hasOnlyDivLocals()); + EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace())); + EXPECT_TRUE(repr.isEqual(expected)); +} + +TEST(SetTest, computeReprWithOnlyDivLocals) { + testComputeReprAtPoints(parsePoly("(x, y) : (x - 2*y == 0)"), + {{1, 0}, {2, 1}, {3, 0}, {4, 2}, {5, 3}}, + /*numToProject=*/0); + testComputeReprAtPoints(parsePoly("(x, e) : (x - 2*e == 0)"), + {{1}, {2}, {3}, {4}, {5}}, /*numToProject=*/1); + + // Tests to check that the space is preserved. + testComputeReprAtPoints(parsePoly("(x, y)[z, w] : ()"), {}, + /*numToProject=*/1); + testComputeReprAtPoints(parsePoly("(x, y)[z, w] : (z - (w floordiv 2) == 0)"), + {}, + /*numToProject=*/1); + + // Bezout's lemma: if a, b are constants, + // the set of values that ax + by can take is all multiples of gcd(a, b). + testComputeRepr( + parsePoly("(x, e, f) : (x - 15*e - 21*f == 0)"), + PresburgerSet(parsePoly({"(x) : (x - 3*(x floordiv 3) == 0)"})), + /*numToProject=*/2); +} + TEST(SetTest, subtractOutputSizeRegression) { PresburgerSet set1 = parsePresburgerSetFromPolyStrings(1, {"(i) : (i >= 0, 10 - i >= 0)"});