Introduce subtraction for FlatAffineConstraints

Subtraction is a foundational arithmetic operation that is often used when computing, for example, data transfer sets or cache hits. Since the result of subtraction need not be a convex polytope, a new class `PresburgerSet` is introduced to represent unions of convex polytopes.

Reviewed By: ftynse, bondhugula

Differential Revision: https://reviews.llvm.org/D87068
This commit is contained in:
Arjun P 2020-10-07 17:16:11 +02:00 committed by Alex Zinenko
parent bcd8422d75
commit 63dead2096
12 changed files with 1015 additions and 23 deletions

View File

@ -97,6 +97,13 @@ public:
ids.append(idArgs.begin(), idArgs.end());
}
/// Return a system with no constraints, i.e., one which is satisfied by all
/// points.
static FlatAffineConstraints getUniverse(unsigned numDims = 0,
unsigned numSymbols = 0) {
return FlatAffineConstraints(numDims, numSymbols);
}
/// Create a flat affine constraint system from an AffineValueMap or a list of
/// these. The constructed system will only include equalities.
explicit FlatAffineConstraints(const AffineValueMap &avm);
@ -153,6 +160,10 @@ public:
/// Returns such a point if one exists, or an empty Optional otherwise.
Optional<SmallVector<int64_t, 8>> findIntegerSample() const;
/// Returns true if the given point satisfies the constraints, or false
/// otherwise.
bool containsPoint(ArrayRef<int64_t> point) const;
// Clones this object.
std::unique_ptr<FlatAffineConstraints> clone() const;

View File

@ -169,6 +169,9 @@ public:
/// Rollback to a snapshot. This invalidates all later snapshots.
void rollback(unsigned snapshot);
/// Add all the constraints from the given FlatAffineConstraints.
void intersectFlatAffineConstraints(const FlatAffineConstraints &fac);
/// Compute the maximum or minimum value of the given row, depending on
/// direction. The specified row is never pivoted.
///

View File

@ -0,0 +1,112 @@
//===- Set.h - MLIR PresburgerSet Class -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A class to represent unions of FlatAffineConstraints.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_PRESBURGERSET_H
#define MLIR_ANALYSIS_PRESBURGERSET_H
#include "mlir/Analysis/AffineStructures.h"
namespace mlir {
/// This class can represent a union of FlatAffineConstraints, with support for
/// union, intersection, subtraction and complement operations, as well as
/// sampling.
///
/// The FlatAffineConstraints (FACs) are stored in a vector, and the set
/// represents the union of these FACs. An empty list corresponds to the empty
/// set.
///
/// Note that there are no invariants guaranteed on the list of FACs other than
/// that they are all in the same space, i.e., they all have the same number of
/// dimensions and symbols. For example, the FACs may overlap each other.
class PresburgerSet {
public:
explicit PresburgerSet(const FlatAffineConstraints &fac);
/// Return the number of FACs in the union.
unsigned getNumFACs() const;
/// Return the number of real dimensions.
unsigned getNumDims() const;
/// Return the number of symbolic dimensions.
unsigned getNumSyms() const;
/// Return a reference to the list of FlatAffineConstraints.
ArrayRef<FlatAffineConstraints> getAllFlatAffineConstraints() const;
/// Return the FlatAffineConstraints at the specified index.
const FlatAffineConstraints &getFlatAffineConstraints(unsigned index) const;
/// Mutate this set, turning it into the union of this set and the given
/// FlatAffineConstraints.
void unionFACInPlace(const FlatAffineConstraints &fac);
/// Mutate this set, turning it into the union of this set and the given set.
void unionSetInPlace(const PresburgerSet &set);
/// Return the union of this set and the given set.
PresburgerSet unionSet(const PresburgerSet &set) const;
/// Return the intersection of this set and the given set.
PresburgerSet intersect(const PresburgerSet &set) const;
/// Return true if the set contains the given point, or false otherwise.
bool containsPoint(ArrayRef<int64_t> point) const;
/// Print the set's internal state.
void print(raw_ostream &os) const;
void dump() const;
/// Return the complement of this set.
PresburgerSet complement() const;
/// Return the set difference of this set and the given set, i.e.,
/// return `this \ set`.
PresburgerSet subtract(const PresburgerSet &set) const;
/// Return a universe set of the specified type that contains all points.
static PresburgerSet getUniverse(unsigned nDim = 0, unsigned nSym = 0);
/// Return an empty set of the specified type that contains no points.
static PresburgerSet getEmptySet(unsigned nDim = 0, unsigned nSym = 0);
/// Return true if all the sets in the union are known to be integer empty
/// false otherwise.
bool isIntegerEmpty() const;
/// Find an integer sample from the given set. This should not be called if
/// any of the FACs in the union are unbounded.
bool findIntegerSample(SmallVectorImpl<int64_t> &sample);
private:
/// Construct an empty PresburgerSet.
PresburgerSet(unsigned nDim = 0, unsigned nSym = 0)
: nDim(nDim), nSym(nSym) {}
/// Return the set difference fac \ set.
static PresburgerSet getSetDifference(FlatAffineConstraints fac,
const PresburgerSet &set);
/// Number of identifiers corresponding to real dimensions.
unsigned nDim;
/// Number of symbolic dimensions, unknown but constant for analysis, as in
/// FlatAffineConstraints.
unsigned nSym;
/// The list of flatAffineConstraints that this set is the union of.
SmallVector<FlatAffineConstraints, 2> flatAffineConstraints;
};
} // namespace mlir
#endif // MLIR_ANALYSIS_PRESBURGERSET_H

