forked from OSchip/llvm-project
Adds Gaussian Elimination to FlatAffineConstraints.
- Adds FlatAffineConstraints::isEmpty method to test if there are no solutions to the system. - Adds GCD test check if equality constraints have no solution. - Adds unit test cases. PiperOrigin-RevId: 218546319
This commit is contained in:
parent
52a0e58bdb
commit
5413239350
|
@ -49,6 +49,14 @@ void getReachableAffineApplyOps(
|
|||
llvm::ArrayRef<MLValue *> operands,
|
||||
llvm::SmallVectorImpl<OperationStmt *> &affineApplyOps);
|
||||
|
||||
/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false
|
||||
/// if 'expr' was unable to be flattened (i.e. because it was not pure affine,
|
||||
/// or because it contained mod's and div's that could not be eliminated
|
||||
/// without introducing local variables).
|
||||
bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
||||
unsigned numSymbols,
|
||||
llvm::SmallVectorImpl<int64_t> *flattenedExpr);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
||||
|
|
|
@ -248,6 +248,9 @@ public:
|
|||
explicit FlatAffineConstraints(const AffineValueMap &avm);
|
||||
explicit FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef);
|
||||
|
||||
/// Creates an affine constraint system from an IntegerSet.
|
||||
explicit FlatAffineConstraints(IntegerSet set);
|
||||
|
||||
/// Create an affine constraint system from an IntegerValueSet.
|
||||
// TODO(bondhugula)
|
||||
explicit FlatAffineConstraints(const IntegerValueSet &set);
|
||||
|
@ -259,6 +262,24 @@ public:
|
|||
|
||||
~FlatAffineConstraints() {}
|
||||
|
||||
// Checks for emptiness by performing variable elimination on all identifiers,
|
||||
// running the GCD test on each equality constraint, and checking for invalid
|
||||
// constraints.
|
||||
// Returns true if the GCD test fails for any equality, or if any invalid
|
||||
// constraints are discovered on any row. Returns false otherwise.
|
||||
// TODO(andydavis) Change this method to operate on cloned constraints.
|
||||
bool isEmpty();
|
||||
|
||||
// Eliminates a single identifier at 'position' from equality and inequality
|
||||
// constraints. Returns 'true' if the identifier was eliminated.
|
||||
// Returns 'false' otherwise.
|
||||
bool eliminateIdentifier(unsigned position);
|
||||
|
||||
// Eliminates identifiers from equality and inequality constraints
|
||||
// in column range [posStart, posLimit).
|
||||
// Returns the number of variables eliminated.
|
||||
unsigned eliminateIdentifiers(unsigned posStart, unsigned posLimit);
|
||||
|
||||
inline int64_t atEq(unsigned i, unsigned j) const {
|
||||
return equalities[i * (numIds + 1) + j];
|
||||
}
|
||||
|
@ -267,6 +288,14 @@ public:
|
|||
return equalities[i * (numIds + 1) + j];
|
||||
}
|
||||
|
||||
inline int64_t atEqIdx(unsigned linearIndex) const {
|
||||
return equalities[linearIndex];
|
||||
}
|
||||
|
||||
inline int64_t &atEqIdx(unsigned linearIndex) {
|
||||
return equalities[linearIndex];
|
||||
}
|
||||
|
||||
inline int64_t atIneq(unsigned i, unsigned j) const {
|
||||
return inequalities[i * (numIds + 1) + j];
|
||||
}
|
||||
|
@ -275,6 +304,14 @@ public:
|
|||
return inequalities[i * (numIds + 1) + j];
|
||||
}
|
||||
|
||||
inline int64_t atIneqIdx(unsigned linearIndex) const {
|
||||
return inequalities[linearIndex];
|
||||
}
|
||||
|
||||
inline int64_t &atIneqIdx(unsigned linearIndex) {
|
||||
return inequalities[linearIndex];
|
||||
}
|
||||
|
||||
inline unsigned getNumCols() const { return numIds + 1; }
|
||||
|
||||
inline unsigned getNumEqualities() const {
|
||||
|
@ -323,6 +360,11 @@ public:
|
|||
void dump() const;
|
||||
|
||||
private:
|
||||
// Removes coefficients in column range [colStart, colLimit),and copies any
|
||||
// remaining valid data into place, updates member variables, and resizes
|
||||
// arrays as needed.
|
||||
void removeColumnRange(unsigned colStart, unsigned colLimit);
|
||||
|
||||
/// Coefficients of affine equalities (in == 0 form).
|
||||
SmallVector<int64_t, 64> equalities;
|
||||
|
||||
|
|
|
@ -59,6 +59,14 @@ public:
|
|||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> eqFlags, MLIRContext *context);
|
||||
|
||||
// Returns a canonical empty IntegerSet (i.e. a set with no integer points).
|
||||
static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols,
|
||||
MLIRContext *context) {
|
||||
auto one = getAffineConstantExpr(1, context);
|
||||
/* 1 == 0 */
|
||||
return get(numDims, numSymbols, one, true, context);
|
||||
}
|
||||
|
||||
explicit operator bool() { return set; }
|
||||
bool operator==(IntegerSet other) const { return set == other.set; }
|
||||
|
||||
|
@ -66,6 +74,8 @@ public:
|
|||
unsigned getNumSymbols() const;
|
||||
unsigned getNumOperands() const;
|
||||
unsigned getNumConstraints() const;
|
||||
unsigned getNumEqualities() const;
|
||||
unsigned getNumInequalities() const;
|
||||
|
||||
ArrayRef<AffineExpr> getConstraints() const;
|
||||
|
||||
|
@ -79,6 +89,8 @@ public:
|
|||
/// inequality.
|
||||
bool isEq(unsigned idx) const;
|
||||
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
|
|
|
@ -465,6 +465,10 @@ public:
|
|||
const AffineCondition getCondition() const;
|
||||
|
||||
IntegerSet getIntegerSet() const { return set; }
|
||||
void setIntegerSet(IntegerSet newSet) {
|
||||
assert(newSet.getNumOperands() == operands.size());
|
||||
set = newSet;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operands
|
||||
|
|
|
@ -51,6 +51,14 @@ inline int64_t mod(int64_t lhs, int64_t rhs) {
|
|||
return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
|
||||
}
|
||||
|
||||
/// Returns the least common multiple of 'a' and 'b'.
|
||||
inline int64_t lcm(int64_t a, int64_t b) {
|
||||
uint64_t x = std::abs(a);
|
||||
uint64_t y = std::abs(b);
|
||||
int64_t lcm = (x * y) / llvm::GreatestCommonDivisor64(x, y);
|
||||
assert((lcm >= a && lcm >= b) && "LCM overflow");
|
||||
return lcm;
|
||||
}
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_SUPPORT_MATHEXTRAS_H_
|
||||
|
|
|
@ -298,6 +298,30 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
|
|||
return simplifiedExpr;
|
||||
}
|
||||
|
||||
// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false
|
||||
// if 'expr' was unable to be flattened (i.e. because it was not pur affine,
|
||||
// or because it contained mod's and div's that could not be eliminated
|
||||
// without introducing local variables).
|
||||
bool mlir::getFlattenedAffineExpr(
|
||||
AffineExpr expr, unsigned numDims, unsigned numSymbols,
|
||||
llvm::SmallVectorImpl<int64_t> *flattenedExpr) {
|
||||
// TODO(bondhugula): only pure affine for now. The simplification here can be
|
||||
// extended to semi-affine maps in the future.
|
||||
if (!expr.isPureAffine())
|
||||
return false;
|
||||
|
||||
AffineExprFlattener flattener(numDims, numSymbols, expr.getContext());
|
||||
flattener.walkPostOrder(expr);
|
||||
// TODO(andydavis) Support local exprs.
|
||||
if (flattener.numLocals > 0) {
|
||||
return false;
|
||||
}
|
||||
for (auto v : flattener.operandExprStack.back()) {
|
||||
flattenedExpr->push_back(v);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns the sequence of AffineApplyOp OperationStmts operation in
|
||||
/// 'affineApplyOps', which are reachable via a search starting from 'operands',
|
||||
/// and ending at operands which are not defined by AffineApplyOps.
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -438,6 +439,286 @@ AffineMap AffineValueMap::getAffineMap() { return map.getAffineMap(); }
|
|||
|
||||
AffineValueMap::~AffineValueMap() {}
|
||||
|
||||
FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
|
||||
: numReservedEqualities(0), numReservedInequalities(0), numReservedIds(0),
|
||||
numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
|
||||
numSymbols(set.getNumSymbols()) {
|
||||
unsigned numConstraints = set.getNumConstraints();
|
||||
for (unsigned i = 0; i < numConstraints; ++i) {
|
||||
AffineExpr expr = set.getConstraint(i);
|
||||
SmallVector<int64_t, 4> flattenedExpr;
|
||||
getFlattenedAffineExpr(expr, set.getNumDims(), set.getNumSymbols(),
|
||||
&flattenedExpr);
|
||||
assert(flattenedExpr.size() == getNumCols());
|
||||
if (set.getEqFlags()[i]) {
|
||||
addEquality(flattenedExpr);
|
||||
} else {
|
||||
addInequality(flattenedExpr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Searches for a constraint with a non-zero coefficient at 'colIdx' in
|
||||
// equality (isEq=true) or inequality (isEq=false) constraints.
|
||||
// Returns true and sets row found in search in 'rowIdx'.
|
||||
// Returns false otherwise.
|
||||
static bool
|
||||
findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints,
|
||||
unsigned colIdx, bool isEq, unsigned &rowIdx) {
|
||||
auto at = [&](unsigned rowIdx) -> int64_t {
|
||||
return isEq ? constraints.atEq(rowIdx, colIdx)
|
||||
: constraints.atIneq(rowIdx, colIdx);
|
||||
};
|
||||
unsigned e =
|
||||
isEq ? constraints.getNumEqualities() : constraints.getNumInequalities();
|
||||
for (rowIdx = 0; rowIdx < e; ++rowIdx) {
|
||||
if (at(rowIdx) != 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Normalizes the coefficient values across all columns in 'rowIDx' by their
|
||||
// GCD in equality or inequality contraints as specified by 'isEq'.
|
||||
static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
|
||||
unsigned rowIdx, bool isEq) {
|
||||
auto at = [&](unsigned colIdx) -> int64_t {
|
||||
return isEq ? constraints->atEq(rowIdx, colIdx)
|
||||
: constraints->atIneq(rowIdx, colIdx);
|
||||
};
|
||||
uint64_t gcd = std::abs(at(0));
|
||||
for (unsigned j = 1; j < constraints->getNumCols(); ++j) {
|
||||
gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
|
||||
}
|
||||
if (gcd > 0 && gcd != 1) {
|
||||
for (unsigned j = 0; j < constraints->getNumCols(); ++j) {
|
||||
int64_t v = at(j) / static_cast<int64_t>(gcd);
|
||||
isEq ? constraints->atEq(rowIdx, j) = v
|
||||
: constraints->atIneq(rowIdx, j) = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Runs the GCD test on all equality constraints. Returns 'true' if this test
|
||||
// fails on any equality. Returns 'false' otherwise.
|
||||
// This test can be used to disprove the existence of a solution. If it returns
|
||||
// true, no integer solution to the equality constraints can exist.
|
||||
//
|
||||
// GCD test definition:
|
||||
//
|
||||
// The equality constraint:
|
||||
//
|
||||
// c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
|
||||
//
|
||||
// has an integer solution iff:
|
||||
//
|
||||
// GCD of c_1, c_2, ..., c_n divides c_0.
|
||||
//
|
||||
static bool isEmptyByGCDTest(const FlatAffineConstraints &constraints) {
|
||||
unsigned numCols = constraints.getNumCols();
|
||||
for (unsigned i = 0, e = constraints.getNumEqualities(); i < e; ++i) {
|
||||
uint64_t gcd = std::abs(constraints.atEq(i, 0));
|
||||
for (unsigned j = 1; j < numCols - 1; ++j) {
|
||||
gcd =
|
||||
llvm::GreatestCommonDivisor64(gcd, std::abs(constraints.atEq(i, j)));
|
||||
}
|
||||
int64_t v = std::abs(constraints.atEq(i, numCols - 1));
|
||||
if (gcd > 0 && (v % gcd != 0)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Checks all rows of equality/inequality constraints for contradictions
|
||||
// (i.e. 1 == 0), which may have surfaced after elimination.
|
||||
// Returns 'true' if a valid constraint is detected. Returns 'false' otherwise.
|
||||
static bool hasInvalidConstraint(const FlatAffineConstraints &constraints) {
|
||||
auto check = [constraints](bool isEq) -> bool {
|
||||
unsigned numCols = constraints.getNumCols();
|
||||
unsigned numRows = isEq ? constraints.getNumEqualities()
|
||||
: constraints.getNumInequalities();
|
||||
for (unsigned i = 0, e = numRows; i < e; ++i) {
|
||||
unsigned j;
|
||||
for (j = 0; j < numCols - 1; ++j) {
|
||||
int64_t v = isEq ? constraints.atEq(i, j) : constraints.atIneq(i, j);
|
||||
// Skip rows with non-zero variable coefficients.
|
||||
if (v != 0)
|
||||
break;
|
||||
}
|
||||
if (j < numCols - 1) {
|
||||
continue;
|
||||
}
|
||||
// Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
|
||||
// Example invalid constraints include: '1 == 0' or '-1 >= 0'
|
||||
int64_t v = isEq ? constraints.atEq(i, numCols - 1)
|
||||
: constraints.atIneq(i, numCols - 1);
|
||||
if ((isEq && v != 0) || (!isEq && v < 0)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
if (check(/*isEq=*/true))
|
||||
return true;
|
||||
return check(/*isEq=*/false);
|
||||
}
|
||||
|
||||
// Eliminate identifier from constraint at 'rowIdx' based on coefficient at
|
||||
// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
|
||||
// updated as they have already been eliminated.
|
||||
static void eliminateFromConstraint(FlatAffineConstraints *constraints,
|
||||
unsigned rowIdx, unsigned pivotRow,
|
||||
unsigned pivotCol, unsigned elimColStart,
|
||||
bool isEq) {
|
||||
// Skip if equality 'rowIdx' if same as 'pivotRow'.
|
||||
if (isEq && rowIdx == pivotRow)
|
||||
return;
|
||||
auto at = [&](unsigned i, unsigned j) -> int64_t {
|
||||
return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
|
||||
};
|
||||
int64_t leadCoeff = at(rowIdx, pivotCol);
|
||||
// Skip if leading coefficient at 'rowIdx' is already zero.
|
||||
if (leadCoeff == 0)
|
||||
return;
|
||||
int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
|
||||
int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
|
||||
int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
|
||||
int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
|
||||
int64_t rowMultiplier = lcm / std::abs(leadCoeff);
|
||||
|
||||
unsigned numCols = constraints->getNumCols();
|
||||
for (unsigned j = 0; j < numCols; ++j) {
|
||||
// Skip updating column 'j' if it was just eliminated.
|
||||
if (j >= elimColStart && j < pivotCol)
|
||||
continue;
|
||||
int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
|
||||
rowMultiplier * at(rowIdx, j);
|
||||
isEq ? constraints->atEq(rowIdx, j) = v
|
||||
: constraints->atIneq(rowIdx, j) = v;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove coefficients in column range [colStart, colLimit) in place.
|
||||
// This removes in data in the specified column range, and copies any
|
||||
// remaining valid data into place.
|
||||
static void removeColumns(FlatAffineConstraints *constraints, unsigned colStart,
|
||||
unsigned colLimit, bool isEq) {
|
||||
unsigned numCols = constraints->getNumCols();
|
||||
unsigned newNumCols = numCols - (colLimit - colStart);
|
||||
unsigned numRows = isEq ? constraints->getNumEqualities()
|
||||
: constraints->getNumInequalities();
|
||||
for (unsigned i = 0, e = numRows; i < e; ++i) {
|
||||
for (unsigned j = 0; j < numCols; ++j) {
|
||||
if (j >= colStart && j < colLimit)
|
||||
continue;
|
||||
unsigned inputIndex = i * numCols + j;
|
||||
unsigned outputOffset = j >= colLimit ? j - (colLimit - colStart) : j;
|
||||
unsigned outputIndex = i * newNumCols + outputOffset;
|
||||
assert(outputIndex <= inputIndex);
|
||||
if (isEq) {
|
||||
constraints->atEqIdx(outputIndex) = constraints->atEqIdx(inputIndex);
|
||||
} else {
|
||||
constraints->atIneqIdx(outputIndex) =
|
||||
constraints->atIneqIdx(inputIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Removes coefficients in column range [colStart, colLimit),and copies any
|
||||
// remaining valid data into place, updates member variables, and resizes
|
||||
// arrays as needed.
|
||||
void FlatAffineConstraints::removeColumnRange(unsigned colStart,
|
||||
unsigned colLimit) {
|
||||
// TODO(andydavis) Make 'removeColumns' a lambda called from here.
|
||||
// Remove eliminated columns from equalities.
|
||||
removeColumns(this, colStart, colLimit, /*isEq=*/true);
|
||||
// Remove eliminated columns from inequalities.
|
||||
removeColumns(this, colStart, colLimit, /*isEq=*/false);
|
||||
// Update members numDims, numSymbols and numIds.
|
||||
unsigned numDimsEliminated = 0;
|
||||
if (colStart < numDims) {
|
||||
numDimsEliminated = std::min(numDims, colLimit) - colStart;
|
||||
}
|
||||
unsigned numEqualities = getNumEqualities();
|
||||
unsigned numInequalities = getNumInequalities();
|
||||
unsigned numColsEliminated = colLimit - colStart;
|
||||
unsigned numSymbolsEliminated =
|
||||
std::min(numSymbols, numColsEliminated - numDimsEliminated);
|
||||
numDims -= numDimsEliminated;
|
||||
numSymbols -= numSymbolsEliminated;
|
||||
numIds = numIds - numColsEliminated;
|
||||
equalities.resize(numEqualities * getNumCols());
|
||||
inequalities.resize(numInequalities * getNumCols());
|
||||
}
|
||||
|
||||
// Performs variable elimination on all identifiers, runs the GCD test on
|
||||
// all equality constraint rows, and checks the constraint validity.
|
||||
// Returns 'true' if the GCD test fails on any row, or if any invalid
|
||||
// constraint is detected. Returns 'false' otherwise.
|
||||
bool FlatAffineConstraints::isEmpty() {
|
||||
if (eliminateIdentifiers(0, numIds) == 0)
|
||||
return false;
|
||||
if (isEmptyByGCDTest(*this))
|
||||
return true;
|
||||
if (hasInvalidConstraint(*this))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Eliminates a single identifier at 'position' from equality and inequality
|
||||
// constraints. Returns 'true' if the identifier was eliminated.
|
||||
// Returns 'false' otherwise.
|
||||
bool FlatAffineConstraints::eliminateIdentifier(unsigned position) {
|
||||
return eliminateIdentifiers(position, position + 1) == 1;
|
||||
}
|
||||
|
||||
// Eliminates all identifer variables in column range [posStart, posLimit).
|
||||
// Returns the number of variables eliminated.
|
||||
unsigned FlatAffineConstraints::eliminateIdentifiers(unsigned posStart,
|
||||
unsigned posLimit) {
|
||||
// Return if identifier positions to eliminate are out of range.
|
||||
if (posStart >= posLimit || posLimit > numIds)
|
||||
return 0;
|
||||
unsigned pivotCol = 0;
|
||||
for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
|
||||
// Find a row which has a non-zero coefficient in column 'j'.
|
||||
unsigned pivotRow;
|
||||
if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
|
||||
pivotRow)) {
|
||||
// No pivot row in equalities with non-zero at 'pivotCol'.
|
||||
if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
|
||||
pivotRow)) {
|
||||
// If inequalities are also non-zero in 'pivotCol' it can be eliminated.
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Eliminate identifier at 'pivotCol' from each equality row.
|
||||
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
|
||||
eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
|
||||
/*isEq=*/true);
|
||||
normalizeConstraintByGCD(this, i, /*isEq=*/true);
|
||||
}
|
||||
|
||||
// Eliminate identifier at 'pivotCol' from each inequality row.
|
||||
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
|
||||
eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
|
||||
/*isEq=*/false);
|
||||
normalizeConstraintByGCD(this, i, /*isEq=*/false);
|
||||
}
|
||||
removeEquality(pivotRow);
|
||||
}
|
||||
// Update position limit based on number eliminated.
|
||||
posLimit = pivotCol;
|
||||
// Remove eliminated columns from all constraints.
|
||||
removeColumnRange(posStart, posLimit);
|
||||
return posLimit - posStart;
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
|
||||
assert(eq.size() == getNumCols());
|
||||
unsigned offset = equalities.size();
|
||||
|
@ -446,3 +727,44 @@ void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
|
|||
equalities[offset + i] = eq[i];
|
||||
}
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::removeEquality(unsigned pos) {
|
||||
unsigned numEqualities = getNumEqualities();
|
||||
assert(pos < numEqualities);
|
||||
unsigned numCols = getNumCols();
|
||||
unsigned outputIndex = pos * numCols;
|
||||
unsigned inputIndex = (pos + 1) * numCols;
|
||||
unsigned numElemsToCopy = (numEqualities - pos - 1) * numCols;
|
||||
for (unsigned i = 0; i < numElemsToCopy; ++i) {
|
||||
equalities[outputIndex + i] = equalities[inputIndex + i];
|
||||
}
|
||||
equalities.resize(equalities.size() - numCols);
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
|
||||
assert(inEq.size() == getNumCols());
|
||||
unsigned offset = inequalities.size();
|
||||
inequalities.resize(inequalities.size() + inEq.size());
|
||||
for (unsigned i = 0, e = inEq.size(); i < e; i++) {
|
||||
inequalities[offset + i] = inEq[i];
|
||||
}
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::print(raw_ostream &os) const {
|
||||
os << "\nConstraints:\n";
|
||||
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
|
||||
for (unsigned j = 0; j < getNumCols(); ++j) {
|
||||
os << atEq(i, j) << " ";
|
||||
}
|
||||
os << "= 0\n";
|
||||
}
|
||||
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
|
||||
for (unsigned j = 0; j < getNumCols(); ++j) {
|
||||
os << atIneq(i, j) << " ";
|
||||
}
|
||||
os << ">= 0\n";
|
||||
}
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::dump() const { print(llvm::errs()); }
|
||||
|
|
|
@ -29,6 +29,18 @@ unsigned IntegerSet::getNumOperands() const {
|
|||
}
|
||||
unsigned IntegerSet::getNumConstraints() const { return set->numConstraints; }
|
||||
|
||||
unsigned IntegerSet::getNumEqualities() const {
|
||||
unsigned numEqualities = 0;
|
||||
for (unsigned i = 0, e = getNumConstraints(); i < e; i++)
|
||||
if (isEq(i))
|
||||
++numEqualities;
|
||||
return numEqualities;
|
||||
}
|
||||
|
||||
unsigned IntegerSet::getNumInequalities() const {
|
||||
return getNumConstraints() - getNumEqualities();
|
||||
}
|
||||
|
||||
ArrayRef<AffineExpr> IntegerSet::getConstraints() const {
|
||||
return set->constraints;
|
||||
}
|
||||
|
@ -44,3 +56,7 @@ ArrayRef<bool> IntegerSet::getEqFlags() const { return set->eqFlags; }
|
|||
/// Returns true if the idx^th constraint is an equality, false if it is an
|
||||
/// inequality.
|
||||
bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; }
|
||||
|
||||
MLIRContext *IntegerSet::getContext() const {
|
||||
return getConstraint(0).getContext();
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/MLFunction.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
@ -36,32 +37,51 @@ namespace {
|
|||
/// the MLFunction. This is mainly to test the simplifyAffineExpr method.
|
||||
// TODO(someone): Gradually, extend this to all affine map references found in
|
||||
// ML functions and CFG functions.
|
||||
struct SimplifyAffineExpr : public FunctionPass {
|
||||
explicit SimplifyAffineExpr() {}
|
||||
struct SimplifyAffineStructures : public FunctionPass,
|
||||
StmtWalker<SimplifyAffineStructures> {
|
||||
explicit SimplifyAffineStructures() {}
|
||||
|
||||
PassResult runOnMLFunction(MLFunction *f);
|
||||
// Does nothing on CFG functions for now. No reusable walkers/visitors exist
|
||||
// for this yet? TODO(someone).
|
||||
PassResult runOnCFGFunction(CFGFunction *f) { return success(); }
|
||||
|
||||
void visitOperationStmt(OperationStmt *stmt);
|
||||
void visitIfStmt(IfStmt *ifStmt);
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
FunctionPass *mlir::createSimplifyAffineExprPass() {
|
||||
return new SimplifyAffineExpr();
|
||||
return new SimplifyAffineStructures();
|
||||
}
|
||||
|
||||
PassResult SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
|
||||
f->walkPostOrder([&](OperationStmt *opStmt) {
|
||||
for (auto attr : opStmt->getAttrs()) {
|
||||
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
|
||||
MutableAffineMap mMap(mapAttr->getValue());
|
||||
mMap.simplify();
|
||||
auto map = mMap.getAffineMap();
|
||||
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
|
||||
}
|
||||
}
|
||||
});
|
||||
static IntegerSet simplifyIntegerSet(IntegerSet set) {
|
||||
FlatAffineConstraints fac(set);
|
||||
if (fac.isEmpty())
|
||||
return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
|
||||
set.getContext());
|
||||
return set;
|
||||
}
|
||||
|
||||
void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
|
||||
auto set = ifStmt->getCondition().getSet();
|
||||
IntegerSet simplified = simplifyIntegerSet(set);
|
||||
ifStmt->setIntegerSet(simplified);
|
||||
}
|
||||
|
||||
void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) {
|
||||
for (auto attr : opStmt->getAttrs()) {
|
||||
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
|
||||
MutableAffineMap mMap(mapAttr->getValue());
|
||||
mMap.simplify();
|
||||
auto map = mMap.getAffineMap();
|
||||
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PassResult SimplifyAffineStructures::runOnMLFunction(MLFunction *f) {
|
||||
walk(f);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -22,6 +22,38 @@
|
|||
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 - (d0 floordiv 8) * 8, (d1 floordiv 8) * 8)
|
||||
#map6 = (d0, d1) -> (d0 mod 8, d1 - d1 mod 8)
|
||||
|
||||
// Set for test case: test_gaussian_elimination_empty_set0
|
||||
// CHECK: @@set0 = (d0, d1) : (1 == 0)
|
||||
@@set0 = (d0, d1) : (2 == 0)
|
||||
|
||||
// Set for test case: test_gaussian_elimination_empty_set1
|
||||
// CHECK: @@set1 = (d0, d1) : (1 == 0)
|
||||
@@set1 = (d0, d1) : (1 >= 0, -1 >= 0)
|
||||
|
||||
// Set for test case: test_gaussian_elimination_non_empty_set2
|
||||
// CHECK: @@set2 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, d0 * -1 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)
|
||||
@@set2 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)
|
||||
|
||||
// Set for test case: test_gaussian_elimination_empty_set3
|
||||
// CHECK: @@set3 = (d0, d1)[s0, s1] : (1 == 0)
|
||||
@@set3 = (d0, d1)[s0, s1] : (d0 - s0 == 0, d0 + s0 == 0, s0 - 1 == 0)
|
||||
|
||||
// Set for test case: test_gaussian_elimination_non_empty_set4
|
||||
// CHECK: @@set4 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)
|
||||
@@set4 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0,
|
||||
d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0,
|
||||
d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0,
|
||||
d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)
|
||||
|
||||
// Add invalide constraints to previous non-empty set to make it empty.
|
||||
// Set for test case: test_gaussian_elimination_empty_set5
|
||||
// CHECK: @@set5 = (d0, d1)[s0, s1] : (1 == 0)
|
||||
@@set5 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0,
|
||||
d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0,
|
||||
d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0,
|
||||
d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0,
|
||||
d0 - 1 == 0, d0 + 2 == 0)
|
||||
|
||||
mlfunc @test() {
|
||||
for %n0 = 0 to 127 {
|
||||
for %n1 = 0 to 7 {
|
||||
|
@ -37,3 +69,80 @@ mlfunc @test() {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @test_gaussian_elimination_empty_set0() {
|
||||
mlfunc @test_gaussian_elimination_empty_set0() {
|
||||
for %i0 = 1 to 10 {
|
||||
for %i1 = 1 to 100 {
|
||||
// CHECK: @@set0(%i0, %i1)
|
||||
if @@set0(%i0, %i1) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @test_gaussian_elimination_empty_set1() {
|
||||
mlfunc @test_gaussian_elimination_empty_set1() {
|
||||
for %i0 = 1 to 10 {
|
||||
for %i1 = 1 to 100 {
|
||||
// CHECK: @@set1(%i0, %i1)
|
||||
if @@set1(%i0, %i1) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @test_gaussian_elimination_non_empty_set2() {
|
||||
mlfunc @test_gaussian_elimination_non_empty_set2() {
|
||||
for %i0 = 1 to 10 {
|
||||
for %i1 = 1 to 100 {
|
||||
// CHECK: @@set2(%i0, %i1)
|
||||
if @@set2(%i0, %i1) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @test_gaussian_elimination_empty_set3() {
|
||||
mlfunc @test_gaussian_elimination_empty_set3() {
|
||||
%c7 = constant 7 : index
|
||||
%c11 = constant 11 : index
|
||||
for %i0 = 1 to 10 {
|
||||
for %i1 = 1 to 100 {
|
||||
// CHECK: @@set3(%i0, %i1)[%c7, %c11]
|
||||
if @@set3(%i0, %i1)[%c7, %c11] {
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @test_gaussian_elimination_non_empty_set4() {
|
||||
mlfunc @test_gaussian_elimination_non_empty_set4() {
|
||||
%c7 = constant 7 : index
|
||||
%c11 = constant 11 : index
|
||||
for %i0 = 1 to 10 {
|
||||
for %i1 = 1 to 100 {
|
||||
// CHECK: @@set4(%i0, %i1)[%c7, %c11]
|
||||
if @@set4(%i0, %i1)[%c7, %c11] {
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @test_gaussian_elimination_empty_set5() {
|
||||
mlfunc @test_gaussian_elimination_empty_set5() {
|
||||
%c7 = constant 7 : index
|
||||
%c11 = constant 11 : index
|
||||
for %i0 = 1 to 10 {
|
||||
for %i1 = 1 to 100 {
|
||||
// CHECK: @@set5(%i0, %i1)[%c7, %c11]
|
||||
if @@set5(%i0, %i1)[%c7, %c11] {
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue