forked from OSchip/llvm-project
[MLIR][Presburger] Add support for piece-wise multi-affine functions
Add the class MultiAffineFunction which represents functions whose domain is an IntegerPolyhedron and which produce an output given by a tuple of affine expressions in the IntegerPolyhedron's ids. Also add support for piece-wise MultiAffineFunctions, which are defined on a union of IntegerPolyhedrons, and may have different output affine expressions on each IntegerPolyhedron. Thus the function is affine on each individual IntegerPolyhedron piece in the domain. This is part of a series of patches leading up to parametric integer programming. Depends on D118778. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D118779
This commit is contained in:
parent
570471199b
commit
d5a2944219
|
@ -56,6 +56,7 @@ public:
|
|||
enum class Kind {
|
||||
FlatAffineConstraints,
|
||||
FlatAffineValueConstraints,
|
||||
MultiAffineFunction,
|
||||
IntegerPolyhedron
|
||||
};
|
||||
|
||||
|
@ -194,6 +195,11 @@ public:
|
|||
/// Adds an equality from the coefficients specified in `eq`.
|
||||
void addEquality(ArrayRef<int64_t> eq);
|
||||
|
||||
/// Eliminate the `posB^th` local identifier, replacing every instance of it
|
||||
/// with the `posA^th` local identifier. This should be used when the two
|
||||
/// local variables are known to always take the same values.
|
||||
virtual void eliminateRedundantLocalId(unsigned posA, unsigned posB);
|
||||
|
||||
/// Removes identifiers of the specified kind with the specified pos (or
|
||||
/// within the specified range) from the system. The specified location is
|
||||
/// relative to the first identifier of the specified kind.
|
||||
|
@ -273,6 +279,9 @@ public:
|
|||
|
||||
/// Returns true if the given point satisfies the constraints, or false
|
||||
/// otherwise.
|
||||
///
|
||||
/// Note: currently, if the polyhedron contains local ids, the values of
|
||||
/// the local ids must also be provided.
|
||||
bool containsPoint(ArrayRef<int64_t> point) const;
|
||||
|
||||
/// Find equality and pairs of inequality contraints identified by their
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
//===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Support for piece-wise multi-affine functions. These are functions that are
|
||||
// defined on a domain that is a union of IntegerPolyhedrons, and on each domain
|
||||
// the value of the function is a tuple of integers, with each value in the
|
||||
// tuple being an affine expression in the ids of the IntegerPolyhedron.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
|
||||
#define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
|
||||
|
||||
#include "mlir/Analysis/Presburger/IntegerPolyhedron.h"
|
||||
#include "mlir/Analysis/Presburger/PresburgerSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// This class represents a multi-affine function whose domain is given by an
|
||||
/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a
|
||||
/// tuple of integer values attached to every point in the polyhedron, with the
|
||||
/// value of each element of the tuple given by an affine expression in the ids
|
||||
/// of the polyhedron. For example we could have the domain
|
||||
///
|
||||
/// (x, y) : (x >= 5, y >= x)
|
||||
///
|
||||
/// and a tuple of three integers defined at every point in the polyhedron:
|
||||
///
|
||||
/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
|
||||
///
|
||||
/// In this way every point in the polyhedron has a tuple of integers associated
|
||||
/// with it. If the integer polyhedron has local ids, then the output
|
||||
/// expressions can use them as well. The output expressions are represented as
|
||||
/// a matrix with one row for every element in the output vector one column for
|
||||
/// each id, and an extra column at the end for the constant term.
|
||||
///
|
||||
/// Checking equality of two such functions is supported, as well as finding the
|
||||
/// value of the function at a specified point. Note that local ids in the
|
||||
/// domain are not yet supported for finding the value at a point.
|
||||
class MultiAffineFunction : protected IntegerPolyhedron {
|
||||
public:
|
||||
/// We use protected inheritance to avoid inheriting the whole public
|
||||
/// interface of IntegerPolyhedron. These using declarations explicitly make
|
||||
/// only the relevant functions part of the public interface.
|
||||
using IntegerPolyhedron::getNumDimAndSymbolIds;
|
||||
using IntegerPolyhedron::getNumDimIds;
|
||||
using IntegerPolyhedron::getNumIds;
|
||||
using IntegerPolyhedron::getNumLocalIds;
|
||||
using IntegerPolyhedron::getNumSymbolIds;
|
||||
|
||||
MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
|
||||
: IntegerPolyhedron(domain), output(output) {}
|
||||
MultiAffineFunction(const Matrix &output, unsigned numDims,
|
||||
unsigned numSymbols = 0, unsigned numLocals = 0)
|
||||
: IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {}
|
||||
|
||||
~MultiAffineFunction() override = default;
|
||||
Kind getKind() const override { return Kind::MultiAffineFunction; }
|
||||
bool classof(const IntegerPolyhedron *poly) const {
|
||||
return poly->getKind() == Kind::MultiAffineFunction;
|
||||
}
|
||||
|
||||
unsigned getNumInputs() const { return getNumDimAndSymbolIds(); }
|
||||
unsigned getNumOutputs() const { return output.getNumRows(); }
|
||||
bool isConsistent() const { return output.getNumColumns() == numIds + 1; }
|
||||
const IntegerPolyhedron &getDomain() const { return *this; }
|
||||
|
||||
bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
|
||||
|
||||
/// Insert `num` identifiers of the specified kind at position `pos`.
|
||||
/// Positions are relative to the kind of identifier. The coefficient columns
|
||||
/// corresponding to the added identifiers are initialized to zero. Return the
|
||||
/// absolute column position (i.e., not relative to the kind of identifier)
|
||||
/// of the first added identifier.
|
||||
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
|
||||
|
||||
/// Swap the posA^th identifier with the posB^th identifier.
|
||||
void swapId(unsigned posA, unsigned posB) override;
|
||||
|
||||
/// Remove the specified range of ids.
|
||||
void removeIdRange(unsigned idStart, unsigned idLimit) override;
|
||||
|
||||
/// Eliminate the `posB^th` local identifier, replacing every instance of it
|
||||
/// with the `posA^th` local identifier. This should be used when the two
|
||||
/// local variables are known to always take the same values.
|
||||
void eliminateRedundantLocalId(unsigned posA, unsigned posB) override;
|
||||
|
||||
/// Return whether the outputs of `this` and `other` agree wherever both
|
||||
/// functions are defined, i.e., the outputs should be equal for all points in
|
||||
/// the intersection of the domains.
|
||||
bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const;
|
||||
|
||||
/// Return whether the `this` and `other` are equal. This is the case if
|
||||
/// they lie in the same space, i.e. have the same dimensions, and their
|
||||
/// domains are identical and their outputs are equal on their domain.
|
||||
bool isEqual(const MultiAffineFunction &other) const;
|
||||
|
||||
/// Get the value of the function at the specified point. If the point lies
|
||||
/// outside the domain, an empty optional is returned.
|
||||
///
|
||||
/// Note: domains with local ids are not yet supported, and will assert-fail.
|
||||
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
void dump() const;
|
||||
|
||||
private:
|
||||
/// The function's output is a tuple of integers, with the ith element of the
|
||||
/// tuple defined by the affine expression given by the ith row of this output
|
||||
/// matrix.
|
||||
Matrix output;
|
||||
};
|
||||
|
||||
/// This class represents a piece-wise MultiAffineFunction. This can be thought
|
||||
/// of as a list of MultiAffineFunction with disjoint domains, with each having
|
||||
/// their own affine expressions for their output tuples. For example, we could
|
||||
/// have a function with two input variables (x, y), defined as
|
||||
///
|
||||
/// f(x, y) = (2*x + y, y - 4) if x >= 0, y >= 0
|
||||
/// = (-2*x + y, y + 4) if x < 0, y < 0
|
||||
/// = (4, 1) if x < 0, y >= 0
|
||||
///
|
||||
/// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of
|
||||
/// this class is undefined. The domains need not cover all possible points;
|
||||
/// this represents a partial function and so could be undefined at some points.
|
||||
///
|
||||
/// As in PresburgerSets, the input ids are partitioned into dimension ids and
|
||||
/// symbolic ids.
|
||||
///
|
||||
/// Support is provided to compare equality of two such functions as well as
|
||||
/// finding the value of the function at a point. Note that local ids in the
|
||||
/// piece are not supported for the latter.
|
||||
class PWMAFunction {
|
||||
public:
|
||||
PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
|
||||
: numDims(numDims), numSymbols(numSymbols), numOutputs(numOutputs) {
|
||||
assert(numOutputs >= 1 && "The function must output something!");
|
||||
}
|
||||
|
||||
void addPiece(const MultiAffineFunction &piece);
|
||||
void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
|
||||
|
||||
const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
|
||||
unsigned getNumPieces() const { return pieces.size(); }
|
||||
unsigned getNumOutputs() const { return numOutputs; }
|
||||
unsigned getNumInputs() const { return numDims + numSymbols; }
|
||||
unsigned getNumDimIds() const { return numDims; }
|
||||
unsigned getNumSymbolIds() const { return numSymbols; }
|
||||
MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
|
||||
|
||||
/// Return the domain of this piece-wise MultiAffineFunction. This is the
|
||||
/// union of the domains of all the pieces.
|
||||
PresburgerSet getDomain() const;
|
||||
|
||||
/// Check whether the `this` and the given function have compatible
|
||||
/// dimensions, i.e., the same number of dimension inputs, symbol inputs, and
|
||||
/// outputs.
|
||||
bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
|
||||
bool hasCompatibleDimensions(const PWMAFunction &f) const;
|
||||
|
||||
/// Return the value at the specified point and an empty optional if the
|
||||
/// point does not lie in the domain.
|
||||
///
|
||||
/// Note: domains with local ids are not yet supported, and will assert-fail.
|
||||
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
|
||||
|
||||
/// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
|
||||
/// they have the same dimensions, the same domain and they take the same
|
||||
/// value at every point in the domain.
|
||||
bool isEqual(const PWMAFunction &other) const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
private:
|
||||
/// The list of pieces in this piece-wise MultiAffineFunction.
|
||||
SmallVector<MultiAffineFunction, 4> pieces;
|
||||
|
||||
/// The number of dimensions ids in the domains.
|
||||
unsigned numDims;
|
||||
/// The number of symbol ids in the domains.
|
||||
unsigned numSymbols;
|
||||
/// The number of output ids.
|
||||
unsigned numOutputs;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
|
|
@ -3,6 +3,7 @@ add_mlir_library(MLIRPresburger
|
|||
LinearTransform.cpp
|
||||
Matrix.cpp
|
||||
PresburgerSet.cpp
|
||||
PWMAFunction.cpp
|
||||
Simplex.cpp
|
||||
Utils.cpp
|
||||
|
||||
|
|
|
@ -1065,24 +1065,17 @@ void IntegerPolyhedron::removeRedundantConstraints() {
|
|||
equalities.resizeVertically(pos);
|
||||
}
|
||||
|
||||
/// Eliminate `pos2^th` local identifier, replacing its every instance with
|
||||
/// `pos1^th` local identifier. This function is intended to be used to remove
|
||||
/// redundancy when local variables at position `pos1` and `pos2` are restricted
|
||||
/// to have the same value.
|
||||
static void eliminateRedundantLocalId(IntegerPolyhedron &poly, unsigned pos1,
|
||||
unsigned pos2) {
|
||||
void IntegerPolyhedron::eliminateRedundantLocalId(unsigned posA,
|
||||
unsigned posB) {
|
||||
assert(posA < getNumLocalIds() && "Invalid local id position");
|
||||
assert(posB < getNumLocalIds() && "Invalid local id position");
|
||||
|
||||
assert(pos1 < poly.getNumLocalIds() && "Invalid local id position");
|
||||
assert(pos2 < poly.getNumLocalIds() && "Invalid local id position");
|
||||
|
||||
unsigned localOffset = poly.getNumDimAndSymbolIds();
|
||||
pos1 += localOffset;
|
||||
pos2 += localOffset;
|
||||
for (unsigned i = 0, e = poly.getNumInequalities(); i < e; ++i)
|
||||
poly.atIneq(i, pos1) += poly.atIneq(i, pos2);
|
||||
for (unsigned i = 0, e = poly.getNumEqualities(); i < e; ++i)
|
||||
poly.atEq(i, pos1) += poly.atEq(i, pos2);
|
||||
poly.removeId(pos2);
|
||||
unsigned localOffset = getIdKindOffset(IdKind::Local);
|
||||
posA += localOffset;
|
||||
posB += localOffset;
|
||||
inequalities.addToColumn(posB, posA, 1);
|
||||
equalities.addToColumn(posB, posA, 1);
|
||||
removeId(posB);
|
||||
}
|
||||
|
||||
/// Adds additional local ids to the sets such that they both have the union
|
||||
|
@ -1129,8 +1122,8 @@ void IntegerPolyhedron::mergeLocalIds(IntegerPolyhedron &other) {
|
|||
// Merge function that merges the local variables in both sets by treating
|
||||
// them as the same identifier.
|
||||
auto merge = [&polyA, &polyB](unsigned i, unsigned j) -> bool {
|
||||
eliminateRedundantLocalId(polyA, i, j);
|
||||
eliminateRedundantLocalId(polyB, i, j);
|
||||
polyA.eliminateRedundantLocalId(i, j);
|
||||
polyB.eliminateRedundantLocalId(i, j);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/Presburger/PWMAFunction.h"
|
||||
#include "mlir/Analysis/Presburger/Simplex.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Return the result of subtracting the two given vectors pointwise.
|
||||
// The vectors must be of the same size.
|
||||
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
|
||||
static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
|
||||
ArrayRef<int64_t> vecB) {
|
||||
assert(vecA.size() == vecB.size() &&
|
||||
"Cannot subtract vectors of differing lengths!");
|
||||
SmallVector<int64_t, 8> result;
|
||||
result.reserve(vecA.size());
|
||||
for (unsigned i = 0, e = vecA.size(); i < e; ++i)
|
||||
result.push_back(vecA[i] - vecB[i]);
|
||||
return result;
|
||||
}
|
||||
|
||||
PresburgerSet PWMAFunction::getDomain() const {
|
||||
PresburgerSet domain =
|
||||
PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds());
|
||||
for (const MultiAffineFunction &piece : pieces)
|
||||
domain.unionPolyInPlace(piece.getDomain());
|
||||
return domain;
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 8>>
|
||||
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
|
||||
assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
|
||||
assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
|
||||
|
||||
if (!getDomain().containsPoint(point))
|
||||
return {};
|
||||
|
||||
// The point lies in the domain, so we need to compute the output value.
|
||||
// The matrix `output` has an affine expression in the ith row, corresponding
|
||||
// to the expression for the ith value in the output vector. The last column
|
||||
// of the matrix contains the constant term. Let v be the input point with
|
||||
// a 1 appended at the end. We can see that output * v gives the desired
|
||||
// output vector.
|
||||
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
|
||||
pointHomogenous.push_back(1);
|
||||
SmallVector<int64_t, 8> result =
|
||||
output.postMultiplyWithColumn(pointHomogenous);
|
||||
assert(result.size() == getNumOutputs());
|
||||
return result;
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 8>>
|
||||
PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
|
||||
assert(point.size() == getNumInputs() &&
|
||||
"Point has incorrect dimensionality!");
|
||||
for (const MultiAffineFunction &piece : pieces)
|
||||
if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
|
||||
return output;
|
||||
return {};
|
||||
}
|
||||
|
||||
void MultiAffineFunction::print(raw_ostream &os) const {
|
||||
os << "Domain:";
|
||||
IntegerPolyhedron::print(os);
|
||||
os << "Output:\n";
|
||||
output.print(os);
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
void MultiAffineFunction::dump() const { print(llvm::errs()); }
|
||||
|
||||
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
|
||||
return hasCompatibleDimensions(other) &&
|
||||
getDomain().isEqual(other.getDomain()) &&
|
||||
isEqualWhereDomainsOverlap(other);
|
||||
}
|
||||
|
||||
unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
|
||||
unsigned num) {
|
||||
unsigned absolutePos = getIdKindOffset(kind) + pos;
|
||||
output.insertColumns(absolutePos, num);
|
||||
return IntegerPolyhedron::insertId(kind, pos, num);
|
||||
}
|
||||
|
||||
void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
|
||||
output.swapColumns(posA, posB);
|
||||
IntegerPolyhedron::swapId(posA, posB);
|
||||
}
|
||||
|
||||
void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) {
|
||||
output.removeColumns(idStart, idLimit - idStart);
|
||||
IntegerPolyhedron::removeIdRange(idStart, idLimit);
|
||||
}
|
||||
|
||||
void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
|
||||
unsigned posB) {
|
||||
output.addToColumn(posB, posA, /*scale=*/1);
|
||||
IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
|
||||
}
|
||||
|
||||
bool MultiAffineFunction::isEqualWhereDomainsOverlap(
|
||||
MultiAffineFunction other) const {
|
||||
if (!hasCompatibleDimensions(other))
|
||||
return false;
|
||||
|
||||
// `commonFunc` has the same output as `this`.
|
||||
MultiAffineFunction commonFunc = *this;
|
||||
// After this merge, `commonFunc` and `other` have the same local ids; they
|
||||
// are merged.
|
||||
commonFunc.mergeLocalIds(other);
|
||||
// After this, the domain of `commonFunc` will be the intersection of the
|
||||
// domains of `this` and `other`.
|
||||
commonFunc.IntegerPolyhedron::append(other);
|
||||
|
||||
// `commonDomainMatching` contains the subset of the common domain
|
||||
// where the outputs of `this` and `other` match.
|
||||
//
|
||||
// We want to add constraints equating the outputs of `this` and `other`.
|
||||
// However, `this` may have difference local ids from `other`, whereas we
|
||||
// need both to have the same locals. Accordingly, we use `commonFunc.output`
|
||||
// in place of `this->output`, since `commonFunc` has the same output but also
|
||||
// has its locals merged.
|
||||
IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
|
||||
for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
|
||||
commonDomainMatching.addEquality(
|
||||
subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
|
||||
|
||||
// If the whole common domain is a subset of commonDomainMatching, then they
|
||||
// are equal and the two functions match on the whole common domain.
|
||||
return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
|
||||
}
|
||||
|
||||
/// Two PWMAFunctions are equal if they have the same dimensionalities,
|
||||
/// the same domain, and take the same value at every point in the domain.
|
||||
bool PWMAFunction::isEqual(const PWMAFunction &other) const {
|
||||
if (!hasCompatibleDimensions(other))
|
||||
return false;
|
||||
|
||||
if (!this->getDomain().isEqual(other.getDomain()))
|
||||
return false;
|
||||
|
||||
// Check if, whenever the domains of a piece of `this` and a piece of `other`
|
||||
// overlap, they take the same output value. If `this` and `other` have the
|
||||
// same domain (checked above), then this check passes iff the two functions
|
||||
// have the same output at every point in the domain.
|
||||
for (const MultiAffineFunction &aPiece : this->pieces)
|
||||
for (const MultiAffineFunction &bPiece : other.pieces)
|
||||
if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
|
||||
assert(hasCompatibleDimensions(piece) &&
|
||||
"Piece to be added is not compatible with this PWMAFunction!");
|
||||
assert(piece.isConsistent() && "Piece is internally inconsistent!");
|
||||
assert(this->getDomain()
|
||||
.intersect(PresburgerSet(piece.getDomain()))
|
||||
.isIntegerEmpty() &&
|
||||
"New piece's domain overlaps with that of existing pieces!");
|
||||
pieces.push_back(piece);
|
||||
}
|
||||
|
||||
void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
|
||||
const Matrix &output) {
|
||||
addPiece(MultiAffineFunction(domain, output));
|
||||
}
|
||||
|
||||
void PWMAFunction::print(raw_ostream &os) const {
|
||||
os << pieces.size() << " pieces:\n";
|
||||
for (const MultiAffineFunction &piece : pieces)
|
||||
piece.print(os);
|
||||
}
|
||||
|
||||
/// The hasCompatibleDimensions functions don't check the number of local ids;
|
||||
/// functions are still compatible if they have differing number of locals.
|
||||
bool MultiAffineFunction::hasCompatibleDimensions(
|
||||
const MultiAffineFunction &f) const {
|
||||
return getNumDimIds() == f.getNumDimIds() &&
|
||||
getNumSymbolIds() == f.getNumSymbolIds() &&
|
||||
getNumOutputs() == f.getNumOutputs();
|
||||
}
|
||||
bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const {
|
||||
return getNumDimIds() == f.getNumDimIds() &&
|
||||
getNumSymbolIds() == f.getNumSymbolIds() &&
|
||||
getNumOutputs() == f.getNumOutputs();
|
||||
}
|
||||
bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const {
|
||||
return getNumDimIds() == f.getNumDimIds() &&
|
||||
getNumSymbolIds() == f.getNumSymbolIds() &&
|
||||
getNumOutputs() == f.getNumOutputs();
|
||||
}
|
|
@ -3,6 +3,7 @@ add_mlir_unittest(MLIRPresburgerTests
|
|||
LinearTransformTest.cpp
|
||||
MatrixTest.cpp
|
||||
PresburgerSetTest.cpp
|
||||
PWMAFunctionTest.cpp
|
||||
SimplexTest.cpp
|
||||
../../Dialect/Affine/Analysis/AffineStructuresParser.cpp
|
||||
)
|
||||
|
|
|
@ -0,0 +1,183 @@
|
|||
//===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains tests for PWMAFunction.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/Presburger/PWMAFunction.h"
|
||||
#include "../../Dialect/Affine/Analysis/AffineStructuresParser.h"
|
||||
#include "mlir/Analysis/Presburger/PresburgerSet.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace mlir {
|
||||
using testing::ElementsAre;
|
||||
|
||||
/// Parses an IntegerPolyhedron from a StringRef. It is expected that the
|
||||
/// string represents a valid IntegerSet, otherwise it will violate a gtest
|
||||
/// assertion.
|
||||
static IntegerPolyhedron parsePoly(StringRef str, MLIRContext *context) {
|
||||
FailureOr<IntegerPolyhedron> poly = parseIntegerSetToFAC(str, context);
|
||||
EXPECT_TRUE(succeeded(poly));
|
||||
return *poly;
|
||||
}
|
||||
|
||||
static Matrix makeMatrix(unsigned numRow, unsigned numColumns,
|
||||
ArrayRef<SmallVector<int64_t, 8>> matrix) {
|
||||
Matrix results(numRow, numColumns);
|
||||
assert(matrix.size() == numRow);
|
||||
for (unsigned i = 0; i < numRow; ++i) {
|
||||
assert(matrix[i].size() == numColumns &&
|
||||
"Output expression has incorrect dimensionality!");
|
||||
for (unsigned j = 0; j < numColumns; ++j)
|
||||
results(i, j) = matrix[i][j];
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/// Construct a PWMAFunction given the dimensionalities and an array describing
|
||||
/// the list of pieces. Each piece is given by a string describing the domain
|
||||
/// and a 2D array that represents the output.
|
||||
static PWMAFunction parsePWMAF(
|
||||
unsigned numInputs, unsigned numOutputs,
|
||||
ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
|
||||
data,
|
||||
unsigned numSymbols = 0) {
|
||||
static MLIRContext context;
|
||||
|
||||
PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs);
|
||||
for (const auto &pair : data) {
|
||||
IntegerPolyhedron domain = parsePoly(pair.first, &context);
|
||||
result.addPiece(
|
||||
domain, makeMatrix(numOutputs, domain.getNumIds() + 1, pair.second));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TEST(PWAFunctionTest, isEqual) {
|
||||
MLIRContext context;
|
||||
|
||||
// The output expressions are different but it doesn't matter because they are
|
||||
// equal in this domain.
|
||||
PWMAFunction idAtZeros = parsePWMAF(
|
||||
/*numInputs=*/2, /*numOutputs=*/2,
|
||||
{
|
||||
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||
{"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||
{"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
|
||||
});
|
||||
PWMAFunction idAtZeros2 = parsePWMAF(
|
||||
/*numInputs=*/2, /*numOutputs=*/2,
|
||||
{
|
||||
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y).
|
||||
{"(x, y) : (y - 1 >= 0, x == 0)", {{30, 0, 0}, {0, 1, 0}}}, //(30x, y)
|
||||
{"(x, y) : (-y - 1 > =0, x == 0)", {{30, 0, 0}, {0, 1, 0}}} //(30x, y)
|
||||
});
|
||||
EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
|
||||
|
||||
PWMAFunction notIdAtZeros = parsePWMAF(
|
||||
/*numInputs=*/2, /*numOutputs=*/2,
|
||||
{
|
||||
{"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||
{"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
|
||||
{"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
|
||||
});
|
||||
EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
|
||||
|
||||
// These match at their intersection but one has a bigger domain.
|
||||
PWMAFunction idNoNegNegQuadrant = parsePWMAF(
|
||||
/*numInputs=*/2, /*numOutputs=*/2,
|
||||
{
|
||||
{"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||
{"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
|
||||
});
|
||||
PWMAFunction idOnlyPosX =
|
||||
parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
|
||||
{
|
||||
{"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
|
||||
});
|
||||
EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
|
||||
|
||||
// Different representations of the same domain.
|
||||
PWMAFunction sumPlusOne = parsePWMAF(
|
||||
/*numInputs=*/2, /*numOutputs=*/1,
|
||||
{
|
||||
{"(x, y) : (x >= 0)", {{1, 1, 1}}}, // x + y + 1.
|
||||
{"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1.
|
||||
{"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 1, 1}}} // x + y + 1.
|
||||
});
|
||||
PWMAFunction sumPlusOne2 =
|
||||
parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
|
||||
{
|
||||
{"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1.
|
||||
});
|
||||
EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
|
||||
|
||||
// Functions with zero input dimensions.
|
||||
PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
|
||||
{
|
||||
{"() : ()", {{1}}}, // 1.
|
||||
});
|
||||
PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
|
||||
{
|
||||
{"() : ()", {{2}}}, // 1.
|
||||
});
|
||||
EXPECT_TRUE(noInputs1.isEqual(noInputs1));
|
||||
EXPECT_FALSE(noInputs1.isEqual(noInputs2));
|
||||
|
||||
// Mismatched dimensionalities.
|
||||
EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
|
||||
EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
|
||||
|
||||
// Divisions.
|
||||
// Domain is only multiples of 6; x = 6k for some k.
|
||||
// x + 4(x/2) + 4(x/3) == 26k.
|
||||
PWMAFunction mul2AndMul3 = parsePWMAF(
|
||||
/*numInputs=*/1, /*numOutputs=*/1,
|
||||
{
|
||||
{"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
|
||||
{{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3).
|
||||
});
|
||||
PWMAFunction mul6 = parsePWMAF(
|
||||
/*numInputs=*/1, /*numOutputs=*/1,
|
||||
{
|
||||
{"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6).
|
||||
});
|
||||
EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
|
||||
|
||||
PWMAFunction mul6diff = parsePWMAF(
|
||||
/*numInputs=*/1, /*numOutputs=*/1,
|
||||
{
|
||||
{"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6).
|
||||
});
|
||||
EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff));
|
||||
|
||||
PWMAFunction mul5 = parsePWMAF(
|
||||
/*numInputs=*/1, /*numOutputs=*/1,
|
||||
{
|
||||
{"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5).
|
||||
});
|
||||
EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
|
||||
}
|
||||
|
||||
TEST(PWMAFunction, valueAt) {
|
||||
PWMAFunction nonNegPWAF = parsePWMAF(
|
||||
/*numInputs=*/2, /*numOutputs=*/2,
|
||||
{
|
||||
{"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
|
||||
{"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
|
||||
});
|
||||
EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
|
||||
EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
|
||||
EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
|
||||
EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
|
||||
}
|
||||
|
||||
} // namespace mlir
|
Loading…
Reference in New Issue