[MLIR][Presburger] IntegerPolyhedron: add support for symbolic integer lexmin

Add support for computing the symbolic integer lexmin of a polyhedron.
This finds, for every assignment to the symbols, the lexicographically
minimum value attained by the dimensions. For example, the symbolic lexmin
of the set

`(x, y)[a, b, c] : (a <= x, b <= x, x <= c)`

can be written as

```
x = a if b <= a, a <= c
x = b if a <  b, b <= c
```

This also finds the set of assignments to the symbols that make the lexmin unbounded.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D122985
This commit is contained in:
Arjun P 2022-04-04 23:31:28 +01:00
parent 3b9833597e
commit da92f92621
9 changed files with 970 additions and 128 deletions

View File

@ -536,6 +536,7 @@ protected:
Matrix inequalities; Matrix inequalities;
}; };
struct SymbolicLexMin;
/// An IntegerPolyhedron is a PresburgerSpace subject to affine /// An IntegerPolyhedron is a PresburgerSpace subject to affine
/// constraints. Affine constraints can be inequalities or equalities in the /// constraints. Affine constraints can be inequalities or equalities in the
/// form: /// form:
@ -593,6 +594,28 @@ public:
/// column position (i.e., not relative to the kind of identifier) of the /// column position (i.e., not relative to the kind of identifier) of the
/// first added identifier. /// first added identifier.
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
/// Compute the symbolic integer lexmin of the polyhedron.
/// This finds, for every assignment to the symbols, the lexicographically
/// minimum value attained by the dimensions. For example, the symbolic lexmin
/// of the set
///
/// (x, y)[a, b, c] : (a <= x, b <= x, x <= c)
///
/// can be written as
///
/// x = a if b <= a, a <= c
/// x = b if a < b, b <= c
///
/// This function is stored in the `lexmin` function in the result.
/// Some assignments to the symbols might make the set empty.
/// Such points are not part of the function's domain.
/// In the above example, this happens when max(a, b) > c.
///
/// For some values of the symbols, the lexmin may be unbounded.
/// `SymbolicLexMin` stores these parts of the symbolic domain in a separate
/// `PresburgerSet`, `unboundedDomain`.
SymbolicLexMin findSymbolicIntegerLexMin() const;
}; };
} // namespace presburger } // namespace presburger

View File

@ -151,6 +151,9 @@ public:
/// Add an extra row at the bottom of the matrix and return its position. /// Add an extra row at the bottom of the matrix and return its position.
unsigned appendExtraRow(); unsigned appendExtraRow();
/// Same as above, but copy the given elements into the row. The length of
/// `elems` must be equal to the number of columns.
unsigned appendExtraRow(ArrayRef<int64_t> elems);
/// Print the matrix. /// Print the matrix.
void print(raw_ostream &os) const; void print(raw_ostream &os) const;

View File

@ -106,6 +106,11 @@ public:
/// outside the domain, an empty optional is returned. /// outside the domain, an empty optional is returned.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const; Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
/// Truncate the output dimensions to the first `count` dimensions.
///
/// TODO: refactor so that this can be accomplished through removeIdRange.
void truncateOutput(unsigned count);
void print(raw_ostream &os) const; void print(raw_ostream &os) const;
void dump() const; void dump() const;
@ -165,6 +170,11 @@ public:
/// value at every point in the domain. /// value at every point in the domain.
bool isEqual(const PWMAFunction &other) const; bool isEqual(const PWMAFunction &other) const;
/// Truncate the output dimensions to the first `count` dimensions.
///
/// TODO: refactor so that this can be accomplished through removeIdRange.
void truncateOutput(unsigned count);
void print(raw_ostream &os) const; void print(raw_ostream &os) const;
void dump() const; void dump() const;

View File

@ -18,6 +18,7 @@
#include "mlir/Analysis/Presburger/Fraction.h" #include "mlir/Analysis/Presburger/Fraction.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/Matrix.h" #include "mlir/Analysis/Presburger/Matrix.h"
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Analysis/Presburger/Utils.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
@ -41,8 +42,9 @@ class GBRSimplex;
/// these constraints that are redundant, i.e. a subset of constraints that /// these constraints that are redundant, i.e. a subset of constraints that
/// doesn't constrain the affine set further after adding the non-redundant /// doesn't constrain the affine set further after adding the non-redundant
/// constraints. The LexSimplex class provides support for computing the /// constraints. The LexSimplex class provides support for computing the
/// lexicographical minimum of an IntegerRelation. Both these classes can be /// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class
/// constructed from an IntegerRelation, and both inherit common /// provides support for computing symbolic lexicographic minimums. All of these
/// classes can be constructed from an IntegerRelation, and all inherit common
/// functionality from SimplexBase. /// functionality from SimplexBase.
/// ///
/// The implementations of the Simplex and SimplexBase classes, other than the /// The implementations of the Simplex and SimplexBase classes, other than the
@ -72,19 +74,22 @@ class GBRSimplex;
/// respectively. As described above, the first column is the common /// respectively. As described above, the first column is the common
/// denominator. The second column represents the constant term, explained in /// denominator. The second column represents the constant term, explained in
/// more detail below. These two are _fixed columns_; they always retain their /// more detail below. These two are _fixed columns_; they always retain their
/// position as the first and second columns. Additionally, LexSimplex stores /// position as the first and second columns. Additionally, LexSimplexBase
/// a so-call big M parameter (explained below) in the third column, so /// stores a so-call big M parameter (explained below) in the third column, so
/// LexSimplex has three fixed columns. /// LexSimplexBase has three fixed columns. Finally, SymbolicLexSimplex has
/// `nSymbol` variables designated as symbols. These occupy the next `nSymbol`
/// columns, viz. the columns [3, 3 + nSymbol). For more information on symbols,
/// see LexSimplexBase and SymbolicLexSimplex.
/// ///
/// LexSimplex does not directly support variables which can be negative, so we /// LexSimplexBase does not directly support variables which can be negative, so
/// introduce the so-called big M parameter, an artificial variable that is /// we introduce the so-called big M parameter, an artificial variable that is
/// considered to have an arbitrarily large value. We then transform the /// considered to have an arbitrarily large value. We then transform the
/// variables, say x, y, z, ... to M, M + x, M + y, M + z. Since M has been /// variables, say x, y, z, ... to M, M + x, M + y, M + z. Since M has been
/// added to these variables, they are now known to have non-negative values. /// added to these variables, they are now known to have non-negative values.
/// For more details, see the documentation for LexSimplex. The big M parameter /// For more details, see the documentation for LexSimplexBase. The big M
/// is not considered a real unknown and is not stored in the `var` data /// parameter is not considered a real unknown and is not stored in the `var`
/// structure; rather the tableau just has an extra fixed column for it just /// data structure; rather the tableau just has an extra fixed column for it
/// like the constant term. /// just like the constant term.
/// ///
/// The vectors var and con store information about the variables and /// The vectors var and con store information about the variables and
/// constraints respectively, namely, whether they are in row or column /// constraints respectively, namely, whether they are in row or column
@ -146,8 +151,8 @@ class GBRSimplex;
/// operation from the end until we reach the snapshot's location. SimplexBase /// operation from the end until we reach the snapshot's location. SimplexBase
/// also supports taking a snapshot including the exact set of basis unknowns; /// also supports taking a snapshot including the exact set of basis unknowns;
/// if this functionality is used, then on rolling back the exact basis will /// if this functionality is used, then on rolling back the exact basis will
/// also be restored. This is used by LexSimplex because its algorithm, unlike /// also be restored. This is used by LexSimplexBase because the lex algorithm,
/// Simplex, is sensitive to the exact basis used at a point. /// unlike `Simplex`, is sensitive to the exact basis used at a point.
class SimplexBase { class SimplexBase {
public: public:
SimplexBase() = delete; SimplexBase() = delete;
@ -211,7 +216,8 @@ protected:
/// constant term, whereas LexSimplex has an extra fixed column for the /// constant term, whereas LexSimplex has an extra fixed column for the
/// so-called big M parameter. For more information see the documentation for /// so-called big M parameter. For more information see the documentation for
/// LexSimplex. /// LexSimplex.
SimplexBase(unsigned nVar, bool mustUseBigM); SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
unsigned nSymbol);
enum class Orientation { Row, Column }; enum class Orientation { Row, Column };
@ -223,11 +229,14 @@ protected:
/// always be non-negative and if it cannot be made non-negative without /// always be non-negative and if it cannot be made non-negative without
/// violating other constraints, the tableau is empty. /// violating other constraints, the tableau is empty.
struct Unknown { struct Unknown {
Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos) Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos,
: pos(oPos), orientation(oOrientation), restricted(oRestricted) {} bool oIsSymbol = false)
: pos(oPos), orientation(oOrientation), restricted(oRestricted),
isSymbol(oIsSymbol) {}
unsigned pos; unsigned pos;
Orientation orientation; Orientation orientation;
bool restricted : 1; bool restricted : 1;
bool isSymbol : 1;
void print(raw_ostream &os) const { void print(raw_ostream &os) const {
os << (orientation == Orientation::Row ? "r" : "c"); os << (orientation == Orientation::Row ? "r" : "c");
@ -326,6 +335,10 @@ protected:
/// nRedundant rows. /// nRedundant rows.
unsigned nRedundant; unsigned nRedundant;
/// The number of parameters. This must be consistent with the number of
/// Unknowns in `var` below that have `isSymbol` set to true.
unsigned nSymbol;
/// The matrix representing the tableau. /// The matrix representing the tableau.
Matrix tableau; Matrix tableau;
@ -363,62 +376,45 @@ protected:
/// introduce an artifical variable M that is considered to have a value of /// introduce an artifical variable M that is considered to have a value of
/// +infinity and instead of the variables x, y, z, we internally use variables /// +infinity and instead of the variables x, y, z, we internally use variables
/// M + x, M + y, M + z, which are now guaranteed to be non-negative. See the /// M + x, M + y, M + z, which are now guaranteed to be non-negative. See the
/// documentation for Simplex for more details. The whole algorithm is performed /// documentation for SimplexBase for more details. M is also considered to be
/// without having to fix a "big enough" value of the big M parameter; it is /// an integer that is divisible by everything.
/// just considered to be infinite throughout and it never appears in the final ///
/// outputs. We will deal with sample values throughout that may in general be /// The whole algorithm is performed with M treated as a symbol;
/// some linear expression involving M like pM + q or aM + b. We can compare /// it is just considered to be infinite throughout and it never appears in the
/// these with each other. They have a total order: /// final outputs. We will deal with sample values throughout that may in
/// aM + b < pM + q iff a < p or (a == p and b < q). /// general be some affine expression involving M, like pM + q or aM + b. We can
/// compare these with each other. They have a total order:
///
/// aM + b < pM + q iff a < p or (a == p and b < q).
/// In particular, aM + b < 0 iff a < 0 or (a == 0 and b < 0). /// In particular, aM + b < 0 iff a < 0 or (a == 0 and b < 0).
/// ///
/// When performing symbolic optimization, sample values will be affine
/// expressions in M and the symbols. For example, we could have sample values
/// aM + bS + c and pM + qS + r, where S is a symbol. Now we have
/// aM + bS + c < pM + qS + r iff (a < p) or (a == p and bS + c < qS + r).
/// bS + c < qS + r can be always true, always false, or neither,
/// depending on the set of values S can take. The symbols are always stored
/// in columns [3, 3 + nSymbols). For more details, see the
/// documentation for SymbolicLexSimplex.
///
/// Initially all the constraints to be added are added as rows, with no attempt /// Initially all the constraints to be added are added as rows, with no attempt
/// to keep the tableau consistent. Pivots are only performed when some query /// to keep the tableau consistent. Pivots are only performed when some query
/// is made, such as a call to getRationalLexMin. Care is taken to always /// is made, such as a call to getRationalLexMin. Care is taken to always
/// maintain a lexicopositive basis transform, explained below. /// maintain a lexicopositive basis transform, explained below.
/// ///
/// Let the variables be x = (x_1, ... x_n). Let the basis unknowns at a /// Let the variables be x = (x_1, ... x_n).
/// particular point be y = (y_1, ... y_n). We know that x = A*y + b for some /// Let the symbols be s = (s_1, ... s_m). Let the basis unknowns at a
/// n x n matrix A and n x 1 column vector b. We want every column in A to be /// particular point be y = (y_1, ... y_n). We know that x = A*y + T*s + b for
/// lexicopositive, i.e., have at least one non-zero element, with the first /// some n x n matrix A, n x m matrix s, and n x 1 column vector b. We want
/// such element being positive. This property is preserved throughout the /// every column in A to be lexicopositive, i.e., have at least one non-zero
/// operation of LexSimplex. Note that on construction, the basis transform A is /// element, with the first such element being positive. This property is
/// the indentity matrix and so every column is lexicopositive. Note that for /// preserved throughout the operation of LexSimplexBase. Note that on
/// LexSimplex, for the tableau to be consistent we must have non-negative /// construction, the basis transform A is the identity matrix and so every
/// sample values not only for the constraints but also for the variables. /// column is lexicopositive. Note that for LexSimplexBase, for the tableau to
/// So if the tableau is consistent then x >= 0 and y >= 0, by which we mean /// be consistent we must have non-negative sample values not only for the
/// every element in these vectors is non-negative. (note that this is a /// constraints but also for the variables. So if the tableau is consistent then
/// different concept from lexicopositivity!) /// x >= 0 and y >= 0, by which we mean every element in these vectors is
/// /// non-negative. (note that this is a different concept from lexicopositivity!)
/// When we arrive at a basis such the basis transform is lexicopositive and the
/// tableau is consistent, the sample point is the lexiographically minimum
/// point in the polytope. We will show that A*y is zero or lexicopositive when
/// y >= 0. Adding a lexicopositive vector to b will make it lexicographically
/// bigger, so A*y + b is lexicographically bigger than b for any y >= 0 except
/// y = 0. This shows that no point lexicographically smaller than x = b can be
/// obtained. Since we already know that x = b is valid point in the space, this
/// shows that x = b is the lexicographic minimum.
///
/// Proof that A*y is lexicopositive or zero when y > 0. Recall that every
/// column of A is lexicopositive. Begin by considering A_1, the first row of A.
/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next
/// row. If we run out of rows, A*y is zero and we are done; otherwise, we
/// encounter some row A_i that has a non-zero element. Every column is
/// lexicopositive and so has some positive element before any negative elements
/// occur, so the element in this row for any column, if non-zero, must be
/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are
/// non-negative, so if this is non-zero then it must be positive. Then the
/// first non-zero element of A*y is positive so A*y is lexicopositive.
///
/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero
/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y
/// and we can completely ignore these columns of A. We now continue downwards,
/// looking for rows of A that have a non-zero element other than in the ignored
/// columns. If we find one, say A_k, once again these elements must be positive
/// since they are the first non-zero element in each of these columns, so if
/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we
/// ignore more columns; eventually if all these dot products become zero then
/// A*y is zero and we are done.
class LexSimplexBase : public SimplexBase { class LexSimplexBase : public SimplexBase {
public: public:
~LexSimplexBase() override = default; ~LexSimplexBase() override = default;
@ -435,25 +431,37 @@ public:
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); } unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
protected: protected:
LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {} LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol)
: SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {}
explicit LexSimplexBase(const IntegerRelation &constraints) explicit LexSimplexBase(const IntegerRelation &constraints)
: LexSimplexBase(constraints.getNumIds()) { : LexSimplexBase(constraints.getNumIds(),
constraints.getIdKindOffset(IdKind::Symbol),
constraints.getNumSymbolIds()) {
intersectIntegerRelation(constraints); intersectIntegerRelation(constraints);
} }
/// Add new symbolic variables to the end of the list of variables.
void appendSymbol();
/// Try to move the specified row to column orientation while preserving the /// Try to move the specified row to column orientation while preserving the
/// lexicopositivity of the basis transform. If this is not possible, return /// lexicopositivity of the basis transform. The row must have a negative
/// failure. This only occurs when the constraints have no solution; the /// sample value. If this is not possible, return failure. This only occurs
/// tableau will be marked empty in such a case. /// when the constraints have no solution; the tableau will be marked empty in
/// such a case.
LogicalResult moveRowUnknownToColumn(unsigned row); LogicalResult moveRowUnknownToColumn(unsigned row);
/// Given a row that has a non-integer sample value, add an inequality such /// Given a row that has a non-integer sample value, add an inequality to cut
/// that this fractional sample value is cut away from the polytope. The added /// away this fractional sample value from the polytope without removing any
/// inequality will be such that no integer points are removed. /// integer points. The integer lexmin, if one existed, remains the same on
/// return.
/// ///
/// Returns whether the cut constraint could be enforced, i.e. failure if the /// This assumes that the symbolic part of the sample is integral,
/// cut made the polytope empty, and success if it didn't. Failure status /// i.e., if the symbolic sample is (c + aM + b_1*s_1 + ... b_n*s_n)/d,
/// indicates that the polytope didn't have any integer points. /// where s_1, ... s_n are symbols, this assumes that
/// (b_1*s_1 + ... + b_n*s_n)/s is integral.
///
/// Return failure if the tableau became empty, and success if it didn't.
/// Failure status indicates that the polytope was integer empty.
LogicalResult addCut(unsigned row); LogicalResult addCut(unsigned row);
/// Undo the addition of the last constraint. This is only called while /// Undo the addition of the last constraint. This is only called while
@ -461,14 +469,19 @@ protected:
void undoLastConstraint() final; void undoLastConstraint() final;
/// Given two potential pivot columns for a row, return the one that results /// Given two potential pivot columns for a row, return the one that results
/// in the lexicographically smallest sample vector. /// in the lexicographically smallest sample vector. The row's sample value
/// must be negative. If symbols are involved, the sample value must be
/// negative for all possible assignments to the symbols.
unsigned getLexMinPivotColumn(unsigned row, unsigned colA, unsigned getLexMinPivotColumn(unsigned row, unsigned colA,
unsigned colB) const; unsigned colB) const;
}; };
/// A class for lexicographic optimization without any symbols. This also
/// provides support for integer-exact redundancy and separateness checks.
class LexSimplex : public LexSimplexBase { class LexSimplex : public LexSimplexBase {
public: public:
explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {} explicit LexSimplex(unsigned nVar)
: LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {}
explicit LexSimplex(const IntegerRelation &constraints) explicit LexSimplex(const IntegerRelation &constraints)
: LexSimplexBase(constraints) { : LexSimplexBase(constraints) {
assert(constraints.getNumSymbolIds() == 0 && assert(constraints.getNumSymbolIds() == 0 &&
@ -502,7 +515,7 @@ private:
MaybeOptimum<SmallVector<Fraction, 8>> getRationalSample() const; MaybeOptimum<SmallVector<Fraction, 8>> getRationalSample() const;
/// Make the tableau configuration consistent. /// Make the tableau configuration consistent.
void restoreRationalConsistency(); LogicalResult restoreRationalConsistency();
/// Return whether the specified row is violated; /// Return whether the specified row is violated;
bool rowIsViolated(unsigned row) const; bool rowIsViolated(unsigned row) const;
@ -514,11 +527,122 @@ private:
/// Get a row corresponding to a var that has a non-integral sample value, if /// Get a row corresponding to a var that has a non-integral sample value, if
/// one exists. Otherwise, return an empty optional. /// one exists. Otherwise, return an empty optional.
Optional<unsigned> maybeGetNonIntegralVarRow() const; Optional<unsigned> maybeGetNonIntegralVarRow() const;
};
/// Given two potential pivot columns for a row, return the one that results /// Represents the result of a symbolic lexicographic minimization computation.
/// in the lexicographically smallest sample vector. struct SymbolicLexMin {
unsigned getLexMinPivotColumn(unsigned row, unsigned colA, SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols)
unsigned colB) const; : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols),
unboundedDomain(
PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {}
/// This maps assignments of symbols to the corresponding lexmin.
/// Takes no value when no integer sample exists for the assignment or if the
/// lexmin is unbounded.
PWMAFunction lexmin;
/// Contains all assignments to the symbols that made the lexmin unbounded.
/// Note that the symbols of the input set to the symbolic lexmin are dims
/// of this PrebsurgerSet.
PresburgerSet unboundedDomain;
};
/// A class to perform symbolic lexicographic optimization,
/// i.e., to find, for every assignment to the symbols the specified
/// `symbolDomain`, the lexicographically minimum value integer value attained
/// by the non-symbol variables.
///
/// The input is a set parametrized by some symbols, i.e., the constant terms
/// of the constraints in the set are affine expressions in the symbols, and
/// every assignment to the symbols defines a non-symbolic set.
///
/// Accordingly, the sample values of the rows in our tableau will be affine
/// expressions in the symbols, and every assignment to the symbols will define
/// a non-symbolic LexSimplex. We then run the algorithm of
/// LexSimplex::findIntegerLexMin simultaneously for every value of the symbols
/// in the domain.
///
/// Often, the pivot to be performed is the same for all values of the symbols,
/// in which case we just do it. For example, if the symbolic sample of a row is
/// negative for all values in the symbol domain, the row needs to be pivoted
/// irrespective of the precise value of the symbols. To answer queries like
/// "Is this symbolic sample always negative in the symbol domain?", we maintain
/// a `LexSimplex domainSimplex` correponding to the symbol domain.
///
/// In other cases, it may be that the symbolic sample is violated at some
/// values in the symbol domain and not violated at others. In this case,
/// the pivot to be performed does depend on the value of the symbols. We
/// handle this by splitting the symbol domain. We run the algorithm for the
/// case where the row isn't violated, and then come back and run the case
/// where it is.
class SymbolicLexSimplex : public LexSimplexBase {
public:
/// `constraints` is the set for which the symbolic lexmin will be computed.
/// `symbolDomain` is the set of values of the symbols for which the lexmin
/// will be computed. `symbolDomain` should have a dim id for every symbol in
/// `constraints`, and no other ids.
SymbolicLexSimplex(const IntegerPolyhedron &constraints,
const IntegerPolyhedron &symbolDomain)
: LexSimplexBase(constraints), domainPoly(symbolDomain),
domainSimplex(symbolDomain) {
assert(domainPoly.getNumIds() == constraints.getNumSymbolIds());
assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds());
}
/// The lexmin will be stored as a function `lexmin` from symbols to
/// non-symbols in the result.
///
/// For some values of the symbols, the lexmin may be unbounded.
/// These parts of the symbol domain will be stored in `unboundedDomain`.
SymbolicLexMin computeSymbolicIntegerLexMin();
private:
/// Perform all pivots that do not require branching.
///
/// Return failure if the tableau became empty, indicating that the polytope
/// is always integer empty in the current symbol domain.
/// Return success otherwise.
LogicalResult doNonBranchingPivots();
/// Get a row that is always violated in the current domain, if one exists.
Optional<unsigned> maybeGetAlwaysViolatedRow();
/// Get a row corresponding to a variable with non-integral sample value, if
/// one exists.
Optional<unsigned> maybeGetNonIntegralVarRow();
/// Given a row that has a non-integer sample value, cut away this fractional
/// sample value witahout removing any integer points, i.e., the integer
/// lexmin, if it exists, remains the same after a call to this function. This
/// may add constraints or local variables to the tableau, as well as to the
/// domain.
///
/// Returns whether the cut constraint could be enforced, i.e. failure if the
/// cut made the polytope empty, and success if it didn't. Failure status
/// indicates that the polytope is always integer empty in the symbol domain
/// at the time of the call. (This function may modify the symbol domain, but
/// failure statu indicates that the polytope was empty for all symbol values
/// in the initial domain.)
LogicalResult addSymbolicCut(unsigned row);
/// Get the numerator of the symbolic sample of the specific row.
/// This is an affine expression in the symbols with integer coefficients.
/// The last element is the constant term. This ignores the big M coefficient.
SmallVector<int64_t, 8> getSymbolicSampleNumerator(unsigned row) const;
/// Return whether all the coefficients of the symbolic sample are integers.
///
/// This does not consult the domain to check if the specified expression
/// is always integral despite coefficients being fractional.
bool isSymbolicSampleIntegral(unsigned row) const;
/// Record a lexmin. The tableau must be consistent with all variables
/// having symbolic samples with integer coefficients.
void recordOutput(SymbolicLexMin &result) const;
/// The symbol domain.
IntegerPolyhedron domainPoly;
/// Simplex corresponding to the symbol domain.
LexSimplex domainSimplex;
}; };
/// The Simplex class uses the Normal pivot rule and supports integer emptiness /// The Simplex class uses the Normal pivot rule and supports integer emptiness
@ -540,7 +664,9 @@ public:
enum class Direction { Up, Down }; enum class Direction { Up, Down };
Simplex() = delete; Simplex() = delete;
explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {} explicit Simplex(unsigned nVar)
: SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0,
/*nSymbol=*/0) {}
explicit Simplex(const IntegerRelation &constraints) explicit Simplex(const IntegerRelation &constraints)
: Simplex(constraints.getNumIds()) { : Simplex(constraints.getNumIds()) {
intersectIntegerRelation(constraints); intersectIntegerRelation(constraints);

View File

@ -14,6 +14,7 @@
#include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/LinearTransform.h" #include "mlir/Analysis/Presburger/LinearTransform.h"
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Analysis/Presburger/Utils.h"
@ -145,6 +146,21 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
removeEqualityRange(counts.getNumEqs(), getNumEqualities()); removeEqualityRange(counts.getNumEqs(), getNumEqualities());
} }
SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
// Compute the symbolic lexmin of the dims and locals, with the symbols being
// the actual symbols of this set.
SymbolicLexMin result =
SymbolicLexSimplex(
*this, PresburgerSpace::getSetSpace(/*numDims=*/getNumSymbolIds()))
.computeSymbolicIntegerLexMin();
// We want to return only the lexmin over the dims, so strip the locals from
// the computed lexmin.
result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
getNumLocalIds());
return result;
}
unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) {
assert(pos <= getNumIdKind(kind)); assert(pos <= getNumIdKind(kind));

View File

@ -66,6 +66,14 @@ unsigned Matrix::appendExtraRow() {
return nRows - 1; return nRows - 1;
} }
unsigned Matrix::appendExtraRow(ArrayRef<int64_t> elems) {
assert(elems.size() == nColumns && "elems must match row length!");
unsigned row = appendExtraRow();
for (unsigned col = 0; col < nColumns; ++col)
at(row, col) = elems[col];
return row;
}
void Matrix::resizeHorizontally(unsigned newNColumns) { void Matrix::resizeHorizontally(unsigned newNColumns) {
if (newNColumns < nColumns) if (newNColumns < nColumns)
removeColumns(newNColumns, nColumns - newNColumns); removeColumns(newNColumns, nColumns - newNColumns);

View File

@ -114,6 +114,18 @@ void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
IntegerPolyhedron::eliminateRedundantLocalId(posA, posB); IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
} }
void MultiAffineFunction::truncateOutput(unsigned count) {
assert(count <= output.getNumRows());
output.resizeVertically(count);
}
void PWMAFunction::truncateOutput(unsigned count) {
assert(count <= numOutputs);
for (MultiAffineFunction &piece : pieces)
piece.truncateOutput(count);
numOutputs = count;
}
bool MultiAffineFunction::isEqualWhereDomainsOverlap( bool MultiAffineFunction::isEqualWhereDomainsOverlap(
MultiAffineFunction other) const { MultiAffineFunction other) const {
if (!isSpaceCompatible(other)) if (!isSpaceCompatible(other))

View File

@ -18,15 +18,24 @@ using Direction = Simplex::Direction;
const int nullIndex = std::numeric_limits<int>::max(); const int nullIndex = std::numeric_limits<int>::max();
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM) SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
unsigned nSymbol)
: usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar), : usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
nRedundant(0), tableau(0, nCol), empty(false) { nRedundant(0), nSymbol(nSymbol), tableau(0, nCol), empty(false) {
assert(symbolOffset + nSymbol <= nVar);
colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex); colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
for (unsigned i = 0; i < nVar; ++i) { for (unsigned i = 0; i < nVar; ++i) {
var.emplace_back(Orientation::Column, /*restricted=*/false, var.emplace_back(Orientation::Column, /*restricted=*/false,
/*pos=*/getNumFixedCols() + i); /*pos=*/getNumFixedCols() + i);
colUnknown.push_back(i); colUnknown.push_back(i);
} }
// Move the symbols to be in columns [3, 3 + nSymbol).
for (unsigned i = 0; i < nSymbol; ++i) {
var[symbolOffset + i].isSymbol = true;
swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i);
}
} }
const Simplex::Unknown &SimplexBase::unknownFromIndex(int index) const { const Simplex::Unknown &SimplexBase::unknownFromIndex(int index) const {
@ -96,9 +105,13 @@ unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
// where M is the big M parameter. As such, when the user tries to add // where M is the big M parameter. As such, when the user tries to add
// a row ax + by + cz + d, we express it in terms of our internal variables // a row ax + by + cz + d, we express it in terms of our internal variables
// as -(a + b + c)M + a(M + x) + b(M + y) + c(M + z) + d. // as -(a + b + c)M + a(M + x) + b(M + y) + c(M + z) + d.
//
// Symbols don't use the big M parameter since they do not get lex
// optimized.
int64_t bigMCoeff = 0; int64_t bigMCoeff = 0;
for (unsigned i = 0; i < coeffs.size() - 1; ++i) for (unsigned i = 0; i < coeffs.size() - 1; ++i)
bigMCoeff -= coeffs[i]; if (!var[i].isSymbol)
bigMCoeff -= coeffs[i];
// The coefficient to the big M parameter is stored in column 2. // The coefficient to the big M parameter is stored in column 2.
tableau(nRow - 1, 2) = bigMCoeff; tableau(nRow - 1, 2) = bigMCoeff;
} }
@ -164,19 +177,97 @@ Direction flippedDirection(Direction direction) {
} }
} // namespace } // namespace
/// We simply make the tableau consistent while maintaining a lexicopositive
/// basis transform, and then return the sample value. If the tableau becomes
/// empty, we return empty.
///
/// Let the variables be x = (x_1, ... x_n).
/// Let the basis unknowns be y = (y_1, ... y_n).
/// We have that x = A*y + b for some n x n matrix A and n x 1 column vector b.
///
/// As we will show below, A*y is either zero or lexicopositive.
/// Adding a lexicopositive vector to b will make it lexicographically
/// greater, so A*y + b is always equal to or lexicographically greater than b.
/// Thus, since we can attain x = b, that is the lexicographic minimum.
///
/// We have that that every column in A is lexicopositive, i.e., has at least
/// one non-zero element, with the first such element being positive. Since for
/// the tableau to be consistent we must have non-negative sample values not
/// only for the constraints but also for the variables, we also have x >= 0 and
/// y >= 0, by which we mean every element in these vectors is non-negative.
///
/// Proof that if every column in A is lexicopositive, and y >= 0, then
/// A*y is zero or lexicopositive. Begin by considering A_1, the first row of A.
/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next
/// row. If we run out of rows, A*y is zero and we are done; otherwise, we
/// encounter some row A_i that has a non-zero element. Every column is
/// lexicopositive and so has some positive element before any negative elements
/// occur, so the element in this row for any column, if non-zero, must be
/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are
/// non-negative, so if this is non-zero then it must be positive. Then the
/// first non-zero element of A*y is positive so A*y is lexicopositive.
///
/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero
/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y
/// and we can completely ignore these columns of A. We now continue downwards,
/// looking for rows of A that have a non-zero element other than in the ignored
/// columns. If we find one, say A_k, once again these elements must be positive
/// since they are the first non-zero element in each of these columns, so if
/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we
/// add these to the set of ignored columns and continue to the next row. If we
/// run out of rows, then A*y is zero and we are done.
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() { MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() {
restoreRationalConsistency(); if (restoreRationalConsistency().failed())
return OptimumKind::Empty;
return getRationalSample(); return getRationalSample();
} }
/// Given a row that has a non-integer sample value, add an inequality such
/// that this fractional sample value is cut away from the polytope. The added
/// inequality will be such that no integer points are removed. i.e., the
/// integer lexmin, if it exists, is the same with and without this constraint.
///
/// Let the row be
/// (c + coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n)/d,
/// where s_1, ... s_m are the symbols and
/// y_1, ... y_n are the other basis unknowns.
///
/// For this to be an integer, we want
/// coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n = -c (mod d)
/// Note that this constraint must always hold, independent of the basis,
/// becuse the row unknown's value always equals this expression, even if *we*
/// later compute the sample value from a different expression based on a
/// different basis.
///
/// Let us assume that M has a factor of d in it. Imposing this constraint on M
/// does not in any way hinder us from finding a value of M that is big enough.
/// Moreover, this function is only called when the symbolic part of the sample,
/// a_1*s_1 + ... + a_m*s_m, is known to be an integer.
///
/// Also, we can safely reduce the coefficients modulo d, so we have:
///
/// (b_1%d)y_1 + ... + (b_n%d)y_n = (-c%d) + k*d for some integer `k`
///
/// Note that all coefficient modulos here are non-negative. Also, all the
/// unknowns are non-negative here as both constraints and variables are
/// non-negative in LexSimplexBase. (We used the big M trick to make the
/// variables non-negative). Therefore, the LHS here is non-negative.
/// Since 0 <= (-c%d) < d, k is the quotient of dividing the LHS by d and
/// is therefore non-negative as well.
///
/// So we have
/// ((b_1%d)y_1 + ... + (b_n%d)y_n - (-c%d))/d >= 0.
///
/// The constraint is violated when added (it would be useless otherwise)
/// so we immediately try to move it to a column.
LogicalResult LexSimplexBase::addCut(unsigned row) { LogicalResult LexSimplexBase::addCut(unsigned row) {
int64_t denom = tableau(row, 0); int64_t d = tableau(row, 0);
addZeroRow(/*makeRestricted=*/true); addZeroRow(/*makeRestricted=*/true);
tableau(nRow - 1, 0) = denom; tableau(nRow - 1, 0) = d;
tableau(nRow - 1, 1) = -mod(-tableau(row, 1), denom); tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -c%d.
tableau(nRow - 1, 2) = 0; // M has all factors in it. tableau(nRow - 1, 2) = 0;
for (unsigned col = 3; col < nCol; ++col) for (unsigned col = 3 + nSymbol; col < nCol; ++col)
tableau(nRow - 1, col) = mod(tableau(row, col), denom); tableau(nRow - 1, col) = mod(tableau(row, col), d); // b_i%d.
return moveRowUnknownToColumn(nRow - 1); return moveRowUnknownToColumn(nRow - 1);
} }
@ -185,7 +276,7 @@ Optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
if (u.orientation == Orientation::Column) if (u.orientation == Orientation::Column)
continue; continue;
// If the sample value is of the form (a/d)M + b/d, we need b to be // If the sample value is of the form (a/d)M + b/d, we need b to be
// divisible by d. We assume M is very large and contains all possible // divisible by d. We assume M contains all possible
// factors and is divisible by everything. // factors and is divisible by everything.
unsigned row = u.pos; unsigned row = u.pos;
if (tableau(row, 1) % tableau(row, 0) != 0) if (tableau(row, 1) % tableau(row, 0) != 0)
@ -195,28 +286,34 @@ Optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
} }
MaybeOptimum<SmallVector<int64_t, 8>> LexSimplex::findIntegerLexMin() { MaybeOptimum<SmallVector<int64_t, 8>> LexSimplex::findIntegerLexMin() {
while (!empty) { // We first try to make the tableau consistent.
restoreRationalConsistency(); if (restoreRationalConsistency().failed())
if (empty) return OptimumKind::Empty;
// Then, if the sample value is integral, we are done.
while (Optional<unsigned> maybeRow = maybeGetNonIntegralVarRow()) {
// Otherwise, for the variable whose row has a non-integral sample value,
// we add a cut, a constraint that remove this rational point
// while preserving all integer points, thus keeping the lexmin the same.
// We then again try to make the tableau with the new constraint
// consistent. This continues until the tableau becomes empty, in which
// case there is no integer point, or until there are no variables with
// non-integral sample values.
//
// Failure indicates that the tableau became empty, which occurs when the
// polytope is integer empty.
if (addCut(*maybeRow).failed())
return OptimumKind::Empty;
if (restoreRationalConsistency().failed())
return OptimumKind::Empty; return OptimumKind::Empty;
if (Optional<unsigned> maybeRow = maybeGetNonIntegralVarRow()) {
// Failure occurs when the polytope is integer empty.
if (failed(addCut(*maybeRow)))
return OptimumKind::Empty;
continue;
}
MaybeOptimum<SmallVector<Fraction, 8>> sample = getRationalSample();
assert(!sample.isEmpty() && "If we reached here the sample should exist!");
if (sample.isUnbounded())
return OptimumKind::Unbounded;
return llvm::to_vector<8>(
llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger)));
} }
// Polytope is integer empty. MaybeOptimum<SmallVector<Fraction, 8>> sample = getRationalSample();
return OptimumKind::Empty; assert(!sample.isEmpty() && "If we reached here the sample should exist!");
if (sample.isUnbounded())
return OptimumKind::Unbounded;
return llvm::to_vector<8>(
llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger)));
} }
bool LexSimplex::isSeparateInequality(ArrayRef<int64_t> coeffs) { bool LexSimplex::isSeparateInequality(ArrayRef<int64_t> coeffs) {
@ -228,6 +325,319 @@ bool LexSimplex::isSeparateInequality(ArrayRef<int64_t> coeffs) {
bool LexSimplex::isRedundantInequality(ArrayRef<int64_t> coeffs) { bool LexSimplex::isRedundantInequality(ArrayRef<int64_t> coeffs) {
return isSeparateInequality(getComplementIneq(coeffs)); return isSeparateInequality(getComplementIneq(coeffs));
} }
SmallVector<int64_t, 8>
SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const {
SmallVector<int64_t, 8> sample;
sample.reserve(nSymbol + 1);
for (unsigned col = 3; col < 3 + nSymbol; ++col)
sample.push_back(tableau(row, col));
sample.push_back(tableau(row, 1));
return sample;
}
void LexSimplexBase::appendSymbol() {
appendVariable();
swapColumns(3 + nSymbol, nCol - 1);
var.back().isSymbol = true;
nSymbol++;
}
static bool isRangeDivisibleBy(ArrayRef<int64_t> range, int64_t divisor) {
assert(divisor > 0 && "divisor must be positive!");
return llvm::all_of(range, [divisor](int64_t x) { return x % divisor == 0; });
}
bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const {
int64_t denom = tableau(row, 0);
return tableau(row, 1) % denom == 0 &&
isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom);
}
/// This proceeds similarly to LexSimplex::addCut(). We are given a row that has
/// a symbolic sample value with fractional coefficients.
///
/// Let the row be
/// (c + coeffM*M + sum_i a_i*s_i + sum_j b_j*y_j)/d,
/// where s_1, ... s_m are the symbols and
/// y_1, ... y_n are the other basis unknowns.
///
/// As in LexSimplex::addCut, for this to be an integer, we want
///
/// coeffM*M + sum_j b_j*y_j = -c + sum_i (-a_i*s_i) (mod d)
///
/// This time, a_1*s_1 + ... + a_m*s_m may not be an integer. We find that
///
/// sum_i (b_i%d)y_i = ((-c%d) + sum_i (-a_i%d)s_i)%d + k*d for some integer k
///
/// where we take a modulo of the whole symbolic expression on the right to
/// bring it into the range [0, d - 1]. Therefore, as in LexSimplex::addCut,
/// k is the quotient on dividing the LHS by d, and since LHS >= 0, we have
/// k >= 0 as well. We realize the modulo of the symbolic expression by adding a
/// division variable
///
/// q = ((-c%d) + sum_i (-a_i%d)s_i)/d
///
/// to the symbol domain, so the equality becomes
///
/// sum_i (b_i%d)y_i = (-c%d) + sum_i (-a_i%d)s_i - q*d + k*d for some integer k
///
/// So the cut is
/// (sum_i (b_i%d)y_i - (-c%d) - sum_i (-a_i%d)s_i + q*d)/d >= 0
/// This constraint is violated when added so we immediately try to move it to a
/// column.
LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
int64_t d = tableau(row, 0);
// Add the division variable `q` described above to the symbol domain.
// q = ((-c%d) + sum_i (-a_i%d)s_i)/d.
SmallVector<int64_t, 8> domainDivCoeffs;
domainDivCoeffs.reserve(nSymbol + 1);
for (unsigned col = 3; col < 3 + nSymbol; ++col)
domainDivCoeffs.push_back(mod(-tableau(row, col), d)); // (-a_i%d)s_i
domainDivCoeffs.push_back(mod(-tableau(row, 1), d)); // -c%d.
domainSimplex.addDivisionVariable(domainDivCoeffs, d);
domainPoly.addLocalFloorDiv(domainDivCoeffs, d);
// Update `this` to account for the additional symbol we just added.
appendSymbol();
// Add the cut (sum_i (b_i%d)y_i - (-c%d) + sum_i -(-a_i%d)s_i + q*d)/d >= 0.
addZeroRow(/*makeRestricted=*/true);
tableau(nRow - 1, 0) = d;
tableau(nRow - 1, 2) = 0;
tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -(-c%d).
for (unsigned col = 3; col < 3 + nSymbol - 1; ++col)
tableau(nRow - 1, col) = -mod(-tableau(row, col), d); // -(-a_i%d)s_i.
tableau(nRow - 1, 3 + nSymbol - 1) = d; // q*d.
for (unsigned col = 3 + nSymbol; col < nCol; ++col)
tableau(nRow - 1, col) = mod(tableau(row, col), d); // (b_i%d)y_i.
return moveRowUnknownToColumn(nRow - 1);
}
void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
Matrix output(0, domainPoly.getNumIds() + 1);
output.reserveRows(result.lexmin.getNumOutputs());
for (const Unknown &u : var) {
if (u.isSymbol)
continue;
if (u.orientation == Orientation::Column) {
// M + u has a sample value of zero so u has a sample value of -M, i.e,
// unbounded.
result.unboundedDomain.unionInPlace(domainPoly);
return;
}
int64_t denom = tableau(u.pos, 0);
if (tableau(u.pos, 2) < denom) {
// M + u has a sample value of fM + something, where f < 1, so
// u = (f - 1)M + something, which has a negative coefficient for M,
// and so is unbounded.
result.unboundedDomain.unionInPlace(domainPoly);
return;
}
assert(tableau(u.pos, 2) == denom &&
"Coefficient of M should not be greater than 1!");
SmallVector<int64_t, 8> sample = getSymbolicSampleNumerator(u.pos);
for (int64_t &elem : sample) {
assert(elem % denom == 0 && "coefficients must be integral!");
elem /= denom;
}
output.appendExtraRow(sample);
}
result.lexmin.addPiece(domainPoly, output);
}
Optional<unsigned> SymbolicLexSimplex::maybeGetAlwaysViolatedRow() {
// First look for rows that are clearly violated just from the big M
// coefficient, without needing to perform any simplex queries on the domain.
for (unsigned row = 0; row < nRow; ++row)
if (tableau(row, 2) < 0)
return row;
for (unsigned row = 0; row < nRow; ++row) {
if (tableau(row, 2) > 0)
continue;
if (domainSimplex.isSeparateInequality(getSymbolicSampleNumerator(row))) {
// Sample numerator always takes negative values in the symbol domain.
return row;
}
}
return {};
}
Optional<unsigned> SymbolicLexSimplex::maybeGetNonIntegralVarRow() {
for (const Unknown &u : var) {
if (u.orientation == Orientation::Column)
continue;
assert(!u.isSymbol && "Symbol should not be in row orientation!");
if (!isSymbolicSampleIntegral(u.pos))
return u.pos;
}
return {};
}
/// The non-branching pivots are just the ones moving the rows
/// that are always violated in the symbol domain.
LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
while (Optional<unsigned> row = maybeGetAlwaysViolatedRow())
if (moveRowUnknownToColumn(*row).failed())
return failure();
return success();
}
SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
SymbolicLexMin result(nSymbol, var.size() - nSymbol);
/// The algorithm is more naturally expressed recursively, but we implement
/// it iteratively here to avoid potential issues with stack overflows in the
/// compiler. We explicitly maintain the stack frames in a vector.
///
/// To "recurse", we store the current "stack frame", i.e., state variables
/// that we will need when we "return", into `stack`, increment `level`, and
/// `continue`. To "tail recurse", we just `continue`.
/// To "return", we decrement `level` and `continue`.
///
/// When there is no stack frame for the current `level`, this indicates that
/// we have just "recursed" or "tail recursed". When there does exist one,
/// this indicates that we have just "returned" from recursing. There is only
/// one point at which non-tail calls occur so we always "return" there.
unsigned level = 1;
struct StackFrame {
int splitIndex;
unsigned snapshot;
unsigned domainSnapshot;
IntegerRelation::CountsSnapshot domainPolyCounts;
};
SmallVector<StackFrame, 8> stack;
while (level > 0) {
assert(level >= stack.size());
if (level > stack.size()) {
if (empty || domainSimplex.findIntegerLexMin().isEmpty()) {
// No integer points; return.
--level;
continue;
}
if (doNonBranchingPivots().failed()) {
// Could not find pivots for violated constraints; return.
--level;
continue;
}
unsigned splitRow;
SmallVector<int64_t, 8> symbolicSample;
for (splitRow = 0; splitRow < nRow; ++splitRow) {
if (tableau(splitRow, 2) > 0)
continue;
assert(tableau(splitRow, 2) == 0 &&
"Non-branching pivots should have been handled already!");
symbolicSample = getSymbolicSampleNumerator(splitRow);
if (domainSimplex.isRedundantInequality(symbolicSample))
continue;
// It's neither redundant nor separate, so it takes both positive and
// negative values, and hence constitutes a row for which we need to
// split the domain and separately run each case.
assert(!domainSimplex.isSeparateInequality(symbolicSample) &&
"Non-branching pivots should have been handled already!");
break;
}
if (splitRow < nRow) {
unsigned domainSnapshot = domainSimplex.getSnapshot();
IntegerRelation::CountsSnapshot domainPolyCounts =
domainPoly.getCounts();
// First, we consider the part of the domain where the row is not
// violated. We don't have to do any pivots for the row in this case,
// but we record the additional constraint that defines this part of
// the domain.
domainSimplex.addInequality(symbolicSample);
domainPoly.addInequality(symbolicSample);
// Recurse.
//
// On return, the basis as a set is preserved but not the internal
// ordering within rows or columns. Thus, we take note of the index of
// the Unknown that caused the split, which may be in a different
// row when we come back from recursing. We will need this to recurse
// on the other part of the split domain, where the row is violated.
//
// Note that we have to capture the index above and not a reference to
// the Unknown itself, since the array it lives in might get
// reallocated.
int splitIndex = rowUnknown[splitRow];
unsigned snapshot = getSnapshot();
stack.push_back(
{splitIndex, snapshot, domainSnapshot, domainPolyCounts});
++level;
continue;
}
// The tableau is rationally consistent for the current domain.
// Now we look for non-integral sample values and add cuts for them.
if (Optional<unsigned> row = maybeGetNonIntegralVarRow()) {
if (addSymbolicCut(*row).failed()) {
// No integral points; return.
--level;
continue;
}
// Rerun this level with the added cut constraint (tail recurse).
continue;
}
// Record output and return.
recordOutput(result);
--level;
continue;
}
if (level == stack.size()) {
// We have "returned" from "recursing".
const StackFrame &frame = stack.back();
domainPoly.truncate(frame.domainPolyCounts);
domainSimplex.rollback(frame.domainSnapshot);
rollback(frame.snapshot);
const Unknown &u = unknownFromIndex(frame.splitIndex);
// Drop the frame. We don't need it anymore.
stack.pop_back();
// Now we consider the part of the domain where the unknown `splitIndex`
// was negative.
assert(u.orientation == Orientation::Row &&
"The split row should have been returned to row orientation!");
SmallVector<int64_t, 8> splitIneq =
getComplementIneq(getSymbolicSampleNumerator(u.pos));
if (moveRowUnknownToColumn(u.pos).failed()) {
// The unknown can't be made non-negative; return.
--level;
continue;
}
// The unknown can be made negative; recurse with the corresponding domain
// constraints.
domainSimplex.addInequality(splitIneq);
domainPoly.addInequality(splitIneq);
// We are now taking care of the second half of the domain and we don't
// need to do anything else here after returning, so it's a tail recurse.
continue;
}
}
return result;
}
bool LexSimplex::rowIsViolated(unsigned row) const { bool LexSimplex::rowIsViolated(unsigned row) const {
if (tableau(row, 2) < 0) if (tableau(row, 2) < 0)
return true; return true;
@ -243,19 +653,20 @@ Optional<unsigned> LexSimplex::maybeGetViolatedRow() const {
return {}; return {};
} }
// We simply look for violated rows and keep trying to move them to column /// We simply look for violated rows and keep trying to move them to column
// orientation, which always succeeds unless the constraints have no solution /// orientation, which always succeeds unless the constraints have no solution
// in which case we just give up and return. /// in which case we just give up and return.
void LexSimplex::restoreRationalConsistency() { LogicalResult LexSimplex::restoreRationalConsistency() {
while (Optional<unsigned> maybeViolatedRow = maybeGetViolatedRow()) { if (empty)
LogicalResult status = moveRowUnknownToColumn(*maybeViolatedRow); return failure();
if (failed(status)) while (Optional<unsigned> maybeViolatedRow = maybeGetViolatedRow())
return; if (moveRowUnknownToColumn(*maybeViolatedRow).failed())
} return failure();
return success();
} }
// Move the row unknown to column orientation while preserving lexicopositivity // Move the row unknown to column orientation while preserving lexicopositivity
// of the basis transform. // of the basis transform. The sample value of the row must be negative.
// //
// We only consider pivots where the pivot element is positive. Suppose no such // We only consider pivots where the pivot element is positive. Suppose no such
// pivot exists, i.e., some violated row has no positive coefficient for any // pivot exists, i.e., some violated row has no positive coefficient for any
@ -318,7 +729,7 @@ void LexSimplex::restoreRationalConsistency() {
// minimizes the change in sample value. // minimizes the change in sample value.
LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) { LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
Optional<unsigned> maybeColumn; Optional<unsigned> maybeColumn;
for (unsigned col = 3; col < nCol; ++col) { for (unsigned col = 3 + nSymbol; col < nCol; ++col) {
if (tableau(row, col) <= 0) if (tableau(row, col) <= 0)
continue; continue;
maybeColumn = maybeColumn =
@ -336,6 +747,7 @@ LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
unsigned colB) const { unsigned colB) const {
// First, let's consider the non-symbolic case.
// A pivot causes the following change. (in the diagram the matrix elements // A pivot causes the following change. (in the diagram the matrix elements
// are shown as rationals and there is no common denominator used) // are shown as rationals and there is no common denominator used)
// //
@ -359,7 +771,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
// (-p/a)M + (-b/a), i.e. 0 to -(pM + b)/a. Thus the change in the sample // (-p/a)M + (-b/a), i.e. 0 to -(pM + b)/a. Thus the change in the sample
// value is -s/a. // value is -s/a.
// //
// If the variable is the pivot row, it sampel value goes from s to 0, for a // If the variable is the pivot row, its sample value goes from s to 0, for a
// change of -s. // change of -s.
// //
// If the variable is a non-pivot row, its sample value changes from // If the variable is a non-pivot row, its sample value changes from
@ -373,8 +785,12 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
// comparisons involved and can be ignored, since -s is strictly positive. // comparisons involved and can be ignored, since -s is strictly positive.
// //
// Thus we take away this common factor and just return 0, 1/a, 1, or c/a as // Thus we take away this common factor and just return 0, 1/a, 1, or c/a as
// appropriate. This allows us to run the entire algorithm without ever having // appropriate. This allows us to run the entire algorithm treating M
// to fix a value of M. // symbolically, as the pivot to be performed does not depend on the value
// of M, so long as the sample value s is negative. Note that this is not
// because of any special feature of M; by the same argument, we ignore the
// symbols too. The caller ensure that the sample value s is negative for
// all possible values of the symbols.
auto getSampleChangeCoeffForVar = [this, row](unsigned col, auto getSampleChangeCoeffForVar = [this, row](unsigned col,
const Unknown &u) -> Fraction { const Unknown &u) -> Fraction {
int64_t a = tableau(row, col); int64_t a = tableau(row, col);
@ -489,6 +905,7 @@ void SimplexBase::pivot(Pivot pair) { pivot(pair.row, pair.column); }
/// element. /// element.
void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) { void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) {
assert(pivotCol >= getNumFixedCols() && "Refusing to pivot invalid column"); assert(pivotCol >= getNumFixedCols() && "Refusing to pivot invalid column");
assert(!unknownFromColumn(pivotCol).isSymbol);
swapRowWithCol(pivotRow, pivotCol); swapRowWithCol(pivotRow, pivotCol);
std::swap(tableau(pivotRow, 0), tableau(pivotRow, pivotCol)); std::swap(tableau(pivotRow, 0), tableau(pivotRow, pivotCol));
@ -778,6 +1195,9 @@ void SimplexBase::undo(UndoLogEntry entry) {
assert(var.back().orientation == Orientation::Column && assert(var.back().orientation == Orientation::Column &&
"Variable to be removed must be in column orientation!"); "Variable to be removed must be in column orientation!");
if (var.back().isSymbol)
nSymbol--;
// Move this variable to the last column and remove the column from the // Move this variable to the last column and remove the column from the
// tableau. // tableau.
swapColumns(var.back().pos, nCol - 1); swapColumns(var.back().pos, nCol - 1);

View File

@ -8,6 +8,7 @@
#include "./Utils.h" #include "./Utils.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Simplex.h"
#include <gmock/gmock.h> #include <gmock/gmock.h>
@ -1134,6 +1135,229 @@ TEST(IntegerPolyhedronTest, findIntegerLexMin) {
">= 0, -11*z + 5*y - 3*x + 7 >= 0)")); ">= 0, -11*z + 5*y - 3*x + 7 >= 0)"));
} }
void expectSymbolicIntegerLexMin(
StringRef polyStr,
ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
expectedLexminRepr,
ArrayRef<StringRef> expectedUnboundedDomainRepr) {
IntegerPolyhedron poly = parsePoly(polyStr);
ASSERT_NE(poly.getNumDimIds(), 0u);
ASSERT_NE(poly.getNumSymbolIds(), 0u);
PWMAFunction expectedLexmin =
parsePWMAF(/*numInputs=*/poly.getNumSymbolIds(),
/*numOutputs=*/poly.getNumDimIds(), expectedLexminRepr);
PresburgerSet expectedUnboundedDomain = parsePresburgerSetFromPolyStrings(
poly.getNumSymbolIds(), expectedUnboundedDomainRepr);
SymbolicLexMin result = poly.findSymbolicIntegerLexMin();
EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin));
if (!result.lexmin.isEqual(expectedLexmin)) {
llvm::errs() << "got:\n";
result.lexmin.dump();
llvm::errs() << "expected:\n";
expectedLexmin.dump();
}
EXPECT_TRUE(result.unboundedDomain.isEqual(expectedUnboundedDomain));
if (!result.unboundedDomain.isEqual(expectedUnboundedDomain))
result.unboundedDomain.dump();
}
void expectSymbolicIntegerLexMin(
StringRef polyStr,
ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
result) {
expectSymbolicIntegerLexMin(polyStr, result, {});
}
TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
expectSymbolicIntegerLexMin("(x)[a] : (x - a >= 0)",
{
{"(a) : ()", {{1, 0}}}, // a
});
expectSymbolicIntegerLexMin(
"(x)[a, b] : (x - a >= 0, x - b >= 0)",
{
{"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a
{"(a, b) : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b
});
expectSymbolicIntegerLexMin(
"(x)[a, b, c] : (x -a >= 0, x - b >= 0, x - c >= 0)",
{
{"(a, b, c) : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}}, // a
{"(a, b, c) : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}}, // b
{"(a, b, c) : (c - a - 1 >= 0, c - b - 1 >= 0)", {{0, 0, 1, 0}}}, // c
});
expectSymbolicIntegerLexMin("(x, y)[a] : (x - a >= 0, x + y >= 0)",
{
{"(a) : ()", {{1, 0}, {-1, 0}}}, // (a, -a)
});
expectSymbolicIntegerLexMin(
"(x, y)[a] : (x - a >= 0, x + y >= 0, y >= 0)",
{
{"(a) : (a >= 0)", {{1, 0}, {0, 0}}}, // (a, 0)
{"(a) : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a)
});
expectSymbolicIntegerLexMin(
"(x, y)[a, b, c] : (x - a >= 0, y - b >= 0, c - x - y >= 0)",
{
{"(a, b, c) : (c - a - b >= 0)",
{{1, 0, 0, 0}, {0, 1, 0, 0}}}, // (a, b)
});
expectSymbolicIntegerLexMin(
"(x, y, z)[a, b, c] : (c - z >= 0, b - y >= 0, x + y + z - a == 0)",
{
{"(a, b, c) : ()",
{{1, -1, -1, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}}}, // (a - b - c, b, c)
});
expectSymbolicIntegerLexMin(
"(x)[a, b] : (a >= 0, b >= 0, x >= 0, a + b + x - 1 >= 0)",
{
{"(a, b) : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0
{"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
});
expectSymbolicIntegerLexMin(
"(x)[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, 1 - x >= 0, x >= "
"0, a + b + x - 1 >= 0)",
{
{"(a, b) : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= 0)",
{{0, 0, 0}}}, // 0
{"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
});
expectSymbolicIntegerLexMin(
"(x, y, z)[a, b] : (x - a == 0, y - b == 0, x >= 0, y >= 0, z >= 0, x + "
"y + z - 1 >= 0)",
{
{"(a, b) : (a >= 0, b >= 0, 1 - a - b >= 0)",
{{1, 0, 0}, {0, 1, 0}, {-1, -1, 1}}}, // (a, b, 1 - a - b)
{"(a, b) : (a >= 0, b >= 0, a + b - 2 >= 0)",
{{1, 0, 0}, {0, 1, 0}, {0, 0, 0}}}, // (a, b, 0)
});
expectSymbolicIntegerLexMin("(x)[a, b] : (x - a == 0, x - b >= 0)",
{
{"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a
});
expectSymbolicIntegerLexMin(
"(q)[a] : (a - 1 - 3*q == 0, q >= 0)",
{
{"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 1, 0}}}, // a floordiv 3
});
expectSymbolicIntegerLexMin(
"(r, q)[a] : (a - r - 3*q == 0, q >= 0, 1 - r >= 0, r >= 0)",
{
{"(a) : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3)
{"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 1}, {0, 1, 0}}}, // (1 a floordiv 3)
});
expectSymbolicIntegerLexMin(
"(r, q)[a] : (a - r - 3*q == 0, q >= 0, 2 - r >= 0, r - 1 >= 0)",
{
{"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3)
{"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3)
});
expectSymbolicIntegerLexMin(
"(r, q)[a] : (a - r - 3*q == 0, q >= 0, r >= 0)",
{
{"(a) : (a - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3)
{"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3)
{"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3)
});
expectSymbolicIntegerLexMin(
"(x, y, z, w)[g] : ("
// x, y, z, w are boolean variables.
"1 - x >= 0, x >= 0, 1 - y >= 0, y >= 0,"
"1 - z >= 0, z >= 0, 1 - w >= 0, w >= 0,"
// We have some constraints on them:
"x + y + z - 1 >= 0," // x or y or z
"x + y + w - 1 >= 0," // x or y or w
"1 - x + 1 - y + 1 - w - 1 >= 0," // ~x or ~y or ~w
// What's the lexmin solution using exactly g true vars?
"g - x - y - z - w == 0)",
{
{"(g) : (g - 1 == 0)",
{{0, 0}, {0, 1}, {0, 0}, {0, 0}}}, // (0, 1, 0, 0)
{"(g) : (g - 2 == 0)",
{{0, 0}, {0, 0}, {0, 1}, {0, 1}}}, // (0, 0, 1, 1)
{"(g) : (g - 3 == 0)",
{{0, 0}, {0, 1}, {0, 1}, {0, 1}}}, // (0, 1, 1, 1)
});
// Bezout's lemma: if a, b are constants,
// the set of values that ax + by can take is all multiples of gcd(a, b).
expectSymbolicIntegerLexMin(
// If (x, y) is a solution for a given [a, r], then so is (x - 5, y + 2).
// So the lexmin is unbounded if it exists.
"(x, y)[a, r] : (a >= 0, r - a + 14*x + 35*y == 0)", {},
// According to Bezout's lemma, 14x + 35y can take on all multiples
// of 7 and no other values. So the solution exists iff r - a is a
// multiple of 7.
{"(a, r) : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"});
// The lexmins are unbounded.
expectSymbolicIntegerLexMin("(x, y)[a] : (9*x - 4*y - 2*a >= 0)", {},
{"(a) : ()"});
// Test cases adapted from isl.
expectSymbolicIntegerLexMin(
// a = 2b - 2(c - b), c - b >= 0.
// So b is minimized when c = b.
"(b, c)[a] : (a - 4*b + 2*c == 0, c - b >= 0)",
{
{"(a) : (a - 2*(a floordiv 2) == 0)",
{{0, 1, 0}, {0, 1, 0}}}, // (a floordiv 2, a floordiv 2)
});
expectSymbolicIntegerLexMin(
// 0 <= b <= 255, 1 <= a - 512b <= 509,
// b + 8 >= 1 + 16*(b + 8 floordiv 16) // i.e. b % 16 != 8
"(b)[a] : (255 - b >= 0, b >= 0, a - 512*b - 1 >= 0, 512*b -a + 509 >= "
"0, b + 7 - 16*((8 + b) floordiv 16) >= 0)",
{
{"(a) : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv "
"512) - 1 >= 0, 512*(a floordiv 512) - a + 509 >= 0, (a floordiv "
"512) + 7 - 16*((8 + (a floordiv 512)) floordiv 16) >= 0)",
{{0, 1, 0, 0}}}, // (a floordiv 2, a floordiv 2)
});
expectSymbolicIntegerLexMin(
"(a, b)[K, N, x, y] : (N - K - 2 >= 0, K + 4 - N >= 0, x - 4 >= 0, x + 6 "
"- 2*N >= 0, K+N - x - 1 >= 0, a - N + 1 >= 0, K+N-1-a >= 0,a + 6 - b - "
"N >= 0, 2*N - 4 - a >= 0,"
"2*N - 3*K + a - b >= 0, 4*N - K + 1 - 3*b >= 0, b - N >= 0, a - x - 1 "
">= 0)",
{{
"(K, N, x, y) : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + N "
">= 0, N + K - 2 - x >= 0, x - 4 >= 0)",
{{0, 0, 1, 0, 1}, {0, 1, 0, 0, 0}} // (1 + x, N)
}});
}
static void static void
expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly, expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly,
Optional<uint64_t> trueVolume, Optional<uint64_t> trueVolume,