View File

@ -1056,6 +1056,33 @@ FlatAffineConstraints::findIntegerSample() const {
return Simplex(*this).findIntegerSample();
}
/// Helper to evaluate an affine expression at a point.
/// The expression is a list of coefficients for the dimensions followed by the
/// constant term.
static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
assert(expr.size() == 1 + point.size() &&
"Dimensionalities of point and expresion don't match!");
int64_t value = expr.back();
for (unsigned i = 0; i < point.size(); ++i)
value += expr[i] * point[i];
return value;
}
/// A point satisfies an equality iff the value of the equality at the
/// expression is zero, and it satisfies an inequality iff the value of the
/// inequality at that point is non-negative.
bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
if (valueAt(getEquality(i), point) != 0)
return false;
}
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
if (valueAt(getInequality(i), point) < 0)
return false;
}
return true;
}
/// Tightens inequalities given that we are dealing with integer spaces. This is
/// analogous to the GCD test but applied to inequalities. The constant term can
/// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,

View File

@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
Liveness.cpp
LoopAnalysis.cpp
NestedMatcher.cpp
PresburgerSet.cpp
SliceAnalysis.cpp
Utils.cpp
)
@ -25,7 +26,6 @@ add_mlir_library(MLIRAnalysis
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRInferTypeOpInterface
MLIRPresburger
MLIRSCF
)
@ -34,6 +34,7 @@ add_mlir_library(MLIRLoopAnalysis
AffineStructures.cpp
LoopAnalysis.cpp
NestedMatcher.cpp
PresburgerSet.cpp
Utils.cpp
ADDITIONAL_HEADER_DIRS
@ -51,4 +52,4 @@ add_mlir_library(MLIRLoopAnalysis
MLIRSCF
)
add_subdirectory(Presburger)
add_subdirectory(Presburger)

View File

@ -1,4 +1,4 @@
add_mlir_library(MLIRPresburger
Simplex.cpp
Matrix.cpp
)
)

View File

@ -451,6 +451,16 @@ void Simplex::rollback(unsigned snapshot) {
}
}
/// Add all the constraints from the given FlatAffineConstraints.
void Simplex::intersectFlatAffineConstraints(const FlatAffineConstraints &fac) {
assert(fac.getNumIds() == numVariables() &&
"FlatAffineConstraints must have same dimensionality as simplex");
for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i)
addInequality(fac.getInequality(i));
for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i)
addEquality(fac.getEquality(i));
}
Optional<Fraction> Simplex::computeRowOptimum(Direction direction,
unsigned row) {
// Keep trying to find a pivot for the row in the specified direction.

View File

@ -0,0 +1,316 @@
//===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/PresburgerSet.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
PresburgerSet::PresburgerSet(const FlatAffineConstraints &fac)
: nDim(fac.getNumDimIds()), nSym(fac.getNumSymbolIds()) {
unionFACInPlace(fac);
}
unsigned PresburgerSet::getNumFACs() const {
return flatAffineConstraints.size();
}
unsigned PresburgerSet::getNumDims() const { return nDim; }
unsigned PresburgerSet::getNumSyms() const { return nSym; }
ArrayRef<FlatAffineConstraints>
PresburgerSet::getAllFlatAffineConstraints() const {
return flatAffineConstraints;
}
const FlatAffineConstraints &
PresburgerSet::getFlatAffineConstraints(unsigned index) const {
assert(index < flatAffineConstraints.size() && "index out of bounds!");
return flatAffineConstraints[index];
}
/// Assert that the FlatAffineConstraints and PresburgerSet live in
/// compatible spaces.
static void assertDimensionsCompatible(const FlatAffineConstraints &fac,
const PresburgerSet &set) {
assert(fac.getNumDimIds() == set.getNumDims() &&
"Number of dimensions of the FlatAffineConstraints and PresburgerSet"
"do not match!");
assert(fac.getNumSymbolIds() == set.getNumSyms() &&
"Number of symbols of the FlatAffineConstraints and PresburgerSet"
"do not match!");
}
/// Assert that the two PresburgerSets live in compatible spaces.
static void assertDimensionsCompatible(const PresburgerSet &setA,
const PresburgerSet &setB) {
assert(setA.getNumDims() == setB.getNumDims() &&
"Number of dimensions of the PresburgerSets do not match!");
assert(setA.getNumSyms() == setB.getNumSyms() &&
"Number of symbols of the PresburgerSets do not match!");
}
/// Mutate this set, turning it into the union of this set and the given
/// FlatAffineConstraints.
void PresburgerSet::unionFACInPlace(const FlatAffineConstraints &fac) {
assertDimensionsCompatible(fac, *this);
flatAffineConstraints.push_back(fac);
}
/// Mutate this set, turning it into the union of this set and the given set.
///
/// This is accomplished by simply adding all the FACs of the given set to this
/// set.
void PresburgerSet::unionSetInPlace(const PresburgerSet &set) {
assertDimensionsCompatible(set, *this);
for (const FlatAffineConstraints &fac : set.flatAffineConstraints)
unionFACInPlace(fac);
}
/// Return the union of this set and the given set.
PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
PresburgerSet result = *this;
result.unionSetInPlace(set);
return result;
}
/// A point is contained in the union iff any of the parts contain the point.
bool PresburgerSet::containsPoint(ArrayRef<int64_t> point) const {
for (const FlatAffineConstraints &fac : flatAffineConstraints) {
if (fac.containsPoint(point))
return true;
}
return false;
}
PresburgerSet PresburgerSet::getUniverse(unsigned nDim, unsigned nSym) {
PresburgerSet result(nDim, nSym);
result.unionFACInPlace(FlatAffineConstraints::getUniverse(nDim, nSym));
return result;
}
PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
return PresburgerSet(nDim, nSym);
}
// Return the intersection of this set with the given set.
//
// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
// as (S_1 and T_1) or (S_1 and T_2) or ...
PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
PresburgerSet result(nDim, nSym);
for (const FlatAffineConstraints &csA : flatAffineConstraints) {
for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
FlatAffineConstraints intersection(csA);
intersection.append(csB);
if (!intersection.isEmpty())
result.unionFACInPlace(std::move(intersection));
}
}
return result;
}
/// Return `coeffs` with all the elements negated.
static SmallVector<int64_t, 8> getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
SmallVector<int64_t, 8> negatedCoeffs;
negatedCoeffs.reserve(coeffs.size());
for (int64_t coeff : coeffs)
negatedCoeffs.emplace_back(-coeff);
return negatedCoeffs;
}
/// Return the complement of the given inequality.
///
/// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is
/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0.
static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
SmallVector<int64_t, 8> coeffs;
coeffs.reserve(ineq.size());
for (int64_t coeff : ineq)
coeffs.emplace_back(-coeff);
--coeffs.back();
return coeffs;
}
/// Return the set difference b \ s and accumulate the result into `result`.
/// `simplex` must correspond to b.
///
/// In the following, V denotes union, ^ denotes intersection, \ denotes set
/// difference and ~ denotes complement.
/// Let b be the FlatAffineConstraints and s = (V_i s_i) be the set. We want
/// b \ (V_i s_i).
///
/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
/// ~s_i = (~s_i1) V (s_i1 ^ ~s_i2) V (s_i1 ^ s_i2 ^ ~s_i3) V ...
/// And the required result is (b ^ ~s_i1) V (b ^ s_i1 ^ ~s_i2) V ...
/// We recurse by subtracting V_{j > i} S_j from each of these parts and
/// returning the union of the results. Each equality is handled as a
/// conjunction of two inequalities.
///
/// As a heuristic, we try adding all the constraints and check if simplex
/// says that the intersection is empty. Also, in the process we find out that
/// some constraints are redundant. These redundant constraints are ignored.
static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
const PresburgerSet &s, unsigned i,
PresburgerSet &result) {
if (i == s.getNumFACs()) {
result.unionFACInPlace(b);
return;
}
const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
unsigned initialSnapshot = simplex.getSnapshot();
unsigned offset = simplex.numConstraints();
simplex.intersectFlatAffineConstraints(sI);
if (simplex.isEmpty()) {
/// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
simplex.rollback(initialSnapshot);
subtractRecursively(b, simplex, s, i + 1, result);
return;
}
simplex.detectRedundant();
llvm::SmallBitVector isMarkedRedundant;
for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
j++)
isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
simplex.rollback(initialSnapshot);
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
// subtractRecursively. At the time this function is called, the current b is
// actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
// inequality, s_{i,j+1}. This function recurses into the next level i + 1
// with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
size_t snapshot = simplex.getSnapshot();
b.addInequality(ineq);
simplex.addInequality(ineq);
subtractRecursively(b, simplex, s, i + 1, result);
b.removeInequality(b.getNumInequalities() - 1);
simplex.rollback(snapshot);
};
// For each inequality ineq, we first recurse with the part where ineq
// is not satisfied, and then add the ineq to b and simplex because
// ineq must be satisfied by all later parts.
auto processInequality = [&](ArrayRef<int64_t> ineq) {
recurseWithInequality(getComplementIneq(ineq));
b.addInequality(ineq);
simplex.addInequality(ineq);
};
// processInequality appends some additional constraints to b. We want to
// rollback b to its initial state before returning, which we will do by
// removing all constraints beyond the original number of inequalities
// and equalities, so we store these counts first.
unsigned originalNumIneqs = b.getNumInequalities();
unsigned originalNumEqs = b.getNumEqualities();
for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
if (isMarkedRedundant[j])
continue;
processInequality(sI.getInequality(j));
}
offset = sI.getNumInequalities();
for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
// Same as the above loop for inequalities, done once each for the positive
// and negative inequalities that make up this equality.
if (!isMarkedRedundant[offset + 2 * j])
processInequality(coeffs);
if (!isMarkedRedundant[offset + 2 * j + 1])
processInequality(getNegatedCoeffs(coeffs));
}
// Rollback b and simplex to their initial states.
for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
b.removeInequality(i - 1);
for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
b.removeEquality(i - 1);
simplex.rollback(initialSnapshot);
}
/// Return the set difference fac \ set.
///
/// The FAC here is modified in subtractRecursively, so it cannot be a const
/// reference even though it is restored to its original state before returning
/// from that function.
PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
const PresburgerSet &set) {
assertDimensionsCompatible(fac, set);
if (fac.isEmptyByGCDTest())
return PresburgerSet::getEmptySet(fac.getNumDimIds(),
fac.getNumSymbolIds());
PresburgerSet result(fac.getNumDimIds(), fac.getNumSymbolIds());
Simplex simplex(fac);
subtractRecursively(fac, simplex, set, 0, result);
return result;
}
/// Return the complement of this set.
PresburgerSet PresburgerSet::complement() const {
return getSetDifference(
FlatAffineConstraints::getUniverse(getNumDims(), getNumSyms()), *this);
}
/// Return the result of subtract the given set from this set, i.e.,
/// return `this \ set`.
PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
PresburgerSet result(nDim, nSym);
// We compute (V_i t_i) \ (V_i set_i) as V_i (t_i \ V_i set_i).
for (const FlatAffineConstraints &fac : flatAffineConstraints)
result.unionSetInPlace(getSetDifference(fac, set));
return result;
}
/// Return true if all the sets in the union are known to be integer empty,
/// false otherwise.
bool PresburgerSet::isIntegerEmpty() const {
assert(nSym == 0 && "isIntegerEmpty is intended for non-symbolic sets");
// The set is empty iff all of the disjuncts are empty.
for (const FlatAffineConstraints &fac : flatAffineConstraints) {
if (!fac.isIntegerEmpty())
return false;
}
return true;
}
bool PresburgerSet::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
assert(nSym == 0 && "findIntegerSample is intended for non-symbolic sets");
// A sample exists iff any of the disjuncts contains a sample.
for (const FlatAffineConstraints &fac : flatAffineConstraints) {
if (Optional<SmallVector<int64_t, 8>> opt = fac.findIntegerSample()) {
sample = std::move(*opt);
return true;
}
}
return false;
}
void PresburgerSet::print(raw_ostream &os) const {
os << getNumFACs() << " FlatAffineConstraints:\n";
for (const FlatAffineConstraints &fac : flatAffineConstraints) {
fac.print(os);
os << '\n';
}
}
void PresburgerSet::dump() const { print(llvm::errs()); }

View File

@ -15,22 +15,11 @@
namespace mlir {
/// Evaluate the value of the given affine expression at the specified point.
/// The expression is a list of coefficients for the dimensions followed by the
/// constant term.
int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
assert(expr.size() == 1 + point.size());
int64_t value = expr.back();
for (unsigned i = 0; i < point.size(); ++i)
value += expr[i] * point[i];
return value;
}
/// If 'hasValue' is true, check that findIntegerSample returns a valid sample
/// for the FlatAffineConstraints fac.
///
/// If hasValue is false, check that findIntegerSample does not return None.
void checkSample(bool hasValue, const FlatAffineConstraints &fac) {
static void checkSample(bool hasValue, const FlatAffineConstraints &fac) {
Optional<SmallVector<int64_t, 8>> maybeSample = fac.findIntegerSample();
if (!hasValue) {
EXPECT_FALSE(maybeSample.hasValue());
@ -41,16 +30,13 @@ void checkSample(bool hasValue, const FlatAffineConstraints &fac) {
}
} else {
ASSERT_TRUE(maybeSample.hasValue());
for (unsigned i = 0; i < fac.getNumEqualities(); ++i)
EXPECT_EQ(valueAt(fac.getEquality(i), *maybeSample), 0);
for (unsigned i = 0; i < fac.getNumInequalities(); ++i)
EXPECT_GE(valueAt(fac.getInequality(i), *maybeSample), 0);
EXPECT_TRUE(fac.containsPoint(*maybeSample));
}
}
/// Construct a FlatAffineConstraints from a set of inequality and
/// equality constraints.
FlatAffineConstraints
static FlatAffineConstraints
makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
ArrayRef<SmallVector<int64_t, 4>> eqs) {
FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims);
@ -66,9 +52,9 @@ makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
/// orderings may cause the algorithm to proceed differently. At least some of
///.these permutations should make it past the heuristics and test the
/// implementation of the GBR algorithm itself.
void checkPermutationsSample(bool hasValue, unsigned nDim,
ArrayRef<SmallVector<int64_t, 4>> ineqs,
ArrayRef<SmallVector<int64_t, 4>> eqs) {
static void checkPermutationsSample(bool hasValue, unsigned nDim,
ArrayRef<SmallVector<int64_t, 4>> ineqs,
ArrayRef<SmallVector<int64_t, 4>> eqs) {
SmallVector<unsigned, 4> perm(nDim);
std::iota(perm.begin(), perm.end(), 0);
auto permute = [&perm](ArrayRef<int64_t> coeffs) {

View File

@ -1,5 +1,6 @@
add_mlir_unittest(MLIRAnalysisTests
AffineStructuresTest.cpp
PresburgerSetTest.cpp
)
target_link_libraries(MLIRAnalysisTests

View File

@ -5,3 +5,4 @@ add_mlir_unittest(MLIRPresburgerTests
target_link_libraries(MLIRPresburgerTests
PRIVATE MLIRPresburger)

View File

@ -0,0 +1,524 @@
//===- SetTest.cpp - Tests for PresburgerSet ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains tests for PresburgerSet. Each test works by computing
// an operation (union, intersection, subtract, or complement) on two sets
// and checking, for a set of points, that the resulting set contains the point
// iff the result is supposed to contain it.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/PresburgerSet.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace mlir {
/// Compute the union of s and t, and check that each of the given points
/// belongs to the union iff it belongs to at least one of s and t.
static void testUnionAtPoints(PresburgerSet s, PresburgerSet t,
ArrayRef<SmallVector<int64_t, 4>> points) {
PresburgerSet unionSet = s.unionSet(t);
for (const SmallVector<int64_t, 4> &point : points) {
bool inS = s.containsPoint(point);
bool inT = t.containsPoint(point);
bool inUnion = unionSet.containsPoint(point);
EXPECT_EQ(inUnion, inS || inT);
}
}
/// Compute the intersection of s and t, and check that each of the given points
/// belongs to the intersection iff it belongs to both of s and t.
static void testIntersectAtPoints(PresburgerSet s, PresburgerSet t,
ArrayRef<SmallVector<int64_t, 4>> points) {
PresburgerSet intersection = s.intersect(t);
for (const SmallVector<int64_t, 4> &point : points) {
bool inS = s.containsPoint(point);
bool inT = t.containsPoint(point);
bool inIntersection = intersection.containsPoint(point);
EXPECT_EQ(inIntersection, inS && inT);
}
}
/// Compute the set difference s \ t, and check that each of the given points
/// belongs to the difference iff it belongs to s and does not belong to t.
static void testSubtractAtPoints(PresburgerSet s, PresburgerSet t,
ArrayRef<SmallVector<int64_t, 4>> points) {
PresburgerSet diff = s.subtract(t);
for (const SmallVector<int64_t, 4> &point : points) {
bool inS = s.containsPoint(point);
bool inT = t.containsPoint(point);
bool inDiff = diff.containsPoint(point);
if (inT)
EXPECT_FALSE(inDiff);
else
EXPECT_EQ(inDiff, inS);
}
}
/// Compute the complement of s, and check that each of the given points
/// belongs to the complement iff it does not belong to s.
static void testComplementAtPoints(PresburgerSet s,
ArrayRef<SmallVector<int64_t, 4>> points) {
PresburgerSet complement = s.complement();
complement.complement();
for (const SmallVector<int64_t, 4> &point : points) {
bool inS = s.containsPoint(point);
bool inComplement = complement.containsPoint(point);
if (inS)
EXPECT_FALSE(inComplement);
else
EXPECT_TRUE(inComplement);
}
}
/// Construct a FlatAffineConstraints from a set of inequality and
/// equality constraints.
static FlatAffineConstraints
makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
ArrayRef<SmallVector<int64_t, 4>> eqs) {
FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims);
for (const SmallVector<int64_t, 4> &eq : eqs)
fac.addEquality(eq);
for (const SmallVector<int64_t, 4> &ineq : ineqs)
fac.addInequality(ineq);
return fac;
}
static FlatAffineConstraints
makeFACFromIneqs(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
return makeFACFromConstraints(dims, ineqs, {});
}
static PresburgerSet makeSetFromFACs(unsigned dims,
ArrayRef<FlatAffineConstraints> facs) {
PresburgerSet set = PresburgerSet::getEmptySet(dims);
for (const FlatAffineConstraints &fac : facs)
set.unionFACInPlace(fac);
return set;
}
TEST(SetTest, containsPoint) {
PresburgerSet setA =
makeSetFromFACs(1, {
makeFACFromIneqs(1, {{1, -2}, // x >= 2.
{-1, 8}}), // x <= 8.
makeFACFromIneqs(1, {{1, -10}, // x >= 10.
{-1, 20}}), // x <= 20.
});
for (unsigned x = 0; x <= 21; ++x) {
if ((2 <= x && x <= 8) || (10 <= x && x <= 20))
EXPECT_TRUE(setA.containsPoint({x}));
else
EXPECT_FALSE(setA.containsPoint({x}));
}
// A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} union
// a square with opposite corners (2, 2) and (10, 10).
PresburgerSet setB =
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 1, -2}, // x + y >= 4.
{-1, -1, 30}, // x + y <= 32.
{1, -1, 0}, // x - y >= 2.
{-1, 1, 10}, // x - y <= 16.
}),
makeFACFromIneqs(2, {
{1, 0, -2}, // x >= 2.
{0, 1, -2}, // y >= 2.
{-1, 0, 10}, // x <= 10.
{0, -1, 10} // y <= 10.
})});
for (unsigned x = 1; x <= 25; ++x) {
for (unsigned y = -6; y <= 16; ++y) {
if (4 <= x + y && x + y <= 32 && 2 <= x - y && x - y <= 16)
EXPECT_TRUE(setB.containsPoint({x, y}));
else if (2 <= x && x <= 10 && 2 <= y && y <= 10)
EXPECT_TRUE(setB.containsPoint({x, y}));
else
EXPECT_FALSE(setB.containsPoint({x, y}));
}
}
}
TEST(SetTest, Union) {
PresburgerSet set =
makeSetFromFACs(1, {
makeFACFromIneqs(1, {{1, -2}, // x >= 2.
{-1, 8}}), // x <= 8.
makeFACFromIneqs(1, {{1, -10}, // x >= 10.
{-1, 20}}), // x <= 20.
});
// Universe union set.
testUnionAtPoints(PresburgerSet::getUniverse(1), set,
{{1}, {2}, {8}, {9}, {10}, {20}, {21}});
// empty set union set.
testUnionAtPoints(PresburgerSet::getEmptySet(1), set,
{{1}, {2}, {8}, {9}, {10}, {20}, {21}});
// empty set union Universe.
testUnionAtPoints(PresburgerSet::getEmptySet(1),
PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
// Universe union empty set.
testUnionAtPoints(PresburgerSet::getUniverse(1),
PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
// empty set union empty set.
testUnionAtPoints(PresburgerSet::getEmptySet(1),
PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
}
TEST(SetTest, Intersect) {
PresburgerSet set =
makeSetFromFACs(1, {
makeFACFromIneqs(1, {{1, -2}, // x >= 2.
{-1, 8}}), // x <= 8.
makeFACFromIneqs(1, {{1, -10}, // x >= 10.
{-1, 20}}), // x <= 20.
});
// Universe intersection set.
testIntersectAtPoints(PresburgerSet::getUniverse(1), set,
{{1}, {2}, {8}, {9}, {10}, {20}, {21}});
// empty set intersection set.
testIntersectAtPoints(PresburgerSet::getEmptySet(1), set,
{{1}, {2}, {8}, {9}, {10}, {20}, {21}});
// empty set intersection Universe.
testIntersectAtPoints(PresburgerSet::getEmptySet(1),
PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
// Universe intersection empty set.
testIntersectAtPoints(PresburgerSet::getUniverse(1),
PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
// Universe intersection Universe.
testIntersectAtPoints(PresburgerSet::getUniverse(1),
PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
}
TEST(SetTest, Subtract) {
// The interval [2, 8] minus
// the interval [10, 20].
testSubtractAtPoints(
makeSetFromFACs(1, {makeFACFromIneqs(1, {})}),
makeSetFromFACs(1,
{
makeFACFromIneqs(1, {{1, -2}, // x >= 2.
{-1, 8}}), // x <= 8.
makeFACFromIneqs(1, {{1, -10}, // x >= 10.
{-1, 20}}), // x <= 20.
}),
{{1}, {2}, {8}, {9}, {10}, {20}, {21}});
// ((-infinity, 0] U [3, 4] U [6, 7]) - ([2, 3] U [5, 6])
testSubtractAtPoints(
makeSetFromFACs(1,
{
makeFACFromIneqs(1,
{
{-1, 0} // x <= 0.
}),
makeFACFromIneqs(1,
{
{1, -3}, // x >= 3.
{-1, 4} // x <= 4.
}),
makeFACFromIneqs(1,
{
{1, -6}, // x >= 6.
{-1, 7} // x <= 7.
}),
}),
makeSetFromFACs(1, {makeFACFromIneqs(1,
{
{1, -2}, // x >= 2.
{-1, 3}, // x <= 3.
}),
makeFACFromIneqs(1,
{
{1, -5}, // x >= 5.
{-1, 6} // x <= 6.
})}),
{{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
// Expected result is {[x, y] : x > y}, i.e., {[x, y] : x >= y + 1}.
testSubtractAtPoints(
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, -1, 0} // x >= y.
})}),
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 1, 0} // x >= -y.
})}),
{{0, 1}, {1, 1}, {1, 0}, {1, -1}, {0, -1}});
// A rectangle with corners at (2, 2) and (10, 10), minus
// a rectangle with corners at (5, -10) and (7, 100).
// This splits the former rectangle into two halves, (2, 2) to (5, 10) and
// (7, 2) to (10, 10).
testSubtractAtPoints(
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 0, -2}, // x >= 2.
{0, 1, -2}, // y >= 2.
{-1, 0, 10}, // x <= 10.
{0, -1, 10} // y <= 10.
})}),
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 0, -5}, // x >= 5.
{0, 1, 10}, // y >= -10.
{-1, 0, 7}, // x <= 7.
{0, -1, 100}, // y <= 100.
})}),
{{1, 2}, {2, 2}, {4, 2}, {5, 2}, {7, 2}, {8, 2}, {11, 2},
{1, 1}, {2, 1}, {4, 1}, {5, 1}, {7, 1}, {8, 1}, {11, 1},
{1, 10}, {2, 10}, {4, 10}, {5, 10}, {7, 10}, {8, 10}, {11, 10},
{1, 11}, {2, 11}, {4, 11}, {5, 11}, {7, 11}, {8, 11}, {11, 11}});
// A rectangle with corners at (2, 2) and (10, 10), minus
// a rectangle with corners at (5, 4) and (7, 8).
// This creates a hole in the middle of the former rectangle, and the
// resulting set can be represented as a union of four rectangles.
testSubtractAtPoints(
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 0, -2}, // x >= 2.
{0, 1, -2}, // y >= 2.
{-1, 0, 10}, // x <= 10.
{0, -1, 10} // y <= 10.
})}),
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 0, -5}, // x >= 5.
{0, 1, -4}, // y >= 4.
{-1, 0, 7}, // x <= 7.
{0, -1, 8}, // y <= 8.
})}),
{{1, 1},
{2, 2},
{10, 10},
{11, 11},
{5, 4},
{7, 4},
{5, 8},
{7, 8},
{4, 4},
{8, 4},
{4, 8},
{8, 8}});
// The second set is a superset of the first one, since on the line x + y = 0,
// y <= 1 is equivalent to x >= -1. So the result is empty.
testSubtractAtPoints(
makeSetFromFACs(2, {makeFACFromConstraints(2,
{
{1, 0, 0} // x >= 0.
},
{
{1, 1, 0} // x + y = 0.
})}),
makeSetFromFACs(2, {makeFACFromConstraints(2,
{
{0, -1, 1} // y <= 1.
},
{
{1, 1, 0} // x + y = 0.
})}),
{{0, 0},
{1, -1},
{2, -2},
{-1, 1},
{-2, 2},
{1, 1},
{-1, -1},
{-1, 1},
{1, -1}});
// The result should be {0} U {2}.
testSubtractAtPoints(
makeSetFromFACs(1,
{
makeFACFromIneqs(1, {{1, 0}, // x >= 0.
{-1, 2}}), // x <= 2.
}),
makeSetFromFACs(1,
{
makeFACFromConstraints(1, {},
{
{1, -1} // x = 1.
}),
}),
{{-1}, {0}, {1}, {2}, {3}});
// Sets with lots of redundant inequalities to test the redundancy heuristic.
// (the heuristic is for the subtrahend, the second set which is the one being
// subtracted)
// A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} minus
// a triangle with vertices {(2, 2), (10, 2), (10, 10)}.
testSubtractAtPoints(
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 1, -2}, // x + y >= 4.
{-1, -1, 30}, // x + y <= 32.
{1, -1, 0}, // x - y >= 2.
{-1, 1, 10}, // x - y <= 16.
})}),
makeSetFromFACs(
2, {makeFACFromIneqs(2,
{
{1, 0, -2}, // x >= 2. [redundant]
{0, 1, -2}, // y >= 2.
{-1, 0, 10}, // x <= 10.
{0, -1, 10}, // y <= 10. [redundant]
{1, 1, -2}, // x + y >= 2. [redundant]
{-1, -1, 30}, // x + y <= 30. [redundant]
{1, -1, 0}, // x - y >= 0.
{-1, 1, 10}, // x - y <= 10.
})}),
{{1, 2}, {2, 2}, {3, 2}, {4, 2}, {1, 1}, {2, 1}, {3, 1},
{4, 1}, {2, 0}, {3, 0}, {4, 0}, {5, 0}, {10, 2}, {11, 2},
{10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6},
{24, 8}, {24, 7}, {17, 15}, {16, 15}});
testSubtractAtPoints(
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 1, -2}, // x + y >= 4.
{-1, -1, 30}, // x + y <= 32.
{1, -1, 0}, // x - y >= 2.
{-1, 1, 10}, // x - y <= 16.
})}),
makeSetFromFACs(
2, {makeFACFromIneqs(2,
{
{1, 0, -2}, // x >= 2. [redundant]
{0, 1, -2}, // y >= 2.
{-1, 0, 10}, // x <= 10.
{0, -1, 10}, // y <= 10. [redundant]
{1, 1, -2}, // x + y >= 2. [redundant]
{-1, -1, 30}, // x + y <= 30. [redundant]
{1, -1, 0}, // x - y >= 0.
{-1, 1, 10}, // x - y <= 10.
})}),
{{1, 2}, {2, 2}, {3, 2}, {4, 2}, {1, 1}, {2, 1}, {3, 1},
{4, 1}, {2, 0}, {3, 0}, {4, 0}, {5, 0}, {10, 2}, {11, 2},
{10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6},
{24, 8}, {24, 7}, {17, 15}, {16, 15}});
// ((-infinity, -5] U [3, 3] U [4, 4] U [5, 5]) - ([-2, -10] U [3, 4] U [6,
// 7])
testSubtractAtPoints(
makeSetFromFACs(1,
{
makeFACFromIneqs(1,
{
{-1, -5}, // x <= -5.
}),
makeFACFromConstraints(1, {},
{
{1, -3} // x = 3.
}),
makeFACFromConstraints(1, {},
{
{1, -4} // x = 4.
}),
makeFACFromConstraints(1, {},
{
{1, -5} // x = 5.
}),
}),
makeSetFromFACs(
1,
{
makeFACFromIneqs(1,
{
{-1, -2}, // x <= -2.
{1, -10}, // x >= -10.
{-1, 0}, // x <= 0. [redundant]
{-1, 10}, // x <= 10. [redundant]
{1, -100}, // x >= -100. [redundant]
{1, -50} // x >= -50. [redundant]
}),
makeFACFromIneqs(1,
{
{1, -3}, // x >= 3.
{-1, 4}, // x <= 4.
{1, 1}, // x >= -1. [redundant]
{1, 7}, // x >= -7. [redundant]
{-1, 10} // x <= 10. [redundant]
}),
makeFACFromIneqs(1,
{
{1, -6}, // x >= 6.
{-1, 7}, // x <= 7.
{1, 1}, // x >= -1. [redundant]
{1, -3}, // x >= -3. [redundant]
{-1, 5} // x <= 5. [redundant]
}),
}),
{{-6},
{-5},
{-4},
{-9},
{-10},
{-11},
{0},
{1},
{2},
{3},
{4},
{5},
{6},
{7},
{8}});
}
TEST(SetTest, Complement) {
// Complement of universe.
testComplementAtPoints(
PresburgerSet::getUniverse(1),
{{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
// Complement of empty set.
testComplementAtPoints(
PresburgerSet::getEmptySet(1),
{{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
testComplementAtPoints(
makeSetFromFACs(2, {makeFACFromIneqs(2,
{
{1, 0, -2}, // x >= 2.
{0, 1, -2}, // y >= 2.
{-1, 0, 10}, // x <= 10.
{0, -1, 10} // y <= 10.
})}),
{{1, 1},
{2, 1},
{1, 2},
{2, 2},
{2, 3},
{3, 2},
{10, 10},
{10, 11},
{11, 10},
{2, 10},
{2, 11},
{1, 10}});
}
} // namespace mlir