[MLIR][Presburger] Refactor MultiAffineFunction to be defined over universe

This patch refactors MAF to be defined over the universe in a given space
instead of being defined over a restricted domain.

The reasoning for this refactor is to store division representation for local
variables explicitly for the function outputs. This change is required for
unionLexMax/Min to support local variables which will be upstreamed after this
patch. Another reason for this refactor is to have a flattened form of
AffineMap as MultiAffineFunction.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D131864
This commit is contained in:
Groverkss 2022-09-11 01:02:52 +01:00
parent c1807c6b9f
commit bb2226ac53
14 changed files with 480 additions and 287 deletions

View File

@ -128,6 +128,8 @@ public:
/// Add `scale` multiples of the source row to the target row.
void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale);
/// Add `scale` multiples of the rowVec row to the specified row.
void addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale);
/// Add `scale` multiples of the source column to the target column.
void addToColumn(unsigned sourceColumn, unsigned targetColumn, int64_t scale);

View File

@ -22,94 +22,93 @@
namespace mlir {
namespace presburger {
/// This class represents a multi-affine function whose domain is given by an
/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a
/// tuple of integer values attached to every point in the polyhedron, with the
/// value of each element of the tuple given by an affine expression in the vars
/// of the polyhedron. For example we could have the domain
///
/// (x, y) : (x >= 5, y >= x)
///
/// and a tuple of three integers defined at every point in the polyhedron:
/// This class represents a multi-affine function with the domain as Z^d, where
/// `d` is the number of domain variables of the function. For example:
///
/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
///
/// In this way every point in the polyhedron has a tuple of integers associated
/// with it. If the integer polyhedron has local vars, then the output
/// expressions can use them as well. The output expressions are represented as
/// a matrix with one row for every element in the output vector one column for
/// each var, and an extra column at the end for the constant term.
/// The output expressions are represented as a matrix with one row for every
/// output, one column for each var including division variables, and an extra
/// column at the end for the constant term.
///
/// Checking equality of two such functions is supported, as well as finding the
/// value of the function at a specified point.
class MultiAffineFunction {
public:
MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
: domainSet(domain), output(output) {}
MultiAffineFunction(const Matrix &output, const PresburgerSpace &space)
: domainSet(space), output(output) {}
unsigned getNumInputs() const { return domainSet.getNumDimAndSymbolVars(); }
unsigned getNumOutputs() const { return output.getNumRows(); }
bool isConsistent() const {
return output.getNumColumns() == domainSet.getNumVars() + 1;
MultiAffineFunction(const PresburgerSpace &space, const Matrix &output)
: space(space), output(output),
divs(space.getNumVars() - space.getNumRangeVars()) {
assertIsConsistent();
}
/// Get the space of the input domain of this function.
const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); }
/// Get the input domain of this function.
const IntegerPolyhedron &getDomain() const { return domainSet; }
MultiAffineFunction(const PresburgerSpace &space, const Matrix &output,
const DivisionRepr &divs)
: space(space), output(output), divs(divs) {
assertIsConsistent();
}
unsigned getNumDomainVars() const { return space.getNumDomainVars(); }
unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); }
unsigned getNumOutputs() const { return space.getNumRangeVars(); }
unsigned getNumDivs() const { return space.getNumLocalVars(); }
/// Get the space of this function.
const PresburgerSpace &getSpace() const { return space; }
/// Get the domain/output space of the function. The returned space is a set
/// space.
PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); }
PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); }
/// Get a matrix with each row representing row^th output expression.
const Matrix &getOutputMatrix() const { return output; }
/// Get the `i^th` output expression.
ArrayRef<int64_t> getOutputExpr(unsigned i) const { return output.getRow(i); }
/// Insert `num` variables of the specified kind at position `pos`.
/// Positions are relative to the kind of variable. The coefficient columns
/// corresponding to the added variables are initialized to zero. Return the
/// absolute column position (i.e., not relative to the kind of variable)
/// of the first added variable.
unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1);
// Remove the specified range of outputs.
void removeOutputs(unsigned start, unsigned end);
/// Remove the specified range of vars.
void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit);
/// Given a MAF `other`, merges division variables such that both functions
/// have the union of the division vars that exist in the functions.
void mergeDivs(MultiAffineFunction &other);
/// Given a MAF `other`, merges local variables such that both funcitons
/// have union of local vars, without changing the set of points in domain or
/// the output.
void mergeLocalVars(MultiAffineFunction &other);
/// Return the output of the function at the given point.
SmallVector<int64_t, 8> valueAt(ArrayRef<int64_t> point) const;
/// Return whether the outputs of `this` and `other` agree wherever both
/// functions are defined, i.e., the outputs should be equal for all points in
/// the intersection of the domains.
bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const;
/// Return whether the `this` and `other` are equal. This is the case if
/// they lie in the same space, i.e. have the same dimensions, and their
/// domains are identical and their outputs are equal on their domain.
/// Return whether the `this` and `other` are equal when the domain is
/// restricted to `domain`. This is the case if they lie in the same space,
/// and their outputs are equal for every point in `domain`.
bool isEqual(const MultiAffineFunction &other) const;
bool isEqual(const MultiAffineFunction &other,
const IntegerPolyhedron &domain) const;
bool isEqual(const MultiAffineFunction &other,
const PresburgerSet &domain) const;
/// Get the value of the function at the specified point. If the point lies
/// outside the domain, an empty optional is returned.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
void subtract(const MultiAffineFunction &other);
/// Truncate the output dimensions to the first `count` dimensions.
///
/// TODO: refactor so that this can be accomplished through removeVarRange.
void truncateOutput(unsigned count);
/// Get this function as a relation.
IntegerRelation getAsRelation() const;
void print(raw_ostream &os) const;
void dump() const;
private:
/// The IntegerPolyhedron representing the domain over which the function is
/// defined.
IntegerPolyhedron domainSet;
/// Assert that the MAF is consistent.
void assertIsConsistent() const;
/// The space of this function. The domain variables are considered as the
/// input variables of the function. The range variables are considered as
/// the outputs. The symbols parametrize the function and locals are used to
/// represent divisions. Each local variable has a corressponding division
/// representation stored in `divs`.
PresburgerSpace space;
/// The function's output is a tuple of integers, with the ith element of the
/// tuple defined by the affine expression given by the ith row of this output
/// matrix.
Matrix output;
/// Storage for division representation for each local variable in space.
DivisionRepr divs;
};
/// This class represents a piece-wise MultiAffineFunction. This can be thought
@ -132,33 +131,47 @@ private:
/// finding the value of the function at a point.
class PWMAFunction {
public:
PWMAFunction(const PresburgerSpace &space, unsigned numOutputs)
: space(space), numOutputs(numOutputs) {
assert(space.getNumDomainVars() == 0 &&
"Set type space should have zero domain vars.");
struct Piece {
PresburgerSet domain;
MultiAffineFunction output;
bool isConsistent() const {
return domain.getSpace().isCompatible(output.getDomainSpace());
}
};
PWMAFunction(const PresburgerSpace &space) : space(space) {
assert(space.getNumLocalVars() == 0 &&
"PWMAFunction cannot have local vars.");
assert(numOutputs >= 1 && "The function must output something!");
}
// Get the space of this function.
const PresburgerSpace &getSpace() const { return space; }
void addPiece(const MultiAffineFunction &piece);
void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
void addPiece(const PresburgerSet &domain, const Matrix &output);
// Add a piece ([domain, output] pair) to this function.
void addPiece(const Piece &piece);
const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
unsigned getNumPieces() const { return pieces.size(); }
unsigned getNumOutputs() const { return numOutputs; }
unsigned getNumInputs() const { return space.getNumVars(); }
MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
unsigned getNumVarKind(VarKind kind) const {
return space.getNumVarKind(kind);
}
unsigned getNumDomainVars() const { return space.getNumDomainVars(); }
unsigned getNumOutputs() const { return space.getNumRangeVars(); }
unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); }
/// Remove the specified range of outputs.
void removeOutputs(unsigned start, unsigned end);
/// Get the domain/output space of the function. The returned space is a set
/// space.
PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); }
PresburgerSpace getOutputSpace() const { return space.getDomainSpace(); }
/// Return the domain of this piece-wise MultiAffineFunction. This is the
/// union of the domains of all the pieces.
PresburgerSet getDomain() const;
/// Return the value at the specified point and an empty optional if the
/// point does not lie in the domain.
/// Return the output of the function at the given point.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
/// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
@ -166,11 +179,6 @@ public:
/// value at every point in the domain.
bool isEqual(const PWMAFunction &other) const;
/// Truncate the output dimensions to the first `count` dimensions.
///
/// TODO: refactor so that this can be accomplished through removeVarRange.
void truncateOutput(unsigned count);
/// Return a function defined on the union of the domains of this and func,
/// such that when only one of the functions is defined, it outputs the same
/// as that function, and if both are defined, it outputs the lexmax/lexmin of
@ -178,8 +186,8 @@ public:
/// function is not defined either.
///
/// Currently this does not support PWMAFunctions which have pieces containing
/// local variables.
/// TODO: Support local variables in peices.
/// divisions.
/// TODO: Support division in pieces.
PWMAFunction unionLexMin(const PWMAFunction &func);
PWMAFunction unionLexMax(const PWMAFunction &func);
@ -200,19 +208,17 @@ private:
///
/// The PresburgerSet returned by `tiebreak` should be disjoint.
/// TODO: Remove this constraint of returning disjoint set.
PWMAFunction
unionFunction(const PWMAFunction &func,
llvm::function_ref<PresburgerSet(MultiAffineFunction mafA,
MultiAffineFunction mafB)>
tiebreak) const;
PWMAFunction unionFunction(
const PWMAFunction &func,
llvm::function_ref<PresburgerSet(Piece mafA, Piece mafB)> tiebreak) const;
/// The space of this function. The domain variables are considered as the
/// input variables of the function. The range variables are considered as
/// the outputs. The symbols paramterize the function.
PresburgerSpace space;
/// The list of pieces in this piece-wise MultiAffineFunction.
SmallVector<MultiAffineFunction, 4> pieces;
/// The number of output vars.
unsigned numOutputs;
// The pieces of the PWMAFunction.
SmallVector<Piece, 4> pieces;
};
} // namespace presburger

