forked from OSchip/llvm-project
[mlir][Analysis][NFC] Reimplement FlatAffineConstraints::addLowerOrUpperBound
Reimplement this function in terms of the function variant without Value semantics. Differential Revision: https://reviews.llvm.org/D107729
This commit is contained in:
parent
389dc94d4b
commit
98e30a9b47
|
@ -527,15 +527,6 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
|
|||
return success();
|
||||
}
|
||||
|
||||
// Turn a dimension into a symbol.
|
||||
static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
|
||||
unsigned pos;
|
||||
if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
|
||||
cst->swapId(pos, cst->getNumDimIds() - 1);
|
||||
cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Turn a symbol into a dimension.
|
||||
static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
|
||||
unsigned pos;
|
||||
|
@ -2014,13 +2005,6 @@ LogicalResult
|
|||
FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
|
||||
ValueRange boundOperands, bool eq,
|
||||
bool lower) {
|
||||
assert(pos < getNumDimAndSymbolIds() && "invalid position");
|
||||
// Equality follows the logic of lower bound except that we add an equality
|
||||
// instead of an inequality.
|
||||
assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
|
||||
if (eq)
|
||||
lower = true;
|
||||
|
||||
// Fully compose map and operands; canonicalize and simplify so that we
|
||||
// transitively get to terminal symbols or loop IVs.
|
||||
auto map = boundMap;
|
||||
|
@ -2031,70 +2015,29 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
|
|||
for (auto operand : operands)
|
||||
addInductionVarOrTerminalSymbol(operand);
|
||||
|
||||
FlatAffineConstraints localVarCst;
|
||||
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
||||
if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
|
||||
return failure();
|
||||
}
|
||||
SmallVector<Value> dims, syms;
|
||||
#ifndef NDEBUG
|
||||
SmallVector<Value> newSyms;
|
||||
SmallVector<Value> *newSymsPtr = &newSyms;
|
||||
#else
|
||||
SmallVector<Value> *newSymsPtr = nullptr;
|
||||
#endif // NDEBUG
|
||||
|
||||
// Merge and align with localVarCst.
|
||||
if (localVarCst.getNumLocalIds() > 0) {
|
||||
// Set values for localVarCst.
|
||||
localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
|
||||
for (auto operand : operands) {
|
||||
unsigned pos;
|
||||
if (findId(operand, &pos)) {
|
||||
if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
|
||||
// If the local var cst has this as a dim, turn it into its symbol.
|
||||
turnDimIntoSymbol(&localVarCst, operand);
|
||||
} else if (pos < getNumDimIds()) {
|
||||
// Or vice versa.
|
||||
turnSymbolIntoDim(&localVarCst, operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
|
||||
append(localVarCst);
|
||||
}
|
||||
dims.reserve(numDims);
|
||||
syms.reserve(numSymbols);
|
||||
for (unsigned i = 0; i < numDims; ++i)
|
||||
dims.push_back(ids[i] ? *ids[i] : Value());
|
||||
for (unsigned i = numDims, e = numDims + numSymbols; i < e; ++i)
|
||||
syms.push_back(ids[i] ? *ids[i] : Value());
|
||||
|
||||
// Record positions of the operands in the constraint system. Need to do
|
||||
// this here since the constraint system changes after a bound is added.
|
||||
SmallVector<unsigned, 8> positions;
|
||||
unsigned numOperands = operands.size();
|
||||
for (auto operand : operands) {
|
||||
unsigned pos;
|
||||
if (!findId(operand, &pos))
|
||||
assert(0 && "expected to be found");
|
||||
positions.push_back(pos);
|
||||
}
|
||||
AffineMap alignedMap =
|
||||
alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
|
||||
// All symbols are already part of this FlatAffineConstraints.
|
||||
assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
|
||||
assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
|
||||
"unexpected new/missing symbols");
|
||||
|
||||
for (const auto &flatExpr : flatExprs) {
|
||||
// Invalid bound: pos appears among the operands.
|
||||
if (llvm::find(positions, pos) != positions.end())
|
||||
continue;
|
||||
|
||||
SmallVector<int64_t, 4> ineq(getNumCols(), 0);
|
||||
ineq[pos] = lower ? 1 : -1;
|
||||
// Dims and symbols.
|
||||
for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
|
||||
ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
|
||||
}
|
||||
// Copy over the local id coefficients.
|
||||
unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
|
||||
for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
|
||||
jj++, j++) {
|
||||
ineq[j] =
|
||||
lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
|
||||
}
|
||||
// Constant term.
|
||||
ineq[getNumCols() - 1] =
|
||||
lower ? -flatExpr[flatExpr.size() - 1]
|
||||
// Upper bound in flattenedExpr is an exclusive one.
|
||||
: flatExpr[flatExpr.size() - 1] - 1;
|
||||
eq ? addEquality(ineq) : addInequality(ineq);
|
||||
}
|
||||
return success();
|
||||
return addLowerOrUpperBound(pos, alignedMap, eq, lower);
|
||||
}
|
||||
|
||||
// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
|
||||
|
|
Loading…
Reference in New Issue