[MLIR][Presburger] Support finding integer lexmin in IntegerPolyhedron

Note: this does not yet support PrebsurgerSets.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D120239
This commit is contained in:
Arjun P 2022-02-21 19:04:12 +00:00
parent df0c16ce00
commit 9f8cb68570
6 changed files with 176 additions and 14 deletions

View File

@ -25,7 +25,7 @@ namespace mlir {
/// representable by 64-bit integers.
struct Fraction {
/// Default constructor initializes the represented rational number to zero.
Fraction() {}
Fraction() = default;
/// Construct a Fraction from a numerator and denominator.
Fraction(int64_t oNum, int64_t oDen) : num(oNum), den(oDen) {
@ -35,6 +35,13 @@ struct Fraction {
}
}
// Return the value of the fraction as an integer. This should only be called
// when the fraction's value is really an integer.
int64_t getAsInteger() const {
assert(num % den == 0 && "Get as integer called on non-integral fraction!");
return num / den;
}
/// The numerator and denominator, respectively. The denominator is always
/// positive.
int64_t num{0}, den{1};

View File

@ -212,6 +212,13 @@ public:
presburger_utils::MaybeOptimum<SmallVector<Fraction, 8>>
getRationalLexMin() const;
/// Same as above, but returns lexicographically minimal integer point.
/// Note: this should be used only when the lexmin is really required.
/// For a generic integer sampling operation, findIntegerSample is more
/// robust and should be preferred.
presburger_utils::MaybeOptimum<SmallVector<int64_t, 8>>
getIntegerLexMin() const;
/// Swap the posA^th identifier with the posB^th identifier.
virtual void swapId(unsigned posA, unsigned posB);

View File

@ -265,6 +265,10 @@ protected:
/// Returns the unknown associated with row.
Unknown &unknownFromRow(unsigned row);
/// Add a new row to the tableau and the associated data structures. The row
/// is initialized to zero.
unsigned addZeroRow(bool makeRestricted = false);
/// Add a new row to the tableau and the associated data structures.
/// The new row is considered to be a constraint; the new Unknown lives in
/// con.
@ -436,6 +440,12 @@ public:
/// Return the lexicographically minimum rational solution to the constraints.
presburger_utils::MaybeOptimum<SmallVector<Fraction, 8>> getRationalLexMin();
/// Return the lexicographically minimum integer solution to the constraints.
///
/// Note: this should be used only when the lexmin is really needed. To obtain
/// any integer sample, use Simplex::findIntegerSample as that is more robust.
presburger_utils::MaybeOptimum<SmallVector<int64_t, 8>> getIntegerLexMin();
protected:
/// Returns the current sample point, which may contain non-integer (rational)
/// coordinates. Returns an empty optimum when the tableau is empty.
@ -446,6 +456,15 @@ protected:
presburger_utils::MaybeOptimum<SmallVector<Fraction, 8>>
getRationalSample() const;
/// Given a row that has a non-integer sample value, add an inequality such
/// that this fractional sample value is cut away from the polytope. The added
/// inequality will be such that no integer points are removed.
///
/// Returns whether the cut constraint could be enforced, i.e. failure if the
/// cut made the polytope empty, and success if it didn't. Failure status
/// indicates that the polytope didn't have any integer points.
LogicalResult addCut(unsigned row);
/// Undo the addition of the last constraint. This is only called while
/// rolling back.
void undoLastConstraint() final;
@ -460,6 +479,10 @@ protected:
/// Otherwise, return an empty optional.
Optional<unsigned> maybeGetViolatedRow() const;
/// Get a row corresponding to a var that has a non-integral sample value, if
/// one exists. Otherwise, return an empty optional.
Optional<unsigned> maybeGetNonIntegeralVarRow() const;
/// Given two potential pivot columns for a row, return the one that results
/// in the lexicographically smallest sample vector.
unsigned getLexMinPivotColumn(unsigned row, unsigned colA,

View File

@ -92,6 +92,26 @@ IntegerPolyhedron::getRationalLexMin() const {
return maybeLexMin;
}
MaybeOptimum<SmallVector<int64_t, 8>>
IntegerPolyhedron::getIntegerLexMin() const {
assert(getNumSymbolIds() == 0 && "Symbols are not supported!");
MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin =
LexSimplex(*this).getIntegerLexMin();
if (!maybeLexMin.isBounded())
return maybeLexMin.getKind();
// The Simplex returns the lexmin over all the variables including locals. But
// locals are not actually part of the space and should not be returned in the
// result. Since the locals are placed last in the list of identifiers, they
// will be minimized last in the lexmin. So simply truncating out the locals
// from the end of the answer gives the desired lexmin over the dimensions.
assert(maybeLexMin->size() == getNumIds() &&
"Incorrect number of vars in lexMin!");
maybeLexMin->resize(getNumDimAndSymbolIds());
return maybeLexMin;
}
unsigned IntegerPolyhedron::insertDimId(unsigned pos, unsigned num) {
return insertId(IdKind::SetDim, pos, num);
}

View File

@ -59,13 +59,7 @@ Simplex::Unknown &SimplexBase::unknownFromRow(unsigned row) {
return unknownFromIndex(rowUnknown[row]);
}
/// Add a new row to the tableau corresponding to the given constant term and
/// list of coefficients. The coefficients are specified as a vector of
/// (variable index, coefficient) pairs.
unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
assert(coeffs.size() == var.size() + 1 &&
"Incorrect number of coefficients!");
unsigned SimplexBase::addZeroRow(bool makeRestricted) {
++nRow;
// If the tableau is not big enough to accomodate the extra row, we extend it.
if (nRow >= tableau.getNumRows())
@ -77,6 +71,17 @@ unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
tableau.fillRow(nRow - 1, 0);
tableau(nRow - 1, 0) = 1;
return con.size() - 1;
}
/// Add a new row to the tableau corresponding to the given constant term and
/// list of coefficients. The coefficients are specified as a vector of
/// (variable index, coefficient) pairs.
unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
assert(coeffs.size() == var.size() + 1 &&
"Incorrect number of coefficients!");
addZeroRow(makeRestricted);
tableau(nRow - 1, 1) = coeffs.back();
if (usingBigM) {
// When the lexicographic pivot rule is used, instead of the variables
@ -164,6 +169,56 @@ MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalLexMin() {
return getRationalSample();
}
LogicalResult LexSimplex::addCut(unsigned row) {
int64_t denom = tableau(row, 0);
addZeroRow(/*makeRestricted=*/true);
tableau(nRow - 1, 0) = denom;
tableau(nRow - 1, 1) = -mod(-tableau(row, 1), denom);
tableau(nRow - 1, 2) = 0; // M has all factors in it.
for (unsigned col = 3; col < nCol; ++col)
tableau(nRow - 1, col) = mod(tableau(row, col), denom);
return moveRowUnknownToColumn(nRow - 1);
}
Optional<unsigned> LexSimplex::maybeGetNonIntegeralVarRow() const {
for (const Unknown &u : var) {
if (u.orientation == Orientation::Column)
continue;
// If the sample value is of the form (a/d)M + b/d, we need b to be
// divisible by d. We assume M is very large and contains all possible
// factors and is divisible by everything.
unsigned row = u.pos;
if (tableau(row, 1) % tableau(row, 0) != 0)
return row;
}
return {};
}
MaybeOptimum<SmallVector<int64_t, 8>> LexSimplex::getIntegerLexMin() {
while (!empty) {
restoreRationalConsistency();
if (empty)
return OptimumKind::Empty;
if (Optional<unsigned> maybeRow = maybeGetNonIntegeralVarRow()) {
// Failure occurs when the polytope is integer empty.
if (failed(addCut(*maybeRow)))
return OptimumKind::Empty;
continue;
}
MaybeOptimum<SmallVector<Fraction, 8>> sample = getRationalSample();
assert(!sample.isEmpty() && "If we reached here the sample should exist!");
if (sample.isUnbounded())
return OptimumKind::Unbounded;
return llvm::to_vector<8>(llvm::map_range(
*sample, [](const Fraction &f) { return f.getAsInteger(); }));
}
// Polytope is integer empty.
return OptimumKind::Empty;
}
bool LexSimplex::rowIsViolated(unsigned row) const {
if (tableau(row, 2) < 0)
return true;

View File

@ -8,6 +8,7 @@
#include "mlir/Analysis/Presburger/IntegerPolyhedron.h"
#include "./Utils.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/IR/MLIRContext.h"
#include <gmock/gmock.h>
@ -36,29 +37,53 @@ makeSetFromConstraints(unsigned ids, ArrayRef<SmallVector<int64_t, 4>> ineqs,
return set;
}
static void dump(ArrayRef<int64_t> vec) {
for (int64_t x : vec)
llvm::errs() << x << ' ';
llvm::errs() << '\n';
}
/// If fn is TestFunction::Sample (default):
/// If hasSample is true, check that findIntegerSample returns a valid sample
/// for the IntegerPolyhedron poly.
/// If hasSample is false, check that findIntegerSample returns None.
///
/// If hasSample is true, check that findIntegerSample returns a valid sample
/// for the IntegerPolyhedron poly. Also check that getIntegerLexmin finds a
/// non-empty lexmin.
///
/// If hasSample is false, check that findIntegerSample returns None and
/// getIntegerLexMin returns Empty.
///
/// If fn is TestFunction::Empty, check that isIntegerEmpty returns the
/// opposite of hasSample.
static void checkSample(bool hasSample, const IntegerPolyhedron &poly,
TestFunction fn = TestFunction::Sample) {
Optional<SmallVector<int64_t, 8>> maybeSample;
MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin;
switch (fn) {
case TestFunction::Sample:
maybeSample = poly.findIntegerSample();
maybeLexMin = poly.getIntegerLexMin();
if (!hasSample) {
EXPECT_FALSE(maybeSample.hasValue());
if (maybeSample.hasValue()) {
for (auto x : *maybeSample)
llvm::errs() << x << ' ';
llvm::errs() << '\n';
llvm::errs() << "findIntegerSample gave sample: ";
dump(*maybeSample);
}
EXPECT_TRUE(maybeLexMin.isEmpty());
if (maybeLexMin.isBounded()) {
llvm::errs() << "getIntegerLexMin gave sample: ";
dump(*maybeLexMin);
}
} else {
ASSERT_TRUE(maybeSample.hasValue());
EXPECT_TRUE(poly.containsPoint(*maybeSample));
ASSERT_FALSE(maybeLexMin.isEmpty());
if (maybeLexMin.isUnbounded())
EXPECT_TRUE(Simplex(poly).isUnbounded());
if (maybeLexMin.isBounded())
EXPECT_TRUE(poly.containsPoint(*maybeLexMin));
}
break;
case TestFunction::Empty:
@ -1138,6 +1163,31 @@ TEST(IntegerPolyhedronTest, getRationalLexMin) {
parsePoly("(x) : (2*x >= 0, -x - 1 >= 0)", &context));
}
void expectIntegerLexMin(const IntegerPolyhedron &poly, ArrayRef<int64_t> min) {
auto lexMin = poly.getIntegerLexMin();
ASSERT_TRUE(lexMin.isBounded());
EXPECT_EQ(ArrayRef<int64_t>(*lexMin), min);
}
void expectNoIntegerLexMin(OptimumKind kind, const IntegerPolyhedron &poly) {
ASSERT_NE(kind, OptimumKind::Bounded)
<< "Use expectRationalLexMin for bounded min";
EXPECT_EQ(poly.getRationalLexMin().getKind(), kind);
}
TEST(IntegerPolyhedronTest, getIntegerLexMin) {
MLIRContext context;
expectIntegerLexMin(parsePoly("(x, y, z) : (2*x + 13 >= 0, 4*y - 3*x - 2 >= "
"0, 11*z + 5*y - 3*x + 7 >= 0)",
&context),
{-6, -4, 0});
// Similar to above but no lower bound on z.
expectNoIntegerLexMin(OptimumKind::Unbounded,
parsePoly("(x, y, z) : (2*x + 13 >= 0, 4*y - 3*x - 2 "
">= 0, -11*z + 5*y - 3*x + 7 >= 0)",
&context));
}
static void
expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly,
Optional<uint64_t> trueVolume,