forked from OSchip/llvm-project
LoopFusion improvements:
*) Adds support for fusing into consumer loop nests with multiple loads from the same memref. *) Adds support for reducing slice loop trip count by projecting out destination loop IVs greater than destination loop depth. *) Removes dependence on src loop depth and simplifies cost model computation. PiperOrigin-RevId: 229575126
This commit is contained in:
parent
9d4bb57189
commit
27d067e164
|
@ -173,23 +173,23 @@ struct ComputationSliceState {
|
|||
/// Returns true on success, false otherwise.
|
||||
bool getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||
const MemRefAccess &dstAccess,
|
||||
unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState);
|
||||
|
||||
/// Creates a clone of the computation contained in the loop nest surrounding
|
||||
/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop
|
||||
/// IVs, and inserts the computation slice at the beginning of the instruction
|
||||
/// block of the loop at 'dstLoopDepth' in the loop nest surrounding
|
||||
/// 'dstAccess'. Returns the top-level loop of the computation slice on
|
||||
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
|
||||
/// in 'sliceState', and inserts the computation slice at the beginning of the
|
||||
/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding
|
||||
/// 'dstOpInst'. Returns the top-level loop of the computation slice on
|
||||
/// success, returns nullptr otherwise.
|
||||
// Loop depth is a crucial optimization choice that determines where to
|
||||
// materialize the results of the backward slice - presenting a trade-off b/w
|
||||
// storage and redundant computation in several cases.
|
||||
// TODO(andydavis) Support computation slices with common surrounding loops.
|
||||
ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
||||
MemRefAccess *dstAccess,
|
||||
ComputationSliceState *sliceState,
|
||||
unsigned srcLoopDepth,
|
||||
unsigned dstLoopDepth);
|
||||
ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst,
|
||||
OperationInst *dstOpInst,
|
||||
unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState);
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_UTILS_H
|
||||
|
|
|
@ -1101,8 +1101,21 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
|
|||
(*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
|
||||
(*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
|
||||
} else {
|
||||
(*lbMaps)[pos] = AffineMap::Null();
|
||||
(*ubMaps)[pos] = AffineMap::Null();
|
||||
// TODO(andydavis, bondhugula) Add support for computing slice bounds
|
||||
// symbolic in the identifies [num, numIds).
|
||||
auto lbConst = getConstantLowerBound(pos);
|
||||
auto ubConst = getConstantUpperBound(pos);
|
||||
if (lbConst.hasValue() && ubConst.hasValue()) {
|
||||
(*lbMaps)[pos] = AffineMap::get(
|
||||
numMapDims, numMapSymbols,
|
||||
getAffineConstantExpr(lbConst.getValue(), context), {});
|
||||
(*ubMaps)[pos] = AffineMap::get(
|
||||
numMapDims, numMapSymbols,
|
||||
getAffineConstantExpr(ubConst.getValue() + 1, context), {});
|
||||
} else {
|
||||
(*lbMaps)[pos] = AffineMap::Null();
|
||||
(*ubMaps)[pos] = AffineMap::Null();
|
||||
}
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
|
||||
LLVM_DEBUG(expr.dump(););
|
||||
|
|
|
@ -346,12 +346,13 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the
|
||||
// dependence constraint system to create AffineMaps with which to adjust the
|
||||
// loop bounds of the inserted compution slice so that they are functions of the
|
||||
// loop IVs and symbols of the loops surrounding 'dstAccess'.
|
||||
// Computes memref dependence between 'srcAccess' and 'dstAccess', projects
|
||||
// out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice
|
||||
// bounds in 'sliceState' which represent the src IVs in terms of the dst IVs,
|
||||
// symbols and constants.
|
||||
bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||
const MemRefAccess &dstAccess,
|
||||
unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState) {
|
||||
FlatAffineConstraints dependenceConstraints;
|
||||
if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1,
|
||||
|
@ -364,6 +365,19 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
|||
getLoopIVs(*srcAccess.opInst, &srcLoopIVs);
|
||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||||
|
||||
// Get loop nest surrounding dst operation.
|
||||
SmallVector<ForInst *, 4> dstLoopIVs;
|
||||
getLoopIVs(*dstAccess.opInst, &dstLoopIVs);
|
||||
unsigned numDstLoopIVs = dstLoopIVs.size();
|
||||
if (dstLoopDepth > numDstLoopIVs) {
|
||||
dstAccess.opInst->emitError("invalid destination loop depth");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Project out dimensions other than those up to 'dstLoopDepth'.
|
||||
dependenceConstraints.projectOut(numSrcLoopIVs + dstLoopDepth,
|
||||
numDstLoopIVs - dstLoopDepth);
|
||||
|
||||
// Set up lower/upper bound affine maps for the slice.
|
||||
sliceState->lbs.resize(numSrcLoopIVs, AffineMap::Null());
|
||||
sliceState->ubs.resize(numSrcLoopIVs, AffineMap::Null());
|
||||
|
@ -385,12 +399,10 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Creates a computation slice of the loop nest surrounding 'srcAccess'
|
||||
/// utilizing slice loop bounds in 'sliceState' (for src loops up to
|
||||
/// 'srcLoopDepth'), and inserts this slice into loop nest surrounding
|
||||
/// 'dstAccess' at loop depth 'dstLoopDepth'. For all loops at loop depth
|
||||
/// greater than 'srcLoopDepth' their full loop bounds will be used in the
|
||||
/// slice.
|
||||
/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
|
||||
/// updates the slice loop bounds with any non-null bound maps specified in
|
||||
/// 'sliceState', and inserts this slice into the loop nest surrounding
|
||||
/// 'dstOpInst' at loop depth 'dstLoopDepth'.
|
||||
// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that
|
||||
// aren't necessarily a one-to-one relation b/w the source and destination. The
|
||||
// relation between the source and destination could be many-to-many in general.
|
||||
|
@ -401,33 +413,27 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
|||
// solution.
|
||||
// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project
|
||||
// out loop IVs we don't care about and produce smaller slice.
|
||||
ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
||||
MemRefAccess *dstAccess,
|
||||
ComputationSliceState *sliceState,
|
||||
unsigned srcLoopDepth,
|
||||
unsigned dstLoopDepth) {
|
||||
ForInst *mlir::insertBackwardComputationSlice(
|
||||
OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState) {
|
||||
// Get loop nest surrounding src operation.
|
||||
SmallVector<ForInst *, 4> srcLoopIVs;
|
||||
getLoopIVs(*srcAccess->opInst, &srcLoopIVs);
|
||||
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||||
if (srcLoopDepth > numSrcLoopIVs) {
|
||||
srcAccess->opInst->emitError("invalid source loop depth");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Get loop nest surrounding dst operation.
|
||||
SmallVector<ForInst *, 4> dstLoopIVs;
|
||||
getLoopIVs(*dstAccess->opInst, &dstLoopIVs);
|
||||
getLoopIVs(*dstOpInst, &dstLoopIVs);
|
||||
unsigned dstLoopIVsSize = dstLoopIVs.size();
|
||||
if (dstLoopDepth > dstLoopIVsSize) {
|
||||
dstAccess->opInst->emitError("invalid destination loop depth");
|
||||
dstOpInst->emitError("invalid destination loop depth");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Find the inst block positions of 'srcAccess->opInst' within 'srcLoopIVs'.
|
||||
// Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'.
|
||||
SmallVector<unsigned, 4> positions;
|
||||
// TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
|
||||
findInstPosition(srcAccess->opInst, srcLoopIVs[0]->getBlock(), &positions);
|
||||
findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
|
||||
|
||||
// Clone src loop nest and insert it a the beginning of the instruction block
|
||||
// of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
|
||||
|
@ -451,7 +457,7 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
|
||||
|
||||
// Update loop bounds for loops in 'sliceLoopNest'.
|
||||
for (unsigned i = 0; i < srcLoopDepth; ++i) {
|
||||
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||||
auto *forInst = sliceSurroundingLoops[dstLoopDepth + i];
|
||||
if (AffineMap lbMap = sliceState->lbs[i])
|
||||
forInst->setLowerBound(sliceState->lbOperands[i], lbMap);
|
||||
|
|
|
@ -69,17 +69,6 @@ char LoopFusion::passID = 0;
|
|||
|
||||
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
|
||||
|
||||
// FusionCandidate encapsulates source and destination memref access within
|
||||
// loop nests which are candidates for loop fusion.
|
||||
struct FusionCandidate {
|
||||
// Load or store access within src loop nest to be fused into dst loop nest.
|
||||
MemRefAccess srcAccess;
|
||||
// Load or store access within dst loop nest.
|
||||
MemRefAccess dstAccess;
|
||||
explicit FusionCandidate(OperationInst *src, OperationInst *dst)
|
||||
: srcAccess(MemRefAccess(src)), dstAccess(MemRefAccess(dst)) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// LoopNestStateCollector walks loop nests and collects load and store
|
||||
|
@ -172,10 +161,27 @@ public:
|
|||
return &it->second;
|
||||
}
|
||||
|
||||
// Returns true iff there is an edge from node 'srcId' to node 'dstId' for
|
||||
// 'memref'. Returns false otherwise.
|
||||
bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) {
|
||||
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
|
||||
return false;
|
||||
}
|
||||
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
|
||||
return edge.id == dstId && edge.memref == memref;
|
||||
});
|
||||
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
|
||||
return edge.id == srcId && edge.memref == memref;
|
||||
});
|
||||
return hasOutEdge && hasInEdge;
|
||||
}
|
||||
|
||||
// Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
|
||||
void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
|
||||
outEdges[srcId].push_back({dstId, memref});
|
||||
inEdges[dstId].push_back({srcId, memref});
|
||||
if (!hasEdge(srcId, dstId, memref)) {
|
||||
outEdges[srcId].push_back({dstId, memref});
|
||||
inEdges[dstId].push_back({srcId, memref});
|
||||
}
|
||||
}
|
||||
|
||||
// Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
|
||||
|
@ -425,10 +431,10 @@ public:
|
|||
// inserting a sliced loop nest of known cost into the loop's body.
|
||||
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
|
||||
// within another loop.
|
||||
static uint64_t
|
||||
getComputeCost(ForInst *forInst, LoopNestStats *stats,
|
||||
DenseMap<ForInst *, uint64_t> *tripCountOverrideMap,
|
||||
DenseMap<ForInst *, uint64_t> *computeCostMap) {
|
||||
static uint64_t getComputeCost(
|
||||
ForInst *forInst, LoopNestStats *stats,
|
||||
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
|
||||
DenseMap<ForInst *, uint64_t> *computeCostMap) {
|
||||
// 'opCount' is the total number operations in one iteration of 'forInst' body
|
||||
uint64_t opCount = stats->opCountMap[forInst];
|
||||
if (stats->loopMap.count(forInst) > 0) {
|
||||
|
@ -458,17 +464,33 @@ getComputeCost(ForInst *forInst, LoopNestStats *stats,
|
|||
|
||||
} // end anonymous namespace
|
||||
|
||||
static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
|
||||
assert(lbMap.getNumResults() == 1);
|
||||
assert(ubMap.getNumResults() == 1);
|
||||
assert(lbMap.getNumDims() == ubMap.getNumDims());
|
||||
assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
|
||||
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
|
||||
// ub_expr - lb_expr
|
||||
AffineExpr lbExpr(lbMap.getResult(0));
|
||||
AffineExpr ubExpr(ubMap.getResult(0));
|
||||
auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
|
||||
lbMap.getNumSymbols());
|
||||
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
|
||||
if (!cExpr)
|
||||
return None;
|
||||
return cExpr.getValue();
|
||||
}
|
||||
|
||||
// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
|
||||
// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
|
||||
// Returns true on success, false otherwise (if a non-constant trip count
|
||||
// was encountered).
|
||||
// TODO(andydavis) Make this work with non-unit step loops.
|
||||
static bool
|
||||
buildSliceTripCountMap(MemRefAccess *srcAccess,
|
||||
ComputationSliceState *sliceState,
|
||||
DenseMap<ForInst *, uint64_t> *tripCountMap) {
|
||||
static bool buildSliceTripCountMap(
|
||||
OperationInst *srcOpInst, ComputationSliceState *sliceState,
|
||||
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
|
||||
SmallVector<ForInst *, 4> srcLoopIVs;
|
||||
getLoopIVs(*srcAccess->opInst, &srcLoopIVs);
|
||||
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||||
// Populate map from ForInst -> trip count
|
||||
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||||
|
@ -485,109 +507,166 @@ buildSliceTripCountMap(MemRefAccess *srcAccess,
|
|||
}
|
||||
return false;
|
||||
}
|
||||
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
|
||||
// ub_expr - lb_expr
|
||||
AffineExpr lbExpr(lbMap.getResult(0));
|
||||
AffineExpr ubExpr(ubMap.getResult(0));
|
||||
auto loopSpanExpr = simplifyAffineExpr(
|
||||
ubExpr - lbExpr, std::max(lbMap.getNumDims(), ubMap.getNumDims()),
|
||||
std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols()));
|
||||
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
|
||||
if (!cExpr)
|
||||
Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
|
||||
if (!tripCount.hasValue())
|
||||
return false;
|
||||
(*tripCountMap)[srcLoopIVs[i]] = cExpr.getValue();
|
||||
(*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns the maximum loop depth within the source loop nest at which a
|
||||
// sliced loop bound is detected in 'sliceState'.
|
||||
static unsigned getMaxSrcLoopDepth(unsigned srcLoopDepthLimit,
|
||||
ComputationSliceState *sliceState) {
|
||||
unsigned maxSrcPos = 0;
|
||||
for (unsigned i = 0; i < srcLoopDepthLimit; ++i) {
|
||||
if (sliceState->lbs[i] != AffineMap::Null() &&
|
||||
sliceState->ubs[i] != AffineMap::Null()) {
|
||||
maxSrcPos = std::max(maxSrcPos, i);
|
||||
}
|
||||
// Removes load operations from 'srcLoads' which operate on 'memref', and
|
||||
// adds them to 'dstLoads'.
|
||||
static void
|
||||
moveLoadsAccessingMemrefTo(Value *memref,
|
||||
SmallVectorImpl<OperationInst *> *srcLoads,
|
||||
SmallVectorImpl<OperationInst *> *dstLoads) {
|
||||
dstLoads->clear();
|
||||
SmallVector<OperationInst *, 4> srcLoadsToKeep;
|
||||
for (auto *load : *srcLoads) {
|
||||
if (load->cast<LoadOp>()->getMemRef() == memref)
|
||||
dstLoads->push_back(load);
|
||||
else
|
||||
srcLoadsToKeep.push_back(load);
|
||||
}
|
||||
return maxSrcPos + 1;
|
||||
srcLoads->swap(srcLoadsToKeep);
|
||||
}
|
||||
|
||||
// Returns the minimum loop depth within the destination loop nest at which the
|
||||
// computation slice can be inserted (based on the destination loop IVs that
|
||||
// the source slice actually depends on / is a function of).
|
||||
static unsigned getMinDstLoopDepth(unsigned srcLoopDepth,
|
||||
ComputationSliceState *sliceState) {
|
||||
// Record in 'maxDstLoopDepth' the largest position (+1) of a dst loop nest
|
||||
// IV, which is used in a sliced loop bound in the src loop nest.
|
||||
unsigned maxDstLoopDepth = 0;
|
||||
for (unsigned i = 0; i < srcLoopDepth; ++i) {
|
||||
if (AffineMap lbMap = sliceState->lbs[i]) {
|
||||
lbMap.walkExprs([&](AffineExpr expr) {
|
||||
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
|
||||
maxDstLoopDepth =
|
||||
std::max(maxDstLoopDepth, dimExpr.getPosition() + 1);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (AffineMap ubMap = sliceState->ubs[i]) {
|
||||
ubMap.walkExprs([&](AffineExpr expr) {
|
||||
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
|
||||
maxDstLoopDepth =
|
||||
std::max(maxDstLoopDepth, dimExpr.getPosition() + 1);
|
||||
}
|
||||
});
|
||||
}
|
||||
// Returns the innermost common loop depth for the set of operations in 'ops'.
|
||||
static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
|
||||
unsigned numOps = ops.size();
|
||||
assert(numOps > 0);
|
||||
|
||||
std::vector<SmallVector<ForInst *, 4>> loops(numOps);
|
||||
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
|
||||
for (unsigned i = 0; i < numOps; ++i) {
|
||||
getLoopIVs(*ops[i], &loops[i]);
|
||||
loopDepthLimit =
|
||||
std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
|
||||
}
|
||||
return maxDstLoopDepth;
|
||||
|
||||
unsigned loopDepth = 0;
|
||||
for (unsigned d = 0; d < loopDepthLimit; ++d) {
|
||||
unsigned i;
|
||||
for (i = 1; i < numOps; ++i) {
|
||||
if (loops[i - 1][d] != loops[i][d]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (i != numOps)
|
||||
break;
|
||||
++loopDepth;
|
||||
}
|
||||
return loopDepth;
|
||||
}
|
||||
|
||||
// Checks the profitability of fusion candidate 'candidate'. Returns true if it
|
||||
// profitable to fuse the candidate loop nests. Returns false otherwise.
|
||||
// Returns true if 'map' is a single result constant or single result
|
||||
// dim expr where its corresponding loop IV in 'operands' has zero constant
|
||||
// lower bound.
|
||||
static bool hasZeroMinValue(AffineMap map, ArrayRef<Value *> operands) {
|
||||
if (map.isSingleConstant() && map.getSingleConstantResult() == 0)
|
||||
return true;
|
||||
if (map.getNumResults() != 1 || !map.getResult(0).isa<AffineDimExpr>())
|
||||
return false;
|
||||
// Get operand position of single dim expr result.
|
||||
unsigned pos = map.getResult(0).cast<AffineDimExpr>().getPosition();
|
||||
// Check if loop IV at 'pos' has zero constant lower bound.
|
||||
auto *operand = operands[pos];
|
||||
assert(isa<ForInst>(operand));
|
||||
auto *forInst = cast<ForInst>(operand);
|
||||
return forInst->hasConstantLowerBound() &&
|
||||
forInst->getConstantLowerBound() == 0;
|
||||
}
|
||||
// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in
|
||||
// 'sliceStateB'.
|
||||
// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
|
||||
// and 'sliceStateB' are aligned.
|
||||
// Specifically, when taking the union of overlapping intervals, it assumes
|
||||
// that both intervals start at zero. Support needs to be added to take into
|
||||
// account interval start offset when computing the union.
|
||||
// TODO(andydavis) Move this function to an analysis library.
|
||||
static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
||||
ComputationSliceState *sliceStateB) {
|
||||
assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
|
||||
assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
|
||||
|
||||
for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
|
||||
AffineMap lbMapA = sliceStateA.lbs[i];
|
||||
AffineMap ubMapA = sliceStateA.ubs[i];
|
||||
if (lbMapA == AffineMap::Null()) {
|
||||
assert(ubMapA == AffineMap::Null());
|
||||
continue;
|
||||
}
|
||||
assert(ubMapA != AffineMap::Null());
|
||||
// Validate that constant lower bounds are aligned at zero.
|
||||
if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i]))
|
||||
return false;
|
||||
|
||||
AffineMap lbMapB = sliceStateB->lbs[i];
|
||||
AffineMap ubMapB = sliceStateB->ubs[i];
|
||||
if (lbMapB == AffineMap::Null()) {
|
||||
assert(ubMapB == AffineMap::Null());
|
||||
// Union 'sliceStateB' does not have a bound for 'i' so copy from A.
|
||||
sliceStateB->lbs[i] = lbMapA;
|
||||
sliceStateB->ubs[i] = ubMapA;
|
||||
continue;
|
||||
}
|
||||
// Validate that constant lower bounds are aligned at zero.
|
||||
if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i]))
|
||||
return false;
|
||||
|
||||
// Add bound with the largest trip count to union.
|
||||
Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
|
||||
Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
|
||||
if (!tripCountA.hasValue() || !tripCountB.hasValue())
|
||||
return false;
|
||||
// TODO(andydavis) Change this code to take the min across all lower bounds
|
||||
// and max across all upper bounds for each dimension. This code can for
|
||||
// cases where a unique min or max could not be statically determined.
|
||||
if (tripCountA.getValue() > tripCountB.getValue()) {
|
||||
sliceStateB->lbs[i] = lbMapA;
|
||||
sliceStateB->ubs[i] = ubMapA;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks the profitability of fusing a backwards slice of the loop nest
|
||||
// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
|
||||
// Returns true if it profitable to fuse the candidate loop nests. Returns
|
||||
// false otherwise.
|
||||
// The profitability model executes the following steps:
|
||||
// *) Computes the backward computation slice at 'candidate.srcAccess'. This
|
||||
// computation slice of the loop nest surrounding 'candidate.srcAccess' is
|
||||
// *) Computes the backward computation slice at 'srcOpInst'. This
|
||||
// computation slice of the loop nest surrounding 'srcOpInst' is
|
||||
// represented by modified src loop bounds in 'sliceState', which are
|
||||
// functions of loop IVs in the loop nest surrounding 'candidate.dstAccess'.
|
||||
// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
|
||||
// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
|
||||
// loop nest is the total number of dynamic operation instances in the loop
|
||||
// nest).
|
||||
// *) Computes the cost of fusing a slice of the src loop nest into the dst
|
||||
// loop nest at various values of src/dst loop depth, attempting to fuse
|
||||
// the biggest compution slice (max src loop depth) at the maximal dst loop
|
||||
// depth (closest to the load) to minimize reuse distance and opportunity for
|
||||
// subsequent load/store forwarding.
|
||||
// NOTE: 'srcLoopDepth' refers to the loop depth within the source loop nest
|
||||
// at which we slice the loops bounds (all src loops below this depth will
|
||||
// utilize full loop bounds).
|
||||
// loop nest at various values of dst loop depth, attempting to fuse
|
||||
// the largest compution slice at the maximal dst loop depth (closest to the
|
||||
// load) to minimize reuse distance and potentially enable subsequent
|
||||
// load/store forwarding.
|
||||
// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for
|
||||
// the same memref as is written by 'srcOpInst', then the union of slice
|
||||
// loop bounds is used to compute the slice and associated slice cost.
|
||||
// NOTE: 'dstLoopDepth' refers the loop depth within the destination loop
|
||||
// nest, at which the src computation slice is inserted/fused.
|
||||
// NOTE: We attempt to maximize the source loop depth, but there are cases
|
||||
// where a particular setting for 'dstLoopNest' might fused an unsliced
|
||||
// NOTE: We attempt to maximize the dst loop depth, but there are cases
|
||||
// where a particular setting for 'dstLoopNest' might fuse an unsliced
|
||||
// loop (within the src computation slice) at a depth which results in
|
||||
// execessive recomputation (see unit tests for examples).
|
||||
// *) Compares the total cost of the unfused loop nests to the min cost fused
|
||||
// loop nest computed in the previous step, and returns true if the latter
|
||||
// is lower.
|
||||
static bool isFusionProfitable(FusionCandidate *candidate,
|
||||
static bool isFusionProfitable(OperationInst *srcOpInst,
|
||||
ArrayRef<OperationInst *> dstOpInsts,
|
||||
ComputationSliceState *sliceState,
|
||||
unsigned *srcLoopDepth, unsigned *dstLoopDepth) {
|
||||
// Compute backward computation slice state: src IV bounds w.r.t dst IVs, etc.
|
||||
if (!mlir::getBackwardComputationSliceState(
|
||||
candidate->srcAccess, candidate->dstAccess, sliceState)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Build trip count map for src loops with sliced loop bounds in 'sliceState'.
|
||||
DenseMap<ForInst *, uint64_t> sliceTripCountMap;
|
||||
if (!buildSliceTripCountMap(&candidate->srcAccess, sliceState,
|
||||
&sliceTripCountMap))
|
||||
return false;
|
||||
|
||||
unsigned *dstLoopDepth) {
|
||||
// Compute cost of sliced and unsliced src loop nest.
|
||||
SmallVector<ForInst *, 4> srcLoopIVs;
|
||||
getLoopIVs(*candidate->srcAccess.opInst, &srcLoopIVs);
|
||||
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||||
|
||||
// Walk src loop nest and collect stats.
|
||||
|
@ -600,8 +679,7 @@ static bool isFusionProfitable(FusionCandidate *candidate,
|
|||
|
||||
// Compute cost of dst loop nest.
|
||||
SmallVector<ForInst *, 4> dstLoopIVs;
|
||||
getLoopIVs(*candidate->dstAccess.opInst, &dstLoopIVs);
|
||||
unsigned numDstLoopIVs = dstLoopIVs.size();
|
||||
getLoopIVs(*dstOpInsts[0], &dstLoopIVs);
|
||||
|
||||
LoopNestStats dstLoopNestStats;
|
||||
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
|
||||
|
@ -610,51 +688,60 @@ static bool isFusionProfitable(FusionCandidate *candidate,
|
|||
if (dstStatsCollector.hasLoopWithNonConstTripCount)
|
||||
return false;
|
||||
|
||||
// Search for min cost values for 'srcLoopDepth' and 'dstLoopDepth'.
|
||||
// This search is O(n^2) where 'n' is very small (eg. six).
|
||||
// TODO(andydavis) Consider a solution where we just iteration through
|
||||
// dstLoopDepth possibilities and project out IVs we do not need (remove
|
||||
// dependence on 'srcLoopDepth'.
|
||||
DenseMap<ForInst *, uint64_t> tripCountMap;
|
||||
DenseMap<ForInst *, uint64_t> computeCostMap;
|
||||
unsigned maxSrcLoopDepth = getMaxSrcLoopDepth(numSrcLoopIVs, sliceState);
|
||||
// Compute the innermost common loop for ops in 'dstOpInst'.
|
||||
unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts);
|
||||
if (maxDstLoopDepth == 0)
|
||||
return false;
|
||||
|
||||
// Search for min cost value for 'dstLoopDepth'. At each value of
|
||||
// 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
|
||||
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
|
||||
// of these bounds). Next the union slice bounds are used to calculate
|
||||
// the cost of the slice and the cost of the slice inserted into the dst
|
||||
// loop nest at 'dstLoopDepth'.
|
||||
unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max();
|
||||
unsigned bestSrcLoopDepth;
|
||||
unsigned bestDstLoopDepth;
|
||||
for (unsigned i = maxSrcLoopDepth; i >= 1; --i) {
|
||||
// Compute minDstLoopDepth based on dst loop IVs used in slice loop bounds.
|
||||
unsigned minDstLoopDepth = getMinDstLoopDepth(i, sliceState);
|
||||
assert(minDstLoopDepth <= numDstLoopIVs);
|
||||
if (minDstLoopDepth == 0) {
|
||||
// TODO(andydavis) Support inserting computation slices at top-level.
|
||||
continue;
|
||||
}
|
||||
// Copy elements from slice trip count map up to src loop depth 'i'.
|
||||
tripCountMap.clear();
|
||||
for (unsigned k = 0; k < i; ++k) {
|
||||
auto *forInst = srcLoopIVs[k];
|
||||
auto it = sliceTripCountMap.find(forInst);
|
||||
if (it != sliceTripCountMap.end()) {
|
||||
tripCountMap[forInst] = it->second;
|
||||
}
|
||||
SmallVector<ComputationSliceState, 4> sliceStates;
|
||||
sliceStates.resize(maxDstLoopDepth);
|
||||
|
||||
llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
|
||||
DenseMap<ForInst *, uint64_t> computeCostMap;
|
||||
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
|
||||
MemRefAccess srcAccess(srcOpInst);
|
||||
// Handle the common case of one dst load without a copy.
|
||||
if (!mlir::getBackwardComputationSliceState(
|
||||
srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1]))
|
||||
return false;
|
||||
// Compute the union of slice bound of all ops in 'dstOpInsts'.
|
||||
for (int j = 1, e = dstOpInsts.size(); j < e; ++j) {
|
||||
MemRefAccess dstAccess(dstOpInsts[j]);
|
||||
ComputationSliceState tmpSliceState;
|
||||
if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
|
||||
&tmpSliceState))
|
||||
return false;
|
||||
// Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
|
||||
getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]);
|
||||
}
|
||||
// Build trip count map for computation slice.
|
||||
sliceTripCountMap.clear();
|
||||
if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
|
||||
&sliceTripCountMap))
|
||||
return false;
|
||||
|
||||
// Compute op instance count for the src loop nest with iteration slicing.
|
||||
uint64_t sliceComputeCost =
|
||||
getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &tripCountMap,
|
||||
getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap,
|
||||
/*computeCostMap=*/nullptr);
|
||||
|
||||
for (unsigned j = numDstLoopIVs; j >= minDstLoopDepth; --j) {
|
||||
// Compute cost of fusion for these values of 'i' and 'j'.
|
||||
computeCostMap.clear();
|
||||
computeCostMap[dstLoopIVs[j - 1]] = sliceComputeCost;
|
||||
uint64_t fusedLoopNestComputeCost =
|
||||
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
|
||||
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
|
||||
if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
|
||||
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
|
||||
bestSrcLoopDepth = i;
|
||||
bestDstLoopDepth = j;
|
||||
}
|
||||
// Compute cost of fusion for these values of 'i' and 'j'.
|
||||
computeCostMap.clear();
|
||||
computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
|
||||
uint64_t fusedLoopNestComputeCost =
|
||||
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
|
||||
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
|
||||
if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
|
||||
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
|
||||
bestDstLoopDepth = i;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -668,7 +755,6 @@ static bool isFusionProfitable(FusionCandidate *candidate,
|
|||
/*computeCostMap=*/nullptr);
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics "
|
||||
<< " bestSrcLoopDepth: " << bestSrcLoopDepth
|
||||
<< " bestDstLoopDepth: " << bestDstLoopDepth
|
||||
<< " srcLoopNestCost: " << srcLoopNestCost
|
||||
<< " dstLoopNestCost: " << dstLoopNestCost
|
||||
|
@ -680,25 +766,23 @@ static bool isFusionProfitable(FusionCandidate *candidate,
|
|||
// for load/store forwarding in cost model.
|
||||
if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost)
|
||||
return false;
|
||||
// Set src/dstLoopDepth based on best values from search.
|
||||
*srcLoopDepth = bestSrcLoopDepth;
|
||||
// Update return parameter 'sliceState' with 'bestSliceState'.
|
||||
ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1];
|
||||
sliceState->lbs = bestSliceState->lbs;
|
||||
sliceState->ubs = bestSliceState->ubs;
|
||||
sliceState->lbOperands = bestSliceState->lbOperands;
|
||||
sliceState->ubOperands = bestSliceState->ubOperands;
|
||||
// Set dstLoopDepth based on best values from search.
|
||||
*dstLoopDepth = bestDstLoopDepth;
|
||||
// Update 'sliceState' bounds based on computed 'srcLoopDepth':
|
||||
// *) Canonicalize affine map now that 'srcLoopDepth' has been chosen.
|
||||
// *) Replace slice bound maps at depth > 'srcLoopDepth' withAffineMap::Null()
|
||||
// Canonicalize slice bound affine maps.
|
||||
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||||
if (i < bestSrcLoopDepth) {
|
||||
if (sliceState->lbs[i] != AffineMap::Null()) {
|
||||
canonicalizeMapAndOperands(&sliceState->lbs[i],
|
||||
&sliceState->lbOperands[i]);
|
||||
}
|
||||
if (sliceState->ubs[i] != AffineMap::Null()) {
|
||||
canonicalizeMapAndOperands(&sliceState->ubs[i],
|
||||
&sliceState->ubOperands[i]);
|
||||
}
|
||||
} else {
|
||||
sliceState->lbs[i] = AffineMap::Null();
|
||||
sliceState->ubs[i] = AffineMap::Null();
|
||||
if (sliceState->lbs[i] != AffineMap::Null()) {
|
||||
canonicalizeMapAndOperands(&sliceState->lbs[i],
|
||||
&sliceState->lbOperands[i]);
|
||||
}
|
||||
if (sliceState->ubs[i] != AffineMap::Null()) {
|
||||
canonicalizeMapAndOperands(&sliceState->ubs[i],
|
||||
&sliceState->ubOperands[i]);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -767,12 +851,12 @@ public:
|
|||
continue;
|
||||
|
||||
SmallVector<OperationInst *, 4> loads = dstNode->loads;
|
||||
SmallVector<OperationInst *, 4> dstLoadOpInsts;
|
||||
while (!loads.empty()) {
|
||||
auto *dstLoadOpInst = loads.pop_back_val();
|
||||
auto *memref = dstLoadOpInst->cast<LoadOp>()->getMemRef();
|
||||
// Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'.
|
||||
if (dstNode->getLoadOpCount(memref) != 1)
|
||||
continue;
|
||||
// Get memref of load on top of the stack.
|
||||
auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
|
||||
// Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
|
||||
moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
|
||||
// Skip if no input edges along which to fuse.
|
||||
if (mdg->inEdges.count(dstId) == 0)
|
||||
continue;
|
||||
|
@ -801,19 +885,15 @@ public:
|
|||
continue;
|
||||
// Get unique 'srcNode' store op.
|
||||
auto *srcStoreOpInst = srcNode->stores.front();
|
||||
// Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'.
|
||||
FusionCandidate candidate(srcStoreOpInst, dstLoadOpInst);
|
||||
// Check if fusion would be profitable.
|
||||
unsigned srcLoopDepth;
|
||||
unsigned dstLoopDepth;
|
||||
mlir::ComputationSliceState sliceState;
|
||||
if (!isFusionProfitable(&candidate, &sliceState, &srcLoopDepth,
|
||||
if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
|
||||
&dstLoopDepth))
|
||||
continue;
|
||||
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
|
||||
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
|
||||
&candidate.srcAccess, &candidate.dstAccess, &sliceState,
|
||||
srcLoopDepth, dstLoopDepth);
|
||||
srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState);
|
||||
if (sliceLoopNest != nullptr) {
|
||||
// Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode'
|
||||
mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
|
||||
|
|
|
@ -95,13 +95,18 @@ func @should_fuse_loop_nests_with_shifts() {
|
|||
}
|
||||
}
|
||||
|
||||
// The cost of fusing the src loop nest at dst loop depth 1 is less expensive
|
||||
// than fusing at dst loop depth 2, because at dst loop depth 1, we are
|
||||
// able to reduce the trip count around the %i1 loop by one (because the
|
||||
// dst loop never reads the last element written by the src loop).
|
||||
// CHECK: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 10 {
|
||||
// CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0)
|
||||
// CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1)
|
||||
// CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2)
|
||||
// CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0)
|
||||
// CHECK-NEXT: for %i1 = 0 to 9 {
|
||||
// CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %i1)
|
||||
// CHECK-NEXT: store %cst, %0[%2#0, %2#1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i2 = 0 to 10 {
|
||||
// CHECK-NEXT: %3 = load %0[%i0, %i2] : memref<10x10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
@ -849,6 +854,7 @@ func @should_fuse_src_depth1_at_dst_depth2() {
|
|||
}
|
||||
|
||||
// -----
|
||||
// CHECK: #map0 = ()[s0] -> (s0)
|
||||
|
||||
// CHECK-LABEL: func @fusion_at_depth0_not_currently_supported
|
||||
func @fusion_at_depth0_not_currently_supported() {
|
||||
|
@ -862,10 +868,9 @@ func @fusion_at_depth0_not_currently_supported() {
|
|||
%1 = load %0[%c0] : memref<10xf32>
|
||||
}
|
||||
// CHECK:for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i1 = 0 to 10 {
|
||||
// CHECK-NEXT: %1 = load %0[%c0] : memref<10xf32>
|
||||
// CHECK-NEXT: %1 = affine_apply #map0()[%c0]
|
||||
// CHECK-NEXT: store %cst, %0[%1] : memref<10xf32>
|
||||
// CHECK-NEXT: %2 = load %0[%c0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
return
|
||||
|
@ -977,3 +982,128 @@ func @should_fuse_deep_loop_nests() {
|
|||
// CHECK-NEXT: return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK: #map0 = (d0) -> (d0)
|
||||
|
||||
// CHECK-LABEL: func @should_fuse_at_depth1_and_reduce_slice_trip_count
|
||||
func @should_fuse_at_depth1_and_reduce_slice_trip_count() {
|
||||
%a = alloc() : memref<4x256xf32>
|
||||
%b = alloc() : memref<4x256xf32>
|
||||
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
|
||||
for %i0 = 0 to 4 {
|
||||
for %i1 = 0 to 256 {
|
||||
%v0 = load %b[%i0, %i1] : memref<4x256xf32>
|
||||
}
|
||||
for %i2 = 0 to 256 {
|
||||
store %cf0, %a[%i0, %i2] : memref<4x256xf32>
|
||||
}
|
||||
}
|
||||
|
||||
for %d0 = 0 to 4 {
|
||||
for %d1 = 0 to 16 {
|
||||
%v1 = load %a[%d0, %d1] : memref<4x256xf32>
|
||||
}
|
||||
}
|
||||
// The cost of fusing at depth 2 is greater than the cost of fusing at depth 1
|
||||
// for two reasons:
|
||||
// 1) Inserting the unsliceable src loop %i1 to a higher depth removes
|
||||
// redundant computation and reduces costs.
|
||||
// 2) Inserting the sliceable src loop %i2 at depth 1, we can still reduce
|
||||
// its trip count to 16 (from 256) reducing costs.
|
||||
// CHECK: for %i0 = 0 to 4 {
|
||||
// CHECK-NEXT: %2 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: for %i1 = 0 to 256 {
|
||||
// CHECK-NEXT: %3 = load %1[%2, %i1] : memref<4x256xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i2 = 0 to 16 {
|
||||
// CHECK-NEXT: store %cst, %0[%2, %i2] : memref<4x256xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i3 = 0 to 16 {
|
||||
// CHECK-NEXT: %4 = load %0[%i0, %i3] : memref<4x256xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_20
|
||||
func @should_fuse_at_depth1_with_trip_count_20() {
|
||||
%a = alloc() : memref<100xf32>
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
|
||||
for %i0 = 0 to 100 {
|
||||
store %cf0, %a[%i0]: memref<100xf32>
|
||||
}
|
||||
|
||||
for %i1 = 0 to 5 {
|
||||
for %i2 = 0 to 10 {
|
||||
%v0 = load %a[%i2]: memref<100xf32>
|
||||
}
|
||||
for %i3 = 0 to 10 {
|
||||
for %i4 = 0 to 20 {
|
||||
%v1 = load %a[%i4]: memref<100xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
// CHECK: for %i0 = 0 to 5 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 20 {
|
||||
// CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i2 = 0 to 10 {
|
||||
// CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i3 = 0 to 10 {
|
||||
// CHECK-NEXT: for %i4 = 0 to 20 {
|
||||
// CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_19
|
||||
func @should_fuse_at_depth1_with_trip_count_19() {
|
||||
%a = alloc() : memref<100xf32>
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
|
||||
for %i0 = 0 to 100 {
|
||||
store %cf0, %a[%i0]: memref<100xf32>
|
||||
}
|
||||
|
||||
for %i1 = 0 to 5 {
|
||||
for %i2 = 0 to 19 {
|
||||
%v0 = load %a[%i2]: memref<100xf32>
|
||||
}
|
||||
for %i3 = 0 to 10 {
|
||||
for %i4 = 0 to 10 {
|
||||
%v1 = load %a[%i4]: memref<100xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
// CHECK: for %i0 = 0 to 5 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 19 {
|
||||
// CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i2 = 0 to 19 {
|
||||
// CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i3 = 0 to 10 {
|
||||
// CHECK-NEXT: for %i4 = 0 to 10 {
|
||||
// CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue