Extend/improve getSliceBounds() / complete TODO + update unionBoundingBox

- compute slices precisely where the destination iteration depends on multiple source
  iterations (instead of over-approximating to the whole source loop extent)
- update unionBoundingBox to deal with input with non-matching symbols
- reenable disabled backend test case

PiperOrigin-RevId: 234714069
This commit is contained in:
Uday Bondhugula 2019-02-19 18:17:19 -08:00 committed by jpienaar
parent 48ccae2476
commit a1dad3a5d9
6 changed files with 322 additions and 54 deletions

View File

@ -222,6 +222,15 @@ AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
AffineExpr rhs);
/// Constructs an affine expression from a flat ArrayRef. If there are local
/// identifiers (neither dimensional nor symbolic) that appear in the sum of
/// products expression, 'localExprs' is expected to have the AffineExpr
/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
/// format [dims, symbols, locals, constant term].
AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
unsigned numSymbols, ArrayRef<AffineExpr> localExprs,
MLIRContext *context);
raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr);
template <typename U> bool AffineExpr::isa() const {

View File

@ -424,7 +424,8 @@ public:
bool findId(const Value &id, unsigned *pos) const;
// Add identifiers of the specified kind - specified positions are relative to
// the kind of identifier. 'id' is the Value corresponding to the
// the kind of identifier. The coefficient column corresponding to the added
// identifier is initialized to zero. 'id' is the Value corresponding to the
// identifier that can optionally be provided.
void addDimId(unsigned pos, Value *id = nullptr);
void addSymbolId(unsigned pos, Value *id = nullptr);
@ -579,6 +580,17 @@ public:
/// one; None otherwise.
Optional<int64_t> getConstantUpperBound(unsigned pos) const;
/// Gets the lower and upper bound of the pos^th identifier treating
/// [dimStartPos, symbStartPos) as dimensions and [symStartPos,
/// getNumDimAndSymbolIds) as symbols. The returned multi-dimensional maps
/// in the pair represent the max and min of potentially multiple affine
/// expressions. The upper bound is exclusive. 'localExprs' holds pre-computed
/// AffineExpr's for all local identifiers in the system.
std::pair<AffineMap, AffineMap>
getLowerAndUpperBound(unsigned pos, unsigned dimStartPos,
unsigned symStartPos, ArrayRef<AffineExpr> localExprs,
MLIRContext *context);
/// Returns true if the set can be trivially detected as being
/// hyper-rectangular on the specified contiguous set of identifiers.
bool isHyperRectangular(unsigned pos, unsigned num) const;
@ -588,6 +600,9 @@ public:
/// constraint.
void removeTrivialRedundancy();
/// A more expensive check to detect redundant inequalities.
void removeRedundantInequalities();
// Removes all equalities and inequalities.
void clearConstraints();

View File

@ -301,8 +301,7 @@ raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) {
/// products expression, 'localExprs' is expected to have the AffineExpr
/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
/// format [dims, symbols, locals, constant term].
// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
static AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
AffineExpr mlir::toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
unsigned numSymbols,
ArrayRef<AffineExpr> localExprs,
MLIRContext *context) {

View File

@ -809,9 +809,6 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
if (posStart >= posLimit)
return 0;
LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", "
<< posLimit << ")\n");
GCDTightenInequalities();
unsigned pivotCol = 0;
@ -909,25 +906,36 @@ static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
return false;
}
// Gather lower and upper bounds for the pos^th identifier.
static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
unsigned pos,
SmallVectorImpl<unsigned> *lbIndices,
SmallVectorImpl<unsigned> *ubIndices) {
assert(pos < cst.getNumIds() && "invalid position");
// Gather all lower bounds and upper bounds of the variable. Since the
// canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
// bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
if (cst.atIneq(r, pos) >= 1) {
// Lower bound.
lbIndices->push_back(r);
} else if (cst.atIneq(r, pos) <= -1) {
// Upper bound.
ubIndices->push_back(r);
}
}
}
// Check if the pos^th identifier can be expressed as a floordiv of an affine
// function of other identifiers (where the divisor is a positive constant).
// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4.
bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) {
assert(pos < cst.getNumIds() && "invalid position");
SmallVector<unsigned, 4> lbIndices, ubIndices;
// Gather all lower bounds and upper bound constraints of this identifier.
// Since the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint
// is a lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
if (cst.atIneq(r, pos) >= 1)
// Lower bound.
lbIndices.push_back(r);
else if (cst.atIneq(r, pos) <= -1)
// Upper bound.
ubIndices.push_back(r);
}
SmallVector<unsigned, 4> lbIndices, ubIndices;
getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices);
// Check if any lower bound, upper bound pair is of the form:
// divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id'
@ -993,6 +1001,107 @@ bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
return false;
}
// Fills an inequality row with the value 'val'.
static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
int64_t val) {
for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
cst->atIneq(r, c) = val;
}
}
// Negates an inequality.
static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
cst->atIneq(r, c) = -cst->atIneq(r, c);
}
}
// A more complex check to eliminate redundant inequalities.
void FlatAffineConstraints::removeRedundantInequalities() {
SmallVector<bool, 32> redun(getNumInequalities(), false);
// To check if an inequality is redundant, we replace the inequality by its
// complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
// system is empty. If it is, the inequality is redundant.
FlatAffineConstraints tmpCst(*this);
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
// Change the inequality to its complement.
negateInequality(&tmpCst, r);
tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
if (tmpCst.isEmpty()) {
redun[r] = true;
// Zero fill the redundant inequality.
fillInequality(this, r, /*val=*/0);
fillInequality(&tmpCst, r, /*val=*/0);
} else {
// Reverse the change (to avoid recreating tmpCst each time).
tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
negateInequality(&tmpCst, r);
}
}
// Scan to get rid of all rows marked redundant, in-place.
auto copyRow = [&](unsigned src, unsigned dest) {
if (src == dest)
return;
for (unsigned c = 0, e = getNumCols(); c < e; c++) {
atIneq(dest, c) = atIneq(src, c);
}
};
unsigned pos = 0;
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
if (!redun[r])
copyRow(r, pos++);
}
inequalities.resize(numReservedCols * pos);
}
std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
unsigned pos, unsigned dimStartPos, unsigned symStartPos,
ArrayRef<AffineExpr> localExprs, MLIRContext *context) {
assert(pos < dimStartPos && "invalid dim start pos");
assert(symStartPos >= dimStartPos && "invalid sym start pos");
assert(getNumLocalIds() == localExprs.size() &&
"incorrect local exprs count");
SmallVector<unsigned, 4> lbIndices, ubIndices;
getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices);
SmallVector<int64_t, 8> lb, ub;
SmallVector<AffineExpr, 4> exprs;
unsigned dimCount = symStartPos - dimStartPos;
unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
exprs.reserve(lbIndices.size());
// Lower bound expressions.
for (auto idx : lbIndices) {
auto ineq = getInequality(idx);
// Extract the lower bound (in terms of other coeff's + const), i.e., if
// i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
// - 1.
lb.assign(ineq.begin() + dimStartPos, ineq.end());
std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context);
exprs.push_back(expr);
}
auto lbMap = exprs.empty() ? AffineMap()
: AffineMap::get(dimCount, symCount, exprs, {});
exprs.clear();
exprs.reserve(ubIndices.size());
// Upper bound expressions.
for (auto idx : ubIndices) {
auto ineq = getInequality(idx);
// Extract the upper bound (in terms of other coeff's + const).
ub.assign(ineq.begin() + dimStartPos, ineq.end());
auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context);
// Upper bound is exclusive.
exprs.push_back(expr + 1);
}
auto ubMap = exprs.empty() ? AffineMap()
: AffineMap::get(dimCount, symCount, exprs, {});
return {lbMap, ubMap};
}
/// Computes the lower and upper bounds of the first 'num' dimensional
/// identifiers as affine maps of the remaining identifiers (dimensional and
/// symbolic identifiers). Local identifiers are themselves explicitly computed
@ -1097,6 +1206,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
// Set the lower and upper bound maps for all the identifiers that were
// computed as affine expressions of the rest as the "detected expr" and
// "detected expr + 1" respectively; set the undetected ones to Null().
Optional<FlatAffineConstraints> tmpClone;
for (unsigned pos = 0; pos < num; pos++) {
unsigned numMapDims = getNumDimIds() - num;
unsigned numMapSymbols = getNumSymbolIds();
@ -1108,24 +1218,49 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
(*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
(*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
} else {
// TODO(andydavis, bondhugula) Add support for computing slice bounds
// symbolic in the identifies [num, numIds).
// TODO(bondhugula): Whenever there have local identifiers in the
// dependence constraints, we'll conservatively over-approximate, since we
// don't always explicitly compute them above (in the while loop).
if (getNumLocalIds() == 0) {
// Work on a copy so that we don't update this constraint system.
if (!tmpClone) {
tmpClone.emplace(FlatAffineConstraints(*this));
// Removing redudnant inequalities is necessary so that we don't get
// redundant loop bounds.
tmpClone->removeRedundantInequalities();
}
std::tie((*lbMaps)[pos], (*ubMaps)[pos]) =
tmpClone->getLowerAndUpperBound(pos, num, getNumDimIds(), {},
context);
}
// If the above fails, we'll just use the constant lower bound and the
// constant upper bound (if they exist) as the slice bounds.
if (!(*lbMaps)[pos]) {
LLVM_DEBUG(llvm::dbgs()
<< "WARNING: Potentially over-approximating slice lb\n");
auto lbConst = getConstantLowerBound(pos);
auto ubConst = getConstantUpperBound(pos);
if (lbConst.hasValue() && ubConst.hasValue()) {
if (lbConst.hasValue()) {
(*lbMaps)[pos] = AffineMap::get(
numMapDims, numMapSymbols,
getAffineConstantExpr(lbConst.getValue(), context), {});
}
}
if (!(*ubMaps)[pos]) {
LLVM_DEBUG(llvm::dbgs()
<< "WARNING: Potentially over-approximating slice ub\n");
auto ubConst = getConstantUpperBound(pos);
if (ubConst.hasValue()) {
(*ubMaps)[pos] = AffineMap::get(
numMapDims, numMapSymbols,
getAffineConstantExpr(ubConst.getValue() + 1, context), {});
} else {
(*lbMaps)[pos] = AffineMap();
(*ubMaps)[pos] = AffineMap();
}
}
}
LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
LLVM_DEBUG(expr.dump(););
LLVM_DEBUG((*lbMaps)[pos].dump(););
LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: ");
LLVM_DEBUG((*ubMaps)[pos].dump(););
}
}
@ -1454,6 +1589,7 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
break;
}
if (c < getNumDimIds())
// Not a pure symbolic bound.
continue;
if (atIneq(r, pos) >= 1)
// Lower bound.
@ -2037,14 +2173,53 @@ static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
}
}; // namespace
// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of
// the symbols. Assumes the common symbols appear in the same order (the
// current/common use case).
static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) {
SmallVector<Value *, 4> symbolsA, symbolsB;
A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA);
B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB);
// Both symbol list have a handful symbols each typically (3-4); a merge
// quadratic in complexity with a linear search is fine.
for (auto *symbolB : symbolsB) {
if (llvm::is_contained(symbolsA, symbolB)) {
A->addSymbolId(symbolsA.size(), symbolB);
symbolsA.push_back(symbolB);
}
}
// symbolsA now holds the merged symbol list.
symbolsB.reserve(symbolsA.size());
unsigned iB = 0;
for (auto *symbolA : symbolsA) {
assert(iB < symbolsB.size());
if (symbolA != symbolsB[iB]) {
symbolsB.insert(symbolsB.begin() + iB, symbolA);
B->addSymbolId(iB, symbolA);
}
++iB;
}
}
// Compute the bounding box with respect to 'other' by finding the min of the
// lower bounds and the max of the upper bounds along each of the dimensions.
bool FlatAffineConstraints::unionBoundingBox(
const FlatAffineConstraints &other) {
assert(other.getNumDimIds() == numDims);
assert(other.getNumSymbolIds() == getNumSymbolIds());
assert(other.getNumLocalIds() == 0);
assert(getNumLocalIds() == 0);
const FlatAffineConstraints &otherArg) {
assert(otherArg.getNumDimIds() == numDims && "dims mismatch");
Optional<FlatAffineConstraints> copy;
if (!otherArg.getIds().equals(getIds())) {
copy.emplace(FlatAffineConstraints(otherArg));
mergeSymbols(this, &copy.getValue());
assert(getIds().equals(copy->getIds()) && "merge failed");
}
const auto &other = copy ? *copy : otherArg;
assert(other.getNumLocalIds() == 0 && "local ids not eliminated");
assert(getNumLocalIds() == 0 && "local ids not eliminated");
std::vector<SmallVector<int64_t, 8>> boundingLbs;
std::vector<SmallVector<int64_t, 8>> boundingUbs;
boundingLbs.reserve(2 * getNumDimIds());
@ -2082,7 +2257,11 @@ bool FlatAffineConstraints::unionBoundingBox(
minLb = otherLb;
} else {
// Uncomparable.
auto constLb = getConstantLowerBound(d);
auto constOtherLb = other.getConstantLowerBound(d);
if (!constLb.hasValue() || !constOtherLb.hasValue())
return false;
minLb = std::min(constLb.getValue(), constOtherLb.getValue());
}
// Do the same for ub's but max of upper bounds.
@ -2098,7 +2277,11 @@ bool FlatAffineConstraints::unionBoundingBox(
maxUb = otherUb;
} else {
// Uncomparable.
auto constUb = getConstantUpperBound(d);
auto constOtherUb = other.getConstantUpperBound(d);
if (!constUb.hasValue() || !constOtherUb.hasValue())
return false;
maxUb = std::max(constUb.getValue(), constOtherUb.getValue());
}
SmallVector<int64_t, 8> newLb(getNumCols(), 0);

View File

@ -1260,12 +1260,9 @@ static bool isFusionProfitable(Instruction *srcOpInst,
unsigned *dstLoopDepth) {
LLVM_DEBUG({
llvm::dbgs() << "Checking whether fusion is profitable between:\n";
llvm::dbgs() << " ";
srcOpInst->dump();
llvm::dbgs() << " and \n";
llvm::dbgs() << " " << *srcOpInst << " and \n";
for (auto dstOpInst : dstLoadOpInsts) {
llvm::dbgs() << " ";
dstOpInst->dump();
llvm::dbgs() << " " << *dstOpInst << "\n";
};
});
@ -1423,7 +1420,10 @@ static bool isFusionProfitable(Instruction *srcOpInst,
<< 100.0 * additionalComputeFraction << "%\n"
<< " storage reduction factor: " << storageReduction << "x\n"
<< " fused nest cost: " << fusedLoopNestComputeCost << "\n"
<< " slice iteration count: " << sliceIterationCount << "\n";
<< " slice iteration count: " << sliceIterationCount << "\n"
<< " src write region size: " << srcWriteRegionSizeBytes << "\n"
<< " slice write region size: " << sliceWriteRegionSizeBytes
<< "\n";
llvm::dbgs() << msg.str();
});
@ -1450,7 +1450,8 @@ static bool isFusionProfitable(Instruction *srcOpInst,
// -maximal-fusion is set, fuse nevertheless.
if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
LLVM_DEBUG(llvm::dbgs()
LLVM_DEBUG(
llvm::dbgs()
<< "All fusion choices involve more than the threshold amount of "
"redundant computation; NOT fusing.\n");
return false;
@ -1694,6 +1695,9 @@ public:
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
LLVM_DEBUG(llvm::dbgs()
<< "\tslice loop nest:\n"
<< *sliceLoopNest->getInstruction() << "\n");
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
if (insertPointInst != dstAffineForOp->getInstruction()) {

View File

@ -1810,3 +1810,61 @@ func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> {
// CHECK-NEXT: }
// CHECK-NEXT: return %arg0 : memref<10xf32>
}
// -----
// The fused slice has 16 iterations from along %i0.
// CHECK-DAG: [[MAP_LB:#map[0-9]+]] = (d0) -> (d0 * 16)
// CHECK-DAG: [[MAP_UB:#map[0-9]+]] = (d0) -> (d0 * 16 + 16)
#map = (d0, d1) -> (d0 * 16 + d1)
// CHECK-LABEL: slice_tile
func @slice_tile(%arg1: memref<32x8xf32>, %arg2: memref<32x8xf32>, %0 : f32) -> memref<32x8xf32> {
for %i0 = 0 to 32 {
for %i1 = 0 to 8 {
store %0, %arg2[%i0, %i1] : memref<32x8xf32>
}
}
for %i = 0 to 2 {
for %j = 0 to 8 {
for %k = 0 to 8 {
for %kk = 0 to 16 {
%1 = affine.apply #map(%k, %kk)
%2 = load %arg1[%1, %j] : memref<32x8xf32>
%3 = "foo"(%2) : (f32) -> f32
}
for %ii = 0 to 16 {
%6 = affine.apply #map(%i, %ii)
%7 = load %arg2[%6, %j] : memref<32x8xf32>
%8 = addf %7, %7 : f32
store %8, %arg2[%6, %j] : memref<32x8xf32>
}
}
}
}
return %arg2 : memref<32x8xf32>
}
// CHECK: for %i0 = 0 to 2 {
// CHECK-NEXT: for %i1 = 0 to 8 {
// CHECK-NEXT: for %i2 = [[MAP_LB]](%i0) to [[MAP_UB]](%i0) {
// CHECK-NEXT: store %arg2, %arg1[%i2, %i1] : memref<32x8xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i3 = 0 to 8 {
// CHECK-NEXT: for %i4 = 0 to 16 {
// CHECK-NEXT: %0 = affine.apply #map{{[0-9]+}}(%i3, %i4)
// CHECK-NEXT: %1 = load %arg0[%0, %i1] : memref<32x8xf32>
// CHECK-NEXT: %2 = "foo"(%1) : (f32) -> f32
// CHECK-NEXT: }
// CHECK-NEXT: for %i5 = 0 to 16 {
// CHECK-NEXT: %3 = affine.apply #map{{[0-9]+}}(%i0, %i5)
// CHECK-NEXT: %4 = load %arg1[%3, %i1] : memref<32x8xf32>
// CHECK-NEXT: %5 = addf %4, %4 : f32
// CHECK-NEXT: store %5, %arg1[%3, %i1] : memref<32x8xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return %arg1 : memref<32x8xf32>
// CHECK-NEXT:}