forked from OSchip/llvm-project
[MLIR][Presburger] LexSimplex::addEquality: add equalities as fixed columns
In LexSimplex, instead of adding equalities as a pair of inequalities, add them as a single row, move them into the basis, and keep them there. There will always be a valid basis involving all non-redundant equalities. Such equalities will then be ignored in some other operations, such as when looking for pivot columns. This speeds them up a little bit. More importantly, this is an important precursor patch to adding support for symbolic integer lexmin, as this heuristic can sometimes make a big difference there. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D122165
This commit is contained in:
parent
08543a5a47
commit
5630143af3
|
@ -166,21 +166,21 @@ public:
|
|||
/// false otherwise.
|
||||
bool isEmpty() const;
|
||||
|
||||
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||
/// is the current number of variables, then the corresponding inequality is
|
||||
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
|
||||
virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
|
||||
|
||||
/// Returns the number of variables in the tableau.
|
||||
unsigned getNumVariables() const;
|
||||
|
||||
/// Returns the number of constraints in the tableau.
|
||||
unsigned getNumConstraints() const;
|
||||
|
||||
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||
/// is the current number of variables, then the corresponding inequality is
|
||||
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
|
||||
virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
|
||||
|
||||
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||
/// is the current number of variables, then the corresponding equality is
|
||||
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
|
||||
void addEquality(ArrayRef<int64_t> coeffs);
|
||||
virtual void addEquality(ArrayRef<int64_t> coeffs) = 0;
|
||||
|
||||
/// Add new variables to the end of the list of variables.
|
||||
void appendVariable(unsigned count = 1);
|
||||
|
@ -249,6 +249,14 @@ protected:
|
|||
/// coefficient for it.
|
||||
Optional<unsigned> findAnyPivotRow(unsigned col);
|
||||
|
||||
/// Return any column that this row can be pivoted with, ignoring tableau
|
||||
/// consistency. Equality rows are not considered.
|
||||
///
|
||||
/// Returns an empty optional if no pivot is possible, which happens only when
|
||||
/// the column unknown is a variable and no constraint has a non-zero
|
||||
/// coefficient for it.
|
||||
Optional<unsigned> findAnyPivotCol(unsigned row);
|
||||
|
||||
/// Swap the row with the column in the tableau's data structures but not the
|
||||
/// tableau itself. This is used by pivot.
|
||||
void swapRowWithCol(unsigned row, unsigned col);
|
||||
|
@ -295,6 +303,7 @@ protected:
|
|||
RemoveLastVariable,
|
||||
UnmarkEmpty,
|
||||
UnmarkLastRedundant,
|
||||
UnmarkLastEquality,
|
||||
RestoreBasis
|
||||
};
|
||||
|
||||
|
@ -308,13 +317,14 @@ protected:
|
|||
/// Undo the operation represented by the log entry.
|
||||
void undo(UndoLogEntry entry);
|
||||
|
||||
/// Return the number of fixed columns, as described in the constructor above,
|
||||
/// this is the number of columns beyond those for the variables in var.
|
||||
unsigned getNumFixedCols() const { return usingBigM ? 3u : 2u; }
|
||||
unsigned getNumFixedCols() const { return numFixedCols; }
|
||||
|
||||
/// Stores whether or not a big M column is present in the tableau.
|
||||
bool usingBigM;
|
||||
|
||||
/// denom + const + maybe M + equality columns
|
||||
unsigned numFixedCols;
|
||||
|
||||
/// The number of rows in the tableau.
|
||||
unsigned nRow;
|
||||
|
||||
|
@ -435,9 +445,12 @@ public:
|
|||
///
|
||||
/// This just adds the inequality to the tableau and does not try to create a
|
||||
/// consistent tableau configuration.
|
||||
void addInequality(ArrayRef<int64_t> coeffs) final {
|
||||
addRow(coeffs, /*makeRestricted=*/true);
|
||||
}
|
||||
void addInequality(ArrayRef<int64_t> coeffs) final;
|
||||
|
||||
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||
/// is the current number of variables, then the corresponding equality is
|
||||
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
|
||||
void addEquality(ArrayRef<int64_t> coeffs) final;
|
||||
|
||||
/// Get a snapshot of the current state. This is used for rolling back.
|
||||
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
|
||||
|
@ -533,6 +546,11 @@ public:
|
|||
/// state and marks the Simplex empty if this is not possible.
|
||||
void addInequality(ArrayRef<int64_t> coeffs) final;
|
||||
|
||||
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||
/// is the current number of variables, then the corresponding equality is
|
||||
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
|
||||
void addEquality(ArrayRef<int64_t> coeffs) final;
|
||||
|
||||
/// Compute the maximum or minimum value of the given row, depending on
|
||||
/// direction. The specified row is never pivoted. On return, the row may
|
||||
/// have a negative sample value if the direction is down.
|
||||
|
|
|
@ -19,12 +19,12 @@ using Direction = Simplex::Direction;
|
|||
const int nullIndex = std::numeric_limits<int>::max();
|
||||
|
||||
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
|
||||
: usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
|
||||
nRedundant(0), tableau(0, nCol), empty(false) {
|
||||
colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
|
||||
: usingBigM(mustUseBigM), numFixedCols(mustUseBigM ? 3 : 2), nRow(0),
|
||||
nCol(numFixedCols + nVar), nRedundant(0), tableau(0, nCol), empty(false) {
|
||||
colUnknown.insert(colUnknown.begin(), numFixedCols, nullIndex);
|
||||
for (unsigned i = 0; i < nVar; ++i) {
|
||||
var.emplace_back(Orientation::Column, /*restricted=*/false,
|
||||
/*pos=*/getNumFixedCols() + i);
|
||||
/*pos=*/numFixedCols + i);
|
||||
colUnknown.push_back(i);
|
||||
}
|
||||
}
|
||||
|
@ -309,7 +309,7 @@ void LexSimplex::restoreRationalConsistency() {
|
|||
// minimizes the change in sample value.
|
||||
LogicalResult LexSimplex::moveRowUnknownToColumn(unsigned row) {
|
||||
Optional<unsigned> maybeColumn;
|
||||
for (unsigned col = 3; col < nCol; ++col) {
|
||||
for (unsigned col = getNumFixedCols(); col < nCol; ++col) {
|
||||
if (tableau(row, col) <= 0)
|
||||
continue;
|
||||
maybeColumn =
|
||||
|
@ -648,7 +648,7 @@ void Simplex::addInequality(ArrayRef<int64_t> coeffs) {
|
|||
///
|
||||
/// We simply add two opposing inequalities, which force the expression to
|
||||
/// be zero.
|
||||
void SimplexBase::addEquality(ArrayRef<int64_t> coeffs) {
|
||||
void Simplex::addEquality(ArrayRef<int64_t> coeffs) {
|
||||
addInequality(coeffs);
|
||||
SmallVector<int64_t, 8> negatedCoeffs;
|
||||
for (int64_t coeff : coeffs)
|
||||
|
@ -705,6 +705,15 @@ Optional<unsigned> SimplexBase::findAnyPivotRow(unsigned col) {
|
|||
return {};
|
||||
}
|
||||
|
||||
// This doesn't find a pivot column only if the row has zero coefficients for
|
||||
// every column not marked as an equality.
|
||||
Optional<unsigned> SimplexBase::findAnyPivotCol(unsigned row) {
|
||||
for (unsigned col = getNumFixedCols(); col < nCol; ++col)
|
||||
if (tableau(row, col) != 0)
|
||||
return col;
|
||||
return {};
|
||||
}
|
||||
|
||||
// It's not valid to remove the constraint by deleting the column since this
|
||||
// would result in an invalid basis.
|
||||
void Simplex::undoLastConstraint() {
|
||||
|
@ -780,6 +789,10 @@ void SimplexBase::undo(UndoLogEntry entry) {
|
|||
empty = false;
|
||||
} else if (entry == UndoLogEntry::UnmarkLastRedundant) {
|
||||
nRedundant--;
|
||||
} else if (entry == UndoLogEntry::UnmarkLastEquality) {
|
||||
numFixedCols--;
|
||||
assert(getNumFixedCols() >= 2 + usingBigM &&
|
||||
"The denominator, constant, big M and symbols are always fixed!");
|
||||
} else if (entry == UndoLogEntry::RestoreBasis) {
|
||||
assert(!savedBases.empty() && "No bases saved!");
|
||||
|
||||
|
@ -1110,6 +1123,26 @@ Optional<SmallVector<Fraction, 8>> Simplex::getRationalSample() const {
|
|||
return sample;
|
||||
}
|
||||
|
||||
void LexSimplex::addInequality(ArrayRef<int64_t> coeffs) {
|
||||
addRow(coeffs, /*makeRestricted=*/true);
|
||||
}
|
||||
|
||||
/// Try to make the equality a fixed column by finding any pivot and performing
|
||||
/// it. The only time this is not possible is when the given equality's
|
||||
/// direction is already in the span of the existing fixed column equalities. In
|
||||
/// that case, we just leave it in row position.
|
||||
void LexSimplex::addEquality(ArrayRef<int64_t> coeffs) {
|
||||
const Unknown &u = con[addRow(coeffs, /*makeRestricted=*/true)];
|
||||
Optional<unsigned> pivotCol = findAnyPivotCol(u.pos);
|
||||
if (!pivotCol)
|
||||
return;
|
||||
|
||||
pivot(u.pos, *pivotCol);
|
||||
swapColumns(*pivotCol, getNumFixedCols());
|
||||
numFixedCols++;
|
||||
undoLog.push_back(UndoLogEntry::UnmarkLastEquality);
|
||||
}
|
||||
|
||||
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
|
||||
if (empty)
|
||||
return OptimumKind::Empty;
|
||||
|
|
|
@ -548,3 +548,10 @@ TEST(SimplexTest, addDivisionVariable) {
|
|||
ASSERT_TRUE(sample.hasValue());
|
||||
EXPECT_EQ((*sample)[0] / 2, (*sample)[1]);
|
||||
}
|
||||
|
||||
TEST(LexSimplexTest, addEquality) {
|
||||
IntegerRelation rel(/*numDomain=*/0, /*numRange=*/1);
|
||||
rel.addEquality({1, 0});
|
||||
LexSimplex simplex(rel);
|
||||
EXPECT_EQ(simplex.getNumConstraints(), 1u);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue