From da92f92621e28a56fe8ad79d82eb60e436bf1d39 Mon Sep 17 00:00:00 2001 From: Arjun P Date: Mon, 4 Apr 2022 23:31:28 +0100 Subject: [PATCH] [MLIR][Presburger] IntegerPolyhedron: add support for symbolic integer lexmin Add support for computing the symbolic integer lexmin of a polyhedron. This finds, for every assignment to the symbols, the lexicographically minimum value attained by the dimensions. For example, the symbolic lexmin of the set `(x, y)[a, b, c] : (a <= x, b <= x, x <= c)` can be written as ``` x = a if b <= a, a <= c x = b if a < b, b <= c ``` This also finds the set of assignments to the symbols that make the lexmin unbounded. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D122985 --- .../Analysis/Presburger/IntegerRelation.h | 23 + .../include/mlir/Analysis/Presburger/Matrix.h | 3 + .../mlir/Analysis/Presburger/PWMAFunction.h | 10 + .../mlir/Analysis/Presburger/Simplex.h | 294 +++++++--- .../Analysis/Presburger/IntegerRelation.cpp | 16 + mlir/lib/Analysis/Presburger/Matrix.cpp | 8 + mlir/lib/Analysis/Presburger/PWMAFunction.cpp | 12 + mlir/lib/Analysis/Presburger/Simplex.cpp | 508 ++++++++++++++++-- .../Presburger/IntegerPolyhedronTest.cpp | 224 ++++++++ 9 files changed, 970 insertions(+), 128 deletions(-) diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index 6add6e8aa9b2..709e4f843835 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -536,6 +536,7 @@ protected: Matrix inequalities; }; +struct SymbolicLexMin; /// An IntegerPolyhedron is a PresburgerSpace subject to affine /// constraints. Affine constraints can be inequalities or equalities in the /// form: @@ -593,6 +594,28 @@ public: /// column position (i.e., not relative to the kind of identifier) of the /// first added identifier. unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; + + /// Compute the symbolic integer lexmin of the polyhedron. + /// This finds, for every assignment to the symbols, the lexicographically + /// minimum value attained by the dimensions. For example, the symbolic lexmin + /// of the set + /// + /// (x, y)[a, b, c] : (a <= x, b <= x, x <= c) + /// + /// can be written as + /// + /// x = a if b <= a, a <= c + /// x = b if a < b, b <= c + /// + /// This function is stored in the `lexmin` function in the result. + /// Some assignments to the symbols might make the set empty. + /// Such points are not part of the function's domain. + /// In the above example, this happens when max(a, b) > c. + /// + /// For some values of the symbols, the lexmin may be unbounded. + /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate + /// `PresburgerSet`, `unboundedDomain`. + SymbolicLexMin findSymbolicIntegerLexMin() const; }; } // namespace presburger diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h index 940b88d8148f..e2ad543070a4 100644 --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -151,6 +151,9 @@ public: /// Add an extra row at the bottom of the matrix and return its position. unsigned appendExtraRow(); + /// Same as above, but copy the given elements into the row. The length of + /// `elems` must be equal to the number of columns. + unsigned appendExtraRow(ArrayRef elems); /// Print the matrix. void print(raw_ostream &os) const; diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h index ce0d77da9bc2..f4bffe5b4e7a 100644 --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -106,6 +106,11 @@ public: /// outside the domain, an empty optional is returned. Optional> valueAt(ArrayRef point) const; + /// Truncate the output dimensions to the first `count` dimensions. + /// + /// TODO: refactor so that this can be accomplished through removeIdRange. + void truncateOutput(unsigned count); + void print(raw_ostream &os) const; void dump() const; @@ -165,6 +170,11 @@ public: /// value at every point in the domain. bool isEqual(const PWMAFunction &other) const; + /// Truncate the output dimensions to the first `count` dimensions. + /// + /// TODO: refactor so that this can be accomplished through removeIdRange. + void truncateOutput(unsigned count); + void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h index 66d408dbf8b6..67a4b5f68e20 100644 --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -18,6 +18,7 @@ #include "mlir/Analysis/Presburger/Fraction.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/Matrix.h" +#include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" @@ -41,8 +42,9 @@ class GBRSimplex; /// these constraints that are redundant, i.e. a subset of constraints that /// doesn't constrain the affine set further after adding the non-redundant /// constraints. The LexSimplex class provides support for computing the -/// lexicographical minimum of an IntegerRelation. Both these classes can be -/// constructed from an IntegerRelation, and both inherit common +/// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class +/// provides support for computing symbolic lexicographic minimums. All of these +/// classes can be constructed from an IntegerRelation, and all inherit common /// functionality from SimplexBase. /// /// The implementations of the Simplex and SimplexBase classes, other than the @@ -72,19 +74,22 @@ class GBRSimplex; /// respectively. As described above, the first column is the common /// denominator. The second column represents the constant term, explained in /// more detail below. These two are _fixed columns_; they always retain their -/// position as the first and second columns. Additionally, LexSimplex stores -/// a so-call big M parameter (explained below) in the third column, so -/// LexSimplex has three fixed columns. +/// position as the first and second columns. Additionally, LexSimplexBase +/// stores a so-call big M parameter (explained below) in the third column, so +/// LexSimplexBase has three fixed columns. Finally, SymbolicLexSimplex has +/// `nSymbol` variables designated as symbols. These occupy the next `nSymbol` +/// columns, viz. the columns [3, 3 + nSymbol). For more information on symbols, +/// see LexSimplexBase and SymbolicLexSimplex. /// -/// LexSimplex does not directly support variables which can be negative, so we -/// introduce the so-called big M parameter, an artificial variable that is +/// LexSimplexBase does not directly support variables which can be negative, so +/// we introduce the so-called big M parameter, an artificial variable that is /// considered to have an arbitrarily large value. We then transform the /// variables, say x, y, z, ... to M, M + x, M + y, M + z. Since M has been /// added to these variables, they are now known to have non-negative values. -/// For more details, see the documentation for LexSimplex. The big M parameter -/// is not considered a real unknown and is not stored in the `var` data -/// structure; rather the tableau just has an extra fixed column for it just -/// like the constant term. +/// For more details, see the documentation for LexSimplexBase. The big M +/// parameter is not considered a real unknown and is not stored in the `var` +/// data structure; rather the tableau just has an extra fixed column for it +/// just like the constant term. /// /// The vectors var and con store information about the variables and /// constraints respectively, namely, whether they are in row or column @@ -146,8 +151,8 @@ class GBRSimplex; /// operation from the end until we reach the snapshot's location. SimplexBase /// also supports taking a snapshot including the exact set of basis unknowns; /// if this functionality is used, then on rolling back the exact basis will -/// also be restored. This is used by LexSimplex because its algorithm, unlike -/// Simplex, is sensitive to the exact basis used at a point. +/// also be restored. This is used by LexSimplexBase because the lex algorithm, +/// unlike `Simplex`, is sensitive to the exact basis used at a point. class SimplexBase { public: SimplexBase() = delete; @@ -211,7 +216,8 @@ protected: /// constant term, whereas LexSimplex has an extra fixed column for the /// so-called big M parameter. For more information see the documentation for /// LexSimplex. - SimplexBase(unsigned nVar, bool mustUseBigM); + SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, + unsigned nSymbol); enum class Orientation { Row, Column }; @@ -223,11 +229,14 @@ protected: /// always be non-negative and if it cannot be made non-negative without /// violating other constraints, the tableau is empty. struct Unknown { - Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos) - : pos(oPos), orientation(oOrientation), restricted(oRestricted) {} + Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos, + bool oIsSymbol = false) + : pos(oPos), orientation(oOrientation), restricted(oRestricted), + isSymbol(oIsSymbol) {} unsigned pos; Orientation orientation; bool restricted : 1; + bool isSymbol : 1; void print(raw_ostream &os) const { os << (orientation == Orientation::Row ? "r" : "c"); @@ -326,6 +335,10 @@ protected: /// nRedundant rows. unsigned nRedundant; + /// The number of parameters. This must be consistent with the number of + /// Unknowns in `var` below that have `isSymbol` set to true. + unsigned nSymbol; + /// The matrix representing the tableau. Matrix tableau; @@ -363,62 +376,45 @@ protected: /// introduce an artifical variable M that is considered to have a value of /// +infinity and instead of the variables x, y, z, we internally use variables /// M + x, M + y, M + z, which are now guaranteed to be non-negative. See the -/// documentation for Simplex for more details. The whole algorithm is performed -/// without having to fix a "big enough" value of the big M parameter; it is -/// just considered to be infinite throughout and it never appears in the final -/// outputs. We will deal with sample values throughout that may in general be -/// some linear expression involving M like pM + q or aM + b. We can compare -/// these with each other. They have a total order: -/// aM + b < pM + q iff a < p or (a == p and b < q). +/// documentation for SimplexBase for more details. M is also considered to be +/// an integer that is divisible by everything. +/// +/// The whole algorithm is performed with M treated as a symbol; +/// it is just considered to be infinite throughout and it never appears in the +/// final outputs. We will deal with sample values throughout that may in +/// general be some affine expression involving M, like pM + q or aM + b. We can +/// compare these with each other. They have a total order: +/// +/// aM + b < pM + q iff a < p or (a == p and b < q). /// In particular, aM + b < 0 iff a < 0 or (a == 0 and b < 0). /// +/// When performing symbolic optimization, sample values will be affine +/// expressions in M and the symbols. For example, we could have sample values +/// aM + bS + c and pM + qS + r, where S is a symbol. Now we have +/// aM + bS + c < pM + qS + r iff (a < p) or (a == p and bS + c < qS + r). +/// bS + c < qS + r can be always true, always false, or neither, +/// depending on the set of values S can take. The symbols are always stored +/// in columns [3, 3 + nSymbols). For more details, see the +/// documentation for SymbolicLexSimplex. +/// /// Initially all the constraints to be added are added as rows, with no attempt /// to keep the tableau consistent. Pivots are only performed when some query /// is made, such as a call to getRationalLexMin. Care is taken to always /// maintain a lexicopositive basis transform, explained below. /// -/// Let the variables be x = (x_1, ... x_n). Let the basis unknowns at a -/// particular point be y = (y_1, ... y_n). We know that x = A*y + b for some -/// n x n matrix A and n x 1 column vector b. We want every column in A to be -/// lexicopositive, i.e., have at least one non-zero element, with the first -/// such element being positive. This property is preserved throughout the -/// operation of LexSimplex. Note that on construction, the basis transform A is -/// the indentity matrix and so every column is lexicopositive. Note that for -/// LexSimplex, for the tableau to be consistent we must have non-negative -/// sample values not only for the constraints but also for the variables. -/// So if the tableau is consistent then x >= 0 and y >= 0, by which we mean -/// every element in these vectors is non-negative. (note that this is a -/// different concept from lexicopositivity!) -/// -/// When we arrive at a basis such the basis transform is lexicopositive and the -/// tableau is consistent, the sample point is the lexiographically minimum -/// point in the polytope. We will show that A*y is zero or lexicopositive when -/// y >= 0. Adding a lexicopositive vector to b will make it lexicographically -/// bigger, so A*y + b is lexicographically bigger than b for any y >= 0 except -/// y = 0. This shows that no point lexicographically smaller than x = b can be -/// obtained. Since we already know that x = b is valid point in the space, this -/// shows that x = b is the lexicographic minimum. -/// -/// Proof that A*y is lexicopositive or zero when y > 0. Recall that every -/// column of A is lexicopositive. Begin by considering A_1, the first row of A. -/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next -/// row. If we run out of rows, A*y is zero and we are done; otherwise, we -/// encounter some row A_i that has a non-zero element. Every column is -/// lexicopositive and so has some positive element before any negative elements -/// occur, so the element in this row for any column, if non-zero, must be -/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are -/// non-negative, so if this is non-zero then it must be positive. Then the -/// first non-zero element of A*y is positive so A*y is lexicopositive. -/// -/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero -/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y -/// and we can completely ignore these columns of A. We now continue downwards, -/// looking for rows of A that have a non-zero element other than in the ignored -/// columns. If we find one, say A_k, once again these elements must be positive -/// since they are the first non-zero element in each of these columns, so if -/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we -/// ignore more columns; eventually if all these dot products become zero then -/// A*y is zero and we are done. +/// Let the variables be x = (x_1, ... x_n). +/// Let the symbols be s = (s_1, ... s_m). Let the basis unknowns at a +/// particular point be y = (y_1, ... y_n). We know that x = A*y + T*s + b for +/// some n x n matrix A, n x m matrix s, and n x 1 column vector b. We want +/// every column in A to be lexicopositive, i.e., have at least one non-zero +/// element, with the first such element being positive. This property is +/// preserved throughout the operation of LexSimplexBase. Note that on +/// construction, the basis transform A is the identity matrix and so every +/// column is lexicopositive. Note that for LexSimplexBase, for the tableau to +/// be consistent we must have non-negative sample values not only for the +/// constraints but also for the variables. So if the tableau is consistent then +/// x >= 0 and y >= 0, by which we mean every element in these vectors is +/// non-negative. (note that this is a different concept from lexicopositivity!) class LexSimplexBase : public SimplexBase { public: ~LexSimplexBase() override = default; @@ -435,25 +431,37 @@ public: unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); } protected: - LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {} + LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol) + : SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {} explicit LexSimplexBase(const IntegerRelation &constraints) - : LexSimplexBase(constraints.getNumIds()) { + : LexSimplexBase(constraints.getNumIds(), + constraints.getIdKindOffset(IdKind::Symbol), + constraints.getNumSymbolIds()) { intersectIntegerRelation(constraints); } + /// Add new symbolic variables to the end of the list of variables. + void appendSymbol(); + /// Try to move the specified row to column orientation while preserving the - /// lexicopositivity of the basis transform. If this is not possible, return - /// failure. This only occurs when the constraints have no solution; the - /// tableau will be marked empty in such a case. + /// lexicopositivity of the basis transform. The row must have a negative + /// sample value. If this is not possible, return failure. This only occurs + /// when the constraints have no solution; the tableau will be marked empty in + /// such a case. LogicalResult moveRowUnknownToColumn(unsigned row); - /// Given a row that has a non-integer sample value, add an inequality such - /// that this fractional sample value is cut away from the polytope. The added - /// inequality will be such that no integer points are removed. + /// Given a row that has a non-integer sample value, add an inequality to cut + /// away this fractional sample value from the polytope without removing any + /// integer points. The integer lexmin, if one existed, remains the same on + /// return. /// - /// Returns whether the cut constraint could be enforced, i.e. failure if the - /// cut made the polytope empty, and success if it didn't. Failure status - /// indicates that the polytope didn't have any integer points. + /// This assumes that the symbolic part of the sample is integral, + /// i.e., if the symbolic sample is (c + aM + b_1*s_1 + ... b_n*s_n)/d, + /// where s_1, ... s_n are symbols, this assumes that + /// (b_1*s_1 + ... + b_n*s_n)/s is integral. + /// + /// Return failure if the tableau became empty, and success if it didn't. + /// Failure status indicates that the polytope was integer empty. LogicalResult addCut(unsigned row); /// Undo the addition of the last constraint. This is only called while @@ -461,14 +469,19 @@ protected: void undoLastConstraint() final; /// Given two potential pivot columns for a row, return the one that results - /// in the lexicographically smallest sample vector. + /// in the lexicographically smallest sample vector. The row's sample value + /// must be negative. If symbols are involved, the sample value must be + /// negative for all possible assignments to the symbols. unsigned getLexMinPivotColumn(unsigned row, unsigned colA, unsigned colB) const; }; +/// A class for lexicographic optimization without any symbols. This also +/// provides support for integer-exact redundancy and separateness checks. class LexSimplex : public LexSimplexBase { public: - explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {} + explicit LexSimplex(unsigned nVar) + : LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {} explicit LexSimplex(const IntegerRelation &constraints) : LexSimplexBase(constraints) { assert(constraints.getNumSymbolIds() == 0 && @@ -502,7 +515,7 @@ private: MaybeOptimum> getRationalSample() const; /// Make the tableau configuration consistent. - void restoreRationalConsistency(); + LogicalResult restoreRationalConsistency(); /// Return whether the specified row is violated; bool rowIsViolated(unsigned row) const; @@ -514,11 +527,122 @@ private: /// Get a row corresponding to a var that has a non-integral sample value, if /// one exists. Otherwise, return an empty optional. Optional maybeGetNonIntegralVarRow() const; +}; - /// Given two potential pivot columns for a row, return the one that results - /// in the lexicographically smallest sample vector. - unsigned getLexMinPivotColumn(unsigned row, unsigned colA, - unsigned colB) const; +/// Represents the result of a symbolic lexicographic minimization computation. +struct SymbolicLexMin { + SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols) + : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols), + unboundedDomain( + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {} + + /// This maps assignments of symbols to the corresponding lexmin. + /// Takes no value when no integer sample exists for the assignment or if the + /// lexmin is unbounded. + PWMAFunction lexmin; + /// Contains all assignments to the symbols that made the lexmin unbounded. + /// Note that the symbols of the input set to the symbolic lexmin are dims + /// of this PrebsurgerSet. + PresburgerSet unboundedDomain; +}; + +/// A class to perform symbolic lexicographic optimization, +/// i.e., to find, for every assignment to the symbols the specified +/// `symbolDomain`, the lexicographically minimum value integer value attained +/// by the non-symbol variables. +/// +/// The input is a set parametrized by some symbols, i.e., the constant terms +/// of the constraints in the set are affine expressions in the symbols, and +/// every assignment to the symbols defines a non-symbolic set. +/// +/// Accordingly, the sample values of the rows in our tableau will be affine +/// expressions in the symbols, and every assignment to the symbols will define +/// a non-symbolic LexSimplex. We then run the algorithm of +/// LexSimplex::findIntegerLexMin simultaneously for every value of the symbols +/// in the domain. +/// +/// Often, the pivot to be performed is the same for all values of the symbols, +/// in which case we just do it. For example, if the symbolic sample of a row is +/// negative for all values in the symbol domain, the row needs to be pivoted +/// irrespective of the precise value of the symbols. To answer queries like +/// "Is this symbolic sample always negative in the symbol domain?", we maintain +/// a `LexSimplex domainSimplex` correponding to the symbol domain. +/// +/// In other cases, it may be that the symbolic sample is violated at some +/// values in the symbol domain and not violated at others. In this case, +/// the pivot to be performed does depend on the value of the symbols. We +/// handle this by splitting the symbol domain. We run the algorithm for the +/// case where the row isn't violated, and then come back and run the case +/// where it is. +class SymbolicLexSimplex : public LexSimplexBase { +public: + /// `constraints` is the set for which the symbolic lexmin will be computed. + /// `symbolDomain` is the set of values of the symbols for which the lexmin + /// will be computed. `symbolDomain` should have a dim id for every symbol in + /// `constraints`, and no other ids. + SymbolicLexSimplex(const IntegerPolyhedron &constraints, + const IntegerPolyhedron &symbolDomain) + : LexSimplexBase(constraints), domainPoly(symbolDomain), + domainSimplex(symbolDomain) { + assert(domainPoly.getNumIds() == constraints.getNumSymbolIds()); + assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds()); + } + + /// The lexmin will be stored as a function `lexmin` from symbols to + /// non-symbols in the result. + /// + /// For some values of the symbols, the lexmin may be unbounded. + /// These parts of the symbol domain will be stored in `unboundedDomain`. + SymbolicLexMin computeSymbolicIntegerLexMin(); + +private: + /// Perform all pivots that do not require branching. + /// + /// Return failure if the tableau became empty, indicating that the polytope + /// is always integer empty in the current symbol domain. + /// Return success otherwise. + LogicalResult doNonBranchingPivots(); + + /// Get a row that is always violated in the current domain, if one exists. + Optional maybeGetAlwaysViolatedRow(); + + /// Get a row corresponding to a variable with non-integral sample value, if + /// one exists. + Optional maybeGetNonIntegralVarRow(); + + /// Given a row that has a non-integer sample value, cut away this fractional + /// sample value witahout removing any integer points, i.e., the integer + /// lexmin, if it exists, remains the same after a call to this function. This + /// may add constraints or local variables to the tableau, as well as to the + /// domain. + /// + /// Returns whether the cut constraint could be enforced, i.e. failure if the + /// cut made the polytope empty, and success if it didn't. Failure status + /// indicates that the polytope is always integer empty in the symbol domain + /// at the time of the call. (This function may modify the symbol domain, but + /// failure statu indicates that the polytope was empty for all symbol values + /// in the initial domain.) + LogicalResult addSymbolicCut(unsigned row); + + /// Get the numerator of the symbolic sample of the specific row. + /// This is an affine expression in the symbols with integer coefficients. + /// The last element is the constant term. This ignores the big M coefficient. + SmallVector getSymbolicSampleNumerator(unsigned row) const; + + /// Return whether all the coefficients of the symbolic sample are integers. + /// + /// This does not consult the domain to check if the specified expression + /// is always integral despite coefficients being fractional. + bool isSymbolicSampleIntegral(unsigned row) const; + + /// Record a lexmin. The tableau must be consistent with all variables + /// having symbolic samples with integer coefficients. + void recordOutput(SymbolicLexMin &result) const; + + /// The symbol domain. + IntegerPolyhedron domainPoly; + /// Simplex corresponding to the symbol domain. + LexSimplex domainSimplex; }; /// The Simplex class uses the Normal pivot rule and supports integer emptiness @@ -540,7 +664,9 @@ public: enum class Direction { Up, Down }; Simplex() = delete; - explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {} + explicit Simplex(unsigned nVar) + : SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0, + /*nSymbol=*/0) {} explicit Simplex(const IntegerRelation &constraints) : Simplex(constraints.getNumIds()) { intersectIntegerRelation(constraints); diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index bfa9a6539077..5e527b5467f5 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -14,6 +14,7 @@ #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/LinearTransform.h" +#include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Utils.h" @@ -145,6 +146,21 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) { removeEqualityRange(counts.getNumEqs(), getNumEqualities()); } +SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const { + // Compute the symbolic lexmin of the dims and locals, with the symbols being + // the actual symbols of this set. + SymbolicLexMin result = + SymbolicLexSimplex( + *this, PresburgerSpace::getSetSpace(/*numDims=*/getNumSymbolIds())) + .computeSymbolicIntegerLexMin(); + + // We want to return only the lexmin over the dims, so strip the locals from + // the computed lexmin. + result.lexmin.truncateOutput(result.lexmin.getNumOutputs() - + getNumLocalIds()); + return result; +} + unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp index 219d490e7368..680e4509b7cc 100644 --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -66,6 +66,14 @@ unsigned Matrix::appendExtraRow() { return nRows - 1; } +unsigned Matrix::appendExtraRow(ArrayRef elems) { + assert(elems.size() == nColumns && "elems must match row length!"); + unsigned row = appendExtraRow(); + for (unsigned col = 0; col < nColumns; ++col) + at(row, col) = elems[col]; + return row; +} + void Matrix::resizeHorizontally(unsigned newNColumns) { if (newNColumns < nColumns) removeColumns(newNColumns, nColumns - newNColumns); diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp index b995bc00a19c..711e99aab35b 100644 --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -114,6 +114,18 @@ void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA, IntegerPolyhedron::eliminateRedundantLocalId(posA, posB); } +void MultiAffineFunction::truncateOutput(unsigned count) { + assert(count <= output.getNumRows()); + output.resizeVertically(count); +} + +void PWMAFunction::truncateOutput(unsigned count) { + assert(count <= numOutputs); + for (MultiAffineFunction &piece : pieces) + piece.truncateOutput(count); + numOutputs = count; +} + bool MultiAffineFunction::isEqualWhereDomainsOverlap( MultiAffineFunction other) const { if (!isSpaceCompatible(other)) diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index 57e8f485742d..f3bf42f40b17 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -18,15 +18,24 @@ using Direction = Simplex::Direction; const int nullIndex = std::numeric_limits::max(); -SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM) +SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, + unsigned nSymbol) : usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar), - nRedundant(0), tableau(0, nCol), empty(false) { + nRedundant(0), nSymbol(nSymbol), tableau(0, nCol), empty(false) { + assert(symbolOffset + nSymbol <= nVar); + colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex); for (unsigned i = 0; i < nVar; ++i) { var.emplace_back(Orientation::Column, /*restricted=*/false, /*pos=*/getNumFixedCols() + i); colUnknown.push_back(i); } + + // Move the symbols to be in columns [3, 3 + nSymbol). + for (unsigned i = 0; i < nSymbol; ++i) { + var[symbolOffset + i].isSymbol = true; + swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i); + } } const Simplex::Unknown &SimplexBase::unknownFromIndex(int index) const { @@ -96,9 +105,13 @@ unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { // where M is the big M parameter. As such, when the user tries to add // a row ax + by + cz + d, we express it in terms of our internal variables // as -(a + b + c)M + a(M + x) + b(M + y) + c(M + z) + d. + // + // Symbols don't use the big M parameter since they do not get lex + // optimized. int64_t bigMCoeff = 0; for (unsigned i = 0; i < coeffs.size() - 1; ++i) - bigMCoeff -= coeffs[i]; + if (!var[i].isSymbol) + bigMCoeff -= coeffs[i]; // The coefficient to the big M parameter is stored in column 2. tableau(nRow - 1, 2) = bigMCoeff; } @@ -164,19 +177,97 @@ Direction flippedDirection(Direction direction) { } } // namespace +/// We simply make the tableau consistent while maintaining a lexicopositive +/// basis transform, and then return the sample value. If the tableau becomes +/// empty, we return empty. +/// +/// Let the variables be x = (x_1, ... x_n). +/// Let the basis unknowns be y = (y_1, ... y_n). +/// We have that x = A*y + b for some n x n matrix A and n x 1 column vector b. +/// +/// As we will show below, A*y is either zero or lexicopositive. +/// Adding a lexicopositive vector to b will make it lexicographically +/// greater, so A*y + b is always equal to or lexicographically greater than b. +/// Thus, since we can attain x = b, that is the lexicographic minimum. +/// +/// We have that that every column in A is lexicopositive, i.e., has at least +/// one non-zero element, with the first such element being positive. Since for +/// the tableau to be consistent we must have non-negative sample values not +/// only for the constraints but also for the variables, we also have x >= 0 and +/// y >= 0, by which we mean every element in these vectors is non-negative. +/// +/// Proof that if every column in A is lexicopositive, and y >= 0, then +/// A*y is zero or lexicopositive. Begin by considering A_1, the first row of A. +/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next +/// row. If we run out of rows, A*y is zero and we are done; otherwise, we +/// encounter some row A_i that has a non-zero element. Every column is +/// lexicopositive and so has some positive element before any negative elements +/// occur, so the element in this row for any column, if non-zero, must be +/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are +/// non-negative, so if this is non-zero then it must be positive. Then the +/// first non-zero element of A*y is positive so A*y is lexicopositive. +/// +/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero +/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y +/// and we can completely ignore these columns of A. We now continue downwards, +/// looking for rows of A that have a non-zero element other than in the ignored +/// columns. If we find one, say A_k, once again these elements must be positive +/// since they are the first non-zero element in each of these columns, so if +/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we +/// add these to the set of ignored columns and continue to the next row. If we +/// run out of rows, then A*y is zero and we are done. MaybeOptimum> LexSimplex::findRationalLexMin() { - restoreRationalConsistency(); + if (restoreRationalConsistency().failed()) + return OptimumKind::Empty; return getRationalSample(); } +/// Given a row that has a non-integer sample value, add an inequality such +/// that this fractional sample value is cut away from the polytope. The added +/// inequality will be such that no integer points are removed. i.e., the +/// integer lexmin, if it exists, is the same with and without this constraint. +/// +/// Let the row be +/// (c + coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n)/d, +/// where s_1, ... s_m are the symbols and +/// y_1, ... y_n are the other basis unknowns. +/// +/// For this to be an integer, we want +/// coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n = -c (mod d) +/// Note that this constraint must always hold, independent of the basis, +/// becuse the row unknown's value always equals this expression, even if *we* +/// later compute the sample value from a different expression based on a +/// different basis. +/// +/// Let us assume that M has a factor of d in it. Imposing this constraint on M +/// does not in any way hinder us from finding a value of M that is big enough. +/// Moreover, this function is only called when the symbolic part of the sample, +/// a_1*s_1 + ... + a_m*s_m, is known to be an integer. +/// +/// Also, we can safely reduce the coefficients modulo d, so we have: +/// +/// (b_1%d)y_1 + ... + (b_n%d)y_n = (-c%d) + k*d for some integer `k` +/// +/// Note that all coefficient modulos here are non-negative. Also, all the +/// unknowns are non-negative here as both constraints and variables are +/// non-negative in LexSimplexBase. (We used the big M trick to make the +/// variables non-negative). Therefore, the LHS here is non-negative. +/// Since 0 <= (-c%d) < d, k is the quotient of dividing the LHS by d and +/// is therefore non-negative as well. +/// +/// So we have +/// ((b_1%d)y_1 + ... + (b_n%d)y_n - (-c%d))/d >= 0. +/// +/// The constraint is violated when added (it would be useless otherwise) +/// so we immediately try to move it to a column. LogicalResult LexSimplexBase::addCut(unsigned row) { - int64_t denom = tableau(row, 0); + int64_t d = tableau(row, 0); addZeroRow(/*makeRestricted=*/true); - tableau(nRow - 1, 0) = denom; - tableau(nRow - 1, 1) = -mod(-tableau(row, 1), denom); - tableau(nRow - 1, 2) = 0; // M has all factors in it. - for (unsigned col = 3; col < nCol; ++col) - tableau(nRow - 1, col) = mod(tableau(row, col), denom); + tableau(nRow - 1, 0) = d; + tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -c%d. + tableau(nRow - 1, 2) = 0; + for (unsigned col = 3 + nSymbol; col < nCol; ++col) + tableau(nRow - 1, col) = mod(tableau(row, col), d); // b_i%d. return moveRowUnknownToColumn(nRow - 1); } @@ -185,7 +276,7 @@ Optional LexSimplex::maybeGetNonIntegralVarRow() const { if (u.orientation == Orientation::Column) continue; // If the sample value is of the form (a/d)M + b/d, we need b to be - // divisible by d. We assume M is very large and contains all possible + // divisible by d. We assume M contains all possible // factors and is divisible by everything. unsigned row = u.pos; if (tableau(row, 1) % tableau(row, 0) != 0) @@ -195,28 +286,34 @@ Optional LexSimplex::maybeGetNonIntegralVarRow() const { } MaybeOptimum> LexSimplex::findIntegerLexMin() { - while (!empty) { - restoreRationalConsistency(); - if (empty) + // We first try to make the tableau consistent. + if (restoreRationalConsistency().failed()) + return OptimumKind::Empty; + + // Then, if the sample value is integral, we are done. + while (Optional maybeRow = maybeGetNonIntegralVarRow()) { + // Otherwise, for the variable whose row has a non-integral sample value, + // we add a cut, a constraint that remove this rational point + // while preserving all integer points, thus keeping the lexmin the same. + // We then again try to make the tableau with the new constraint + // consistent. This continues until the tableau becomes empty, in which + // case there is no integer point, or until there are no variables with + // non-integral sample values. + // + // Failure indicates that the tableau became empty, which occurs when the + // polytope is integer empty. + if (addCut(*maybeRow).failed()) + return OptimumKind::Empty; + if (restoreRationalConsistency().failed()) return OptimumKind::Empty; - - if (Optional maybeRow = maybeGetNonIntegralVarRow()) { - // Failure occurs when the polytope is integer empty. - if (failed(addCut(*maybeRow))) - return OptimumKind::Empty; - continue; - } - - MaybeOptimum> sample = getRationalSample(); - assert(!sample.isEmpty() && "If we reached here the sample should exist!"); - if (sample.isUnbounded()) - return OptimumKind::Unbounded; - return llvm::to_vector<8>( - llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger))); } - // Polytope is integer empty. - return OptimumKind::Empty; + MaybeOptimum> sample = getRationalSample(); + assert(!sample.isEmpty() && "If we reached here the sample should exist!"); + if (sample.isUnbounded()) + return OptimumKind::Unbounded; + return llvm::to_vector<8>( + llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger))); } bool LexSimplex::isSeparateInequality(ArrayRef coeffs) { @@ -228,6 +325,319 @@ bool LexSimplex::isSeparateInequality(ArrayRef coeffs) { bool LexSimplex::isRedundantInequality(ArrayRef coeffs) { return isSeparateInequality(getComplementIneq(coeffs)); } + +SmallVector +SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const { + SmallVector sample; + sample.reserve(nSymbol + 1); + for (unsigned col = 3; col < 3 + nSymbol; ++col) + sample.push_back(tableau(row, col)); + sample.push_back(tableau(row, 1)); + return sample; +} + +void LexSimplexBase::appendSymbol() { + appendVariable(); + swapColumns(3 + nSymbol, nCol - 1); + var.back().isSymbol = true; + nSymbol++; +} + +static bool isRangeDivisibleBy(ArrayRef range, int64_t divisor) { + assert(divisor > 0 && "divisor must be positive!"); + return llvm::all_of(range, [divisor](int64_t x) { return x % divisor == 0; }); +} + +bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const { + int64_t denom = tableau(row, 0); + return tableau(row, 1) % denom == 0 && + isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom); +} + +/// This proceeds similarly to LexSimplex::addCut(). We are given a row that has +/// a symbolic sample value with fractional coefficients. +/// +/// Let the row be +/// (c + coeffM*M + sum_i a_i*s_i + sum_j b_j*y_j)/d, +/// where s_1, ... s_m are the symbols and +/// y_1, ... y_n are the other basis unknowns. +/// +/// As in LexSimplex::addCut, for this to be an integer, we want +/// +/// coeffM*M + sum_j b_j*y_j = -c + sum_i (-a_i*s_i) (mod d) +/// +/// This time, a_1*s_1 + ... + a_m*s_m may not be an integer. We find that +/// +/// sum_i (b_i%d)y_i = ((-c%d) + sum_i (-a_i%d)s_i)%d + k*d for some integer k +/// +/// where we take a modulo of the whole symbolic expression on the right to +/// bring it into the range [0, d - 1]. Therefore, as in LexSimplex::addCut, +/// k is the quotient on dividing the LHS by d, and since LHS >= 0, we have +/// k >= 0 as well. We realize the modulo of the symbolic expression by adding a +/// division variable +/// +/// q = ((-c%d) + sum_i (-a_i%d)s_i)/d +/// +/// to the symbol domain, so the equality becomes +/// +/// sum_i (b_i%d)y_i = (-c%d) + sum_i (-a_i%d)s_i - q*d + k*d for some integer k +/// +/// So the cut is +/// (sum_i (b_i%d)y_i - (-c%d) - sum_i (-a_i%d)s_i + q*d)/d >= 0 +/// This constraint is violated when added so we immediately try to move it to a +/// column. +LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { + int64_t d = tableau(row, 0); + + // Add the division variable `q` described above to the symbol domain. + // q = ((-c%d) + sum_i (-a_i%d)s_i)/d. + SmallVector domainDivCoeffs; + domainDivCoeffs.reserve(nSymbol + 1); + for (unsigned col = 3; col < 3 + nSymbol; ++col) + domainDivCoeffs.push_back(mod(-tableau(row, col), d)); // (-a_i%d)s_i + domainDivCoeffs.push_back(mod(-tableau(row, 1), d)); // -c%d. + + domainSimplex.addDivisionVariable(domainDivCoeffs, d); + domainPoly.addLocalFloorDiv(domainDivCoeffs, d); + + // Update `this` to account for the additional symbol we just added. + appendSymbol(); + + // Add the cut (sum_i (b_i%d)y_i - (-c%d) + sum_i -(-a_i%d)s_i + q*d)/d >= 0. + addZeroRow(/*makeRestricted=*/true); + tableau(nRow - 1, 0) = d; + tableau(nRow - 1, 2) = 0; + + tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -(-c%d). + for (unsigned col = 3; col < 3 + nSymbol - 1; ++col) + tableau(nRow - 1, col) = -mod(-tableau(row, col), d); // -(-a_i%d)s_i. + tableau(nRow - 1, 3 + nSymbol - 1) = d; // q*d. + + for (unsigned col = 3 + nSymbol; col < nCol; ++col) + tableau(nRow - 1, col) = mod(tableau(row, col), d); // (b_i%d)y_i. + return moveRowUnknownToColumn(nRow - 1); +} + +void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const { + Matrix output(0, domainPoly.getNumIds() + 1); + output.reserveRows(result.lexmin.getNumOutputs()); + for (const Unknown &u : var) { + if (u.isSymbol) + continue; + + if (u.orientation == Orientation::Column) { + // M + u has a sample value of zero so u has a sample value of -M, i.e, + // unbounded. + result.unboundedDomain.unionInPlace(domainPoly); + return; + } + + int64_t denom = tableau(u.pos, 0); + if (tableau(u.pos, 2) < denom) { + // M + u has a sample value of fM + something, where f < 1, so + // u = (f - 1)M + something, which has a negative coefficient for M, + // and so is unbounded. + result.unboundedDomain.unionInPlace(domainPoly); + return; + } + assert(tableau(u.pos, 2) == denom && + "Coefficient of M should not be greater than 1!"); + + SmallVector sample = getSymbolicSampleNumerator(u.pos); + for (int64_t &elem : sample) { + assert(elem % denom == 0 && "coefficients must be integral!"); + elem /= denom; + } + output.appendExtraRow(sample); + } + result.lexmin.addPiece(domainPoly, output); +} + +Optional SymbolicLexSimplex::maybeGetAlwaysViolatedRow() { + // First look for rows that are clearly violated just from the big M + // coefficient, without needing to perform any simplex queries on the domain. + for (unsigned row = 0; row < nRow; ++row) + if (tableau(row, 2) < 0) + return row; + + for (unsigned row = 0; row < nRow; ++row) { + if (tableau(row, 2) > 0) + continue; + if (domainSimplex.isSeparateInequality(getSymbolicSampleNumerator(row))) { + // Sample numerator always takes negative values in the symbol domain. + return row; + } + } + return {}; +} + +Optional SymbolicLexSimplex::maybeGetNonIntegralVarRow() { + for (const Unknown &u : var) { + if (u.orientation == Orientation::Column) + continue; + assert(!u.isSymbol && "Symbol should not be in row orientation!"); + if (!isSymbolicSampleIntegral(u.pos)) + return u.pos; + } + return {}; +} + +/// The non-branching pivots are just the ones moving the rows +/// that are always violated in the symbol domain. +LogicalResult SymbolicLexSimplex::doNonBranchingPivots() { + while (Optional row = maybeGetAlwaysViolatedRow()) + if (moveRowUnknownToColumn(*row).failed()) + return failure(); + return success(); +} + +SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { + SymbolicLexMin result(nSymbol, var.size() - nSymbol); + + /// The algorithm is more naturally expressed recursively, but we implement + /// it iteratively here to avoid potential issues with stack overflows in the + /// compiler. We explicitly maintain the stack frames in a vector. + /// + /// To "recurse", we store the current "stack frame", i.e., state variables + /// that we will need when we "return", into `stack`, increment `level`, and + /// `continue`. To "tail recurse", we just `continue`. + /// To "return", we decrement `level` and `continue`. + /// + /// When there is no stack frame for the current `level`, this indicates that + /// we have just "recursed" or "tail recursed". When there does exist one, + /// this indicates that we have just "returned" from recursing. There is only + /// one point at which non-tail calls occur so we always "return" there. + unsigned level = 1; + struct StackFrame { + int splitIndex; + unsigned snapshot; + unsigned domainSnapshot; + IntegerRelation::CountsSnapshot domainPolyCounts; + }; + SmallVector stack; + + while (level > 0) { + assert(level >= stack.size()); + if (level > stack.size()) { + if (empty || domainSimplex.findIntegerLexMin().isEmpty()) { + // No integer points; return. + --level; + continue; + } + + if (doNonBranchingPivots().failed()) { + // Could not find pivots for violated constraints; return. + --level; + continue; + } + + unsigned splitRow; + SmallVector symbolicSample; + for (splitRow = 0; splitRow < nRow; ++splitRow) { + if (tableau(splitRow, 2) > 0) + continue; + assert(tableau(splitRow, 2) == 0 && + "Non-branching pivots should have been handled already!"); + + symbolicSample = getSymbolicSampleNumerator(splitRow); + if (domainSimplex.isRedundantInequality(symbolicSample)) + continue; + + // It's neither redundant nor separate, so it takes both positive and + // negative values, and hence constitutes a row for which we need to + // split the domain and separately run each case. + assert(!domainSimplex.isSeparateInequality(symbolicSample) && + "Non-branching pivots should have been handled already!"); + break; + } + + if (splitRow < nRow) { + unsigned domainSnapshot = domainSimplex.getSnapshot(); + IntegerRelation::CountsSnapshot domainPolyCounts = + domainPoly.getCounts(); + + // First, we consider the part of the domain where the row is not + // violated. We don't have to do any pivots for the row in this case, + // but we record the additional constraint that defines this part of + // the domain. + domainSimplex.addInequality(symbolicSample); + domainPoly.addInequality(symbolicSample); + + // Recurse. + // + // On return, the basis as a set is preserved but not the internal + // ordering within rows or columns. Thus, we take note of the index of + // the Unknown that caused the split, which may be in a different + // row when we come back from recursing. We will need this to recurse + // on the other part of the split domain, where the row is violated. + // + // Note that we have to capture the index above and not a reference to + // the Unknown itself, since the array it lives in might get + // reallocated. + int splitIndex = rowUnknown[splitRow]; + unsigned snapshot = getSnapshot(); + stack.push_back( + {splitIndex, snapshot, domainSnapshot, domainPolyCounts}); + ++level; + continue; + } + + // The tableau is rationally consistent for the current domain. + // Now we look for non-integral sample values and add cuts for them. + if (Optional row = maybeGetNonIntegralVarRow()) { + if (addSymbolicCut(*row).failed()) { + // No integral points; return. + --level; + continue; + } + + // Rerun this level with the added cut constraint (tail recurse). + continue; + } + + // Record output and return. + recordOutput(result); + --level; + continue; + } + + if (level == stack.size()) { + // We have "returned" from "recursing". + const StackFrame &frame = stack.back(); + domainPoly.truncate(frame.domainPolyCounts); + domainSimplex.rollback(frame.domainSnapshot); + rollback(frame.snapshot); + const Unknown &u = unknownFromIndex(frame.splitIndex); + + // Drop the frame. We don't need it anymore. + stack.pop_back(); + + // Now we consider the part of the domain where the unknown `splitIndex` + // was negative. + assert(u.orientation == Orientation::Row && + "The split row should have been returned to row orientation!"); + SmallVector splitIneq = + getComplementIneq(getSymbolicSampleNumerator(u.pos)); + if (moveRowUnknownToColumn(u.pos).failed()) { + // The unknown can't be made non-negative; return. + --level; + continue; + } + + // The unknown can be made negative; recurse with the corresponding domain + // constraints. + domainSimplex.addInequality(splitIneq); + domainPoly.addInequality(splitIneq); + + // We are now taking care of the second half of the domain and we don't + // need to do anything else here after returning, so it's a tail recurse. + continue; + } + } + + return result; +} + bool LexSimplex::rowIsViolated(unsigned row) const { if (tableau(row, 2) < 0) return true; @@ -243,19 +653,20 @@ Optional LexSimplex::maybeGetViolatedRow() const { return {}; } -// We simply look for violated rows and keep trying to move them to column -// orientation, which always succeeds unless the constraints have no solution -// in which case we just give up and return. -void LexSimplex::restoreRationalConsistency() { - while (Optional maybeViolatedRow = maybeGetViolatedRow()) { - LogicalResult status = moveRowUnknownToColumn(*maybeViolatedRow); - if (failed(status)) - return; - } +/// We simply look for violated rows and keep trying to move them to column +/// orientation, which always succeeds unless the constraints have no solution +/// in which case we just give up and return. +LogicalResult LexSimplex::restoreRationalConsistency() { + if (empty) + return failure(); + while (Optional maybeViolatedRow = maybeGetViolatedRow()) + if (moveRowUnknownToColumn(*maybeViolatedRow).failed()) + return failure(); + return success(); } // Move the row unknown to column orientation while preserving lexicopositivity -// of the basis transform. +// of the basis transform. The sample value of the row must be negative. // // We only consider pivots where the pivot element is positive. Suppose no such // pivot exists, i.e., some violated row has no positive coefficient for any @@ -318,7 +729,7 @@ void LexSimplex::restoreRationalConsistency() { // minimizes the change in sample value. LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) { Optional maybeColumn; - for (unsigned col = 3; col < nCol; ++col) { + for (unsigned col = 3 + nSymbol; col < nCol; ++col) { if (tableau(row, col) <= 0) continue; maybeColumn = @@ -336,6 +747,7 @@ LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) { unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, unsigned colB) const { + // First, let's consider the non-symbolic case. // A pivot causes the following change. (in the diagram the matrix elements // are shown as rationals and there is no common denominator used) // @@ -359,7 +771,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, // (-p/a)M + (-b/a), i.e. 0 to -(pM + b)/a. Thus the change in the sample // value is -s/a. // - // If the variable is the pivot row, it sampel value goes from s to 0, for a + // If the variable is the pivot row, its sample value goes from s to 0, for a // change of -s. // // If the variable is a non-pivot row, its sample value changes from @@ -373,8 +785,12 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, // comparisons involved and can be ignored, since -s is strictly positive. // // Thus we take away this common factor and just return 0, 1/a, 1, or c/a as - // appropriate. This allows us to run the entire algorithm without ever having - // to fix a value of M. + // appropriate. This allows us to run the entire algorithm treating M + // symbolically, as the pivot to be performed does not depend on the value + // of M, so long as the sample value s is negative. Note that this is not + // because of any special feature of M; by the same argument, we ignore the + // symbols too. The caller ensure that the sample value s is negative for + // all possible values of the symbols. auto getSampleChangeCoeffForVar = [this, row](unsigned col, const Unknown &u) -> Fraction { int64_t a = tableau(row, col); @@ -489,6 +905,7 @@ void SimplexBase::pivot(Pivot pair) { pivot(pair.row, pair.column); } /// element. void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) { assert(pivotCol >= getNumFixedCols() && "Refusing to pivot invalid column"); + assert(!unknownFromColumn(pivotCol).isSymbol); swapRowWithCol(pivotRow, pivotCol); std::swap(tableau(pivotRow, 0), tableau(pivotRow, pivotCol)); @@ -778,6 +1195,9 @@ void SimplexBase::undo(UndoLogEntry entry) { assert(var.back().orientation == Orientation::Column && "Variable to be removed must be in column orientation!"); + if (var.back().isSymbol) + nSymbol--; + // Move this variable to the last column and remove the column from the // tableau. swapColumns(var.back().pos, nCol - 1); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp index 4149d85d8759..2cb6ada89397 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -8,6 +8,7 @@ #include "./Utils.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/Simplex.h" #include @@ -1134,6 +1135,229 @@ TEST(IntegerPolyhedronTest, findIntegerLexMin) { ">= 0, -11*z + 5*y - 3*x + 7 >= 0)")); } +void expectSymbolicIntegerLexMin( + StringRef polyStr, + ArrayRef, 8>>> + expectedLexminRepr, + ArrayRef expectedUnboundedDomainRepr) { + IntegerPolyhedron poly = parsePoly(polyStr); + + ASSERT_NE(poly.getNumDimIds(), 0u); + ASSERT_NE(poly.getNumSymbolIds(), 0u); + + PWMAFunction expectedLexmin = + parsePWMAF(/*numInputs=*/poly.getNumSymbolIds(), + /*numOutputs=*/poly.getNumDimIds(), expectedLexminRepr); + + PresburgerSet expectedUnboundedDomain = parsePresburgerSetFromPolyStrings( + poly.getNumSymbolIds(), expectedUnboundedDomainRepr); + + SymbolicLexMin result = poly.findSymbolicIntegerLexMin(); + + EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin)); + if (!result.lexmin.isEqual(expectedLexmin)) { + llvm::errs() << "got:\n"; + result.lexmin.dump(); + llvm::errs() << "expected:\n"; + expectedLexmin.dump(); + } + + EXPECT_TRUE(result.unboundedDomain.isEqual(expectedUnboundedDomain)); + if (!result.unboundedDomain.isEqual(expectedUnboundedDomain)) + result.unboundedDomain.dump(); +} + +void expectSymbolicIntegerLexMin( + StringRef polyStr, + ArrayRef, 8>>> + result) { + expectSymbolicIntegerLexMin(polyStr, result, {}); +} + +TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) { + expectSymbolicIntegerLexMin("(x)[a] : (x - a >= 0)", + { + {"(a) : ()", {{1, 0}}}, // a + }); + + expectSymbolicIntegerLexMin( + "(x)[a, b] : (x - a >= 0, x - b >= 0)", + { + {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a + {"(a, b) : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b + }); + + expectSymbolicIntegerLexMin( + "(x)[a, b, c] : (x -a >= 0, x - b >= 0, x - c >= 0)", + { + {"(a, b, c) : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}}, // a + {"(a, b, c) : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}}, // b + {"(a, b, c) : (c - a - 1 >= 0, c - b - 1 >= 0)", {{0, 0, 1, 0}}}, // c + }); + + expectSymbolicIntegerLexMin("(x, y)[a] : (x - a >= 0, x + y >= 0)", + { + {"(a) : ()", {{1, 0}, {-1, 0}}}, // (a, -a) + }); + + expectSymbolicIntegerLexMin( + "(x, y)[a] : (x - a >= 0, x + y >= 0, y >= 0)", + { + {"(a) : (a >= 0)", {{1, 0}, {0, 0}}}, // (a, 0) + {"(a) : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a) + }); + + expectSymbolicIntegerLexMin( + "(x, y)[a, b, c] : (x - a >= 0, y - b >= 0, c - x - y >= 0)", + { + {"(a, b, c) : (c - a - b >= 0)", + {{1, 0, 0, 0}, {0, 1, 0, 0}}}, // (a, b) + }); + + expectSymbolicIntegerLexMin( + "(x, y, z)[a, b, c] : (c - z >= 0, b - y >= 0, x + y + z - a == 0)", + { + {"(a, b, c) : ()", + {{1, -1, -1, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}}}, // (a - b - c, b, c) + }); + + expectSymbolicIntegerLexMin( + "(x)[a, b] : (a >= 0, b >= 0, x >= 0, a + b + x - 1 >= 0)", + { + {"(a, b) : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0 + {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1 + }); + + expectSymbolicIntegerLexMin( + "(x)[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, 1 - x >= 0, x >= " + "0, a + b + x - 1 >= 0)", + { + {"(a, b) : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= 0)", + {{0, 0, 0}}}, // 0 + {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1 + }); + + expectSymbolicIntegerLexMin( + "(x, y, z)[a, b] : (x - a == 0, y - b == 0, x >= 0, y >= 0, z >= 0, x + " + "y + z - 1 >= 0)", + { + {"(a, b) : (a >= 0, b >= 0, 1 - a - b >= 0)", + {{1, 0, 0}, {0, 1, 0}, {-1, -1, 1}}}, // (a, b, 1 - a - b) + {"(a, b) : (a >= 0, b >= 0, a + b - 2 >= 0)", + {{1, 0, 0}, {0, 1, 0}, {0, 0, 0}}}, // (a, b, 0) + }); + + expectSymbolicIntegerLexMin("(x)[a, b] : (x - a == 0, x - b >= 0)", + { + {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a + }); + + expectSymbolicIntegerLexMin( + "(q)[a] : (a - 1 - 3*q == 0, q >= 0)", + { + {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", + {{0, 1, 0}}}, // a floordiv 3 + }); + + expectSymbolicIntegerLexMin( + "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 1 - r >= 0, r >= 0)", + { + {"(a) : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)", + {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3) + {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", + {{0, 0, 1}, {0, 1, 0}}}, // (1 a floordiv 3) + }); + + expectSymbolicIntegerLexMin( + "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 2 - r >= 0, r - 1 >= 0)", + { + {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", + {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3) + {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)", + {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3) + }); + + expectSymbolicIntegerLexMin( + "(r, q)[a] : (a - r - 3*q == 0, q >= 0, r >= 0)", + { + {"(a) : (a - 3*(a floordiv 3) == 0, a >= 0)", + {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3) + {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", + {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3) + {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)", + {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3) + }); + + expectSymbolicIntegerLexMin( + "(x, y, z, w)[g] : (" + // x, y, z, w are boolean variables. + "1 - x >= 0, x >= 0, 1 - y >= 0, y >= 0," + "1 - z >= 0, z >= 0, 1 - w >= 0, w >= 0," + // We have some constraints on them: + "x + y + z - 1 >= 0," // x or y or z + "x + y + w - 1 >= 0," // x or y or w + "1 - x + 1 - y + 1 - w - 1 >= 0," // ~x or ~y or ~w + // What's the lexmin solution using exactly g true vars? + "g - x - y - z - w == 0)", + { + {"(g) : (g - 1 == 0)", + {{0, 0}, {0, 1}, {0, 0}, {0, 0}}}, // (0, 1, 0, 0) + {"(g) : (g - 2 == 0)", + {{0, 0}, {0, 0}, {0, 1}, {0, 1}}}, // (0, 0, 1, 1) + {"(g) : (g - 3 == 0)", + {{0, 0}, {0, 1}, {0, 1}, {0, 1}}}, // (0, 1, 1, 1) + }); + + // Bezout's lemma: if a, b are constants, + // the set of values that ax + by can take is all multiples of gcd(a, b). + expectSymbolicIntegerLexMin( + // If (x, y) is a solution for a given [a, r], then so is (x - 5, y + 2). + // So the lexmin is unbounded if it exists. + "(x, y)[a, r] : (a >= 0, r - a + 14*x + 35*y == 0)", {}, + // According to Bezout's lemma, 14x + 35y can take on all multiples + // of 7 and no other values. So the solution exists iff r - a is a + // multiple of 7. + {"(a, r) : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"}); + + // The lexmins are unbounded. + expectSymbolicIntegerLexMin("(x, y)[a] : (9*x - 4*y - 2*a >= 0)", {}, + {"(a) : ()"}); + + // Test cases adapted from isl. + expectSymbolicIntegerLexMin( + // a = 2b - 2(c - b), c - b >= 0. + // So b is minimized when c = b. + "(b, c)[a] : (a - 4*b + 2*c == 0, c - b >= 0)", + { + {"(a) : (a - 2*(a floordiv 2) == 0)", + {{0, 1, 0}, {0, 1, 0}}}, // (a floordiv 2, a floordiv 2) + }); + + expectSymbolicIntegerLexMin( + // 0 <= b <= 255, 1 <= a - 512b <= 509, + // b + 8 >= 1 + 16*(b + 8 floordiv 16) // i.e. b % 16 != 8 + "(b)[a] : (255 - b >= 0, b >= 0, a - 512*b - 1 >= 0, 512*b -a + 509 >= " + "0, b + 7 - 16*((8 + b) floordiv 16) >= 0)", + { + {"(a) : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv " + "512) - 1 >= 0, 512*(a floordiv 512) - a + 509 >= 0, (a floordiv " + "512) + 7 - 16*((8 + (a floordiv 512)) floordiv 16) >= 0)", + {{0, 1, 0, 0}}}, // (a floordiv 2, a floordiv 2) + }); + + expectSymbolicIntegerLexMin( + "(a, b)[K, N, x, y] : (N - K - 2 >= 0, K + 4 - N >= 0, x - 4 >= 0, x + 6 " + "- 2*N >= 0, K+N - x - 1 >= 0, a - N + 1 >= 0, K+N-1-a >= 0,a + 6 - b - " + "N >= 0, 2*N - 4 - a >= 0," + "2*N - 3*K + a - b >= 0, 4*N - K + 1 - 3*b >= 0, b - N >= 0, a - x - 1 " + ">= 0)", + {{ + "(K, N, x, y) : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + N " + ">= 0, N + K - 2 - x >= 0, x - 4 >= 0)", + {{0, 0, 1, 0, 1}, {0, 1, 0, 0, 0}} // (1 + x, N) + }}); +} + static void expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly, Optional trueVolume,