From 97e41c004ca0e1c9b969daa02a587e0d41166383 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 11 Aug 2021 15:08:02 +0900 Subject: [PATCH] [mlir][Analysis] Add FlatAffineConstraints::addLowerOrUpperBound This function overload is similar to the existing `FlatAffineConstraints::addLowerOrUpperBound`. It constrains a dimension based on an affine map. However, in contrast to the other overloading, it does not attempt to align dimensions/symbols of the affine map with the dimensions/symbols of the constraint set. Instead, dimensions/symbols are expected to already be aligned. Differential Revision: https://reviews.llvm.org/D107727 --- mlir/include/mlir/Analysis/AffineStructures.h | 9 +++ mlir/lib/Analysis/AffineStructures.cpp | 66 +++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index f173c1b3eb60..f0535cb5c36f 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -234,6 +234,15 @@ public: /// the columns in the current one regarding numbers and values. void addAffineIfOpDomain(AffineIfOp ifOp); + /// Adds a lower or an upper bound for the identifier at the specified + /// position with constraints being drawn from the specified bound map. If + /// `eq` is true, add a single equality equal to the bound map's first result + /// expr. + /// Note: The dimensions/symbols of this FlatAffineConstraints must match the + /// dimensions/symbols of the affine map. + LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, bool eq, + bool lower = true); + /// Adds a lower or an upper bound for the identifier at the specified /// position with constraints being drawn from the specified bound map and /// operands. If `eq` is true, add a single equality equal to the bound map's diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index c9000f6d6400..92952a42d2dc 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1944,6 +1944,72 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num, } } +LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, + AffineMap boundMap, + bool eq, bool lower) { + assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch"); + assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); + assert(pos < getNumDimAndSymbolIds() && "invalid position"); + + // Equality follows the logic of lower bound except that we add an equality + // instead of an inequality. + assert((!eq || boundMap.getNumResults() == 1) && "single result expected"); + if (eq) + lower = true; + + std::vector> flatExprs; + FlatAffineConstraints localCst; + if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localCst))) { + LLVM_DEBUG(llvm::dbgs() + << "composition unimplemented for semi-affine maps\n"); + return failure(); + } + assert(flatExprs.size() == boundMap.getNumResults()); + + // Add localCst information. + if (localCst.getNumLocalIds() > 0) { + unsigned numLocalIds = getNumLocalIds(); + // Insert local dims of localCst at the beginning. + for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; ++l) + addLocalId(0); + // Insert local dims of `this` at the end of localCst. + for (unsigned l = 0; l < numLocalIds; ++l) + localCst.addLocalId(localCst.getNumLocalIds()); + // Dimensions of localCst and this constraint set match. Append localCst to + // this constraint set. + append(localCst); + } + + // Add one (in)equality for each result. + for (const auto &flatExpr : flatExprs) { + SmallVector ineq(getNumCols(), 0); + // Dims and symbols. + for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) { + ineq[j] = lower ? -flatExpr[j] : flatExpr[j]; + } + // Invalid bound: pos appears in `boundMap`. + // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or + // its callers to prevent invalid bounds from being added. + if (ineq[pos] != 0) + continue; + ineq[pos] = lower ? 1 : -1; + // Local vars common to eq and localCst are at the beginning. + unsigned j = getNumDimIds() + getNumSymbolIds(); + unsigned end = flatExpr.size() - 1; + for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) { + ineq[j] = lower ? -flatExpr[i] : flatExpr[i]; + } + // Constant term. + ineq[getNumCols() - 1] = + lower ? -flatExpr[flatExpr.size() - 1] + // Upper bound in flattenedExpr is an exclusive one. + : flatExpr[flatExpr.size() - 1] - 1; + eq ? addEquality(ineq) : addInequality(ineq); + } + + return success(); +} + LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, ValueRange boundOperands, bool eq,