forked from OSchip/llvm-project
[MLIR][Presburger] LexSimplex::addEquality: add equalities as fixed columns
In LexSimplex, instead of adding equalities as a pair of inequalities, add them as a single row, move them into the basis, and keep them there. There will always be a valid basis involving all non-redundant equalities. Such equalities will then be ignored in some other operations, such as when looking for pivot columns. This speeds them up a little bit. More importantly, this is an important precursor patch to adding support for symbolic integer lexmin, as this heuristic can sometimes make a big difference there. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D122165
This commit is contained in:
parent
08543a5a47
commit
5630143af3
|
@ -166,21 +166,21 @@ public:
|
||||||
/// false otherwise.
|
/// false otherwise.
|
||||||
bool isEmpty() const;
|
bool isEmpty() const;
|
||||||
|
|
||||||
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
|
||||||
/// is the current number of variables, then the corresponding inequality is
|
|
||||||
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
|
|
||||||
virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
|
|
||||||
|
|
||||||
/// Returns the number of variables in the tableau.
|
/// Returns the number of variables in the tableau.
|
||||||
unsigned getNumVariables() const;
|
unsigned getNumVariables() const;
|
||||||
|
|
||||||
/// Returns the number of constraints in the tableau.
|
/// Returns the number of constraints in the tableau.
|
||||||
unsigned getNumConstraints() const;
|
unsigned getNumConstraints() const;
|
||||||
|
|
||||||
|
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||||
|
/// is the current number of variables, then the corresponding inequality is
|
||||||
|
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
|
||||||
|
virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
|
||||||
|
|
||||||
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||||
/// is the current number of variables, then the corresponding equality is
|
/// is the current number of variables, then the corresponding equality is
|
||||||
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
|
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
|
||||||
void addEquality(ArrayRef<int64_t> coeffs);
|
virtual void addEquality(ArrayRef<int64_t> coeffs) = 0;
|
||||||
|
|
||||||
/// Add new variables to the end of the list of variables.
|
/// Add new variables to the end of the list of variables.
|
||||||
void appendVariable(unsigned count = 1);
|
void appendVariable(unsigned count = 1);
|
||||||
|
@ -249,6 +249,14 @@ protected:
|
||||||
/// coefficient for it.
|
/// coefficient for it.
|
||||||
Optional<unsigned> findAnyPivotRow(unsigned col);
|
Optional<unsigned> findAnyPivotRow(unsigned col);
|
||||||
|
|
||||||
|
/// Return any column that this row can be pivoted with, ignoring tableau
|
||||||
|
/// consistency. Equality rows are not considered.
|
||||||
|
///
|
||||||
|
/// Returns an empty optional if no pivot is possible, which happens only when
|
||||||
|
/// the column unknown is a variable and no constraint has a non-zero
|
||||||
|
/// coefficient for it.
|
||||||
|
Optional<unsigned> findAnyPivotCol(unsigned row);
|
||||||
|
|
||||||
/// Swap the row with the column in the tableau's data structures but not the
|
/// Swap the row with the column in the tableau's data structures but not the
|
||||||
/// tableau itself. This is used by pivot.
|
/// tableau itself. This is used by pivot.
|
||||||
void swapRowWithCol(unsigned row, unsigned col);
|
void swapRowWithCol(unsigned row, unsigned col);
|
||||||
|
@ -295,6 +303,7 @@ protected:
|
||||||
RemoveLastVariable,
|
RemoveLastVariable,
|
||||||
UnmarkEmpty,
|
UnmarkEmpty,
|
||||||
UnmarkLastRedundant,
|
UnmarkLastRedundant,
|
||||||
|
UnmarkLastEquality,
|
||||||
RestoreBasis
|
RestoreBasis
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -308,13 +317,14 @@ protected:
|
||||||
/// Undo the operation represented by the log entry.
|
/// Undo the operation represented by the log entry.
|
||||||
void undo(UndoLogEntry entry);
|
void undo(UndoLogEntry entry);
|
||||||
|
|
||||||
/// Return the number of fixed columns, as described in the constructor above,
|
unsigned getNumFixedCols() const { return numFixedCols; }
|
||||||
/// this is the number of columns beyond those for the variables in var.
|
|
||||||
unsigned getNumFixedCols() const { return usingBigM ? 3u : 2u; }
|
|
||||||
|
|
||||||
/// Stores whether or not a big M column is present in the tableau.
|
/// Stores whether or not a big M column is present in the tableau.
|
||||||
bool usingBigM;
|
bool usingBigM;
|
||||||
|
|
||||||
|
/// denom + const + maybe M + equality columns
|
||||||
|
unsigned numFixedCols;
|
||||||
|
|
||||||
/// The number of rows in the tableau.
|
/// The number of rows in the tableau.
|
||||||
unsigned nRow;
|
unsigned nRow;
|
||||||
|
|
||||||
|
@ -435,9 +445,12 @@ public:
|
||||||
///
|
///
|
||||||
/// This just adds the inequality to the tableau and does not try to create a
|
/// This just adds the inequality to the tableau and does not try to create a
|
||||||
/// consistent tableau configuration.
|
/// consistent tableau configuration.
|
||||||
void addInequality(ArrayRef<int64_t> coeffs) final {
|
void addInequality(ArrayRef<int64_t> coeffs) final;
|
||||||
addRow(coeffs, /*makeRestricted=*/true);
|
|
||||||
}
|
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||||
|
/// is the current number of variables, then the corresponding equality is
|
||||||
|
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
|
||||||
|
void addEquality(ArrayRef<int64_t> coeffs) final;
|
||||||
|
|
||||||
/// Get a snapshot of the current state. This is used for rolling back.
|
/// Get a snapshot of the current state. This is used for rolling back.
|
||||||
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
|
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
|
||||||
|
@ -533,6 +546,11 @@ public:
|
||||||
/// state and marks the Simplex empty if this is not possible.
|
/// state and marks the Simplex empty if this is not possible.
|
||||||
void addInequality(ArrayRef<int64_t> coeffs) final;
|
void addInequality(ArrayRef<int64_t> coeffs) final;
|
||||||
|
|
||||||
|
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
|
||||||
|
/// is the current number of variables, then the corresponding equality is
|
||||||
|
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
|
||||||
|
void addEquality(ArrayRef<int64_t> coeffs) final;
|
||||||
|
|
||||||
/// Compute the maximum or minimum value of the given row, depending on
|
/// Compute the maximum or minimum value of the given row, depending on
|
||||||
/// direction. The specified row is never pivoted. On return, the row may
|
/// direction. The specified row is never pivoted. On return, the row may
|
||||||
/// have a negative sample value if the direction is down.
|
/// have a negative sample value if the direction is down.
|
||||||
|
|
|
@ -19,12 +19,12 @@ using Direction = Simplex::Direction;
|
||||||
const int nullIndex = std::numeric_limits<int>::max();
|
const int nullIndex = std::numeric_limits<int>::max();
|
||||||
|
|
||||||
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
|
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
|
||||||
: usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
|
: usingBigM(mustUseBigM), numFixedCols(mustUseBigM ? 3 : 2), nRow(0),
|
||||||
nRedundant(0), tableau(0, nCol), empty(false) {
|
nCol(numFixedCols + nVar), nRedundant(0), tableau(0, nCol), empty(false) {
|
||||||
colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
|
colUnknown.insert(colUnknown.begin(), numFixedCols, nullIndex);
|
||||||
for (unsigned i = 0; i < nVar; ++i) {
|
for (unsigned i = 0; i < nVar; ++i) {
|
||||||
var.emplace_back(Orientation::Column, /*restricted=*/false,
|
var.emplace_back(Orientation::Column, /*restricted=*/false,
|
||||||
/*pos=*/getNumFixedCols() + i);
|
/*pos=*/numFixedCols + i);
|
||||||
colUnknown.push_back(i);
|
colUnknown.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -309,7 +309,7 @@ void LexSimplex::restoreRationalConsistency() {
|
||||||
// minimizes the change in sample value.
|
// minimizes the change in sample value.
|
||||||
LogicalResult LexSimplex::moveRowUnknownToColumn(unsigned row) {
|
LogicalResult LexSimplex::moveRowUnknownToColumn(unsigned row) {
|
||||||
Optional<unsigned> maybeColumn;
|
Optional<unsigned> maybeColumn;
|
||||||
for (unsigned col = 3; col < nCol; ++col) {
|
for (unsigned col = getNumFixedCols(); col < nCol; ++col) {
|
||||||
if (tableau(row, col) <= 0)
|
if (tableau(row, col) <= 0)
|
||||||
continue;
|
continue;
|
||||||
maybeColumn =
|
maybeColumn =
|
||||||
|
@ -648,7 +648,7 @@ void Simplex::addInequality(ArrayRef<int64_t> coeffs) {
|
||||||
///
|
///
|
||||||
/// We simply add two opposing inequalities, which force the expression to
|
/// We simply add two opposing inequalities, which force the expression to
|
||||||
/// be zero.
|
/// be zero.
|
||||||
void SimplexBase::addEquality(ArrayRef<int64_t> coeffs) {
|
void Simplex::addEquality(ArrayRef<int64_t> coeffs) {
|
||||||
addInequality(coeffs);
|
addInequality(coeffs);
|
||||||
SmallVector<int64_t, 8> negatedCoeffs;
|
SmallVector<int64_t, 8> negatedCoeffs;
|
||||||
for (int64_t coeff : coeffs)
|
for (int64_t coeff : coeffs)
|
||||||
|
@ -705,6 +705,15 @@ Optional<unsigned> SimplexBase::findAnyPivotRow(unsigned col) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This doesn't find a pivot column only if the row has zero coefficients for
|
||||||
|
// every column not marked as an equality.
|
||||||
|
Optional<unsigned> SimplexBase::findAnyPivotCol(unsigned row) {
|
||||||
|
for (unsigned col = getNumFixedCols(); col < nCol; ++col)
|
||||||
|
if (tableau(row, col) != 0)
|
||||||
|
return col;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
// It's not valid to remove the constraint by deleting the column since this
|
// It's not valid to remove the constraint by deleting the column since this
|
||||||
// would result in an invalid basis.
|
// would result in an invalid basis.
|
||||||
void Simplex::undoLastConstraint() {
|
void Simplex::undoLastConstraint() {
|
||||||
|
@ -780,6 +789,10 @@ void SimplexBase::undo(UndoLogEntry entry) {
|
||||||
empty = false;
|
empty = false;
|
||||||
} else if (entry == UndoLogEntry::UnmarkLastRedundant) {
|
} else if (entry == UndoLogEntry::UnmarkLastRedundant) {
|
||||||
nRedundant--;
|
nRedundant--;
|
||||||
|
} else if (entry == UndoLogEntry::UnmarkLastEquality) {
|
||||||
|
numFixedCols--;
|
||||||
|
assert(getNumFixedCols() >= 2 + usingBigM &&
|
||||||
|
"The denominator, constant, big M and symbols are always fixed!");
|
||||||
} else if (entry == UndoLogEntry::RestoreBasis) {
|
} else if (entry == UndoLogEntry::RestoreBasis) {
|
||||||
assert(!savedBases.empty() && "No bases saved!");
|
assert(!savedBases.empty() && "No bases saved!");
|
||||||
|
|
||||||
|
@ -1110,6 +1123,26 @@ Optional<SmallVector<Fraction, 8>> Simplex::getRationalSample() const {
|
||||||
return sample;
|
return sample;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LexSimplex::addInequality(ArrayRef<int64_t> coeffs) {
|
||||||
|
addRow(coeffs, /*makeRestricted=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to make the equality a fixed column by finding any pivot and performing
|
||||||
|
/// it. The only time this is not possible is when the given equality's
|
||||||
|
/// direction is already in the span of the existing fixed column equalities. In
|
||||||
|
/// that case, we just leave it in row position.
|
||||||
|
void LexSimplex::addEquality(ArrayRef<int64_t> coeffs) {
|
||||||
|
const Unknown &u = con[addRow(coeffs, /*makeRestricted=*/true)];
|
||||||
|
Optional<unsigned> pivotCol = findAnyPivotCol(u.pos);
|
||||||
|
if (!pivotCol)
|
||||||
|
return;
|
||||||
|
|
||||||
|
pivot(u.pos, *pivotCol);
|
||||||
|
swapColumns(*pivotCol, getNumFixedCols());
|
||||||
|
numFixedCols++;
|
||||||
|
undoLog.push_back(UndoLogEntry::UnmarkLastEquality);
|
||||||
|
}
|
||||||
|
|
||||||
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
|
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
|
||||||
if (empty)
|
if (empty)
|
||||||
return OptimumKind::Empty;
|
return OptimumKind::Empty;
|
||||||
|
|
|
@ -548,3 +548,10 @@ TEST(SimplexTest, addDivisionVariable) {
|
||||||
ASSERT_TRUE(sample.hasValue());
|
ASSERT_TRUE(sample.hasValue());
|
||||||
EXPECT_EQ((*sample)[0] / 2, (*sample)[1]);
|
EXPECT_EQ((*sample)[0] / 2, (*sample)[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(LexSimplexTest, addEquality) {
|
||||||
|
IntegerRelation rel(/*numDomain=*/0, /*numRange=*/1);
|
||||||
|
rel.addEquality({1, 0});
|
||||||
|
LexSimplex simplex(rel);
|
||||||
|
EXPECT_EQ(simplex.getNumConstraints(), 1u);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue