[MLIR][Presburger] LexSimplex::addEquality: add equalities as fixed columns

In LexSimplex, instead of adding equalities as a pair of inequalities,
add them as a single row, move them into the basis, and keep them there.

There will always be a valid basis involving all non-redundant equalities. Such
equalities will then be ignored in some other operations, such as when looking
for pivot columns. This speeds them up a little bit.

More importantly, this is an important precursor patch to adding support for
symbolic integer lexmin, as this heuristic can sometimes make a big difference there.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D122165
This commit is contained in:
Arjun P 2022-03-23 23:04:13 +00:00
parent 08543a5a47
commit 5630143af3
3 changed files with 76 additions and 18 deletions

View File

@ -166,21 +166,21 @@ public:
/// false otherwise.
bool isEmpty() const;
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
/// is the current number of variables, then the corresponding inequality is
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
/// Returns the number of variables in the tableau.
unsigned getNumVariables() const;
/// Returns the number of constraints in the tableau.
unsigned getNumConstraints() const;
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
/// is the current number of variables, then the corresponding inequality is
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
/// is the current number of variables, then the corresponding equality is
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
void addEquality(ArrayRef<int64_t> coeffs);
virtual void addEquality(ArrayRef<int64_t> coeffs) = 0;
/// Add new variables to the end of the list of variables.
void appendVariable(unsigned count = 1);
@ -249,6 +249,14 @@ protected:
/// coefficient for it.
Optional<unsigned> findAnyPivotRow(unsigned col);
/// Return any column that this row can be pivoted with, ignoring tableau
/// consistency. Equality rows are not considered.
///
/// Returns an empty optional if no pivot is possible, which happens only when
/// the column unknown is a variable and no constraint has a non-zero
/// coefficient for it.
Optional<unsigned> findAnyPivotCol(unsigned row);
/// Swap the row with the column in the tableau's data structures but not the
/// tableau itself. This is used by pivot.
void swapRowWithCol(unsigned row, unsigned col);
@ -295,6 +303,7 @@ protected:
RemoveLastVariable,
UnmarkEmpty,
UnmarkLastRedundant,
UnmarkLastEquality,
RestoreBasis
};
@ -308,13 +317,14 @@ protected:
/// Undo the operation represented by the log entry.
void undo(UndoLogEntry entry);
/// Return the number of fixed columns, as described in the constructor above,
/// this is the number of columns beyond those for the variables in var.
unsigned getNumFixedCols() const { return usingBigM ? 3u : 2u; }
unsigned getNumFixedCols() const { return numFixedCols; }
/// Stores whether or not a big M column is present in the tableau.
bool usingBigM;
/// denom + const + maybe M + equality columns
unsigned numFixedCols;
/// The number of rows in the tableau.
unsigned nRow;
@ -435,9 +445,12 @@ public:
///
/// This just adds the inequality to the tableau and does not try to create a
/// consistent tableau configuration.
void addInequality(ArrayRef<int64_t> coeffs) final {
addRow(coeffs, /*makeRestricted=*/true);
}
void addInequality(ArrayRef<int64_t> coeffs) final;
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
/// is the current number of variables, then the corresponding equality is
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
void addEquality(ArrayRef<int64_t> coeffs) final;
/// Get a snapshot of the current state. This is used for rolling back.
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
@ -533,6 +546,11 @@ public:
/// state and marks the Simplex empty if this is not possible.
void addInequality(ArrayRef<int64_t> coeffs) final;
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
/// is the current number of variables, then the corresponding equality is
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
void addEquality(ArrayRef<int64_t> coeffs) final;
/// Compute the maximum or minimum value of the given row, depending on
/// direction. The specified row is never pivoted. On return, the row may
/// have a negative sample value if the direction is down.

View File

@ -19,12 +19,12 @@ using Direction = Simplex::Direction;
const int nullIndex = std::numeric_limits<int>::max();
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
: usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
nRedundant(0), tableau(0, nCol), empty(false) {
colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
: usingBigM(mustUseBigM), numFixedCols(mustUseBigM ? 3 : 2), nRow(0),
nCol(numFixedCols + nVar), nRedundant(0), tableau(0, nCol), empty(false) {
colUnknown.insert(colUnknown.begin(), numFixedCols, nullIndex);
for (unsigned i = 0; i < nVar; ++i) {
var.emplace_back(Orientation::Column, /*restricted=*/false,
/*pos=*/getNumFixedCols() + i);
/*pos=*/numFixedCols + i);
colUnknown.push_back(i);
}
}
@ -309,7 +309,7 @@ void LexSimplex::restoreRationalConsistency() {
// minimizes the change in sample value.
LogicalResult LexSimplex::moveRowUnknownToColumn(unsigned row) {
Optional<unsigned> maybeColumn;
for (unsigned col = 3; col < nCol; ++col) {
for (unsigned col = getNumFixedCols(); col < nCol; ++col) {
if (tableau(row, col) <= 0)
continue;
maybeColumn =
@ -648,7 +648,7 @@ void Simplex::addInequality(ArrayRef<int64_t> coeffs) {
///
/// We simply add two opposing inequalities, which force the expression to
/// be zero.
void SimplexBase::addEquality(ArrayRef<int64_t> coeffs) {
void Simplex::addEquality(ArrayRef<int64_t> coeffs) {
addInequality(coeffs);
SmallVector<int64_t, 8> negatedCoeffs;
for (int64_t coeff : coeffs)
@ -705,6 +705,15 @@ Optional<unsigned> SimplexBase::findAnyPivotRow(unsigned col) {
return {};
}
// This doesn't find a pivot column only if the row has zero coefficients for
// every column not marked as an equality.
Optional<unsigned> SimplexBase::findAnyPivotCol(unsigned row) {
for (unsigned col = getNumFixedCols(); col < nCol; ++col)
if (tableau(row, col) != 0)
return col;
return {};
}
// It's not valid to remove the constraint by deleting the column since this
// would result in an invalid basis.
void Simplex::undoLastConstraint() {
@ -780,6 +789,10 @@ void SimplexBase::undo(UndoLogEntry entry) {
empty = false;
} else if (entry == UndoLogEntry::UnmarkLastRedundant) {
nRedundant--;
} else if (entry == UndoLogEntry::UnmarkLastEquality) {
numFixedCols--;
assert(getNumFixedCols() >= 2 + usingBigM &&
"The denominator, constant, big M and symbols are always fixed!");
} else if (entry == UndoLogEntry::RestoreBasis) {
assert(!savedBases.empty() && "No bases saved!");
@ -1110,6 +1123,26 @@ Optional<SmallVector<Fraction, 8>> Simplex::getRationalSample() const {
return sample;
}
void LexSimplex::addInequality(ArrayRef<int64_t> coeffs) {
addRow(coeffs, /*makeRestricted=*/true);
}
/// Try to make the equality a fixed column by finding any pivot and performing
/// it. The only time this is not possible is when the given equality's
/// direction is already in the span of the existing fixed column equalities. In
/// that case, we just leave it in row position.
void LexSimplex::addEquality(ArrayRef<int64_t> coeffs) {
const Unknown &u = con[addRow(coeffs, /*makeRestricted=*/true)];
Optional<unsigned> pivotCol = findAnyPivotCol(u.pos);
if (!pivotCol)
return;
pivot(u.pos, *pivotCol);
swapColumns(*pivotCol, getNumFixedCols());
numFixedCols++;
undoLog.push_back(UndoLogEntry::UnmarkLastEquality);
}
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
if (empty)
return OptimumKind::Empty;

View File

@ -548,3 +548,10 @@ TEST(SimplexTest, addDivisionVariable) {
ASSERT_TRUE(sample.hasValue());
EXPECT_EQ((*sample)[0] / 2, (*sample)[1]);
}
TEST(LexSimplexTest, addEquality) {
IntegerRelation rel(/*numDomain=*/0, /*numRange=*/1);
rel.addEquality({1, 0});
LexSimplex simplex(rel);
EXPECT_EQ(simplex.getNumConstraints(), 1u);
}