forked from OSchip/llvm-project
[MLIR][Presburger] Add support for piece-wise multi-affine functions
Add the class MultiAffineFunction which represents functions whose domain is an IntegerPolyhedron and which produce an output given by a tuple of affine expressions in the IntegerPolyhedron's ids. Also add support for piece-wise MultiAffineFunctions, which are defined on a union of IntegerPolyhedrons, and may have different output affine expressions on each IntegerPolyhedron. Thus the function is affine on each individual IntegerPolyhedron piece in the domain. This is part of a series of patches leading up to parametric integer programming. Depends on D118778. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D118779
This commit is contained in:
parent
570471199b
commit
d5a2944219
|
@ -56,6 +56,7 @@ public:
|
||||||
enum class Kind {
|
enum class Kind {
|
||||||
FlatAffineConstraints,
|
FlatAffineConstraints,
|
||||||
FlatAffineValueConstraints,
|
FlatAffineValueConstraints,
|
||||||
|
MultiAffineFunction,
|
||||||
IntegerPolyhedron
|
IntegerPolyhedron
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -194,6 +195,11 @@ public:
|
||||||
/// Adds an equality from the coefficients specified in `eq`.
|
/// Adds an equality from the coefficients specified in `eq`.
|
||||||
void addEquality(ArrayRef<int64_t> eq);
|
void addEquality(ArrayRef<int64_t> eq);
|
||||||
|
|
||||||
|
/// Eliminate the `posB^th` local identifier, replacing every instance of it
|
||||||
|
/// with the `posA^th` local identifier. This should be used when the two
|
||||||
|
/// local variables are known to always take the same values.
|
||||||
|
virtual void eliminateRedundantLocalId(unsigned posA, unsigned posB);
|
||||||
|
|
||||||
/// Removes identifiers of the specified kind with the specified pos (or
|
/// Removes identifiers of the specified kind with the specified pos (or
|
||||||
/// within the specified range) from the system. The specified location is
|
/// within the specified range) from the system. The specified location is
|
||||||
/// relative to the first identifier of the specified kind.
|
/// relative to the first identifier of the specified kind.
|
||||||
|
@ -273,6 +279,9 @@ public:
|
||||||
|
|
||||||
/// Returns true if the given point satisfies the constraints, or false
|
/// Returns true if the given point satisfies the constraints, or false
|
||||||
/// otherwise.
|
/// otherwise.
|
||||||
|
///
|
||||||
|
/// Note: currently, if the polyhedron contains local ids, the values of
|
||||||
|
/// the local ids must also be provided.
|
||||||
bool containsPoint(ArrayRef<int64_t> point) const;
|
bool containsPoint(ArrayRef<int64_t> point) const;
|
||||||
|
|
||||||
/// Find equality and pairs of inequality contraints identified by their
|
/// Find equality and pairs of inequality contraints identified by their
|
||||||
|
|
|
@ -0,0 +1,195 @@
|
||||||
|
//===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Support for piece-wise multi-affine functions. These are functions that are
|
||||||
|
// defined on a domain that is a union of IntegerPolyhedrons, and on each domain
|
||||||
|
// the value of the function is a tuple of integers, with each value in the
|
||||||
|
// tuple being an affine expression in the ids of the IntegerPolyhedron.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
|
||||||
|
#define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
|
||||||
|
|
||||||
|
#include "mlir/Analysis/Presburger/IntegerPolyhedron.h"
|
||||||
|
#include "mlir/Analysis/Presburger/PresburgerSet.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
/// This class represents a multi-affine function whose domain is given by an
|
||||||
|
/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a
|
||||||
|
/// tuple of integer values attached to every point in the polyhedron, with the
|
||||||
|
/// value of each element of the tuple given by an affine expression in the ids
|
||||||
|
/// of the polyhedron. For example we could have the domain
|
||||||
|
///
|
||||||
|
/// (x, y) : (x >= 5, y >= x)
|
||||||
|
///
|
||||||
|
/// and a tuple of three integers defined at every point in the polyhedron:
|
||||||
|
///
|
||||||
|
/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
|
||||||
|
///
|
||||||
|
/// In this way every point in the polyhedron has a tuple of integers associated
|
||||||
|
/// with it. If the integer polyhedron has local ids, then the output
|
||||||
|
/// expressions can use them as well. The output expressions are represented as
|
||||||
|
/// a matrix with one row for every element in the output vector one column for
|
||||||
|
/// each id, and an extra column at the end for the constant term.
|
||||||
|
///
|
||||||
|
/// Checking equality of two such functions is supported, as well as finding the
|
||||||
|
/// value of the function at a specified point. Note that local ids in the
|
||||||
|
/// domain are not yet supported for finding the value at a point.
|
||||||
|
class MultiAffineFunction : protected IntegerPolyhedron {
|
||||||
|
public:
|
||||||
|
/// We use protected inheritance to avoid inheriting the whole public
|
||||||
|
/// interface of IntegerPolyhedron. These using declarations explicitly make
|
||||||
|
/// only the relevant functions part of the public interface.
|
||||||
|
using IntegerPolyhedron::getNumDimAndSymbolIds;
|
||||||
|
using IntegerPolyhedron::getNumDimIds;
|
||||||
|
using IntegerPolyhedron::getNumIds;
|
||||||
|
using IntegerPolyhedron::getNumLocalIds;
|
||||||
|
using IntegerPolyhedron::getNumSymbolIds;
|
||||||
|
|
||||||
|
MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
|
||||||
|
: IntegerPolyhedron(domain), output(output) {}
|
||||||
|
MultiAffineFunction(const Matrix &output, unsigned numDims,
|
||||||
|
unsigned numSymbols = 0, unsigned numLocals = 0)
|
||||||
|
: IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {}
|
||||||
|
|
||||||
|
~MultiAffineFunction() override = default;
|
||||||
|
Kind getKind() const override { return Kind::MultiAffineFunction; }
|
||||||
|
bool classof(const IntegerPolyhedron *poly) const {
|
||||||
|
return poly->getKind() == Kind::MultiAffineFunction;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getNumInputs() const { return getNumDimAndSymbolIds(); }
|
||||||
|
unsigned getNumOutputs() const { return output.getNumRows(); }
|
||||||
|
bool isConsistent() const { return output.getNumColumns() == numIds + 1; }
|
||||||
|
const IntegerPolyhedron &getDomain() const { return *this; }
|
||||||
|
|
||||||
|
bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
|
||||||
|
|
||||||
|
/// Insert `num` identifiers of the specified kind at position `pos`.
|
||||||
|
/// Positions are relative to the kind of identifier. The coefficient columns
|
||||||
|
/// corresponding to the added identifiers are initialized to zero. Return the
|
||||||
|
/// absolute column position (i.e., not relative to the kind of identifier)
|
||||||
|
/// of the first added identifier.
|
||||||
|
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
|
||||||
|
|
||||||
|
/// Swap the posA^th identifier with the posB^th identifier.
|
||||||
|
void swapId(unsigned posA, unsigned posB) override;
|
||||||
|
|
||||||
|
/// Remove the specified range of ids.
|
||||||
|
void removeIdRange(unsigned idStart, unsigned idLimit) override;
|
||||||
|
|
||||||
|
/// Eliminate the `posB^th` local identifier, replacing every instance of it
|
||||||
|
/// with the `posA^th` local identifier. This should be used when the two
|
||||||
|
/// local variables are known to always take the same values.
|
||||||
|
void eliminateRedundantLocalId(unsigned posA, unsigned posB) override;
|
||||||
|
|
||||||
|
/// Return whether the outputs of `this` and `other` agree wherever both
|
||||||
|
/// functions are defined, i.e., the outputs should be equal for all points in
|
||||||
|
/// the intersection of the domains.
|
||||||
|
bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const;
|
||||||
|
|
||||||
|
/// Return whether the `this` and `other` are equal. This is the case if
|
||||||
|
/// they lie in the same space, i.e. have the same dimensions, and their
|
||||||
|
/// domains are identical and their outputs are equal on their domain.
|
||||||
|
bool isEqual(const MultiAffineFunction &other) const;
|
||||||
|
|
||||||
|
/// Get the value of the function at the specified point. If the point lies
|
||||||
|
/// outside the domain, an empty optional is returned.
|
||||||
|
///
|
||||||
|
/// Note: domains with local ids are not yet supported, and will assert-fail.
|
||||||
|
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
|
||||||
|
|
||||||
|
void print(raw_ostream &os) const;
|
||||||
|
|
||||||
|
void dump() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// The function's output is a tuple of integers, with the ith element of the
|
||||||
|
/// tuple defined by the affine expression given by the ith row of this output
|
||||||
|
/// matrix.
|
||||||
|
Matrix output;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// This class represents a piece-wise MultiAffineFunction. This can be thought
|
||||||
|
/// of as a list of MultiAffineFunction with disjoint domains, with each having
|
||||||
|
/// their own affine expressions for their output tuples. For example, we could
|
||||||
|
/// have a function with two input variables (x, y), defined as
|
||||||
|
///
|
||||||
|
/// f(x, y) = (2*x + y, y - 4) if x >= 0, y >= 0
|
||||||
|
/// = (-2*x + y, y + 4) if x < 0, y < 0
|
||||||
|
/// = (4, 1) if x < 0, y >= 0
|
||||||
|
///
|
||||||
|
/// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of
|
||||||
|
/// this class is undefined. The domains need not cover all possible points;
|
||||||
|
/// this represents a partial function and so could be undefined at some points.
|
||||||
|
///
|
||||||
|
/// As in PresburgerSets, the input ids are partitioned into dimension ids and
|
||||||
|
/// symbolic ids.
|
||||||
|
///
|
||||||
|
/// Support is provided to compare equality of two such functions as well as
|
||||||
|
/// finding the value of the function at a point. Note that local ids in the
|
||||||
|
/// piece are not supported for the latter.
|
||||||
|
class PWMAFunction {
|
||||||
|
public:
|
||||||
|
PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
|
||||||
|
: numDims(numDims), numSymbols(numSymbols), numOutputs(numOutputs) {
|
||||||
|
assert(numOutputs >= 1 && "The function must output something!");
|
||||||
|
}
|
||||||
|
|
||||||
|
void addPiece(const MultiAffineFunction &piece);
|
||||||
|
void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
|
||||||
|
|
||||||
|
const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
|
||||||
|
unsigned getNumPieces() const { return pieces.size(); }
|
||||||
|
unsigned getNumOutputs() const { return numOutputs; }
|
||||||
|
unsigned getNumInputs() const { return numDims + numSymbols; }
|
||||||
|
unsigned getNumDimIds() const { return numDims; }
|
||||||
|
unsigned getNumSymbolIds() const { return numSymbols; }
|
||||||
|
MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
|
||||||
|
|
||||||
|
/// Return the domain of this piece-wise MultiAffineFunction. This is the
|
||||||
|
/// union of the domains of all the pieces.
|
||||||
|
PresburgerSet getDomain() const;
|
||||||
|
|
||||||
|
/// Check whether the `this` and the given function have compatible
|
||||||
|
/// dimensions, i.e., the same number of dimension inputs, symbol inputs, and
|
||||||
|
/// outputs.
|
||||||
|
bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
|
||||||
|
bool hasCompatibleDimensions(const PWMAFunction &f) const;
|
||||||
|
|
||||||
|
/// Return the value at the specified point and an empty optional if the
|
||||||
|
/// point does not lie in the domain.
|
||||||
|
///
|
||||||
|
/// Note: domains with local ids are not yet supported, and will assert-fail.
|
||||||
|
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
|
||||||
|
|
||||||
|
/// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
|
||||||
|
/// they have the same dimensions, the same domain and they take the same
|
||||||
|
/// value at every point in the domain.
|
||||||
|
bool isEqual(const PWMAFunction &other) const;
|
||||||
|
|
||||||
|
void print(raw_ostream &os) const;
|
||||||
|
void dump() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// The list of pieces in this piece-wise MultiAffineFunction.
|
||||||
|
SmallVector<MultiAffineFunction, 4> pieces;
|
||||||
|
|
||||||
|
/// The number of dimensions ids in the domains.
|
||||||
|
unsigned numDims;
|
||||||
|
/// The number of symbol ids in the domains.
|
||||||
|
unsigned numSymbols;
|
||||||
|
/// The number of output ids.
|
||||||
|
unsigned numOutputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
|
|
@ -3,6 +3,7 @@ add_mlir_library(MLIRPresburger
|
||||||
LinearTransform.cpp
|
LinearTransform.cpp
|
||||||
Matrix.cpp
|
Matrix.cpp
|
||||||
PresburgerSet.cpp
|
PresburgerSet.cpp
|
||||||
|
PWMAFunction.cpp
|
||||||
Simplex.cpp
|
Simplex.cpp
|
||||||
Utils.cpp
|
Utils.cpp
|
||||||
|
|
||||||
|
|
|
@ -1065,24 +1065,17 @@ void IntegerPolyhedron::removeRedundantConstraints() {
|
||||||
equalities.resizeVertically(pos);
|
equalities.resizeVertically(pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Eliminate `pos2^th` local identifier, replacing its every instance with
|
void IntegerPolyhedron::eliminateRedundantLocalId(unsigned posA,
|
||||||
/// `pos1^th` local identifier. This function is intended to be used to remove
|
unsigned posB) {
|
||||||
/// redundancy when local variables at position `pos1` and `pos2` are restricted
|
assert(posA < getNumLocalIds() && "Invalid local id position");
|
||||||
/// to have the same value.
|
assert(posB < getNumLocalIds() && "Invalid local id position");
|
||||||
static void eliminateRedundantLocalId(IntegerPolyhedron &poly, unsigned pos1,
|
|
||||||
unsigned pos2) {
|
|
||||||
|
|
||||||
assert(pos1 < poly.getNumLocalIds() && "Invalid local id position");
|
unsigned localOffset = getIdKindOffset(IdKind::Local);
|
||||||
assert(pos2 < poly.getNumLocalIds() && "Invalid local id position");
|
posA += localOffset;
|
||||||
|
posB += localOffset;
|
||||||
unsigned localOffset = poly.getNumDimAndSymbolIds();
|
inequalities.addToColumn(posB, posA, 1);
|
||||||
pos1 += localOffset;
|
equalities.addToColumn(posB, posA, 1);
|
||||||
pos2 += localOffset;
|
removeId(posB);
|
||||||
for (unsigned i = 0, e = poly.getNumInequalities(); i < e; ++i)
|
|
||||||
poly.atIneq(i, pos1) += poly.atIneq(i, pos2);
|
|
||||||
for (unsigned i = 0, e = poly.getNumEqualities(); i < e; ++i)
|
|
||||||
poly.atEq(i, pos1) += poly.atEq(i, pos2);
|
|
||||||
poly.removeId(pos2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Adds additional local ids to the sets such that they both have the union
|
/// Adds additional local ids to the sets such that they both have the union
|
||||||
|
@ -1129,8 +1122,8 @@ void IntegerPolyhedron::mergeLocalIds(IntegerPolyhedron &other) {
|
||||||
// Merge function that merges the local variables in both sets by treating
|
// Merge function that merges the local variables in both sets by treating
|
||||||
// them as the same identifier.
|
// them as the same identifier.
|
||||||
auto merge = [&polyA, &polyB](unsigned i, unsigned j) -> bool {
|
auto merge = [&polyA, &polyB](unsigned i, unsigned j) -> bool {
|
||||||
eliminateRedundantLocalId(polyA, i, j);
|
polyA.eliminateRedundantLocalId(i, j);
|
||||||
eliminateRedundantLocalId(polyB, i, j);
|
polyB.eliminateRedundantLocalId(i, j);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Analysis/Presburger/PWMAFunction.h"
|
||||||
|
#include "mlir/Analysis/Presburger/Simplex.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
// Return the result of subtracting the two given vectors pointwise.
|
||||||
|
// The vectors must be of the same size.
|
||||||
|
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
|
||||||
|
static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
|
||||||
|
ArrayRef<int64_t> vecB) {
|
||||||
|
assert(vecA.size() == vecB.size() &&
|
||||||
|
"Cannot subtract vectors of differing lengths!");
|
||||||
|
SmallVector<int64_t, 8> result;
|
||||||
|
result.reserve(vecA.size());
|
||||||
|
for (unsigned i = 0, e = vecA.size(); i < e; ++i)
|
||||||
|
result.push_back(vecA[i] - vecB[i]);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
PresburgerSet PWMAFunction::getDomain() const {
|
||||||
|
PresburgerSet domain =
|
||||||
|
PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds());
|
||||||
|
for (const MultiAffineFunction &piece : pieces)
|
||||||
|
domain.unionPolyInPlace(piece.getDomain());
|
||||||
|
return domain;
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<SmallVector<int64_t, 8>>
|
||||||
|
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
|
||||||
|
assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
|
||||||
|
assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
|
||||||
|
|
||||||
|
if (!getDomain().containsPoint(point))
|
||||||
|
return {};
|
||||||
|
|
||||||
|
// The point lies in the domain, so we need to compute the output value.
|
||||||
|
// The matrix `output` has an affine expression in the ith row, corresponding
|
||||||
|
// to the expression for the ith value in the output vector. The last column
|
||||||
|
// of the matrix contains the constant term. Let v be the input point with
|
||||||
|
// a 1 appended at the end. We can see that output * v gives the desired
|
||||||
|
// output vector.
|
||||||
|
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
|
||||||
|
pointHomogenous.push_back(1);
|
||||||
|
SmallVector<int64_t, 8> result =
|
||||||
|
output.postMultiplyWithColumn(pointHomogenous);
|
||||||
|
assert(result.size() == getNumOutputs());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<SmallVector<int64_t, 8>>
|
||||||
|
PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
|
||||||
|
assert(point.size() == getNumInputs() &&
|
||||||
|
"Point has incorrect dimensionality!");
|
||||||
|
for (const MultiAffineFunction &piece : pieces)
|
||||||
|
if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
|
||||||
|
return output;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
void MultiAffineFunction::print(raw_ostream &os) const {
|
||||||
|
os << "Domain:";
|
||||||
|
IntegerPolyhedron::print(os);
|
||||||
|
os << "Output:\n";
|
||||||
|
output.print(os);
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void MultiAffineFunction::dump() const { print(llvm::errs()); }
|
||||||
|
|
||||||
|
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
|
||||||
|
return hasCompatibleDimensions(other) &&
|
||||||
|
getDomain().isEqual(other.getDomain()) &&
|
||||||
|
isEqualWhereDomainsOverlap(other);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
|
||||||
|
unsigned num) {
|
||||||
|
unsigned absolutePos = getIdKindOffset(kind) + pos;
|
||||||
|
output.insertColumns(absolutePos, num);
|
||||||
|
return IntegerPolyhedron::insertId(kind, pos, num);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
|
||||||
|
output.swapColumns(posA, posB);
|
||||||
|
IntegerPolyhedron::swapId(posA, posB);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) {
|
||||||
|
output.removeColumns(idStart, idLimit - idStart);
|
||||||
|
IntegerPolyhedron::removeIdRange(idStart, idLimit);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
|
||||||
|
unsigned posB) {
|
||||||
|
output.addToColumn(posB, posA, /*scale=*/1);
|
||||||
|
IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MultiAffineFunction::isEqualWhereDomainsOverlap(
|
||||||
|
MultiAffineFunction other) const {
|
||||||
|
if (!hasCompatibleDimensions(other))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// `commonFunc` has the same output as `this`.
|
||||||
|
MultiAffineFunction commonFunc = *this;
|
||||||
|
// After this merge, `commonFunc` and `other` have the same local ids; they
|
||||||
|
// are merged.
|
||||||
|
commonFunc.mergeLocalIds(other);
|
||||||
|
// After this, the domain of `commonFunc` will be the intersection of the
|
||||||
|
// domains of `this` and `other`.
|
||||||
|
commonFunc.IntegerPolyhedron::append(other);
|
||||||
|
|
||||||
|
// `commonDomainMatching` contains the subset of the common domain
|
||||||
|
// where the outputs of `this` and `other` match.
|
||||||
|
//
|
||||||
|
// We want to add constraints equating the outputs of `this` and `other`.
|
||||||
|
// However, `this` may have difference local ids from `other`, whereas we
|
||||||
|
// need both to have the same locals. Accordingly, we use `commonFunc.output`
|
||||||
|
// in place of `this->output`, since `commonFunc` has the same output but also
|
||||||
|
// has its locals merged.
|
||||||
|
IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
|
||||||
|
for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
|
||||||
|
commonDomainMatching.addEquality(
|
||||||
|
subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
|
||||||
|
|
||||||
|
// If the whole common domain is a subset of commonDomainMatching, then they
|
||||||
|
// are equal and the two functions match on the whole common domain.
|
||||||
|
return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Two PWMAFunctions are equal if they have the same dimensionalities,
|
||||||
|
/// the same domain, and take the same value at every point in the domain.
|
||||||
|
bool PWMAFunction::isEqual(const PWMAFunction &other) const {
|
||||||
|
if (!hasCompatibleDimensions(other))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (!this->getDomain().isEqual(other.getDomain()))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Check if, whenever the domains of a piece of `this` and a piece of `other`
|
||||||
|
// overlap, they take the same output value. If `this` and `other` have the
|
||||||
|
// same domain (checked above), then this check passes iff the two functions
|
||||||
|
// have the same output at every point in the domain.
|
||||||
|
for (const MultiAffineFunction &aPiece : this->pieces)
|
||||||
|
for (const MultiAffineFunction &bPiece : other.pieces)
|
||||||
|
if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
|
||||||
|
assert(hasCompatibleDimensions(piece) &&
|
||||||
|
"Piece to be added is not compatible with this PWMAFunction!");
|
||||||
|
assert(piece.isConsistent() && "Piece is internally inconsistent!");
|
||||||
|
assert(this->getDomain()
|
||||||
|
.intersect(PresburgerSet(piece.getDomain()))
|
||||||
|
.isIntegerEmpty() &&
|
||||||
|
"New piece's domain overlaps with that of existing pieces!");
|
||||||
|
pieces.push_back(piece);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
|
||||||
|
const Matrix &output) {
|
||||||
|
addPiece(MultiAffineFunction(domain, output));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PWMAFunction::print(raw_ostream &os) const {
|
||||||
|
os << pieces.size() << " pieces:\n";
|
||||||
|
for (const MultiAffineFunction &piece : pieces)
|
||||||
|
piece.print(os);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The hasCompatibleDimensions functions don't check the number of local ids;
|
||||||
|
/// functions are still compatible if they have differing number of locals.
|
||||||
|
bool MultiAffineFunction::hasCompatibleDimensions(
|
||||||
|
const MultiAffineFunction &f) const {
|
||||||
|
return getNumDimIds() == f.getNumDimIds() &&
|
||||||
|
getNumSymbolIds() == f.getNumSymbolIds() &&
|
||||||
|
getNumOutputs() == f.getNumOutputs();
|
||||||
|
}
|
||||||
|
bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const {
|
||||||
|
return getNumDimIds() == f.getNumDimIds() &&
|
||||||
|
getNumSymbolIds() == f.getNumSymbolIds() &&
|
||||||
|
getNumOutputs() == f.getNumOutputs();
|
||||||
|
}
|
||||||
|
bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const {
|
||||||
|
return getNumDimIds() == f.getNumDimIds() &&
|
||||||
|
getNumSymbolIds() == f.getNumSymbolIds() &&
|
||||||
|
getNumOutputs() == f.getNumOutputs();
|
||||||
|
}
|
|
@ -3,6 +3,7 @@ add_mlir_unittest(MLIRPresburgerTests
|
||||||
LinearTransformTest.cpp
|
LinearTransformTest.cpp
|
||||||
MatrixTest.cpp
|
MatrixTest.cpp
|
||||||
PresburgerSetTest.cpp
|
PresburgerSetTest.cpp
|
||||||
|
PWMAFunctionTest.cpp
|
||||||
SimplexTest.cpp
|
SimplexTest.cpp
|
||||||
../../Dialect/Affine/Analysis/AffineStructuresParser.cpp
|
../../Dialect/Affine/Analysis/AffineStructuresParser.cpp
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,183 @@
|
||||||
|
//===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file contains tests for PWMAFunction.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Analysis/Presburger/PWMAFunction.h"
|
||||||
|
#include "../../Dialect/Affine/Analysis/AffineStructuresParser.h"
|
||||||
|
#include "mlir/Analysis/Presburger/PresburgerSet.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
using testing::ElementsAre;
|
||||||
|
|
||||||
|
/// Parses an IntegerPolyhedron from a StringRef. It is expected that the
|
||||||
|
/// string represents a valid IntegerSet, otherwise it will violate a gtest
|
||||||
|
/// assertion.
|
||||||
|
static IntegerPolyhedron parsePoly(StringRef str, MLIRContext *context) {
|
||||||
|
FailureOr<IntegerPolyhedron> poly = parseIntegerSetToFAC(str, context);
|
||||||
|
EXPECT_TRUE(succeeded(poly));
|
||||||
|
return *poly;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Matrix makeMatrix(unsigned numRow, unsigned numColumns,
|
||||||
|
ArrayRef<SmallVector<int64_t, 8>> matrix) {
|
||||||
|
Matrix results(numRow, numColumns);
|
||||||
|
assert(matrix.size() == numRow);
|
||||||
|
for (unsigned i = 0; i < numRow; ++i) {
|
||||||
|
assert(matrix[i].size() == numColumns &&
|
||||||
|
"Output expression has incorrect dimensionality!");
|
||||||
|
for (unsigned j = 0; j < numColumns; ++j)
|
||||||
|
results(i, j) = matrix[i][j];
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct a PWMAFunction given the dimensionalities and an array describing
|
||||||
|
/// the list of pieces. Each piece is given by a string describing the domain
|
||||||
|
/// and a 2D array that represents the output.
|
||||||
|
static PWMAFunction parsePWMAF(
|
||||||
|
unsigned numInputs, unsigned numOutputs,
|
||||||
|
ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
|
||||||
|
data,
|
||||||
|
unsigned numSymbols = 0) {
|
||||||
|
static MLIRContext context;
|
||||||
|
|
||||||
|
PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs);
|
||||||
|
for (const auto &pair : data) {
|
||||||
|
IntegerPolyhedron domain = parsePoly(pair.first, &context);
|
||||||
|
result.addPiece(
|
||||||
|
domain, makeMatrix(numOutputs, domain.getNumIds() + 1, pair.second));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PWAFunctionTest, isEqual) {
|
||||||
|
MLIRContext context;
|
||||||
|
|
||||||
|
// The output expressions are different but it doesn't matter because they are
|
||||||
|
// equal in this domain.
|
||||||
|
PWMAFunction idAtZeros = parsePWMAF(
|
||||||
|
/*numInputs=*/2, /*numOutputs=*/2,
|
||||||
|
{
|
||||||
|
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||||
|
{"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||||
|
{"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
|
||||||
|
});
|
||||||
|
PWMAFunction idAtZeros2 = parsePWMAF(
|
||||||
|
/*numInputs=*/2, /*numOutputs=*/2,
|
||||||
|
{
|
||||||
|
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y).
|
||||||
|
{"(x, y) : (y - 1 >= 0, x == 0)", {{30, 0, 0}, {0, 1, 0}}}, //(30x, y)
|
||||||
|
{"(x, y) : (-y - 1 > =0, x == 0)", {{30, 0, 0}, {0, 1, 0}}} //(30x, y)
|
||||||
|
});
|
||||||
|
EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
|
||||||
|
|
||||||
|
PWMAFunction notIdAtZeros = parsePWMAF(
|
||||||
|
/*numInputs=*/2, /*numOutputs=*/2,
|
||||||
|
{
|
||||||
|
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||||
|
{"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
|
||||||
|
{"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
|
||||||
|
});
|
||||||
|
EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
|
||||||
|
|
||||||
|
// These match at their intersection but one has a bigger domain.
|
||||||
|
PWMAFunction idNoNegNegQuadrant = parsePWMAF(
|
||||||
|
/*numInputs=*/2, /*numOutputs=*/2,
|
||||||
|
{
|
||||||
|
{"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||||
|
{"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
|
||||||
|
});
|
||||||
|
PWMAFunction idOnlyPosX =
|
||||||
|
parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
|
||||||
|
{
|
||||||
|
{"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||||
|
});
|
||||||
|
EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
|
||||||
|
|
||||||
|
// Different representations of the same domain.
|
||||||
|
PWMAFunction sumPlusOne = parsePWMAF(
|
||||||
|
/*numInputs=*/2, /*numOutputs=*/1,
|
||||||
|
{
|
||||||
|
{"(x, y) : (x >= 0)", {{1, 1, 1}}}, // x + y + 1.
|
||||||
|
{"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1.
|
||||||
|
{"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 1, 1}}} // x + y + 1.
|
||||||
|
});
|
||||||
|
PWMAFunction sumPlusOne2 =
|
||||||
|
parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
|
||||||
|
{
|
||||||
|
{"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1.
|
||||||
|
});
|
||||||
|
EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
|
||||||
|
|
||||||
|
// Functions with zero input dimensions.
|
||||||
|
PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
|
||||||
|
{
|
||||||
|
{"() : ()", {{1}}}, // 1.
|
||||||
|
});
|
||||||
|
PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
|
||||||
|
{
|
||||||
|
{"() : ()", {{2}}}, // 1.
|
||||||
|
});
|
||||||
|
EXPECT_TRUE(noInputs1.isEqual(noInputs1));
|
||||||
|
EXPECT_FALSE(noInputs1.isEqual(noInputs2));
|
||||||
|
|
||||||
|
// Mismatched dimensionalities.
|
||||||
|
EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
|
||||||
|
EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
|
||||||
|
|
||||||
|
// Divisions.
|
||||||
|
// Domain is only multiples of 6; x = 6k for some k.
|
||||||
|
// x + 4(x/2) + 4(x/3) == 26k.
|
||||||
|
PWMAFunction mul2AndMul3 = parsePWMAF(
|
||||||
|
/*numInputs=*/1, /*numOutputs=*/1,
|
||||||
|
{
|
||||||
|
{"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
|
||||||
|
{{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3).
|
||||||
|
});
|
||||||
|
PWMAFunction mul6 = parsePWMAF(
|
||||||
|
/*numInputs=*/1, /*numOutputs=*/1,
|
||||||
|
{
|
||||||
|
{"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6).
|
||||||
|
});
|
||||||
|
EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
|
||||||
|
|
||||||
|
PWMAFunction mul6diff = parsePWMAF(
|
||||||
|
/*numInputs=*/1, /*numOutputs=*/1,
|
||||||
|
{
|
||||||
|
{"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6).
|
||||||
|
});
|
||||||
|
EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff));
|
||||||
|
|
||||||
|
PWMAFunction mul5 = parsePWMAF(
|
||||||
|
/*numInputs=*/1, /*numOutputs=*/1,
|
||||||
|
{
|
||||||
|
{"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5).
|
||||||
|
});
|
||||||
|
EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PWMAFunction, valueAt) {
|
||||||
|
PWMAFunction nonNegPWAF = parsePWMAF(
|
||||||
|
/*numInputs=*/2, /*numOutputs=*/2,
|
||||||
|
{
|
||||||
|
{"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
|
||||||
|
{"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
|
||||||
|
});
|
||||||
|
EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
|
||||||
|
EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
|
||||||
|
EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
|
||||||
|
EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir
|
Loading…
Reference in New Issue