forked from OSchip/llvm-project
[MLIR] PresburgerSet: support divisions in operations
Add support for intersecting, subtracting, complementing and checking equality of sets having divisions. Reviewed By: bondhugula Differential Revision: https://reviews.llvm.org/D110138
This commit is contained in:
parent
3f89e339bb
commit
4a57f5d1e1
|
@ -67,17 +67,18 @@ public:
|
|||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
/// Return the complement of this set. Computing the complement of a set
|
||||
/// containing divisions is not yet supported.
|
||||
/// Return the complement of this set. All local variables in the set must
|
||||
/// correspond to floor divisions.
|
||||
PresburgerSet complement() const;
|
||||
|
||||
/// Return the set difference of this set and the given set, i.e.,
|
||||
/// return `this \ set`. Subtracting when either set contains divisions is not
|
||||
/// yet supported.
|
||||
/// return `this \ set`. All local variables in `set` must correspond
|
||||
/// to floor divisions, but local variables in `this` need not correspond to
|
||||
/// divisions.
|
||||
PresburgerSet subtract(const PresburgerSet &set) const;
|
||||
|
||||
/// Return true if this set is equal to the given set, and false otherwise.
|
||||
/// Checking equality when either set contains divisions is not yet supported.
|
||||
/// All local variables in both sets must correspond to floor divisions.
|
||||
bool isEqual(const PresburgerSet &set) const;
|
||||
|
||||
/// Return a universe set of the specified type that contains all points.
|
||||
|
|
|
@ -106,16 +106,20 @@ PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
|
|||
//
|
||||
// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
|
||||
// as (S_1 and T_1) or (S_1 and T_2) or ...
|
||||
//
|
||||
// If S_i or T_j have local variables, then S_i and T_j contains the local
|
||||
// variables of both.
|
||||
PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
|
||||
assertDimensionsCompatible(set, *this);
|
||||
|
||||
PresburgerSet result(nDim, nSym);
|
||||
for (const FlatAffineConstraints &csA : flatAffineConstraints) {
|
||||
for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
|
||||
FlatAffineConstraints intersection(csA);
|
||||
intersection.append(csB);
|
||||
if (!intersection.isEmpty())
|
||||
result.unionFACInPlace(std::move(intersection));
|
||||
FlatAffineConstraints csACopy = csA, csBCopy = csB;
|
||||
csACopy.mergeLocalIds(csBCopy);
|
||||
csACopy.append(std::move(csBCopy));
|
||||
if (!csACopy.isEmpty())
|
||||
result.unionFACInPlace(std::move(csACopy));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
@ -160,6 +164,17 @@ static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
|
|||
/// returning the union of the results. Each equality is handled as a
|
||||
/// conjunction of two inequalities.
|
||||
///
|
||||
/// Note that the same approach works even if an inequality involves a floor
|
||||
/// division. For example, the complement of x <= 7*floor(x/7) is still
|
||||
/// x > 7*floor(x/7). Since b \ s_i contains the inequalities of both b and s_i
|
||||
/// (or the complements of those inequalities), b \ s_i may contain the
|
||||
/// divisions present in both b and s_i. Therefore, we need to add the local
|
||||
/// division variables of both b and s_i to each part in the result. This means
|
||||
/// adding the local variables of both b and s_i, as well as the corresponding
|
||||
/// division inequalities to each part. Since the division inequalities are
|
||||
/// added to each part, we can skip the parts where the complement of any
|
||||
/// division inequality is added, as these parts will become empty anyway.
|
||||
///
|
||||
/// As a heuristic, we try adding all the constraints and check if simplex
|
||||
/// says that the intersection is empty. If it is, then subtracting this FAC is
|
||||
/// a no-op and we just skip it. Also, in the process we find out that some
|
||||
|
@ -174,27 +189,63 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
|
|||
result.unionFACInPlace(b);
|
||||
return;
|
||||
}
|
||||
const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
|
||||
assert(sI.getNumLocalIds() == 0 &&
|
||||
"Subtracting sets with divisions is not yet supported!");
|
||||
FlatAffineConstraints sI = s.getFlatAffineConstraints(i);
|
||||
unsigned bInitNumLocals = b.getNumLocalIds();
|
||||
|
||||
// Find out which inequalities of sI correspond to division inequalities for
|
||||
// the local variables of sI.
|
||||
std::vector<llvm::Optional<std::pair<unsigned, unsigned>>> repr(
|
||||
sI.getNumLocalIds());
|
||||
sI.getLocalReprLbUbPairs(repr);
|
||||
|
||||
// Add sI's locals to b, after b's locals. Also add b's locals to sI, before
|
||||
// sI's locals.
|
||||
b.mergeLocalIds(sI);
|
||||
|
||||
// Mark which inequalities of sI are division inequalities and add all such
|
||||
// inequalities to b.
|
||||
llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
|
||||
for (Optional<std::pair<unsigned, unsigned>> &maybePair : repr) {
|
||||
assert(maybePair &&
|
||||
"Subtraction is not supported when a representation of the local "
|
||||
"variables of the subtrahend cannot be found!");
|
||||
|
||||
b.addInequality(sI.getInequality(maybePair->first));
|
||||
b.addInequality(sI.getInequality(maybePair->second));
|
||||
|
||||
assert(maybePair->first != maybePair->second &&
|
||||
"Upper and lower bounds must be different inequalities!");
|
||||
isDivInequality[maybePair->first] = true;
|
||||
isDivInequality[maybePair->second] = true;
|
||||
}
|
||||
|
||||
unsigned initialSnapshot = simplex.getSnapshot();
|
||||
unsigned offset = simplex.getNumConstraints();
|
||||
unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals;
|
||||
simplex.appendVariable(numLocalsAdded);
|
||||
|
||||
unsigned snapshotBeforeIntersect = simplex.getSnapshot();
|
||||
simplex.intersectFlatAffineConstraints(sI);
|
||||
|
||||
if (simplex.isEmpty()) {
|
||||
/// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
|
||||
simplex.rollback(initialSnapshot);
|
||||
b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
|
||||
b.getNumLocalIds());
|
||||
subtractRecursively(b, simplex, s, i + 1, result);
|
||||
return;
|
||||
}
|
||||
|
||||
simplex.detectRedundant();
|
||||
llvm::SmallBitVector isMarkedRedundant;
|
||||
for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
|
||||
j++)
|
||||
isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
|
||||
|
||||
simplex.rollback(initialSnapshot);
|
||||
// Equalities are added to simplex as a pair of inequalities.
|
||||
unsigned totalNewSimplexInequalities =
|
||||
2 * sI.getNumEqualities() + sI.getNumInequalities();
|
||||
llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
|
||||
for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
|
||||
isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);
|
||||
|
||||
simplex.rollback(snapshotBeforeIntersect);
|
||||
|
||||
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
|
||||
// subtractRecursively. At the time this function is called, the current b is
|
||||
|
@ -223,20 +274,28 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
|
|||
// rollback b to its initial state before returning, which we will do by
|
||||
// removing all constraints beyond the original number of inequalities
|
||||
// and equalities, so we store these counts first.
|
||||
unsigned originalNumIneqs = b.getNumInequalities();
|
||||
unsigned originalNumEqs = b.getNumEqualities();
|
||||
unsigned bInitNumIneqs = b.getNumInequalities();
|
||||
unsigned bInitNumEqs = b.getNumEqualities();
|
||||
|
||||
// Process all the inequalities, ignoring redundant inequalities and division
|
||||
// inequalities. The result is correct whether or not we ignore these, but
|
||||
// ignoring them makes the result simpler.
|
||||
for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
|
||||
if (isMarkedRedundant[j])
|
||||
continue;
|
||||
if (isDivInequality[j])
|
||||
continue;
|
||||
processInequality(sI.getInequality(j));
|
||||
}
|
||||
|
||||
offset = sI.getNumInequalities();
|
||||
for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
|
||||
const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
|
||||
// Same as the above loop for inequalities, done once each for the positive
|
||||
// and negative inequalities that make up this equality.
|
||||
ArrayRef<int64_t> coeffs = sI.getEquality(j);
|
||||
// For each equality, process the positive and negative inequalities that
|
||||
// make up this equality. If Simplex found an inequality to be redundant, we
|
||||
// skip it as above to make the result simpler. Divisions are always
|
||||
// represented in terms of inequalities and not equalities, so we do not
|
||||
// check for division inequalities here.
|
||||
if (!isMarkedRedundant[offset + 2 * j])
|
||||
processInequality(coeffs);
|
||||
if (!isMarkedRedundant[offset + 2 * j + 1])
|
||||
|
@ -244,11 +303,10 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
|
|||
}
|
||||
|
||||
// Rollback b and simplex to their initial states.
|
||||
for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
|
||||
b.removeInequality(i - 1);
|
||||
|
||||
for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
|
||||
b.removeEquality(i - 1);
|
||||
b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
|
||||
b.getNumLocalIds());
|
||||
b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities());
|
||||
b.removeEqualityRange(bInitNumEqs, b.getNumEqualities());
|
||||
|
||||
simplex.rollback(initialSnapshot);
|
||||
}
|
||||
|
@ -261,8 +319,6 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
|
|||
PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
|
||||
const PresburgerSet &set) {
|
||||
assertDimensionsCompatible(fac, set);
|
||||
assert(fac.getNumLocalIds() == 0 &&
|
||||
"Subtracting sets with divisions is not yet supported!");
|
||||
if (fac.isEmptyByGCDTest())
|
||||
return PresburgerSet::getEmptySet(fac.getNumDimIds(),
|
||||
fac.getNumSymbolIds());
|
||||
|
|
|
@ -80,12 +80,17 @@ static void testComplementAtPoints(PresburgerSet s,
|
|||
}
|
||||
|
||||
/// Construct a FlatAffineConstraints from a set of inequality and
|
||||
/// equality constraints.
|
||||
/// equality constraints. `numIds` is the total number of ids, of which
|
||||
/// `numLocals` is the number of local ids.
|
||||
static FlatAffineConstraints
|
||||
makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
|
||||
ArrayRef<SmallVector<int64_t, 4>> eqs) {
|
||||
FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims,
|
||||
/*numSymbols=*/0, /*numLocals=*/0);
|
||||
makeFACFromConstraints(unsigned numIds, ArrayRef<SmallVector<int64_t, 4>> ineqs,
|
||||
ArrayRef<SmallVector<int64_t, 4>> eqs,
|
||||
unsigned numLocals = 0) {
|
||||
FlatAffineConstraints fac(/*numReservedInequalities=*/ineqs.size(),
|
||||
/*numReservedEqualities=*/eqs.size(),
|
||||
/*numReservedCols=*/numIds + 1,
|
||||
/*numDims=*/numIds - numLocals,
|
||||
/*numSymbols=*/0, numLocals);
|
||||
for (const SmallVector<int64_t, 4> &eq : eqs)
|
||||
fac.addEquality(eq);
|
||||
for (const SmallVector<int64_t, 4> &ineq : ineqs)
|
||||
|
@ -93,14 +98,22 @@ makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
|
|||
return fac;
|
||||
}
|
||||
|
||||
/// Construct a FlatAffineConstraints having `numDims` dimensions from the given
|
||||
/// set of inequality constraints. This is a convenience function to be used
|
||||
/// when the FAC to be constructed does not have any local ids and does not have
|
||||
/// equalties.
|
||||
static FlatAffineConstraints
|
||||
makeFACFromIneqs(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
|
||||
return makeFACFromConstraints(dims, ineqs, {});
|
||||
makeFACFromIneqs(unsigned numDims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
|
||||
return makeFACFromConstraints(numDims, ineqs, /*eqs=*/{});
|
||||
}
|
||||
|
||||
static PresburgerSet makeSetFromFACs(unsigned dims,
|
||||
/// Construct a PresburgerSet having `numDims` dimensions and no symbols from
|
||||
/// the given list of FlatAffineConstraints. Each FAC in `facs` should also have
|
||||
/// `numDims` dimensions and no symbols, although it can have any number of
|
||||
/// local ids.
|
||||
static PresburgerSet makeSetFromFACs(unsigned numDims,
|
||||
ArrayRef<FlatAffineConstraints> facs) {
|
||||
PresburgerSet set = PresburgerSet::getEmptySet(dims);
|
||||
PresburgerSet set = PresburgerSet::getEmptySet(numDims);
|
||||
for (const FlatAffineConstraints &fac : facs)
|
||||
set.unionFACInPlace(fac);
|
||||
return set;
|
||||
|
@ -592,4 +605,37 @@ TEST(SetTest, isEqual) {
|
|||
EXPECT_FALSE(rect.complement().isEqual(square.complement()));
|
||||
}
|
||||
|
||||
void expectEqual(PresburgerSet s, PresburgerSet t) {
|
||||
EXPECT_TRUE(s.isEqual(t));
|
||||
}
|
||||
|
||||
void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); }
|
||||
|
||||
TEST(SetTest, divisions) {
|
||||
// Note: we currently need to add the equalities as inequalities to the FAC
|
||||
// since detecting divisions based on equalities is not yet supported.
|
||||
|
||||
// evens = {x : exists q, x = 2q}.
|
||||
PresburgerSet evens{
|
||||
makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, 0}}, 1)};
|
||||
// odds = {x : exists q, x = 2q + 1}.
|
||||
PresburgerSet odds{
|
||||
makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, -1}}, 1)};
|
||||
// multiples6 = {x : exists q, x = 6q}.
|
||||
PresburgerSet multiples3{
|
||||
makeFACFromConstraints(2, {{1, -3, 0}, {-1, 3, 2}}, {{1, -3, 0}}, 1)};
|
||||
// multiples6 = {x : exists q, x = 6q}.
|
||||
PresburgerSet multiples6{
|
||||
makeFACFromConstraints(2, {{1, -6, 0}, {-1, 6, 5}}, {{1, -6, 0}}, 1)};
|
||||
|
||||
// evens /\ odds = empty.
|
||||
expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
|
||||
// evens U odds = universe.
|
||||
expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1));
|
||||
expectEqual(evens.complement(), odds);
|
||||
expectEqual(odds.complement(), evens);
|
||||
// even multiples of 3 = multiples of 6.
|
||||
expectEqual(multiples3.intersect(evens), multiples6);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue