forked from OSchip/llvm-project
Update addSliceBounds to deal with loops with floor's/mod's.
- This change only impacts the cost model for fusion, given the way addSliceBounds was being used. It so happens that the output in spite of this CL's fix is the same; however, the assertions added no longer fail. (an invalid/inconsistent memref region was being used earlier). PiperOrigin-RevId: 236405030
This commit is contained in:
parent
f37651c708
commit
62e3e2c57c
|
@ -684,6 +684,16 @@ static void turnDimIntoSymbol(FlatAffineConstraints *cst, const Value &id) {
|
|||
}
|
||||
}
|
||||
|
||||
// Turn a symbol into a dimension.
|
||||
static void turnSymbolIntoDim(FlatAffineConstraints *cst, const Value &id) {
|
||||
unsigned pos;
|
||||
if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
|
||||
pos < cst->getNumDimAndSymbolIds()) {
|
||||
swapId(cst, pos, cst->getNumDimIds());
|
||||
cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
bool FlatAffineConstraints::addAffineForOpDomain(
|
||||
ConstOpPointer<AffineForOp> forOp) {
|
||||
unsigned pos;
|
||||
|
@ -721,7 +731,7 @@ bool FlatAffineConstraints::addAffineForOpDomain(
|
|||
}
|
||||
localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), values);
|
||||
|
||||
for (const auto &operand : boundOperands) {
|
||||
for (const auto *operand : boundOperands) {
|
||||
unsigned pos;
|
||||
if (!findId(*operand, &pos)) {
|
||||
if (isValidSymbol(operand)) {
|
||||
|
@ -754,7 +764,7 @@ bool FlatAffineConstraints::addAffineForOpDomain(
|
|||
|
||||
// Record positions of the operands in the constraint system.
|
||||
SmallVector<unsigned, 8> positions;
|
||||
for (const auto &operand : boundOperands) {
|
||||
for (const auto *operand : boundOperands) {
|
||||
unsigned pos;
|
||||
if (!findId(*operand, &pos))
|
||||
assert(0 && "expected to be found");
|
||||
|
@ -1628,42 +1638,77 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef<Value *> values,
|
|||
assert(values.size() == lbMaps.size());
|
||||
assert(lbMaps.size() == ubMaps.size());
|
||||
|
||||
// Record positions of the operands in the constraint system.
|
||||
SmallVector<unsigned, 8> positions;
|
||||
for (const auto &operand : operands) {
|
||||
unsigned loc;
|
||||
if (!findId(*operand, &loc))
|
||||
assert(0 && "expected to be found");
|
||||
positions.push_back(loc);
|
||||
}
|
||||
// Adds a lower or upper bound when the bounds aren't constant. If eq is true,
|
||||
// add a single equality equal to the first bound map result expr.
|
||||
// TODO(andydavis,bondhugula): refactor and reuse from addAffineForOpDomain.
|
||||
auto addLowerOrUpperBound = [&](unsigned pos, AffineMap boundMap, bool eq,
|
||||
bool lower = true) -> bool {
|
||||
assert(pos < getNumDimAndSymbolIds() && "invalid position");
|
||||
// Equality follows the logic of lower bound except that we add an equality
|
||||
// instead of an inequality.
|
||||
assert(!eq || boundMap.getNumResults() == 1 && "single result expected");
|
||||
if (eq)
|
||||
lower = true;
|
||||
|
||||
unsigned numOperands = operands.size();
|
||||
|
||||
auto addLowerOrUpperBound = [&](unsigned pos, AffineMap boundMap,
|
||||
bool lower) -> bool {
|
||||
assert(pos < getNumIds());
|
||||
FlatAffineConstraints localVarCst;
|
||||
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
||||
if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Merge and align with localVarCst.
|
||||
if (localVarCst.getNumLocalIds() > 0) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "loop bounds with mod/floordiv expr's not yet supported\n");
|
||||
return false;
|
||||
// Set values for localVarCst.
|
||||
localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
|
||||
for (const auto *operand : operands) {
|
||||
unsigned pos;
|
||||
if (findId(*operand, &pos)) {
|
||||
if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
|
||||
// If the local var cst has this as a dim, turn it into its symbol.
|
||||
turnDimIntoSymbol(&localVarCst, *operand);
|
||||
} else if (pos < getNumDimIds()) {
|
||||
// Or vice versa.
|
||||
turnSymbolIntoDim(&localVarCst, *operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
mergeAndAlignIds(this, &localVarCst);
|
||||
append(localVarCst);
|
||||
}
|
||||
|
||||
// Record positions of the operands in the constraint system. Need to do
|
||||
// this here since the constraint system changes after a bound is added.
|
||||
SmallVector<unsigned, 8> positions;
|
||||
for (const auto *operand : operands) {
|
||||
unsigned pos;
|
||||
if (!findId(*operand, &pos))
|
||||
assert(0 && "expected to be found");
|
||||
positions.push_back(pos);
|
||||
}
|
||||
|
||||
for (const auto &flatExpr : flatExprs) {
|
||||
SmallVector<int64_t, 4> ineq(getNumCols(), 0);
|
||||
ineq[pos] = lower ? 1 : -1;
|
||||
// Dims and symbols.
|
||||
for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
|
||||
ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
|
||||
}
|
||||
// Copy over the local id coefficients.
|
||||
unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
|
||||
for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
|
||||
jj++, j++) {
|
||||
ineq[j] =
|
||||
lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
|
||||
}
|
||||
// Constant term.
|
||||
ineq[getNumCols() - 1] =
|
||||
lower ? -flatExpr[flatExpr.size() - 1]
|
||||
// Upper bound in flattenedExpr is an exclusive one.
|
||||
: flatExpr[flatExpr.size() - 1] - 1;
|
||||
addInequality(ineq);
|
||||
eq ? addEquality(ineq) : addInequality(ineq);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
@ -1673,16 +1718,27 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef<Value *> values,
|
|||
if (!findId(*values[i], &pos))
|
||||
continue;
|
||||
|
||||
if (AffineMap lbMap = lbMaps[i]) {
|
||||
assert(lbMaps[i].getNumInputs() == operands.size());
|
||||
if (!addLowerOrUpperBound(pos, lbMap, /*lower=*/true))
|
||||
return false;
|
||||
}
|
||||
if (AffineMap ubMap = ubMaps[i]) {
|
||||
assert(ubMaps[i].getNumInputs() == operands.size());
|
||||
if (!addLowerOrUpperBound(pos, ubMap, /*lower=*/false))
|
||||
AffineMap lbMap = lbMaps[i];
|
||||
AffineMap ubMap = ubMaps[i];
|
||||
assert(!lbMap || lbMap.getNumInputs() == operands.size());
|
||||
assert(!ubMap || ubMap.getNumInputs() == operands.size());
|
||||
|
||||
// Check if this slice is just an equality along this dimension.
|
||||
if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
|
||||
ubMap.getNumResults() == 1 &&
|
||||
lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
|
||||
if (!addLowerOrUpperBound(pos, lbMap, /*eq=*/true, /*lower=*/true))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (lbMap &&
|
||||
!addLowerOrUpperBound(pos, lbMap, /*eq=*/false, /*lower=*/true))
|
||||
return false;
|
||||
|
||||
if (ubMap &&
|
||||
!addLowerOrUpperBound(pos, ubMap, /*eq=*/false, /*lower=*/false))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -66,6 +66,8 @@ Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
|
|||
if (shape)
|
||||
shape->reserve(rank);
|
||||
|
||||
assert(rank == cst.getNumDimIds() && "inconsistent memref region");
|
||||
|
||||
// Find a constant upper bound on the extent of this memref region along each
|
||||
// dimension.
|
||||
int64_t numElements = 1;
|
||||
|
@ -221,9 +223,10 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
|
|||
}
|
||||
}
|
||||
// Add upper/lower bounds from 'sliceState' to 'cst'.
|
||||
if (!cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
|
||||
sliceState->lbOperands[0]))
|
||||
return false;
|
||||
bool ret = cst.addSliceBounds(sliceState->ivs, sliceState->lbs,
|
||||
sliceState->ubs, sliceState->lbOperands[0]);
|
||||
assert(ret && "should not fail as we never have semi-affine slice maps");
|
||||
(void)ret;
|
||||
}
|
||||
|
||||
// Add access function equalities to connect loop IVs to data dimensions.
|
||||
|
|
Loading…
Reference in New Issue