View File

@ -90,6 +90,11 @@ public:
numLocals);
}
// Get the domain/range space of this space. The returned space is a set
// space.
PresburgerSpace getDomainSpace() const;
PresburgerSpace getRangeSpace() const;
unsigned getNumDomainVars() const { return numDomain; }
unsigned getNumRangeVars() const { return numRange; }
unsigned getNumSetDimVars() const { return numRange; }

View File

@ -529,9 +529,9 @@ private:
/// Represents the result of a symbolic lexicographic minimization computation.
struct SymbolicLexMin {
SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs)
: lexmin(domainSpace, numOutputs),
unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {}
SymbolicLexMin(const PresburgerSpace &space)
: lexmin(space),
unboundedDomain(PresburgerSet::getEmpty(space.getDomainSpace())) {}
/// This maps assignments of symbols to the corresponding lexmin.
/// Takes no value when no integer sample exists for the assignment or if the

View File

@ -118,7 +118,7 @@ public:
DivisionRepr(unsigned numVars, unsigned numDivs)
: dividends(numDivs, numVars + 1), denoms(numDivs, 0) {}
DivisionRepr(unsigned numVars) : dividends(numVars + 1, 0) {}
DivisionRepr(unsigned numVars) : dividends(0, numVars + 1) {}
unsigned getNumVars() const { return dividends.getNumColumns() - 1; }
unsigned getNumDivs() const { return dividends.getNumRows(); }
@ -142,16 +142,25 @@ public:
return dividends.getRow(i);
}
// For a given point containing values for each variable other than the
// division variables, try to find the values for each division variable from
// their division representation.
SmallVector<Optional<int64_t>, 4> divValuesAt(ArrayRef<int64_t> point) const;
// Get the `i^th` denominator.
unsigned &getDenom(unsigned i) { return denoms[i]; }
unsigned getDenom(unsigned i) const { return denoms[i]; }
ArrayRef<unsigned> getDenoms() const { return denoms; }
void setDividend(unsigned i, ArrayRef<int64_t> dividend) {
void setDiv(unsigned i, ArrayRef<int64_t> dividend, unsigned divisor) {
dividends.setRow(i, dividend);
denoms[i] = divisor;
}
void insertDiv(unsigned pos, ArrayRef<int64_t> dividend, unsigned divisor);
void insertDiv(unsigned pos, unsigned num = 1);
/// Removes duplicate divisions. On every possible duplicate division found,
/// `merge(i, j)`, where `i`, `j` are current index of the duplicate
/// divisions, is called and division at index `j` is merged into division at

View File

@ -238,6 +238,7 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
getVarKindEnd(VarKind::Domain));
// Compute the symbolic lexmin of the dims and locals, with the symbols being
// the actual symbols of this set.
// The resultant space of lexmin is the space of the relation itself.
SymbolicLexMin result =
SymbolicLexSimplex(*this,
IntegerPolyhedron(PresburgerSpace::getSetSpace(
@ -248,8 +249,8 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
// We want to return only the lexmin over the dims, so strip the locals from
// the computed lexmin.
result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
getNumLocalVars());
result.lexmin.removeOutputs(result.lexmin.getNumOutputs() - getNumLocalVars(),
result.lexmin.getNumOutputs());
return result;
}

View File

@ -192,10 +192,14 @@ void Matrix::fillRow(unsigned row, int64_t value) {
}
void Matrix::addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) {
addToRow(targetRow, getRow(sourceRow), scale);
}
void Matrix::addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale) {
if (scale == 0)
return;
for (unsigned col = 0; col < nColumns; ++col)
at(targetRow, col) += scale * at(sourceRow, col);
at(row, col) += scale * rowVec[col];
}
void Matrix::addToColumn(unsigned sourceColumn, unsigned targetColumn,

View File

@ -12,11 +12,25 @@
using namespace mlir;
using namespace presburger;
void MultiAffineFunction::assertIsConsistent() const {
assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
output.getNumColumns() &&
"Inconsistent number of output columns");
assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
divs.getNumNonDivs() &&
"Inconsistent number of non-division variables in divs");
assert(space.getNumRangeVars() == output.getNumRows() &&
"Inconsistent number of output rows");
assert(space.getNumLocalVars() == divs.getNumDivs() &&
"Inconsistent number of divisions.");
assert(divs.hasAllReprs() && "All divisions should have a representation");
}
// Return the result of subtracting the two given vectors pointwise.
// The vectors must be of the same size.
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
ArrayRef<int64_t> vecB) {
static SmallVector<int64_t, 8> subtractExprs(ArrayRef<int64_t> vecA,
ArrayRef<int64_t> vecB) {
assert(vecA.size() == vecB.size() &&
"Cannot subtract vectors of differing lengths!");
SmallVector<int64_t, 8> result;
@ -27,152 +41,135 @@ static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
}
PresburgerSet PWMAFunction::getDomain() const {
PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
for (const MultiAffineFunction &piece : pieces)
domain.unionInPlace(piece.getDomain());
PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace());
for (const Piece &piece : pieces)
domain.unionInPlace(piece.domain);
return domain;
}
Optional<SmallVector<int64_t, 8>>
void MultiAffineFunction::print(raw_ostream &os) const {
space.print(os);
os << "Division Representation:\n";
divs.print(os);
os << "Output:\n";
output.print(os);
}
SmallVector<int64_t, 8>
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == domainSet.getNumDimAndSymbolVars() &&
assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
"Point has incorrect dimensionality!");
Optional<SmallVector<int64_t, 8>> maybeLocalValues =
getDomain().containsPointNoLocal(point);
if (!maybeLocalValues)
return {};
// The point lies in the domain, so we need to compute the output value.
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
// The given point didn't include the values of locals which the output is a
// function of; we have computed one possible set of values and use them
// here. The function is not allowed to have local vars that take more than
// one possible value.
pointHomogenous.append(*maybeLocalValues);
// Get the division values at this point.
SmallVector<Optional<int64_t>, 8> divValues = divs.divValuesAt(point);
// The given point didn't include the values of the divs which the output is a
// function of; we have computed one possible set of values and use them here.
pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
for (const Optional<int64_t> &divVal : divValues)
pointHomogenous.push_back(*divVal);
// The matrix `output` has an affine expression in the ith row, corresponding
// to the expression for the ith value in the output vector. The last column
// of the matrix contains the constant term. Let v be the input point with
// a 1 appended at the end. We can see that output * v gives the desired
// output vector.
pointHomogenous.emplace_back(1);
pointHomogenous.push_back(1);
SmallVector<int64_t, 8> result =
output.postMultiplyWithColumn(pointHomogenous);
assert(result.size() == getNumOutputs());
return result;
}
Optional<SmallVector<int64_t, 8>>
PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == getNumInputs() &&
"Point has incorrect dimensionality!");
for (const MultiAffineFunction &piece : pieces)
if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
return output;
return {};
}
void MultiAffineFunction::print(raw_ostream &os) const {
os << "Domain:";
domainSet.print(os);
os << "Output:\n";
output.print(os);
os << "\n";
}
void MultiAffineFunction::dump() const { print(llvm::errs()); }
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
return getDomainSpace().isCompatible(other.getDomainSpace()) &&
getDomain().isEqual(other.getDomain()) &&
isEqualWhereDomainsOverlap(other);
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
return getAsRelation().isEqual(other.getAsRelation());
}
unsigned MultiAffineFunction::insertVar(VarKind kind, unsigned pos,
unsigned num) {
assert(kind != VarKind::Domain && "Domain has to be zero in a set");
unsigned absolutePos = domainSet.getVarKindOffset(kind) + pos;
output.insertColumns(absolutePos, num);
return domainSet.insertVar(kind, pos, num);
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
const IntegerPolyhedron &domain) const {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
IntegerRelation restrictedThis = getAsRelation();
restrictedThis.intersectDomain(domain);
IntegerRelation restrictedOther = other.getAsRelation();
restrictedOther.intersectDomain(domain);
return restrictedThis.isEqual(restrictedOther);
}
void MultiAffineFunction::removeVarRange(VarKind kind, unsigned varStart,
unsigned varLimit) {
output.removeColumns(varStart + domainSet.getVarKindOffset(kind),
varLimit - varStart);
domainSet.removeVarRange(kind, varStart, varLimit);
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
const PresburgerSet &domain) const {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
return llvm::all_of(domain.getAllDisjuncts(),
[&](const IntegerRelation &disjunct) {
return isEqual(other, IntegerPolyhedron(disjunct));
});
}
void MultiAffineFunction::truncateOutput(unsigned count) {
assert(count <= output.getNumRows());
output.resizeVertically(count);
void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
assert(end <= getNumOutputs() && "Invalid range");
if (start >= end)
return;
space.removeVarRange(VarKind::Range, start, end);
output.removeRows(start, end - start);
}
void PWMAFunction::truncateOutput(unsigned count) {
assert(count <= numOutputs);
for (MultiAffineFunction &piece : pieces)
piece.truncateOutput(count);
numOutputs = count;
}
void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
assert(space.isCompatible(other.space) && "Functions should be compatible");
void MultiAffineFunction::mergeLocalVars(MultiAffineFunction &other) {
// Merge output local vars of both functions without using division
// information i.e. append local vars of `other` to `this` and insert
// local vars of `this` to `other` at the start of it's local vars.
output.insertColumns(domainSet.getVarKindEnd(VarKind::Local),
other.domainSet.getNumLocalVars());
other.output.insertColumns(other.domainSet.getVarKindOffset(VarKind::Local),
domainSet.getNumLocalVars());
unsigned nDivs = getNumDivs();
unsigned divOffset = divs.getDivOffset();
auto merge = [this, &other](unsigned i, unsigned j) -> bool {
// Merge local at position j into local at position i in function domain.
domainSet.eliminateRedundantLocalVar(i, j);
other.domainSet.eliminateRedundantLocalVar(i, j);
other.divs.insertDiv(0, nDivs);
unsigned localOffset = domainSet.getVarKindOffset(VarKind::Local);
SmallVector<int64_t, 8> div(other.divs.getNumVars() + 1);
for (unsigned i = 0; i < nDivs; ++i) {
// Zero fill.
std::fill(div.begin(), div.end(), 0);
// Fill div with dividend from `divs`. Do not fill the constant.
std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
div.begin());
// Fill constant.
div.back() = divs.getDividend(i).back();
other.divs.setDiv(i, div, divs.getDenom(i));
}
// Merge local at position j into local at position i in output domain.
output.addToColumn(localOffset + j, localOffset + i, 1);
output.removeColumn(localOffset + j);
other.output.addToColumn(localOffset + j, localOffset + i, 1);
other.output.removeColumn(localOffset + j);
other.space.insertVar(VarKind::Local, 0, nDivs);
other.output.insertColumns(divOffset, nDivs);
auto merge = [&](unsigned i, unsigned j) {
// We only merge from local at pos j to local at pos i, where j > i.
if (i >= j)
return false;
// If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
// do not want to merge duplicates in `this`, we ignore this call.
if (j < nDivs)
return false;
// Merge things in space and output.
other.space.removeVarRange(VarKind::Local, j, j + 1);
other.output.addToColumn(divOffset + i, divOffset + j, 1);
other.output.removeColumn(divOffset + j);
return true;
};
presburger::mergeLocalVars(domainSet, other.domainSet, merge);
}
other.divs.removeDuplicateDivs(merge);
bool MultiAffineFunction::isEqualWhereDomainsOverlap(
MultiAffineFunction other) const {
if (!getDomainSpace().isCompatible(other.getDomainSpace()))
return false;
unsigned newDivs = other.divs.getNumDivs() - nDivs;
// `commonFunc` has the same output as `this`.
MultiAffineFunction commonFunc = *this;
// After this merge, `commonFunc` and `other` have the same local vars; they
// are merged.
commonFunc.mergeLocalVars(other);
// After this, the domain of `commonFunc` will be the intersection of the
// domains of `this` and `other`.
commonFunc.domainSet.append(other.domainSet);
space.insertVar(VarKind::Local, nDivs, newDivs);
output.insertColumns(divOffset + nDivs, newDivs);
divs = other.divs;
// `commonDomainMatching` contains the subset of the common domain
// where the outputs of `this` and `other` match.
//
// We want to add constraints equating the outputs of `this` and `other`.
// However, `this` may have difference local vars from `other`, whereas we
// need both to have the same locals. Accordingly, we use `commonFunc.output`
// in place of `this->output`, since `commonFunc` has the same output but also
// has its locals merged.
IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
commonDomainMatching.addEquality(
subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
// If the whole common domain is a subset of commonDomainMatching, then they
// are equal and the two functions match on the whole common domain.
return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
// Check consistency.
assertIsConsistent();
other.assertIsConsistent();
}
/// Two PWMAFunctions are equal if they have the same dimensionalities,
@ -188,89 +185,79 @@ bool PWMAFunction::isEqual(const PWMAFunction &other) const {
// overlap, they take the same output value. If `this` and `other` have the
// same domain (checked above), then this check passes iff the two functions
// have the same output at every point in the domain.
for (const MultiAffineFunction &aPiece : this->pieces)
for (const MultiAffineFunction &bPiece : other.pieces)
if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
return false;
return true;
return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
return pieceA.output.isEqual(pieceB.output, commonDomain);
});
});
}
void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
assert(space.isCompatible(piece.getDomainSpace()) &&
"Piece to be added is not compatible with this PWMAFunction!");
assert(piece.isConsistent() && "Piece is internally inconsistent!");
assert(this->getDomain()
.intersect(PresburgerSet(piece.getDomain()))
.isIntegerEmpty() &&
"New piece's domain overlaps with that of existing pieces!");
void PWMAFunction::addPiece(const Piece &piece) {
assert(piece.isConsistent() && "Piece should be consistent");
pieces.push_back(piece);
}
void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
const Matrix &output) {
addPiece(MultiAffineFunction(domain, output));
}
void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) {
for (const IntegerRelation &newDom : domain.getAllDisjuncts())
addPiece(IntegerPolyhedron(newDom), output);
}
void PWMAFunction::print(raw_ostream &os) const {
os << pieces.size() << " pieces:\n";
for (const MultiAffineFunction &piece : pieces)
piece.print(os);
space.print(os);
os << getNumPieces() << " pieces:\n";
for (const Piece &piece : pieces) {
os << "Domain of piece:\n";
piece.domain.print(os);
os << "Output of piece\n";
piece.output.print(os);
}
}
void PWMAFunction::dump() const { print(llvm::errs()); }
PWMAFunction PWMAFunction::unionFunction(
const PWMAFunction &func,
llvm::function_ref<PresburgerSet(MultiAffineFunction maf1,
MultiAffineFunction maf2)>
tiebreak) const {
llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
assert(getNumOutputs() == func.getNumOutputs() &&
"Number of outputs of functions should be same.");
"Ranges of functions should be same.");
assert(getSpace().isCompatible(func.getSpace()) &&
"Space is not compatible.");
// The algorithm used here is as follows:
// - Add the output of funcB for the part of the domain where both funcA and
// funcB are defined, and `tiebreak` chooses the output of funcB.
// - Add the output of funcA, where funcB is not defined or `tiebreak` chooses
// funcA over funcB.
// - Add the output of funcB, where funcA is not defined.
// - Add the output of pieceB for the part of the domain where both pieceA and
// pieceB are defined, and `tiebreak` chooses the output of pieceB.
// - Add the output of pieceA, where pieceB is not defined or `tiebreak`
// chooses
// pieceA over pieceB.
// - Add the output of pieceB, where pieceA is not defined.
// Add parts of the common domain where funcB's output is used. Also
// add all the parts where funcA's output is used, both common and non-common.
PWMAFunction result(getSpace(), getNumOutputs());
for (const MultiAffineFunction &funcA : pieces) {
PresburgerSet dom(funcA.getDomain());
for (const MultiAffineFunction &funcB : func.pieces) {
PresburgerSet better = tiebreak(funcB, funcA);
// Add the output of funcB, where it is better than output of funcA.
// Add parts of the common domain where pieceB's output is used. Also
// add all the parts where pieceA's output is used, both common and
// non-common.
PWMAFunction result(getSpace());
for (const Piece &pieceA : pieces) {
PresburgerSet dom(pieceA.domain);
for (const Piece &pieceB : func.pieces) {
PresburgerSet better = tiebreak(pieceB, pieceA);
// Add the output of pieceB, where it is better than output of pieceA.
// The disjuncts in "better" will be disjoint as tiebreak should gurantee
// that.
result.addPiece(better, funcB.getOutputMatrix());
result.addPiece({better, pieceB.output});
dom = dom.subtract(better);
}
// Add output of funcA, where it is better than funcB, or funcB is not
// Add output of pieceA, where it is better than pieceB, or pieceB is not
// defined.
//
// `dom` here is guranteed to be disjoint from already added pieces
// because because the pieces added before are either:
// - Subsets of the domain of other MAFs in `this`, which are guranteed
// to be disjoint from `dom`, or
// - They are one of the pieces added for `funcB`, and we have been
// - They are one of the pieces added for `pieceB`, and we have been
// subtracting all such pieces from `dom`, so `dom` is disjoint from those
// pieces as well.
result.addPiece(dom, funcA.getOutputMatrix());
result.addPiece({dom, pieceA.output});
}
// Add parts of funcB which are not shared with funcA.
// Add parts of pieceB which are not shared with pieceA.
PresburgerSet dom = getDomain();
for (const MultiAffineFunction &funcB : func.pieces)
result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix());
for (const Piece &pieceB : func.pieces)
result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
return result;
}
@ -280,21 +267,19 @@ PWMAFunction PWMAFunction::unionFunction(
/// taking the lexicographically smaller output and otherwise, by taking the
/// lexicographically larger output.
template <bool lexMin>
static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
const MultiAffineFunction &mafB) {
static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
const PWMAFunction::Piece &pieceB) {
// TODO: Support local variables here.
assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) &&
"Domain spaces should be compatible.");
assert(mafA.getNumOutputs() == mafB.getNumOutputs() &&
"Number of outputs of both functions should be same.");
assert(mafA.getDomain().getNumLocalVars() == 0 &&
assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) &&
"Pieces should be compatible");
assert(pieceA.domain.getSpace().getNumLocalVars() == 0 &&
"Local variables are not supported yet.");
PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals();
const PresburgerSpace &space = mafA.getDomain().getSpace();
PresburgerSpace compatibleSpace = pieceA.domain.getSpace();
const PresburgerSpace &space = pieceA.domain.getSpace();
// We first create the set `result`, corresponding to the set where output
// of mafA is lexicographically larger/smaller than mafB. This is done by
// of pieceA is lexicographically larger/smaller than pieceB. This is done by
// creating a PresburgerSet with the following constraints:
//
// (outA[0] > outB[0]) U
@ -312,14 +297,15 @@ static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
// ...
// (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
IntegerPolyhedron levelSet(/*numReservedInequalities=*/1,
/*numReservedEqualities=*/mafA.getNumOutputs(),
/*numReservedCols=*/space.getNumVars() + 1, space);
for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) {
IntegerPolyhedron levelSet(
/*numReservedInequalities=*/1,
/*numReservedEqualities=*/pieceA.output.getNumOutputs(),
/*numReservedCols=*/space.getNumVars() + 1, space);
for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) {
// Create the expression `outA - outB` for this level.
SmallVector<int64_t, 8> subExpr =
subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level));
SmallVector<int64_t, 8> subExpr = subtractExprs(
pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level));
if (lexMin) {
// For lexMin, we add an upper bound of -1:
@ -343,10 +329,9 @@ static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
levelSet.addEquality(subExpr);
}
// We then intersect `result` with the domain of mafA and mafB, to only
// We then intersect `result` with the domain of pieceA and pieceB, to only
// tiebreak on the domain where both are defined.
result = result.intersect(PresburgerSet(mafA.getDomain()))
.intersect(PresburgerSet(mafB.getDomain()));
result = result.intersect(pieceA.domain).intersect(pieceB.domain);
return result;
}
@ -358,3 +343,93 @@ PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex</*lexMin=*/false>);
}
void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for subtraction.");
MultiAffineFunction copyOther = other;
mergeDivs(copyOther);
for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
output.addToRow(i, copyOther.getOutputExpr(i), -1);
// Check consistency.
assertIsConsistent();
}
/// Adds division constraints corresponding to local variables, given a
/// relation and division representations of the local variables in the
/// relation.
static void addDivisionConstraints(IntegerRelation &rel,
const DivisionRepr &divs) {
assert(divs.hasAllReprs() &&
"All divisions in divs should have a representation");
assert(rel.getNumVars() == divs.getNumVars() &&
"Relation and divs should have the same number of vars");
assert(rel.getNumLocalVars() == divs.getNumDivs() &&
"Relation and divs should have the same number of local vars");
for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
divs.getDivOffset() + i));
rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
divs.getDivOffset() + i));
}
}
IntegerRelation MultiAffineFunction::getAsRelation() const {
// Create a relation corressponding to the input space plus the divisions
// used in outputs.
IntegerRelation result(PresburgerSpace::getRelationSpace(
space.getNumDomainVars(), 0, space.getNumSymbolVars(),
space.getNumLocalVars()));
// Add division constraints corresponding to divisions used in outputs.
addDivisionConstraints(result, divs);
// The outputs are represented as range variables in the relation. We add
// range variables for the outputs.
result.insertVar(VarKind::Range, 0, getNumOutputs());
// Add equalities such that the i^th range variable is equal to the i^th
// output expression.
SmallVector<int64_t, 8> eq(result.getNumCols());
for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
// TODO: Add functions to get VarKind offsets in output in MAF and use them
// here.
// The output expression does not contain range variables, while the
// equality does. So, we need to copy all variables and mark all range
// variables as 0 in the equality.
ArrayRef<int64_t> expr = getOutputExpr(i);
// Copy domain variables in `expr` to domain variables in `eq`.
std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
// Fill the range variables in `eq` as zero.
std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
// Copy remaining variables in `expr` to the remaining variables in `eq`.
std::copy(expr.begin() + getNumDomainVars(), expr.end(),
eq.begin() + result.getVarKindEnd(VarKind::Range));
// Set the i^th range var to -1 in `eq` to equate the output expression to
// this range var.
eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
// Add the equality `rangeVar_i = output[i]`.
result.addEquality(eq);
}
return result;
}
void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
space.removeVarRange(VarKind::Range, start, end);
for (Piece &piece : pieces)
piece.output.removeOutputs(start, end);
}
Optional<SmallVector<int64_t, 8>>
PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == getNumDomainVars() + getNumSymbolVars());
for (const Piece &piece : pieces)
if (piece.domain.containsPoint(point))
return piece.output.valueAt(point);
return None;
}

View File

@ -13,6 +13,15 @@
using namespace mlir;
using namespace presburger;
PresburgerSpace PresburgerSpace::getDomainSpace() const {
// TODO: Preserve identifiers here.
return PresburgerSpace::getSetSpace(numDomain, numSymbols, numLocals);
}
PresburgerSpace PresburgerSpace::getRangeSpace() const {
return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals);
}
unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
if (kind == VarKind::Domain)
return getNumDomainVars();

View File

@ -466,7 +466,14 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
}
output.appendExtraRow(sample);
}
result.lexmin.addPiece(domainPoly, output);
// Store the output in a MultiAffineFunction and add it the result.
PresburgerSpace funcSpace = result.lexmin.getSpace();
funcSpace.insertVar(VarKind::Local, 0, domainPoly.getNumLocalVars());
result.lexmin.addPiece(
{PresburgerSet(domainPoly),
MultiAffineFunction(funcSpace, output, domainPoly.getLocalReprs())});
}
Optional<unsigned> SymbolicLexSimplex::maybeGetAlwaysViolatedRow() {
@ -508,7 +515,10 @@ LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
}
SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol);
SymbolicLexMin result(PresburgerSpace::getRelationSpace(
/*numDomain=*/domainPoly.getNumDimVars(),
/*numRange=*/var.size() - nSymbol,
/*numSymbols=*/domainPoly.getNumSymbolVars()));
/// The algorithm is more naturally expressed recursively, but we implement
/// it iteratively here to avoid potential issues with stack overflows in the

View File

@ -16,6 +16,8 @@
#include "mlir/Support/MathExtras.h"
#include <numeric>
#include <numeric>
using namespace mlir;
using namespace presburger;
@ -280,10 +282,8 @@ void presburger::mergeLocalVars(
DivisionRepr divsA = relA.getLocalReprs();
DivisionRepr divsB = relB.getLocalReprs();
for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) {
divsA.setDividend(i, divsB.getDividend(i));
divsA.getDenom(i) = divsB.getDenom(i);
}
for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i)
divsA.setDiv(i, divsB.getDividend(i), divsB.getDenom(i));
// Remove duplicate divisions from divsA. The removing duplicate divisions
// call, calls `merge` to effectively merge divisions in relA and relB.
@ -357,6 +357,55 @@ SmallVector<int64_t, 8> presburger::getComplementIneq(ArrayRef<int64_t> ineq) {
return coeffs;
}
SmallVector<Optional<int64_t>, 4>
DivisionRepr::divValuesAt(ArrayRef<int64_t> point) const {
assert(point.size() == getNumNonDivs() && "Incorrect point size");
SmallVector<Optional<int64_t>, 4> divValues(getNumDivs(), None);
bool changed = true;
while (changed) {
changed = false;
for (unsigned i = 0, e = getNumDivs(); i < e; ++i) {
// If division value is found, continue;
if (divValues[i])
continue;
ArrayRef<int64_t> dividend = getDividend(i);
int64_t divVal = 0;
// Check if we have all the division values required for this division.
unsigned j, f;
for (j = 0, f = getNumDivs(); j < f; ++j) {
if (dividend[getDivOffset() + j] == 0)
continue;
// Division value required, but not found yet.
if (!divValues[j])
break;
divVal += dividend[getDivOffset() + j] * divValues[j].value();
}
// We have some division values that are still not found, but are required
// to find the value of this division.
if (j < f)
continue;
// Fill remaining values.
divVal = std::inner_product(point.begin(), point.end(), dividend.begin(),
divVal);
// Add constant.
divVal += dividend.back();
// Take floor division with denominator.
divVal = floorDiv(divVal, denoms[i]);
// Set div value and continue.
divValues[i] = divVal;
changed = true;
}
}
return divValues;
}
void DivisionRepr::removeDuplicateDivs(
llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
@ -402,6 +451,23 @@ void DivisionRepr::removeDuplicateDivs(
}
}
void DivisionRepr::insertDiv(unsigned pos, ArrayRef<int64_t> dividend,
unsigned divisor) {
assert(pos <= getNumDivs() && "Invalid insertion position");
assert(dividend.size() == getNumVars() + 1 && "Incorrect dividend size");
dividends.appendExtraRow(dividend);
denoms.insert(denoms.begin() + pos, divisor);
dividends.insertColumn(getDivOffset() + pos);
}
void DivisionRepr::insertDiv(unsigned pos, unsigned num) {
assert(pos <= getNumDivs() && "Invalid insertion position");
dividends.insertColumns(getDivOffset() + pos, num);
dividends.insertRows(pos, num);
denoms.insert(denoms.begin() + pos, num, 0);
}
void DivisionRepr::print(raw_ostream &os) const {
os << "Dividends:\n";
dividends.print(os);

View File

@ -1171,7 +1171,7 @@ void expectSymbolicIntegerLexMin(
ASSERT_NE(poly.getNumSymbolVars(), 0u);
PWMAFunction expectedLexmin =
parsePWMAF(/*numInputs=*/poly.getNumSymbolVars(),
parsePWMAF(/*numInputs=*/0,
/*numOutputs=*/poly.getNumDimVars(), expectedLexminRepr,
/*numSymbols=*/poly.getNumSymbolVars());

View File

@ -130,7 +130,7 @@ TEST(IntegerRelationTest, symbolicLexmin) {
.findSymbolicIntegerLexMin();
PWMAFunction expectedLexmin =
parsePWMAF(/*numInputs=*/2,
parsePWMAF(/*numInputs=*/1,
/*numOutputs=*/1,
{
{"(a)[b] : (a - b >= 0)", {{1, 0, 0}}}, // a

View File

@ -73,14 +73,20 @@ inline PWMAFunction parsePWMAF(
unsigned numSymbols = 0) {
static MLIRContext context;
PWMAFunction result(PresburgerSpace::getSetSpace(
/*numDims=*/numInputs - numSymbols, numSymbols),
numOutputs);
PWMAFunction result(
PresburgerSpace::getRelationSpace(numInputs, numOutputs, numSymbols));
for (const auto &pair : data) {
IntegerPolyhedron domain = parsePoly(pair.first);
PresburgerSpace funcSpace = result.getSpace();
funcSpace.insertVar(VarKind::Local, 0, domain.getNumLocalVars());
result.addPiece(
domain, makeMatrix(numOutputs, domain.getNumVars() + 1, pair.second));
{PresburgerSet(domain),
MultiAffineFunction(
funcSpace,
makeMatrix(numOutputs, domain.getNumVars() + 1, pair.second),
domain.getLocalReprs())});
}
return result;
}