[MLIR] PresburgerSet: support divisions in operations

Add support for intersecting, subtracting, complementing and checking equality of sets having divisions.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D110138
This commit is contained in:
Arjun P 2021-09-24 15:34:06 +05:30
parent 3f89e339bb
commit 4a57f5d1e1
3 changed files with 141 additions and 38 deletions

View File

@ -67,17 +67,18 @@ public:
void print(raw_ostream &os) const;
void dump() const;
/// Return the complement of this set. Computing the complement of a set
/// containing divisions is not yet supported.
/// Return the complement of this set. All local variables in the set must
/// correspond to floor divisions.
PresburgerSet complement() const;
/// Return the set difference of this set and the given set, i.e.,
/// return `this \ set`. Subtracting when either set contains divisions is not
/// yet supported.
/// return `this \ set`. All local variables in `set` must correspond
/// to floor divisions, but local variables in `this` need not correspond to
/// divisions.
PresburgerSet subtract(const PresburgerSet &set) const;
/// Return true if this set is equal to the given set, and false otherwise.
/// Checking equality when either set contains divisions is not yet supported.
/// All local variables in both sets must correspond to floor divisions.
bool isEqual(const PresburgerSet &set) const;
/// Return a universe set of the specified type that contains all points.

View File

@ -106,16 +106,20 @@ PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
//
// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
// as (S_1 and T_1) or (S_1 and T_2) or ...
//
// If S_i or T_j have local variables, then S_i and T_j contains the local
// variables of both.
PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
PresburgerSet result(nDim, nSym);
for (const FlatAffineConstraints &csA : flatAffineConstraints) {
for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
FlatAffineConstraints intersection(csA);
intersection.append(csB);
if (!intersection.isEmpty())
result.unionFACInPlace(std::move(intersection));
FlatAffineConstraints csACopy = csA, csBCopy = csB;
csACopy.mergeLocalIds(csBCopy);
csACopy.append(std::move(csBCopy));
if (!csACopy.isEmpty())
result.unionFACInPlace(std::move(csACopy));
}
}
return result;
@ -160,6 +164,17 @@ static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
/// returning the union of the results. Each equality is handled as a
/// conjunction of two inequalities.
///
/// Note that the same approach works even if an inequality involves a floor
/// division. For example, the complement of x <= 7*floor(x/7) is still
/// x > 7*floor(x/7). Since b \ s_i contains the inequalities of both b and s_i
/// (or the complements of those inequalities), b \ s_i may contain the
/// divisions present in both b and s_i. Therefore, we need to add the local
/// division variables of both b and s_i to each part in the result. This means
/// adding the local variables of both b and s_i, as well as the corresponding
/// division inequalities to each part. Since the division inequalities are
/// added to each part, we can skip the parts where the complement of any
/// division inequality is added, as these parts will become empty anyway.
///
/// As a heuristic, we try adding all the constraints and check if simplex
/// says that the intersection is empty. If it is, then subtracting this FAC is
/// a no-op and we just skip it. Also, in the process we find out that some
@ -174,27 +189,63 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
result.unionFACInPlace(b);
return;
}
const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
assert(sI.getNumLocalIds() == 0 &&
"Subtracting sets with divisions is not yet supported!");
FlatAffineConstraints sI = s.getFlatAffineConstraints(i);
unsigned bInitNumLocals = b.getNumLocalIds();
// Find out which inequalities of sI correspond to division inequalities for
// the local variables of sI.
std::vector<llvm::Optional<std::pair<unsigned, unsigned>>> repr(
sI.getNumLocalIds());
sI.getLocalReprLbUbPairs(repr);
// Add sI's locals to b, after b's locals. Also add b's locals to sI, before
// sI's locals.
b.mergeLocalIds(sI);
// Mark which inequalities of sI are division inequalities and add all such
// inequalities to b.
llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
for (Optional<std::pair<unsigned, unsigned>> &maybePair : repr) {
assert(maybePair &&
"Subtraction is not supported when a representation of the local "
"variables of the subtrahend cannot be found!");
b.addInequality(sI.getInequality(maybePair->first));
b.addInequality(sI.getInequality(maybePair->second));
assert(maybePair->first != maybePair->second &&
"Upper and lower bounds must be different inequalities!");
isDivInequality[maybePair->first] = true;
isDivInequality[maybePair->second] = true;
}
unsigned initialSnapshot = simplex.getSnapshot();
unsigned offset = simplex.getNumConstraints();
unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals;
simplex.appendVariable(numLocalsAdded);
unsigned snapshotBeforeIntersect = simplex.getSnapshot();
simplex.intersectFlatAffineConstraints(sI);
if (simplex.isEmpty()) {
/// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
simplex.rollback(initialSnapshot);
b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
b.getNumLocalIds());
subtractRecursively(b, simplex, s, i + 1, result);
return;
}
simplex.detectRedundant();
llvm::SmallBitVector isMarkedRedundant;
for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
j++)
isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
simplex.rollback(initialSnapshot);
// Equalities are added to simplex as a pair of inequalities.
unsigned totalNewSimplexInequalities =
2 * sI.getNumEqualities() + sI.getNumInequalities();
llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);
simplex.rollback(snapshotBeforeIntersect);
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
// subtractRecursively. At the time this function is called, the current b is
@ -223,20 +274,28 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
// rollback b to its initial state before returning, which we will do by
// removing all constraints beyond the original number of inequalities
// and equalities, so we store these counts first.
unsigned originalNumIneqs = b.getNumInequalities();
unsigned originalNumEqs = b.getNumEqualities();
unsigned bInitNumIneqs = b.getNumInequalities();
unsigned bInitNumEqs = b.getNumEqualities();
// Process all the inequalities, ignoring redundant inequalities and division
// inequalities. The result is correct whether or not we ignore these, but
// ignoring them makes the result simpler.
for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
if (isMarkedRedundant[j])
continue;
if (isDivInequality[j])
continue;
processInequality(sI.getInequality(j));
}
offset = sI.getNumInequalities();
for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
// Same as the above loop for inequalities, done once each for the positive
// and negative inequalities that make up this equality.
ArrayRef<int64_t> coeffs = sI.getEquality(j);
// For each equality, process the positive and negative inequalities that
// make up this equality. If Simplex found an inequality to be redundant, we
// skip it as above to make the result simpler. Divisions are always
// represented in terms of inequalities and not equalities, so we do not
// check for division inequalities here.
if (!isMarkedRedundant[offset + 2 * j])
processInequality(coeffs);
if (!isMarkedRedundant[offset + 2 * j + 1])
@ -244,11 +303,10 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
}
// Rollback b and simplex to their initial states.
for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
b.removeInequality(i - 1);
for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
b.removeEquality(i - 1);
b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
b.getNumLocalIds());
b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities());
b.removeEqualityRange(bInitNumEqs, b.getNumEqualities());
simplex.rollback(initialSnapshot);
}
@ -261,8 +319,6 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
const PresburgerSet &set) {
assertDimensionsCompatible(fac, set);
assert(fac.getNumLocalIds() == 0 &&
"Subtracting sets with divisions is not yet supported!");
if (fac.isEmptyByGCDTest())
return PresburgerSet::getEmptySet(fac.getNumDimIds(),
fac.getNumSymbolIds());

View File

@ -80,12 +80,17 @@ static void testComplementAtPoints(PresburgerSet s,
}
/// Construct a FlatAffineConstraints from a set of inequality and
/// equality constraints.
/// equality constraints. `numIds` is the total number of ids, of which
/// `numLocals` is the number of local ids.
static FlatAffineConstraints
makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
ArrayRef<SmallVector<int64_t, 4>> eqs) {
FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims,
/*numSymbols=*/0, /*numLocals=*/0);
makeFACFromConstraints(unsigned numIds, ArrayRef<SmallVector<int64_t, 4>> ineqs,
ArrayRef<SmallVector<int64_t, 4>> eqs,
unsigned numLocals = 0) {
FlatAffineConstraints fac(/*numReservedInequalities=*/ineqs.size(),
/*numReservedEqualities=*/eqs.size(),
/*numReservedCols=*/numIds + 1,
/*numDims=*/numIds - numLocals,
/*numSymbols=*/0, numLocals);
for (const SmallVector<int64_t, 4> &eq : eqs)
fac.addEquality(eq);
for (const SmallVector<int64_t, 4> &ineq : ineqs)
@ -93,14 +98,22 @@ makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
return fac;
}
/// Construct a FlatAffineConstraints having `numDims` dimensions from the given
/// set of inequality constraints. This is a convenience function to be used
/// when the FAC to be constructed does not have any local ids and does not have
/// equalties.
static FlatAffineConstraints
makeFACFromIneqs(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
return makeFACFromConstraints(dims, ineqs, {});
makeFACFromIneqs(unsigned numDims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
return makeFACFromConstraints(numDims, ineqs, /*eqs=*/{});
}
static PresburgerSet makeSetFromFACs(unsigned dims,
/// Construct a PresburgerSet having `numDims` dimensions and no symbols from
/// the given list of FlatAffineConstraints. Each FAC in `facs` should also have
/// `numDims` dimensions and no symbols, although it can have any number of
/// local ids.
static PresburgerSet makeSetFromFACs(unsigned numDims,
ArrayRef<FlatAffineConstraints> facs) {
PresburgerSet set = PresburgerSet::getEmptySet(dims);
PresburgerSet set = PresburgerSet::getEmptySet(numDims);
for (const FlatAffineConstraints &fac : facs)
set.unionFACInPlace(fac);
return set;
@ -592,4 +605,37 @@ TEST(SetTest, isEqual) {
EXPECT_FALSE(rect.complement().isEqual(square.complement()));
}
void expectEqual(PresburgerSet s, PresburgerSet t) {
EXPECT_TRUE(s.isEqual(t));
}
void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); }
TEST(SetTest, divisions) {
// Note: we currently need to add the equalities as inequalities to the FAC
// since detecting divisions based on equalities is not yet supported.
// evens = {x : exists q, x = 2q}.
PresburgerSet evens{
makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, 0}}, 1)};
// odds = {x : exists q, x = 2q + 1}.
PresburgerSet odds{
makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, -1}}, 1)};
// multiples6 = {x : exists q, x = 6q}.
PresburgerSet multiples3{
makeFACFromConstraints(2, {{1, -3, 0}, {-1, 3, 2}}, {{1, -3, 0}}, 1)};
// multiples6 = {x : exists q, x = 6q}.
PresburgerSet multiples6{
makeFACFromConstraints(2, {{1, -6, 0}, {-1, 6, 5}}, {{1, -6, 0}}, 1)};
// evens /\ odds = empty.
expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
// evens U odds = universe.
expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1));
expectEqual(evens.complement(), odds);
expectEqual(odds.complement(), evens);
// even multiples of 3 = multiples of 6.
expectEqual(multiples3.intersect(evens), multiples6);
}
} // namespace mlir