[mlir][Analysis][NFC] Reimplement FlatAffineConstraints::composeMap

Reimplement this function in terms of `composeMatchingMap`.

Also fix a bug in `composeMatchingMap` where local dims of `this` could be missing in `localCst`.

Differential Revision: https://reviews.llvm.org/D107813
This commit is contained in:
Matthias Springer 2021-08-11 15:48:21 +09:00
parent 61526b1262
commit 9e6e08149c
2 changed files with 43 additions and 98 deletions

View File

@ -350,16 +350,15 @@ public:
/// and with the dimensions set to the equalities specified by the value map.
/// Returns failure if the composition fails (when vMap is a semi-affine map).
/// The vMap's operand Value's are used to look up the right positions in
/// the FlatAffineConstraints with which to associate. The dimensional and
/// symbolic operands of vMap should match 1:1 (in the same order) with those
/// of this constraint system, but the latter could have additional trailing
/// operands.
/// the FlatAffineConstraints with which to associate. Every operand of vMap
/// should have a matching dim/symbol column in this constraint system (with
/// the same associated Value).
LogicalResult composeMap(const AffineValueMap *vMap);
/// Composes an affine map whose dimensions match one to one to the
/// dimensions of this FlatAffineConstraints. The results of the map 'other'
/// are added as the leading dimensions of this constraint system. Returns
/// failure if 'other' is a semi-affine map.
/// Composes an affine map whose dimensions and symbols match one to one with
/// the dimensions and symbols of this FlatAffineConstraints. The results of
/// the map `other` are added as the leading dimensions of this constraint
/// system. Returns failure if `other` is a semi-affine map.
LogicalResult composeMatchingMap(AffineMap other);
/// Projects out (aka eliminates) 'num' identifiers starting at position
@ -599,6 +598,10 @@ private:
template <bool isLower>
Optional<int64_t> computeConstantLowerOrUpperBound(unsigned pos);
/// Align `map` with this constraint system based on `operands`. Each operand
/// must already have a corresponding dim/symbol in this constraint system.
AffineMap computeAlignedMap(AffineMap map, ValueRange operands) const;
// Eliminates a single identifier at 'position' from equality and inequality
// constraints. Returns 'success' if the identifier was eliminated, and
// 'failure' otherwise.

View File

@ -387,81 +387,14 @@ void FlatAffineConstraints::mergeAndAlignIdsWithOther(
mergeAndAlignIds(offset, this, other);
}
// This routine may add additional local variables if the flattened expression
// corresponding to the map has such variables due to mod's, ceildiv's, and
// floordiv's in it.
LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
std::vector<SmallVector<int64_t, 8>> flatExprs;
FlatAffineConstraints localCst;
if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
&localCst))) {
LLVM_DEBUG(llvm::dbgs()
<< "composition unimplemented for semi-affine maps\n");
return failure();
}
assert(flatExprs.size() == vMap->getNumResults());
// Add localCst information.
if (localCst.getNumLocalIds() > 0) {
localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(),
/*values=*/vMap->getOperands());
// Align localCst and this.
mergeAndAlignIds(/*offset=*/0, &localCst, this);
// Finally, append localCst to this constraint set.
append(localCst);
}
// Add dimensions corresponding to the map's results.
for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
// TODO: Consider using a batched version to add a range of IDs.
addDimId(0);
}
// We add one equality for each result connecting the result dim of the map to
// the other identifiers.
// For eg: if the expression is 16*i0 + i1, and this is the r^th
// iteration/result of the value map, we are adding the equality:
// d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
// add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
const auto &flatExpr = flatExprs[r];
assert(flatExpr.size() >= vMap->getNumOperands() + 1);
// eqToAdd is the equality corresponding to the flattened affine expression.
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
// Set the coefficient for this result to one.
eqToAdd[r] = 1;
// Dims and symbols.
for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
unsigned loc;
bool ret = findId(vMap->getOperand(i), &loc);
assert(ret && "value map's id can't be found");
(void)ret;
// Negate 'eq[r]' since the newly added dimension will be set to this one.
eqToAdd[loc] = -flatExpr[i];
}
// Local vars common to eq and localCst are at the beginning.
unsigned j = getNumDimIds() + getNumSymbolIds();
unsigned end = flatExpr.size() - 1;
for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
eqToAdd[j] = -flatExpr[i];
}
// Constant term.
eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
// Add the equality connecting the result of the map to this constraint set.
addEquality(eqToAdd);
}
return success();
return composeMatchingMap(
computeAlignedMap(vMap->getAffineMap(), vMap->getOperands()));
}
// Similar to composeMap except that no Value's need be associated with the
// constraint system nor are they looked at -- since the dimensions and
// symbols of 'other' are expected to correspond 1:1 to 'this' system. It
// is thus not convenient to share code with composeMap.
// Similar to `composeMap` except that no Values need be associated with the
// constraint system nor are they looked at -- the dimensions and symbols of
// `other` are expected to correspond 1:1 to `this` system.
LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
@ -477,11 +410,15 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
// Add localCst information.
if (localCst.getNumLocalIds() > 0) {
// Place local id's of A after local id's of B.
for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
unsigned numLocalIds = getNumLocalIds();
// Insert local dims of localCst at the beginning.
for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; ++l)
addLocalId(0);
}
// Finally, append localCst to this constraint set.
// Insert local dims of `this` at the end of localCst.
for (unsigned l = 0; l < numLocalIds; ++l)
localCst.addLocalId(localCst.getNumLocalIds());
// Dimensions of localCst and this constraint set match. Append localCst to
// this constraint set.
append(localCst);
}
@ -2001,19 +1938,9 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
return success();
}
LogicalResult
FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
ValueRange boundOperands, bool eq,
bool lower) {
// Fully compose map and operands; canonicalize and simplify so that we
// transitively get to terminal symbols or loop IVs.
auto map = boundMap;
SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
fullyComposeAffineMapAndOperands(&map, &operands);
map = simplifyAffineMap(map);
canonicalizeMapAndOperands(&map, &operands);
for (auto operand : operands)
addInductionVarOrTerminalSymbol(operand);
AffineMap FlatAffineConstraints::computeAlignedMap(AffineMap map,
ValueRange operands) const {
assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
SmallVector<Value> dims, syms;
#ifndef NDEBUG
@ -2036,8 +1963,23 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
"unexpected new/missing symbols");
return alignedMap;
}
return addLowerOrUpperBound(pos, alignedMap, eq, lower);
LogicalResult
FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
ValueRange boundOperands, bool eq,
bool lower) {
// Fully compose map and operands; canonicalize and simplify so that we
// transitively get to terminal symbols or loop IVs.
auto map = boundMap;
SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
fullyComposeAffineMapAndOperands(&map, &operands);
map = simplifyAffineMap(map);
canonicalizeMapAndOperands(&map, &operands);
for (auto operand : operands)
addInductionVarOrTerminalSymbol(operand);
return addLowerOrUpperBound(pos, computeAlignedMap(map, operands), eq, lower);
}
// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper