[mlir][Analysis][NFC] Reimplement FlatAffineConstraints::addLowerOrUpperBound

Reimplement this function in terms of the function variant without Value semantics.

Differential Revision: https://reviews.llvm.org/D107729
This commit is contained in:
Matthias Springer 2021-08-11 15:15:30 +09:00
parent 389dc94d4b
commit 98e30a9b47
1 changed files with 20 additions and 77 deletions

View File

@ -527,15 +527,6 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
return success();
}
// Turn a dimension into a symbol.
static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
unsigned pos;
if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
cst->swapId(pos, cst->getNumDimIds() - 1);
cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
}
}
// Turn a symbol into a dimension.
static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
unsigned pos;
@ -2014,13 +2005,6 @@ LogicalResult
FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
ValueRange boundOperands, bool eq,
bool lower) {
assert(pos < getNumDimAndSymbolIds() && "invalid position");
// Equality follows the logic of lower bound except that we add an equality
// instead of an inequality.
assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
if (eq)
lower = true;
// Fully compose map and operands; canonicalize and simplify so that we
// transitively get to terminal symbols or loop IVs.
auto map = boundMap;
@ -2031,70 +2015,29 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
for (auto operand : operands)
addInductionVarOrTerminalSymbol(operand);
FlatAffineConstraints localVarCst;
std::vector<SmallVector<int64_t, 8>> flatExprs;
if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
return failure();
}
SmallVector<Value> dims, syms;
#ifndef NDEBUG
SmallVector<Value> newSyms;
SmallVector<Value> *newSymsPtr = &newSyms;
#else
SmallVector<Value> *newSymsPtr = nullptr;
#endif // NDEBUG
// Merge and align with localVarCst.
if (localVarCst.getNumLocalIds() > 0) {
// Set values for localVarCst.
localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
for (auto operand : operands) {
unsigned pos;
if (findId(operand, &pos)) {
if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
// If the local var cst has this as a dim, turn it into its symbol.
turnDimIntoSymbol(&localVarCst, operand);
} else if (pos < getNumDimIds()) {
// Or vice versa.
turnSymbolIntoDim(&localVarCst, operand);
}
}
}
mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
append(localVarCst);
}
dims.reserve(numDims);
syms.reserve(numSymbols);
for (unsigned i = 0; i < numDims; ++i)
dims.push_back(ids[i] ? *ids[i] : Value());
for (unsigned i = numDims, e = numDims + numSymbols; i < e; ++i)
syms.push_back(ids[i] ? *ids[i] : Value());
// Record positions of the operands in the constraint system. Need to do
// this here since the constraint system changes after a bound is added.
SmallVector<unsigned, 8> positions;
unsigned numOperands = operands.size();
for (auto operand : operands) {
unsigned pos;
if (!findId(operand, &pos))
assert(0 && "expected to be found");
positions.push_back(pos);
}
AffineMap alignedMap =
alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
// All symbols are already part of this FlatAffineConstraints.
assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
"unexpected new/missing symbols");
for (const auto &flatExpr : flatExprs) {
// Invalid bound: pos appears among the operands.
if (llvm::find(positions, pos) != positions.end())
continue;
SmallVector<int64_t, 4> ineq(getNumCols(), 0);
ineq[pos] = lower ? 1 : -1;
// Dims and symbols.
for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
}
// Copy over the local id coefficients.
unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
jj++, j++) {
ineq[j] =
lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
}
// Constant term.
ineq[getNumCols() - 1] =
lower ? -flatExpr[flatExpr.size() - 1]
// Upper bound in flattenedExpr is an exclusive one.
: flatExpr[flatExpr.size() - 1] - 1;
eq ? addEquality(ineq) : addInequality(ineq);
}
return success();
return addLowerOrUpperBound(pos, alignedMap, eq, lower);
}
// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper