[MLIR] AffineStructures::removeIdRange: support specifying a range within an IdKind

Reviewed By: Groverkss, grosser

Differential Revision: https://reviews.llvm.org/D109896
This commit is contained in:
Arjun P 2021-09-17 00:52:20 +05:30
parent f263ea1571
commit 6607bd9fd8
3 changed files with 67 additions and 14 deletions

View File

@ -282,6 +282,12 @@ public:
void projectOut(unsigned pos, unsigned num);
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
/// Removes identifiers of the specified kind with the specified pos (or
/// within the specified range) from the system. The specified location is
/// relative to the first identifier of the specified kind.
void removeId(IdKind kind, unsigned pos);
void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit);
/// Removes the specified identifier from the system.
void removeId(unsigned pos);
@ -423,6 +429,12 @@ public:
void dump() const;
protected:
/// Return the index at which the specified kind of id starts.
unsigned getIdKindOffset(IdKind kind) const;
/// Assert that `value` is at most the number of ids of the specified kind.
void assertAtMostNumIdKind(unsigned value, IdKind kind) const;
/// Returns false if the fields corresponding to various identifier counts, or
/// equality/inequality buffer sizes aren't consistent; true otherwise. This
/// is meant to be used within an assert internally.

View File

@ -311,23 +311,13 @@ unsigned FlatAffineConstraints::insertLocalId(unsigned pos, unsigned num) {
unsigned FlatAffineConstraints::insertId(IdKind kind, unsigned pos,
unsigned num) {
if (kind == IdKind::Dimension)
assert(pos <= getNumDimIds());
else if (kind == IdKind::Symbol)
assert(pos <= getNumSymbolIds());
else
assert(pos <= getNumLocalIds());
assertAtMostNumIdKind(pos, kind);
unsigned absolutePos;
if (kind == IdKind::Dimension) {
absolutePos = pos;
unsigned absolutePos = getIdKindOffset(kind) + pos;
if (kind == IdKind::Dimension)
numDims += num;
} else if (kind == IdKind::Symbol) {
absolutePos = pos + getNumDimIds();
else if (kind == IdKind::Symbol)
numSymbols += num;
} else {
absolutePos = pos + getNumDimIds() + getNumSymbolIds();
}
numIds += num;
inequalities.insertColumns(absolutePos, num);
@ -336,6 +326,28 @@ unsigned FlatAffineConstraints::insertId(IdKind kind, unsigned pos,
return absolutePos;
}
void FlatAffineConstraints::assertAtMostNumIdKind(unsigned val,
IdKind kind) const {
if (kind == IdKind::Dimension)
assert(val <= getNumDimIds());
else if (kind == IdKind::Symbol)
assert(val <= getNumSymbolIds());
else if (kind == IdKind::Local)
assert(val <= getNumLocalIds());
else
llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
}
unsigned FlatAffineConstraints::getIdKindOffset(IdKind kind) const {
if (kind == IdKind::Dimension)
return 0;
if (kind == IdKind::Symbol)
return getNumDimIds();
if (kind == IdKind::Local)
return getNumDimAndSymbolIds();
llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
}
unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
unsigned num) {
unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num);
@ -365,6 +377,17 @@ bool FlatAffineValueConstraints::hasValues() const {
}) != values.end();
}
void FlatAffineConstraints::removeId(IdKind kind, unsigned pos) {
removeIdRange(kind, pos, pos + 1);
}
void FlatAffineConstraints::removeIdRange(IdKind kind, unsigned idStart,
unsigned idLimit) {
assertAtMostNumIdKind(idLimit, kind);
removeIdRange(getIdKindOffset(kind) + idStart,
getIdKindOffset(kind) + idLimit);
}
/// Checks if two constraint systems are in the same space, i.e., if they are
/// associated with the same set of identifiers, appearing in the same order.
static bool areIdsAligned(const FlatAffineValueConstraints &a,

View File

@ -711,4 +711,22 @@ TEST(FlatAffineConstraintsTest, computeLocalReprRecursive) {
checkDivisionRepresentation(fac, divisions, denoms);
}
TEST(FlatAffineConstraintsTest, removeIdRange) {
FlatAffineConstraints fac(3, 2, 1);
fac.addInequality({10, 11, 12, 20, 21, 30, 40});
fac.removeId(FlatAffineConstraints::IdKind::Symbol, 1);
EXPECT_THAT(fac.getInequality(0),
testing::ElementsAre(10, 11, 12, 20, 30, 40));
fac.removeIdRange(FlatAffineConstraints::IdKind::Dimension, 0, 2);
EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 30, 40));
fac.removeIdRange(FlatAffineConstraints::IdKind::Local, 1, 1);
EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 30, 40));
fac.removeIdRange(FlatAffineConstraints::IdKind::Local, 0, 1);
EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 40));
}
} // namespace mlir