From 6607bd9fd819de1a5872dce47ce1a67bbb9a12e8 Mon Sep 17 00:00:00 2001 From: Arjun P Date: Fri, 17 Sep 2021 00:52:20 +0530 Subject: [PATCH] [MLIR] AffineStructures::removeIdRange: support specifying a range within an IdKind Reviewed By: Groverkss, grosser Differential Revision: https://reviews.llvm.org/D109896 --- mlir/include/mlir/Analysis/AffineStructures.h | 12 +++++ mlir/lib/Analysis/AffineStructures.cpp | 51 ++++++++++++++----- .../Analysis/AffineStructuresTest.cpp | 18 +++++++ 3 files changed, 67 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index c2dbc5c89da7..b5676186d87d 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -282,6 +282,12 @@ public: void projectOut(unsigned pos, unsigned num); inline void projectOut(unsigned pos) { return projectOut(pos, 1); } + /// Removes identifiers of the specified kind with the specified pos (or + /// within the specified range) from the system. The specified location is + /// relative to the first identifier of the specified kind. + void removeId(IdKind kind, unsigned pos); + void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); + /// Removes the specified identifier from the system. void removeId(unsigned pos); @@ -423,6 +429,12 @@ public: void dump() const; protected: + /// Return the index at which the specified kind of id starts. + unsigned getIdKindOffset(IdKind kind) const; + + /// Assert that `value` is at most the number of ids of the specified kind. + void assertAtMostNumIdKind(unsigned value, IdKind kind) const; + /// Returns false if the fields corresponding to various identifier counts, or /// equality/inequality buffer sizes aren't consistent; true otherwise. This /// is meant to be used within an assert internally. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 43de94407db2..4453e8531884 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -311,23 +311,13 @@ unsigned FlatAffineConstraints::insertLocalId(unsigned pos, unsigned num) { unsigned FlatAffineConstraints::insertId(IdKind kind, unsigned pos, unsigned num) { - if (kind == IdKind::Dimension) - assert(pos <= getNumDimIds()); - else if (kind == IdKind::Symbol) - assert(pos <= getNumSymbolIds()); - else - assert(pos <= getNumLocalIds()); + assertAtMostNumIdKind(pos, kind); - unsigned absolutePos; - if (kind == IdKind::Dimension) { - absolutePos = pos; + unsigned absolutePos = getIdKindOffset(kind) + pos; + if (kind == IdKind::Dimension) numDims += num; - } else if (kind == IdKind::Symbol) { - absolutePos = pos + getNumDimIds(); + else if (kind == IdKind::Symbol) numSymbols += num; - } else { - absolutePos = pos + getNumDimIds() + getNumSymbolIds(); - } numIds += num; inequalities.insertColumns(absolutePos, num); @@ -336,6 +326,28 @@ unsigned FlatAffineConstraints::insertId(IdKind kind, unsigned pos, return absolutePos; } +void FlatAffineConstraints::assertAtMostNumIdKind(unsigned val, + IdKind kind) const { + if (kind == IdKind::Dimension) + assert(val <= getNumDimIds()); + else if (kind == IdKind::Symbol) + assert(val <= getNumSymbolIds()); + else if (kind == IdKind::Local) + assert(val <= getNumLocalIds()); + else + llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!"); +} + +unsigned FlatAffineConstraints::getIdKindOffset(IdKind kind) const { + if (kind == IdKind::Dimension) + return 0; + if (kind == IdKind::Symbol) + return getNumDimIds(); + if (kind == IdKind::Local) + return getNumDimAndSymbolIds(); + llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!"); +} + unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos, unsigned num) { unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num); @@ -365,6 +377,17 @@ bool FlatAffineValueConstraints::hasValues() const { }) != values.end(); } +void FlatAffineConstraints::removeId(IdKind kind, unsigned pos) { + removeIdRange(kind, pos, pos + 1); +} + +void FlatAffineConstraints::removeIdRange(IdKind kind, unsigned idStart, + unsigned idLimit) { + assertAtMostNumIdKind(idLimit, kind); + removeIdRange(getIdKindOffset(kind) + idStart, + getIdKindOffset(kind) + idLimit); +} + /// Checks if two constraint systems are in the same space, i.e., if they are /// associated with the same set of identifiers, appearing in the same order. static bool areIdsAligned(const FlatAffineValueConstraints &a, diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp index 56f0dbdc950c..d5a88b684b9e 100644 --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -711,4 +711,22 @@ TEST(FlatAffineConstraintsTest, computeLocalReprRecursive) { checkDivisionRepresentation(fac, divisions, denoms); } +TEST(FlatAffineConstraintsTest, removeIdRange) { + FlatAffineConstraints fac(3, 2, 1); + + fac.addInequality({10, 11, 12, 20, 21, 30, 40}); + fac.removeId(FlatAffineConstraints::IdKind::Symbol, 1); + EXPECT_THAT(fac.getInequality(0), + testing::ElementsAre(10, 11, 12, 20, 30, 40)); + + fac.removeIdRange(FlatAffineConstraints::IdKind::Dimension, 0, 2); + EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 30, 40)); + + fac.removeIdRange(FlatAffineConstraints::IdKind::Local, 1, 1); + EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 30, 40)); + + fac.removeIdRange(FlatAffineConstraints::IdKind::Local, 0, 1); + EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 40)); +} + } // namespace mlir