forked from OSchip/llvm-project
Extend/improve getSliceBounds() / complete TODO + update unionBoundingBox
- compute slices precisely where the destination iteration depends on multiple source iterations (instead of over-approximating to the whole source loop extent) - update unionBoundingBox to deal with input with non-matching symbols - reenable disabled backend test case PiperOrigin-RevId: 234714069
This commit is contained in:
parent
48ccae2476
commit
a1dad3a5d9
|
@ -222,6 +222,15 @@ AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
|
|||
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
|
||||
AffineExpr rhs);
|
||||
|
||||
/// Constructs an affine expression from a flat ArrayRef. If there are local
|
||||
/// identifiers (neither dimensional nor symbolic) that appear in the sum of
|
||||
/// products expression, 'localExprs' is expected to have the AffineExpr
|
||||
/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
|
||||
/// format [dims, symbols, locals, constant term].
|
||||
AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
|
||||
unsigned numSymbols, ArrayRef<AffineExpr> localExprs,
|
||||
MLIRContext *context);
|
||||
|
||||
raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr);
|
||||
|
||||
template <typename U> bool AffineExpr::isa() const {
|
||||
|
|
|
@ -424,7 +424,8 @@ public:
|
|||
bool findId(const Value &id, unsigned *pos) const;
|
||||
|
||||
// Add identifiers of the specified kind - specified positions are relative to
|
||||
// the kind of identifier. 'id' is the Value corresponding to the
|
||||
// the kind of identifier. The coefficient column corresponding to the added
|
||||
// identifier is initialized to zero. 'id' is the Value corresponding to the
|
||||
// identifier that can optionally be provided.
|
||||
void addDimId(unsigned pos, Value *id = nullptr);
|
||||
void addSymbolId(unsigned pos, Value *id = nullptr);
|
||||
|
@ -579,6 +580,17 @@ public:
|
|||
/// one; None otherwise.
|
||||
Optional<int64_t> getConstantUpperBound(unsigned pos) const;
|
||||
|
||||
/// Gets the lower and upper bound of the pos^th identifier treating
|
||||
/// [dimStartPos, symbStartPos) as dimensions and [symStartPos,
|
||||
/// getNumDimAndSymbolIds) as symbols. The returned multi-dimensional maps
|
||||
/// in the pair represent the max and min of potentially multiple affine
|
||||
/// expressions. The upper bound is exclusive. 'localExprs' holds pre-computed
|
||||
/// AffineExpr's for all local identifiers in the system.
|
||||
std::pair<AffineMap, AffineMap>
|
||||
getLowerAndUpperBound(unsigned pos, unsigned dimStartPos,
|
||||
unsigned symStartPos, ArrayRef<AffineExpr> localExprs,
|
||||
MLIRContext *context);
|
||||
|
||||
/// Returns true if the set can be trivially detected as being
|
||||
/// hyper-rectangular on the specified contiguous set of identifiers.
|
||||
bool isHyperRectangular(unsigned pos, unsigned num) const;
|
||||
|
@ -588,6 +600,9 @@ public:
|
|||
/// constraint.
|
||||
void removeTrivialRedundancy();
|
||||
|
||||
/// A more expensive check to detect redundant inequalities.
|
||||
void removeRedundantInequalities();
|
||||
|
||||
// Removes all equalities and inequalities.
|
||||
void clearConstraints();
|
||||
|
||||
|
|
|
@ -301,8 +301,7 @@ raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr) {
|
|||
/// products expression, 'localExprs' is expected to have the AffineExpr
|
||||
/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
|
||||
/// format [dims, symbols, locals, constant term].
|
||||
// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
|
||||
static AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
|
||||
AffineExpr mlir::toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
|
||||
unsigned numSymbols,
|
||||
ArrayRef<AffineExpr> localExprs,
|
||||
MLIRContext *context) {
|
||||
|
|
|
@ -809,9 +809,6 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
|
|||
if (posStart >= posLimit)
|
||||
return 0;
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", "
|
||||
<< posLimit << ")\n");
|
||||
|
||||
GCDTightenInequalities();
|
||||
|
||||
unsigned pivotCol = 0;
|
||||
|
@ -909,25 +906,36 @@ static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
|
|||
return false;
|
||||
}
|
||||
|
||||
// Gather lower and upper bounds for the pos^th identifier.
|
||||
static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
|
||||
unsigned pos,
|
||||
SmallVectorImpl<unsigned> *lbIndices,
|
||||
SmallVectorImpl<unsigned> *ubIndices) {
|
||||
assert(pos < cst.getNumIds() && "invalid position");
|
||||
|
||||
// Gather all lower bounds and upper bounds of the variable. Since the
|
||||
// canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
|
||||
// bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
|
||||
for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
|
||||
if (cst.atIneq(r, pos) >= 1) {
|
||||
// Lower bound.
|
||||
lbIndices->push_back(r);
|
||||
} else if (cst.atIneq(r, pos) <= -1) {
|
||||
// Upper bound.
|
||||
ubIndices->push_back(r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the pos^th identifier can be expressed as a floordiv of an affine
|
||||
// function of other identifiers (where the divisor is a positive constant).
|
||||
// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4.
|
||||
bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
|
||||
SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) {
|
||||
assert(pos < cst.getNumIds() && "invalid position");
|
||||
SmallVector<unsigned, 4> lbIndices, ubIndices;
|
||||
|
||||
// Gather all lower bounds and upper bound constraints of this identifier.
|
||||
// Since the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint
|
||||
// is a lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
|
||||
for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
|
||||
if (cst.atIneq(r, pos) >= 1)
|
||||
// Lower bound.
|
||||
lbIndices.push_back(r);
|
||||
else if (cst.atIneq(r, pos) <= -1)
|
||||
// Upper bound.
|
||||
ubIndices.push_back(r);
|
||||
}
|
||||
SmallVector<unsigned, 4> lbIndices, ubIndices;
|
||||
getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices);
|
||||
|
||||
// Check if any lower bound, upper bound pair is of the form:
|
||||
// divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id'
|
||||
|
@ -993,6 +1001,107 @@ bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
|
|||
return false;
|
||||
}
|
||||
|
||||
// Fills an inequality row with the value 'val'.
|
||||
static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
|
||||
int64_t val) {
|
||||
for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
|
||||
cst->atIneq(r, c) = val;
|
||||
}
|
||||
}
|
||||
|
||||
// Negates an inequality.
|
||||
static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
|
||||
for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
|
||||
cst->atIneq(r, c) = -cst->atIneq(r, c);
|
||||
}
|
||||
}
|
||||
|
||||
// A more complex check to eliminate redundant inequalities.
|
||||
void FlatAffineConstraints::removeRedundantInequalities() {
|
||||
SmallVector<bool, 32> redun(getNumInequalities(), false);
|
||||
// To check if an inequality is redundant, we replace the inequality by its
|
||||
// complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
|
||||
// system is empty. If it is, the inequality is redundant.
|
||||
FlatAffineConstraints tmpCst(*this);
|
||||
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
|
||||
// Change the inequality to its complement.
|
||||
negateInequality(&tmpCst, r);
|
||||
tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
|
||||
if (tmpCst.isEmpty()) {
|
||||
redun[r] = true;
|
||||
// Zero fill the redundant inequality.
|
||||
fillInequality(this, r, /*val=*/0);
|
||||
fillInequality(&tmpCst, r, /*val=*/0);
|
||||
} else {
|
||||
// Reverse the change (to avoid recreating tmpCst each time).
|
||||
tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
|
||||
negateInequality(&tmpCst, r);
|
||||
}
|
||||
}
|
||||
|
||||
// Scan to get rid of all rows marked redundant, in-place.
|
||||
auto copyRow = [&](unsigned src, unsigned dest) {
|
||||
if (src == dest)
|
||||
return;
|
||||
for (unsigned c = 0, e = getNumCols(); c < e; c++) {
|
||||
atIneq(dest, c) = atIneq(src, c);
|
||||
}
|
||||
};
|
||||
unsigned pos = 0;
|
||||
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
|
||||
if (!redun[r])
|
||||
copyRow(r, pos++);
|
||||
}
|
||||
inequalities.resize(numReservedCols * pos);
|
||||
}
|
||||
|
||||
std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
|
||||
unsigned pos, unsigned dimStartPos, unsigned symStartPos,
|
||||
ArrayRef<AffineExpr> localExprs, MLIRContext *context) {
|
||||
assert(pos < dimStartPos && "invalid dim start pos");
|
||||
assert(symStartPos >= dimStartPos && "invalid sym start pos");
|
||||
assert(getNumLocalIds() == localExprs.size() &&
|
||||
"incorrect local exprs count");
|
||||
|
||||
SmallVector<unsigned, 4> lbIndices, ubIndices;
|
||||
getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices);
|
||||
|
||||
SmallVector<int64_t, 8> lb, ub;
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
unsigned dimCount = symStartPos - dimStartPos;
|
||||
unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
|
||||
exprs.reserve(lbIndices.size());
|
||||
// Lower bound expressions.
|
||||
for (auto idx : lbIndices) {
|
||||
auto ineq = getInequality(idx);
|
||||
// Extract the lower bound (in terms of other coeff's + const), i.e., if
|
||||
// i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
|
||||
// - 1.
|
||||
lb.assign(ineq.begin() + dimStartPos, ineq.end());
|
||||
std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
|
||||
auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context);
|
||||
exprs.push_back(expr);
|
||||
}
|
||||
auto lbMap = exprs.empty() ? AffineMap()
|
||||
: AffineMap::get(dimCount, symCount, exprs, {});
|
||||
|
||||
exprs.clear();
|
||||
exprs.reserve(ubIndices.size());
|
||||
// Upper bound expressions.
|
||||
for (auto idx : ubIndices) {
|
||||
auto ineq = getInequality(idx);
|
||||
// Extract the upper bound (in terms of other coeff's + const).
|
||||
ub.assign(ineq.begin() + dimStartPos, ineq.end());
|
||||
auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context);
|
||||
// Upper bound is exclusive.
|
||||
exprs.push_back(expr + 1);
|
||||
}
|
||||
auto ubMap = exprs.empty() ? AffineMap()
|
||||
: AffineMap::get(dimCount, symCount, exprs, {});
|
||||
|
||||
return {lbMap, ubMap};
|
||||
}
|
||||
|
||||
/// Computes the lower and upper bounds of the first 'num' dimensional
|
||||
/// identifiers as affine maps of the remaining identifiers (dimensional and
|
||||
/// symbolic identifiers). Local identifiers are themselves explicitly computed
|
||||
|
@ -1097,6 +1206,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
|
|||
// Set the lower and upper bound maps for all the identifiers that were
|
||||
// computed as affine expressions of the rest as the "detected expr" and
|
||||
// "detected expr + 1" respectively; set the undetected ones to Null().
|
||||
Optional<FlatAffineConstraints> tmpClone;
|
||||
for (unsigned pos = 0; pos < num; pos++) {
|
||||
unsigned numMapDims = getNumDimIds() - num;
|
||||
unsigned numMapSymbols = getNumSymbolIds();
|
||||
|
@ -1108,24 +1218,49 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
|
|||
(*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
|
||||
(*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
|
||||
} else {
|
||||
// TODO(andydavis, bondhugula) Add support for computing slice bounds
|
||||
// symbolic in the identifies [num, numIds).
|
||||
// TODO(bondhugula): Whenever there have local identifiers in the
|
||||
// dependence constraints, we'll conservatively over-approximate, since we
|
||||
// don't always explicitly compute them above (in the while loop).
|
||||
if (getNumLocalIds() == 0) {
|
||||
// Work on a copy so that we don't update this constraint system.
|
||||
if (!tmpClone) {
|
||||
tmpClone.emplace(FlatAffineConstraints(*this));
|
||||
// Removing redudnant inequalities is necessary so that we don't get
|
||||
// redundant loop bounds.
|
||||
tmpClone->removeRedundantInequalities();
|
||||
}
|
||||
std::tie((*lbMaps)[pos], (*ubMaps)[pos]) =
|
||||
tmpClone->getLowerAndUpperBound(pos, num, getNumDimIds(), {},
|
||||
context);
|
||||
}
|
||||
|
||||
// If the above fails, we'll just use the constant lower bound and the
|
||||
// constant upper bound (if they exist) as the slice bounds.
|
||||
if (!(*lbMaps)[pos]) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "WARNING: Potentially over-approximating slice lb\n");
|
||||
auto lbConst = getConstantLowerBound(pos);
|
||||
auto ubConst = getConstantUpperBound(pos);
|
||||
if (lbConst.hasValue() && ubConst.hasValue()) {
|
||||
if (lbConst.hasValue()) {
|
||||
(*lbMaps)[pos] = AffineMap::get(
|
||||
numMapDims, numMapSymbols,
|
||||
getAffineConstantExpr(lbConst.getValue(), context), {});
|
||||
}
|
||||
}
|
||||
if (!(*ubMaps)[pos]) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "WARNING: Potentially over-approximating slice ub\n");
|
||||
auto ubConst = getConstantUpperBound(pos);
|
||||
if (ubConst.hasValue()) {
|
||||
(*ubMaps)[pos] = AffineMap::get(
|
||||
numMapDims, numMapSymbols,
|
||||
getAffineConstantExpr(ubConst.getValue() + 1, context), {});
|
||||
} else {
|
||||
(*lbMaps)[pos] = AffineMap();
|
||||
(*ubMaps)[pos] = AffineMap();
|
||||
}
|
||||
}
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
|
||||
LLVM_DEBUG(expr.dump(););
|
||||
LLVM_DEBUG((*lbMaps)[pos].dump(););
|
||||
LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: ");
|
||||
LLVM_DEBUG((*ubMaps)[pos].dump(););
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1454,6 +1589,7 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
|
|||
break;
|
||||
}
|
||||
if (c < getNumDimIds())
|
||||
// Not a pure symbolic bound.
|
||||
continue;
|
||||
if (atIneq(r, pos) >= 1)
|
||||
// Lower bound.
|
||||
|
@ -2037,14 +2173,53 @@ static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
|
|||
}
|
||||
}; // namespace
|
||||
|
||||
// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of
|
||||
// the symbols. Assumes the common symbols appear in the same order (the
|
||||
// current/common use case).
|
||||
static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) {
|
||||
SmallVector<Value *, 4> symbolsA, symbolsB;
|
||||
A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA);
|
||||
B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB);
|
||||
|
||||
// Both symbol list have a handful symbols each typically (3-4); a merge
|
||||
// quadratic in complexity with a linear search is fine.
|
||||
for (auto *symbolB : symbolsB) {
|
||||
if (llvm::is_contained(symbolsA, symbolB)) {
|
||||
A->addSymbolId(symbolsA.size(), symbolB);
|
||||
symbolsA.push_back(symbolB);
|
||||
}
|
||||
}
|
||||
// symbolsA now holds the merged symbol list.
|
||||
symbolsB.reserve(symbolsA.size());
|
||||
unsigned iB = 0;
|
||||
for (auto *symbolA : symbolsA) {
|
||||
assert(iB < symbolsB.size());
|
||||
if (symbolA != symbolsB[iB]) {
|
||||
symbolsB.insert(symbolsB.begin() + iB, symbolA);
|
||||
B->addSymbolId(iB, symbolA);
|
||||
}
|
||||
++iB;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the bounding box with respect to 'other' by finding the min of the
|
||||
// lower bounds and the max of the upper bounds along each of the dimensions.
|
||||
bool FlatAffineConstraints::unionBoundingBox(
|
||||
const FlatAffineConstraints &other) {
|
||||
assert(other.getNumDimIds() == numDims);
|
||||
assert(other.getNumSymbolIds() == getNumSymbolIds());
|
||||
assert(other.getNumLocalIds() == 0);
|
||||
assert(getNumLocalIds() == 0);
|
||||
const FlatAffineConstraints &otherArg) {
|
||||
assert(otherArg.getNumDimIds() == numDims && "dims mismatch");
|
||||
|
||||
Optional<FlatAffineConstraints> copy;
|
||||
if (!otherArg.getIds().equals(getIds())) {
|
||||
copy.emplace(FlatAffineConstraints(otherArg));
|
||||
mergeSymbols(this, ©.getValue());
|
||||
assert(getIds().equals(copy->getIds()) && "merge failed");
|
||||
}
|
||||
|
||||
const auto &other = copy ? *copy : otherArg;
|
||||
|
||||
assert(other.getNumLocalIds() == 0 && "local ids not eliminated");
|
||||
assert(getNumLocalIds() == 0 && "local ids not eliminated");
|
||||
|
||||
std::vector<SmallVector<int64_t, 8>> boundingLbs;
|
||||
std::vector<SmallVector<int64_t, 8>> boundingUbs;
|
||||
boundingLbs.reserve(2 * getNumDimIds());
|
||||
|
@ -2082,7 +2257,11 @@ bool FlatAffineConstraints::unionBoundingBox(
|
|||
minLb = otherLb;
|
||||
} else {
|
||||
// Uncomparable.
|
||||
auto constLb = getConstantLowerBound(d);
|
||||
auto constOtherLb = other.getConstantLowerBound(d);
|
||||
if (!constLb.hasValue() || !constOtherLb.hasValue())
|
||||
return false;
|
||||
minLb = std::min(constLb.getValue(), constOtherLb.getValue());
|
||||
}
|
||||
|
||||
// Do the same for ub's but max of upper bounds.
|
||||
|
@ -2098,7 +2277,11 @@ bool FlatAffineConstraints::unionBoundingBox(
|
|||
maxUb = otherUb;
|
||||
} else {
|
||||
// Uncomparable.
|
||||
auto constUb = getConstantUpperBound(d);
|
||||
auto constOtherUb = other.getConstantUpperBound(d);
|
||||
if (!constUb.hasValue() || !constOtherUb.hasValue())
|
||||
return false;
|
||||
maxUb = std::max(constUb.getValue(), constOtherUb.getValue());
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 8> newLb(getNumCols(), 0);
|
||||
|
|
|
@ -1260,12 +1260,9 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
unsigned *dstLoopDepth) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Checking whether fusion is profitable between:\n";
|
||||
llvm::dbgs() << " ";
|
||||
srcOpInst->dump();
|
||||
llvm::dbgs() << " and \n";
|
||||
llvm::dbgs() << " " << *srcOpInst << " and \n";
|
||||
for (auto dstOpInst : dstLoadOpInsts) {
|
||||
llvm::dbgs() << " ";
|
||||
dstOpInst->dump();
|
||||
llvm::dbgs() << " " << *dstOpInst << "\n";
|
||||
};
|
||||
});
|
||||
|
||||
|
@ -1423,7 +1420,10 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
<< 100.0 * additionalComputeFraction << "%\n"
|
||||
<< " storage reduction factor: " << storageReduction << "x\n"
|
||||
<< " fused nest cost: " << fusedLoopNestComputeCost << "\n"
|
||||
<< " slice iteration count: " << sliceIterationCount << "\n";
|
||||
<< " slice iteration count: " << sliceIterationCount << "\n"
|
||||
<< " src write region size: " << srcWriteRegionSizeBytes << "\n"
|
||||
<< " slice write region size: " << sliceWriteRegionSizeBytes
|
||||
<< "\n";
|
||||
llvm::dbgs() << msg.str();
|
||||
});
|
||||
|
||||
|
@ -1450,7 +1450,8 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
// -maximal-fusion is set, fuse nevertheless.
|
||||
|
||||
if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs()
|
||||
<< "All fusion choices involve more than the threshold amount of "
|
||||
"redundant computation; NOT fusing.\n");
|
||||
return false;
|
||||
|
@ -1694,6 +1695,9 @@ public:
|
|||
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
|
||||
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
|
||||
if (sliceLoopNest != nullptr) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "\tslice loop nest:\n"
|
||||
<< *sliceLoopNest->getInstruction() << "\n");
|
||||
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
|
||||
auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
|
||||
if (insertPointInst != dstAffineForOp->getInstruction()) {
|
||||
|
|
|
@ -1810,3 +1810,61 @@ func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %arg0 : memref<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// The fused slice has 16 iterations from along %i0.
|
||||
|
||||
// CHECK-DAG: [[MAP_LB:#map[0-9]+]] = (d0) -> (d0 * 16)
|
||||
// CHECK-DAG: [[MAP_UB:#map[0-9]+]] = (d0) -> (d0 * 16 + 16)
|
||||
|
||||
#map = (d0, d1) -> (d0 * 16 + d1)
|
||||
|
||||
// CHECK-LABEL: slice_tile
|
||||
func @slice_tile(%arg1: memref<32x8xf32>, %arg2: memref<32x8xf32>, %0 : f32) -> memref<32x8xf32> {
|
||||
for %i0 = 0 to 32 {
|
||||
for %i1 = 0 to 8 {
|
||||
store %0, %arg2[%i0, %i1] : memref<32x8xf32>
|
||||
}
|
||||
}
|
||||
for %i = 0 to 2 {
|
||||
for %j = 0 to 8 {
|
||||
for %k = 0 to 8 {
|
||||
for %kk = 0 to 16 {
|
||||
%1 = affine.apply #map(%k, %kk)
|
||||
%2 = load %arg1[%1, %j] : memref<32x8xf32>
|
||||
%3 = "foo"(%2) : (f32) -> f32
|
||||
}
|
||||
for %ii = 0 to 16 {
|
||||
%6 = affine.apply #map(%i, %ii)
|
||||
%7 = load %arg2[%6, %j] : memref<32x8xf32>
|
||||
%8 = addf %7, %7 : f32
|
||||
store %8, %arg2[%6, %j] : memref<32x8xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return %arg2 : memref<32x8xf32>
|
||||
}
|
||||
// CHECK: for %i0 = 0 to 2 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 8 {
|
||||
// CHECK-NEXT: for %i2 = [[MAP_LB]](%i0) to [[MAP_UB]](%i0) {
|
||||
// CHECK-NEXT: store %arg2, %arg1[%i2, %i1] : memref<32x8xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i3 = 0 to 8 {
|
||||
// CHECK-NEXT: for %i4 = 0 to 16 {
|
||||
// CHECK-NEXT: %0 = affine.apply #map{{[0-9]+}}(%i3, %i4)
|
||||
// CHECK-NEXT: %1 = load %arg0[%0, %i1] : memref<32x8xf32>
|
||||
// CHECK-NEXT: %2 = "foo"(%1) : (f32) -> f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i5 = 0 to 16 {
|
||||
// CHECK-NEXT: %3 = affine.apply #map{{[0-9]+}}(%i0, %i5)
|
||||
// CHECK-NEXT: %4 = load %arg1[%3, %i1] : memref<32x8xf32>
|
||||
// CHECK-NEXT: %5 = addf %4, %4 : f32
|
||||
// CHECK-NEXT: store %5, %arg1[%3, %i1] : memref<32x8xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %arg1 : memref<32x8xf32>
|
||||
// CHECK-NEXT:}
|
||||
|
|
Loading…
Reference in New Issue