forked from OSchip/llvm-project
[mlir][Analysis][NFC] Remove code duplication around getFlattenedAffineExprs
Remove code duplication in `addLowerOrUpperBound` and `composeMatchingMap`. Differential Revision: https://reviews.llvm.org/D107814
This commit is contained in:
parent
9e6e08149c
commit
4b56e2ee1d
|
@ -602,6 +602,19 @@ private:
|
||||||
/// must already have a corresponding dim/symbol in this constraint system.
|
/// must already have a corresponding dim/symbol in this constraint system.
|
||||||
AffineMap computeAlignedMap(AffineMap map, ValueRange operands) const;
|
AffineMap computeAlignedMap(AffineMap map, ValueRange operands) const;
|
||||||
|
|
||||||
|
/// Given an affine map that is aligned with this constraint system:
|
||||||
|
/// * Flatten the map.
|
||||||
|
/// * Add newly introduced local columns at the beginning of this constraint
|
||||||
|
/// system (local column pos 0).
|
||||||
|
/// * Add equalities that define the new local columns to this constraint
|
||||||
|
/// system.
|
||||||
|
/// * Return the flattened expressions via `flattenedExprs`.
|
||||||
|
///
|
||||||
|
/// Note: This is a shared helper function of `addLowerOrUpperBound` and
|
||||||
|
/// `composeMatchingMap`.
|
||||||
|
LogicalResult flattenAlignedMapAndMergeLocals(
|
||||||
|
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs);
|
||||||
|
|
||||||
// Eliminates a single identifier at 'position' from equality and inequality
|
// Eliminates a single identifier at 'position' from equality and inequality
|
||||||
// constraints. Returns 'success' if the identifier was eliminated, and
|
// constraints. Returns 'success' if the identifier was eliminated, and
|
||||||
// 'failure' otherwise.
|
// 'failure' otherwise.
|
||||||
|
|
|
@ -400,28 +400,10 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
|
||||||
assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
|
assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
|
||||||
|
|
||||||
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
||||||
FlatAffineConstraints localCst;
|
if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
|
||||||
if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
|
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
|
||||||
<< "composition unimplemented for semi-affine maps\n");
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
|
||||||
assert(flatExprs.size() == other.getNumResults());
|
assert(flatExprs.size() == other.getNumResults());
|
||||||
|
|
||||||
// Add localCst information.
|
|
||||||
if (localCst.getNumLocalIds() > 0) {
|
|
||||||
unsigned numLocalIds = getNumLocalIds();
|
|
||||||
// Insert local dims of localCst at the beginning.
|
|
||||||
for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; ++l)
|
|
||||||
addLocalId(0);
|
|
||||||
// Insert local dims of `this` at the end of localCst.
|
|
||||||
for (unsigned l = 0; l < numLocalIds; ++l)
|
|
||||||
localCst.addLocalId(localCst.getNumLocalIds());
|
|
||||||
// Dimensions of localCst and this constraint set match. Append localCst to
|
|
||||||
// this constraint set.
|
|
||||||
append(localCst);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add dimensions corresponding to the map's results.
|
// Add dimensions corresponding to the map's results.
|
||||||
for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
|
for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
|
||||||
addDimId(0);
|
addDimId(0);
|
||||||
|
@ -429,25 +411,24 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
|
||||||
|
|
||||||
// We add one equality for each result connecting the result dim of the map to
|
// We add one equality for each result connecting the result dim of the map to
|
||||||
// the other identifiers.
|
// the other identifiers.
|
||||||
// For eg: if the expression is 16*i0 + i1, and this is the r^th
|
// E.g.: if the expression is 16*i0 + i1, and this is the r^th
|
||||||
// iteration/result of the value map, we are adding the equality:
|
// iteration/result of the value map, we are adding the equality:
|
||||||
// d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
|
// d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
|
||||||
// add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
|
// add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
|
||||||
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
|
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
|
||||||
const auto &flatExpr = flatExprs[r];
|
const auto &flatExpr = flatExprs[r];
|
||||||
assert(flatExpr.size() >= other.getNumInputs() + 1);
|
assert(flatExpr.size() >= other.getNumInputs() + 1);
|
||||||
|
|
||||||
// eqToAdd is the equality corresponding to the flattened affine expression.
|
|
||||||
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
|
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
|
||||||
// Set the coefficient for this result to one.
|
// Set the coefficient for this result to one.
|
||||||
eqToAdd[r] = 1;
|
eqToAdd[r] = 1;
|
||||||
|
|
||||||
// Dims and symbols.
|
// Dims and symbols.
|
||||||
for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
|
for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
|
||||||
// Negate 'eq[r]' since the newly added dimension will be set to this one.
|
// Negate `eq[r]` since the newly added dimension will be set to this one.
|
||||||
eqToAdd[e + i] = -flatExpr[i];
|
eqToAdd[e + i] = -flatExpr[i];
|
||||||
}
|
}
|
||||||
// Local vars common to eq and localCst are at the beginning.
|
// Local columns of `eq` are at the beginning.
|
||||||
unsigned j = getNumDimIds() + getNumSymbolIds();
|
unsigned j = getNumDimIds() + getNumSymbolIds();
|
||||||
unsigned end = flatExpr.size() - 1;
|
unsigned end = flatExpr.size() - 1;
|
||||||
for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
|
for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
|
||||||
|
@ -1872,27 +1853,14 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
|
LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals(
|
||||||
AffineMap boundMap,
|
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
|
||||||
bool eq, bool lower) {
|
|
||||||
assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
|
|
||||||
assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
|
|
||||||
assert(pos < getNumDimAndSymbolIds() && "invalid position");
|
|
||||||
|
|
||||||
// Equality follows the logic of lower bound except that we add an equality
|
|
||||||
// instead of an inequality.
|
|
||||||
assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
|
|
||||||
if (eq)
|
|
||||||
lower = true;
|
|
||||||
|
|
||||||
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
|
||||||
FlatAffineConstraints localCst;
|
FlatAffineConstraints localCst;
|
||||||
if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localCst))) {
|
if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
<< "composition unimplemented for semi-affine maps\n");
|
<< "composition unimplemented for semi-affine maps\n");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
assert(flatExprs.size() == boundMap.getNumResults());
|
|
||||||
|
|
||||||
// Add localCst information.
|
// Add localCst information.
|
||||||
if (localCst.getNumLocalIds() > 0) {
|
if (localCst.getNumLocalIds() > 0) {
|
||||||
|
@ -1908,6 +1876,27 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
|
||||||
append(localCst);
|
append(localCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
|
||||||
|
AffineMap boundMap,
|
||||||
|
bool eq, bool lower) {
|
||||||
|
assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
|
||||||
|
assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
|
||||||
|
assert(pos < getNumDimAndSymbolIds() && "invalid position");
|
||||||
|
|
||||||
|
// Equality follows the logic of lower bound except that we add an equality
|
||||||
|
// instead of an inequality.
|
||||||
|
assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
|
||||||
|
if (eq)
|
||||||
|
lower = true;
|
||||||
|
|
||||||
|
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
||||||
|
if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
|
||||||
|
return failure();
|
||||||
|
assert(flatExprs.size() == boundMap.getNumResults());
|
||||||
|
|
||||||
// Add one (in)equality for each result.
|
// Add one (in)equality for each result.
|
||||||
for (const auto &flatExpr : flatExprs) {
|
for (const auto &flatExpr : flatExprs) {
|
||||||
SmallVector<int64_t> ineq(getNumCols(), 0);
|
SmallVector<int64_t> ineq(getNumCols(), 0);
|
||||||
|
@ -1921,7 +1910,7 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
|
||||||
if (ineq[pos] != 0)
|
if (ineq[pos] != 0)
|
||||||
continue;
|
continue;
|
||||||
ineq[pos] = lower ? 1 : -1;
|
ineq[pos] = lower ? 1 : -1;
|
||||||
// Local vars common to eq and localCst are at the beginning.
|
// Local columns of `ineq` are at the beginning.
|
||||||
unsigned j = getNumDimIds() + getNumSymbolIds();
|
unsigned j = getNumDimIds() + getNumSymbolIds();
|
||||||
unsigned end = flatExpr.size() - 1;
|
unsigned end = flatExpr.size() - 1;
|
||||||
for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
|
for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
|
||||||
|
|
Loading…
Reference in New Issue