LoopFusion improvements:

*) Adds support for fusing into consumer loop nests with multiple loads from the same memref.
*) Adds support for reducing slice loop trip count by projecting out destination loop IVs greater than destination loop depth.
*) Removes dependence on src loop depth and simplifies cost model computation.

PiperOrigin-RevId: 229575126
This commit is contained in:
MLIR Team 2019-01-16 09:55:02 -08:00 committed by jpienaar
parent 9d4bb57189
commit 27d067e164
5 changed files with 442 additions and 213 deletions

View File

@ -173,23 +173,23 @@ struct ComputationSliceState {
/// Returns true on success, false otherwise.
bool getBackwardComputationSliceState(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
unsigned dstLoopDepth,
ComputationSliceState *sliceState);
/// Creates a clone of the computation contained in the loop nest surrounding
/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop
/// IVs, and inserts the computation slice at the beginning of the instruction
/// block of the loop at 'dstLoopDepth' in the loop nest surrounding
/// 'dstAccess'. Returns the top-level loop of the computation slice on
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
/// in 'sliceState', and inserts the computation slice at the beginning of the
/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding
/// 'dstOpInst'. Returns the top-level loop of the computation slice on
/// success, returns nullptr otherwise.
// Loop depth is a crucial optimization choice that determines where to
// materialize the results of the backward slice - presenting a trade-off b/w
// storage and redundant computation in several cases.
// TODO(andydavis) Support computation slices with common surrounding loops.
ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess,
MemRefAccess *dstAccess,
ComputationSliceState *sliceState,
unsigned srcLoopDepth,
unsigned dstLoopDepth);
ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst,
OperationInst *dstOpInst,
unsigned dstLoopDepth,
ComputationSliceState *sliceState);
} // end namespace mlir
#endif // MLIR_ANALYSIS_UTILS_H

View File

@ -1101,8 +1101,21 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
(*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
(*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
} else {
(*lbMaps)[pos] = AffineMap::Null();
(*ubMaps)[pos] = AffineMap::Null();
// TODO(andydavis, bondhugula) Add support for computing slice bounds
// symbolic in the identifies [num, numIds).
auto lbConst = getConstantLowerBound(pos);
auto ubConst = getConstantUpperBound(pos);
if (lbConst.hasValue() && ubConst.hasValue()) {
(*lbMaps)[pos] = AffineMap::get(
numMapDims, numMapSymbols,
getAffineConstantExpr(lbConst.getValue(), context), {});
(*ubMaps)[pos] = AffineMap::get(
numMapDims, numMapSymbols,
getAffineConstantExpr(ubConst.getValue() + 1, context), {});
} else {
(*lbMaps)[pos] = AffineMap::Null();
(*ubMaps)[pos] = AffineMap::Null();
}
}
LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
LLVM_DEBUG(expr.dump(););

View File

@ -346,12 +346,13 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
return nullptr;
}
// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the
// dependence constraint system to create AffineMaps with which to adjust the
// loop bounds of the inserted compution slice so that they are functions of the
// loop IVs and symbols of the loops surrounding 'dstAccess'.
// Computes memref dependence between 'srcAccess' and 'dstAccess', projects
// out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice
// bounds in 'sliceState' which represent the src IVs in terms of the dst IVs,
// symbols and constants.
bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
unsigned dstLoopDepth,
ComputationSliceState *sliceState) {
FlatAffineConstraints dependenceConstraints;
if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1,
@ -364,6 +365,19 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
getLoopIVs(*srcAccess.opInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Get loop nest surrounding dst operation.
SmallVector<ForInst *, 4> dstLoopIVs;
getLoopIVs(*dstAccess.opInst, &dstLoopIVs);
unsigned numDstLoopIVs = dstLoopIVs.size();
if (dstLoopDepth > numDstLoopIVs) {
dstAccess.opInst->emitError("invalid destination loop depth");
return false;
}
// Project out dimensions other than those up to 'dstLoopDepth'.
dependenceConstraints.projectOut(numSrcLoopIVs + dstLoopDepth,
numDstLoopIVs - dstLoopDepth);
// Set up lower/upper bound affine maps for the slice.
sliceState->lbs.resize(numSrcLoopIVs, AffineMap::Null());
sliceState->ubs.resize(numSrcLoopIVs, AffineMap::Null());
@ -385,12 +399,10 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
return true;
}
/// Creates a computation slice of the loop nest surrounding 'srcAccess'
/// utilizing slice loop bounds in 'sliceState' (for src loops up to
/// 'srcLoopDepth'), and inserts this slice into loop nest surrounding
/// 'dstAccess' at loop depth 'dstLoopDepth'. For all loops at loop depth
/// greater than 'srcLoopDepth' their full loop bounds will be used in the
/// slice.
/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
/// updates the slice loop bounds with any non-null bound maps specified in
/// 'sliceState', and inserts this slice into the loop nest surrounding
/// 'dstOpInst' at loop depth 'dstLoopDepth'.
// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that
// aren't necessarily a one-to-one relation b/w the source and destination. The
// relation between the source and destination could be many-to-many in general.
@ -401,33 +413,27 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
// solution.
// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project
// out loop IVs we don't care about and produce smaller slice.
ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
MemRefAccess *dstAccess,
ComputationSliceState *sliceState,
unsigned srcLoopDepth,
unsigned dstLoopDepth) {
ForInst *mlir::insertBackwardComputationSlice(
OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth,
ComputationSliceState *sliceState) {
// Get loop nest surrounding src operation.
SmallVector<ForInst *, 4> srcLoopIVs;
getLoopIVs(*srcAccess->opInst, &srcLoopIVs);
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
if (srcLoopDepth > numSrcLoopIVs) {
srcAccess->opInst->emitError("invalid source loop depth");
return nullptr;
}
// Get loop nest surrounding dst operation.
SmallVector<ForInst *, 4> dstLoopIVs;
getLoopIVs(*dstAccess->opInst, &dstLoopIVs);
getLoopIVs(*dstOpInst, &dstLoopIVs);
unsigned dstLoopIVsSize = dstLoopIVs.size();
if (dstLoopDepth > dstLoopIVsSize) {
dstAccess->opInst->emitError("invalid destination loop depth");
dstOpInst->emitError("invalid destination loop depth");
return nullptr;
}
// Find the inst block positions of 'srcAccess->opInst' within 'srcLoopIVs'.
// Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'.
SmallVector<unsigned, 4> positions;
// TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
findInstPosition(srcAccess->opInst, srcLoopIVs[0]->getBlock(), &positions);
findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
// Clone src loop nest and insert it a the beginning of the instruction block
// of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
@ -451,7 +457,7 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
// Update loop bounds for loops in 'sliceLoopNest'.
for (unsigned i = 0; i < srcLoopDepth; ++i) {
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
auto *forInst = sliceSurroundingLoops[dstLoopDepth + i];
if (AffineMap lbMap = sliceState->lbs[i])
forInst->setLowerBound(sliceState->lbOperands[i], lbMap);

View File

@ -69,17 +69,6 @@ char LoopFusion::passID = 0;
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
// FusionCandidate encapsulates source and destination memref access within
// loop nests which are candidates for loop fusion.
struct FusionCandidate {
// Load or store access within src loop nest to be fused into dst loop nest.
MemRefAccess srcAccess;
// Load or store access within dst loop nest.
MemRefAccess dstAccess;
explicit FusionCandidate(OperationInst *src, OperationInst *dst)
: srcAccess(MemRefAccess(src)), dstAccess(MemRefAccess(dst)) {}
};
namespace {
// LoopNestStateCollector walks loop nests and collects load and store
@ -172,10 +161,27 @@ public:
return &it->second;
}
// Returns true iff there is an edge from node 'srcId' to node 'dstId' for
// 'memref'. Returns false otherwise.
bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) {
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
return false;
}
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
return edge.id == dstId && edge.memref == memref;
});
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
return edge.id == srcId && edge.memref == memref;
});
return hasOutEdge && hasInEdge;
}
// Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
outEdges[srcId].push_back({dstId, memref});
inEdges[dstId].push_back({srcId, memref});
if (!hasEdge(srcId, dstId, memref)) {
outEdges[srcId].push_back({dstId, memref});
inEdges[dstId].push_back({srcId, memref});
}
}
// Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
@ -425,10 +431,10 @@ public:
// inserting a sliced loop nest of known cost into the loop's body.
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
// within another loop.
static uint64_t
getComputeCost(ForInst *forInst, LoopNestStats *stats,
DenseMap<ForInst *, uint64_t> *tripCountOverrideMap,
DenseMap<ForInst *, uint64_t> *computeCostMap) {
static uint64_t getComputeCost(
ForInst *forInst, LoopNestStats *stats,
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
DenseMap<ForInst *, uint64_t> *computeCostMap) {
// 'opCount' is the total number operations in one iteration of 'forInst' body
uint64_t opCount = stats->opCountMap[forInst];
if (stats->loopMap.count(forInst) > 0) {
@ -458,17 +464,33 @@ getComputeCost(ForInst *forInst, LoopNestStats *stats,
} // end anonymous namespace
static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
assert(lbMap.getNumResults() == 1);
assert(ubMap.getNumResults() == 1);
assert(lbMap.getNumDims() == ubMap.getNumDims());
assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
// ub_expr - lb_expr
AffineExpr lbExpr(lbMap.getResult(0));
AffineExpr ubExpr(ubMap.getResult(0));
auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
lbMap.getNumSymbols());
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
if (!cExpr)
return None;
return cExpr.getValue();
}
// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
// Returns true on success, false otherwise (if a non-constant trip count
// was encountered).
// TODO(andydavis) Make this work with non-unit step loops.
static bool
buildSliceTripCountMap(MemRefAccess *srcAccess,
ComputationSliceState *sliceState,
DenseMap<ForInst *, uint64_t> *tripCountMap) {
static bool buildSliceTripCountMap(
OperationInst *srcOpInst, ComputationSliceState *sliceState,
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
SmallVector<ForInst *, 4> srcLoopIVs;
getLoopIVs(*srcAccess->opInst, &srcLoopIVs);
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Populate map from ForInst -> trip count
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
@ -485,109 +507,166 @@ buildSliceTripCountMap(MemRefAccess *srcAccess,
}
return false;
}
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
// ub_expr - lb_expr
AffineExpr lbExpr(lbMap.getResult(0));
AffineExpr ubExpr(ubMap.getResult(0));
auto loopSpanExpr = simplifyAffineExpr(
ubExpr - lbExpr, std::max(lbMap.getNumDims(), ubMap.getNumDims()),
std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols()));
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
if (!cExpr)
Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
if (!tripCount.hasValue())
return false;
(*tripCountMap)[srcLoopIVs[i]] = cExpr.getValue();
(*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
}
return true;
}
// Returns the maximum loop depth within the source loop nest at which a
// sliced loop bound is detected in 'sliceState'.
static unsigned getMaxSrcLoopDepth(unsigned srcLoopDepthLimit,
ComputationSliceState *sliceState) {
unsigned maxSrcPos = 0;
for (unsigned i = 0; i < srcLoopDepthLimit; ++i) {
if (sliceState->lbs[i] != AffineMap::Null() &&
sliceState->ubs[i] != AffineMap::Null()) {
maxSrcPos = std::max(maxSrcPos, i);
}
// Removes load operations from 'srcLoads' which operate on 'memref', and
// adds them to 'dstLoads'.
static void
moveLoadsAccessingMemrefTo(Value *memref,
SmallVectorImpl<OperationInst *> *srcLoads,
SmallVectorImpl<OperationInst *> *dstLoads) {
dstLoads->clear();
SmallVector<OperationInst *, 4> srcLoadsToKeep;
for (auto *load : *srcLoads) {
if (load->cast<LoadOp>()->getMemRef() == memref)
dstLoads->push_back(load);
else
srcLoadsToKeep.push_back(load);
}
return maxSrcPos + 1;
srcLoads->swap(srcLoadsToKeep);
}
// Returns the minimum loop depth within the destination loop nest at which the
// computation slice can be inserted (based on the destination loop IVs that
// the source slice actually depends on / is a function of).
static unsigned getMinDstLoopDepth(unsigned srcLoopDepth,
ComputationSliceState *sliceState) {
// Record in 'maxDstLoopDepth' the largest position (+1) of a dst loop nest
// IV, which is used in a sliced loop bound in the src loop nest.
unsigned maxDstLoopDepth = 0;
for (unsigned i = 0; i < srcLoopDepth; ++i) {
if (AffineMap lbMap = sliceState->lbs[i]) {
lbMap.walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
maxDstLoopDepth =
std::max(maxDstLoopDepth, dimExpr.getPosition() + 1);
}
});
}
if (AffineMap ubMap = sliceState->ubs[i]) {
ubMap.walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
maxDstLoopDepth =
std::max(maxDstLoopDepth, dimExpr.getPosition() + 1);
}
});
}
// Returns the innermost common loop depth for the set of operations in 'ops'.
static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
unsigned numOps = ops.size();
assert(numOps > 0);
std::vector<SmallVector<ForInst *, 4>> loops(numOps);
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
for (unsigned i = 0; i < numOps; ++i) {
getLoopIVs(*ops[i], &loops[i]);
loopDepthLimit =
std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
}
return maxDstLoopDepth;
unsigned loopDepth = 0;
for (unsigned d = 0; d < loopDepthLimit; ++d) {
unsigned i;
for (i = 1; i < numOps; ++i) {
if (loops[i - 1][d] != loops[i][d]) {
break;
}
}
if (i != numOps)
break;
++loopDepth;
}
return loopDepth;
}
// Checks the profitability of fusion candidate 'candidate'. Returns true if it
// profitable to fuse the candidate loop nests. Returns false otherwise.
// Returns true if 'map' is a single result constant or single result
// dim expr where its corresponding loop IV in 'operands' has zero constant
// lower bound.
static bool hasZeroMinValue(AffineMap map, ArrayRef<Value *> operands) {
if (map.isSingleConstant() && map.getSingleConstantResult() == 0)
return true;
if (map.getNumResults() != 1 || !map.getResult(0).isa<AffineDimExpr>())
return false;
// Get operand position of single dim expr result.
unsigned pos = map.getResult(0).cast<AffineDimExpr>().getPosition();
// Check if loop IV at 'pos' has zero constant lower bound.
auto *operand = operands[pos];
assert(isa<ForInst>(operand));
auto *forInst = cast<ForInst>(operand);
return forInst->hasConstantLowerBound() &&
forInst->getConstantLowerBound() == 0;
}
// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in
// 'sliceStateB'.
// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
// and 'sliceStateB' are aligned.
// Specifically, when taking the union of overlapping intervals, it assumes
// that both intervals start at zero. Support needs to be added to take into
// account interval start offset when computing the union.
// TODO(andydavis) Move this function to an analysis library.
static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
ComputationSliceState *sliceStateB) {
assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
AffineMap lbMapA = sliceStateA.lbs[i];
AffineMap ubMapA = sliceStateA.ubs[i];
if (lbMapA == AffineMap::Null()) {
assert(ubMapA == AffineMap::Null());
continue;
}
assert(ubMapA != AffineMap::Null());
// Validate that constant lower bounds are aligned at zero.
if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i]))
return false;
AffineMap lbMapB = sliceStateB->lbs[i];
AffineMap ubMapB = sliceStateB->ubs[i];
if (lbMapB == AffineMap::Null()) {
assert(ubMapB == AffineMap::Null());
// Union 'sliceStateB' does not have a bound for 'i' so copy from A.
sliceStateB->lbs[i] = lbMapA;
sliceStateB->ubs[i] = ubMapA;
continue;
}
// Validate that constant lower bounds are aligned at zero.
if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i]))
return false;
// Add bound with the largest trip count to union.
Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
if (!tripCountA.hasValue() || !tripCountB.hasValue())
return false;
// TODO(andydavis) Change this code to take the min across all lower bounds
// and max across all upper bounds for each dimension. This code can for
// cases where a unique min or max could not be statically determined.
if (tripCountA.getValue() > tripCountB.getValue()) {
sliceStateB->lbs[i] = lbMapA;
sliceStateB->ubs[i] = ubMapA;
}
}
return true;
}
// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
// Returns true if it profitable to fuse the candidate loop nests. Returns
// false otherwise.
// The profitability model executes the following steps:
// *) Computes the backward computation slice at 'candidate.srcAccess'. This
// computation slice of the loop nest surrounding 'candidate.srcAccess' is
// *) Computes the backward computation slice at 'srcOpInst'. This
// computation slice of the loop nest surrounding 'srcOpInst' is
// represented by modified src loop bounds in 'sliceState', which are
// functions of loop IVs in the loop nest surrounding 'candidate.dstAccess'.
// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
// loop nest is the total number of dynamic operation instances in the loop
// nest).
// *) Computes the cost of fusing a slice of the src loop nest into the dst
// loop nest at various values of src/dst loop depth, attempting to fuse
// the biggest compution slice (max src loop depth) at the maximal dst loop
// depth (closest to the load) to minimize reuse distance and opportunity for
// subsequent load/store forwarding.
// NOTE: 'srcLoopDepth' refers to the loop depth within the source loop nest
// at which we slice the loops bounds (all src loops below this depth will
// utilize full loop bounds).
// loop nest at various values of dst loop depth, attempting to fuse
// the largest compution slice at the maximal dst loop depth (closest to the
// load) to minimize reuse distance and potentially enable subsequent
// load/store forwarding.
// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for
// the same memref as is written by 'srcOpInst', then the union of slice
// loop bounds is used to compute the slice and associated slice cost.
// NOTE: 'dstLoopDepth' refers the loop depth within the destination loop
// nest, at which the src computation slice is inserted/fused.
// NOTE: We attempt to maximize the source loop depth, but there are cases
// where a particular setting for 'dstLoopNest' might fused an unsliced
// NOTE: We attempt to maximize the dst loop depth, but there are cases
// where a particular setting for 'dstLoopNest' might fuse an unsliced
// loop (within the src computation slice) at a depth which results in
// execessive recomputation (see unit tests for examples).
// *) Compares the total cost of the unfused loop nests to the min cost fused
// loop nest computed in the previous step, and returns true if the latter
// is lower.
static bool isFusionProfitable(FusionCandidate *candidate,
static bool isFusionProfitable(OperationInst *srcOpInst,
ArrayRef<OperationInst *> dstOpInsts,
ComputationSliceState *sliceState,
unsigned *srcLoopDepth, unsigned *dstLoopDepth) {
// Compute backward computation slice state: src IV bounds w.r.t dst IVs, etc.
if (!mlir::getBackwardComputationSliceState(
candidate->srcAccess, candidate->dstAccess, sliceState)) {
return false;
}
// Build trip count map for src loops with sliced loop bounds in 'sliceState'.
DenseMap<ForInst *, uint64_t> sliceTripCountMap;
if (!buildSliceTripCountMap(&candidate->srcAccess, sliceState,
&sliceTripCountMap))
return false;
unsigned *dstLoopDepth) {
// Compute cost of sliced and unsliced src loop nest.
SmallVector<ForInst *, 4> srcLoopIVs;
getLoopIVs(*candidate->srcAccess.opInst, &srcLoopIVs);
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Walk src loop nest and collect stats.
@ -600,8 +679,7 @@ static bool isFusionProfitable(FusionCandidate *candidate,
// Compute cost of dst loop nest.
SmallVector<ForInst *, 4> dstLoopIVs;
getLoopIVs(*candidate->dstAccess.opInst, &dstLoopIVs);
unsigned numDstLoopIVs = dstLoopIVs.size();
getLoopIVs(*dstOpInsts[0], &dstLoopIVs);
LoopNestStats dstLoopNestStats;
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
@ -610,51 +688,60 @@ static bool isFusionProfitable(FusionCandidate *candidate,
if (dstStatsCollector.hasLoopWithNonConstTripCount)
return false;
// Search for min cost values for 'srcLoopDepth' and 'dstLoopDepth'.
// This search is O(n^2) where 'n' is very small (eg. six).
// TODO(andydavis) Consider a solution where we just iteration through
// dstLoopDepth possibilities and project out IVs we do not need (remove
// dependence on 'srcLoopDepth'.
DenseMap<ForInst *, uint64_t> tripCountMap;
DenseMap<ForInst *, uint64_t> computeCostMap;
unsigned maxSrcLoopDepth = getMaxSrcLoopDepth(numSrcLoopIVs, sliceState);
// Compute the innermost common loop for ops in 'dstOpInst'.
unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts);
if (maxDstLoopDepth == 0)
return false;
// Search for min cost value for 'dstLoopDepth'. At each value of
// 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
// of these bounds). Next the union slice bounds are used to calculate
// the cost of the slice and the cost of the slice inserted into the dst
// loop nest at 'dstLoopDepth'.
unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max();
unsigned bestSrcLoopDepth;
unsigned bestDstLoopDepth;
for (unsigned i = maxSrcLoopDepth; i >= 1; --i) {
// Compute minDstLoopDepth based on dst loop IVs used in slice loop bounds.
unsigned minDstLoopDepth = getMinDstLoopDepth(i, sliceState);
assert(minDstLoopDepth <= numDstLoopIVs);
if (minDstLoopDepth == 0) {
// TODO(andydavis) Support inserting computation slices at top-level.
continue;
}
// Copy elements from slice trip count map up to src loop depth 'i'.
tripCountMap.clear();
for (unsigned k = 0; k < i; ++k) {
auto *forInst = srcLoopIVs[k];
auto it = sliceTripCountMap.find(forInst);
if (it != sliceTripCountMap.end()) {
tripCountMap[forInst] = it->second;
}
SmallVector<ComputationSliceState, 4> sliceStates;
sliceStates.resize(maxDstLoopDepth);
llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
DenseMap<ForInst *, uint64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
MemRefAccess srcAccess(srcOpInst);
// Handle the common case of one dst load without a copy.
if (!mlir::getBackwardComputationSliceState(
srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1]))
return false;
// Compute the union of slice bound of all ops in 'dstOpInsts'.
for (int j = 1, e = dstOpInsts.size(); j < e; ++j) {
MemRefAccess dstAccess(dstOpInsts[j]);
ComputationSliceState tmpSliceState;
if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
&tmpSliceState))
return false;
// Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]);
}
// Build trip count map for computation slice.
sliceTripCountMap.clear();
if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
&sliceTripCountMap))
return false;
// Compute op instance count for the src loop nest with iteration slicing.
uint64_t sliceComputeCost =
getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &tripCountMap,
getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap,
/*computeCostMap=*/nullptr);
for (unsigned j = numDstLoopIVs; j >= minDstLoopDepth; --j) {
// Compute cost of fusion for these values of 'i' and 'j'.
computeCostMap.clear();
computeCostMap[dstLoopIVs[j - 1]] = sliceComputeCost;
uint64_t fusedLoopNestComputeCost =
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
bestSrcLoopDepth = i;
bestDstLoopDepth = j;
}
// Compute cost of fusion for these values of 'i' and 'j'.
computeCostMap.clear();
computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
uint64_t fusedLoopNestComputeCost =
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
bestDstLoopDepth = i;
}
}
@ -668,7 +755,6 @@ static bool isFusionProfitable(FusionCandidate *candidate,
/*computeCostMap=*/nullptr);
LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics "
<< " bestSrcLoopDepth: " << bestSrcLoopDepth
<< " bestDstLoopDepth: " << bestDstLoopDepth
<< " srcLoopNestCost: " << srcLoopNestCost
<< " dstLoopNestCost: " << dstLoopNestCost
@ -680,25 +766,23 @@ static bool isFusionProfitable(FusionCandidate *candidate,
// for load/store forwarding in cost model.
if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost)
return false;
// Set src/dstLoopDepth based on best values from search.
*srcLoopDepth = bestSrcLoopDepth;
// Update return parameter 'sliceState' with 'bestSliceState'.
ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1];
sliceState->lbs = bestSliceState->lbs;
sliceState->ubs = bestSliceState->ubs;
sliceState->lbOperands = bestSliceState->lbOperands;
sliceState->ubOperands = bestSliceState->ubOperands;
// Set dstLoopDepth based on best values from search.
*dstLoopDepth = bestDstLoopDepth;
// Update 'sliceState' bounds based on computed 'srcLoopDepth':
// *) Canonicalize affine map now that 'srcLoopDepth' has been chosen.
// *) Replace slice bound maps at depth > 'srcLoopDepth' withAffineMap::Null()
// Canonicalize slice bound affine maps.
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
if (i < bestSrcLoopDepth) {
if (sliceState->lbs[i] != AffineMap::Null()) {
canonicalizeMapAndOperands(&sliceState->lbs[i],
&sliceState->lbOperands[i]);
}
if (sliceState->ubs[i] != AffineMap::Null()) {
canonicalizeMapAndOperands(&sliceState->ubs[i],
&sliceState->ubOperands[i]);
}
} else {
sliceState->lbs[i] = AffineMap::Null();
sliceState->ubs[i] = AffineMap::Null();
if (sliceState->lbs[i] != AffineMap::Null()) {
canonicalizeMapAndOperands(&sliceState->lbs[i],
&sliceState->lbOperands[i]);
}
if (sliceState->ubs[i] != AffineMap::Null()) {
canonicalizeMapAndOperands(&sliceState->ubs[i],
&sliceState->ubOperands[i]);
}
}
return true;
@ -767,12 +851,12 @@ public:
continue;
SmallVector<OperationInst *, 4> loads = dstNode->loads;
SmallVector<OperationInst *, 4> dstLoadOpInsts;
while (!loads.empty()) {
auto *dstLoadOpInst = loads.pop_back_val();
auto *memref = dstLoadOpInst->cast<LoadOp>()->getMemRef();
// Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'.
if (dstNode->getLoadOpCount(memref) != 1)
continue;
// Get memref of load on top of the stack.
auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
// Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
// Skip if no input edges along which to fuse.
if (mdg->inEdges.count(dstId) == 0)
continue;
@ -801,19 +885,15 @@ public:
continue;
// Get unique 'srcNode' store op.
auto *srcStoreOpInst = srcNode->stores.front();
// Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'.
FusionCandidate candidate(srcStoreOpInst, dstLoadOpInst);
// Check if fusion would be profitable.
unsigned srcLoopDepth;
unsigned dstLoopDepth;
mlir::ComputationSliceState sliceState;
if (!isFusionProfitable(&candidate, &sliceState, &srcLoopDepth,
if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
&dstLoopDepth))
continue;
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
&candidate.srcAccess, &candidate.dstAccess, &sliceState,
srcLoopDepth, dstLoopDepth);
srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
// Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode'
mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);

View File

@ -95,13 +95,18 @@ func @should_fuse_loop_nests_with_shifts() {
}
}
// The cost of fusing the src loop nest at dst loop depth 1 is less expensive
// than fusing at dst loop depth 2, because at dst loop depth 1, we are
// able to reduce the trip count around the %i1 loop by one (because the
// dst loop never reads the last element written by the src loop).
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: for %i1 = 0 to 10 {
// CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0)
// CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1)
// CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2)
// CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32>
// CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32>
// CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0)
// CHECK-NEXT: for %i1 = 0 to 9 {
// CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %i1)
// CHECK-NEXT: store %cst, %0[%2#0, %2#1] : memref<10x10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i2 = 0 to 10 {
// CHECK-NEXT: %3 = load %0[%i0, %i2] : memref<10x10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
@ -849,6 +854,7 @@ func @should_fuse_src_depth1_at_dst_depth2() {
}
// -----
// CHECK: #map0 = ()[s0] -> (s0)
// CHECK-LABEL: func @fusion_at_depth0_not_currently_supported
func @fusion_at_depth0_not_currently_supported() {
@ -862,10 +868,9 @@ func @fusion_at_depth0_not_currently_supported() {
%1 = load %0[%c0] : memref<10xf32>
}
// CHECK:for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i1 = 0 to 10 {
// CHECK-NEXT: %1 = load %0[%c0] : memref<10xf32>
// CHECK-NEXT: %1 = affine_apply #map0()[%c0]
// CHECK-NEXT: store %cst, %0[%1] : memref<10xf32>
// CHECK-NEXT: %2 = load %0[%c0] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
return
@ -977,3 +982,128 @@ func @should_fuse_deep_loop_nests() {
// CHECK-NEXT: return
return
}
// -----
// CHECK: #map0 = (d0) -> (d0)
// CHECK-LABEL: func @should_fuse_at_depth1_and_reduce_slice_trip_count
func @should_fuse_at_depth1_and_reduce_slice_trip_count() {
%a = alloc() : memref<4x256xf32>
%b = alloc() : memref<4x256xf32>
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
for %i0 = 0 to 4 {
for %i1 = 0 to 256 {
%v0 = load %b[%i0, %i1] : memref<4x256xf32>
}
for %i2 = 0 to 256 {
store %cf0, %a[%i0, %i2] : memref<4x256xf32>
}
}
for %d0 = 0 to 4 {
for %d1 = 0 to 16 {
%v1 = load %a[%d0, %d1] : memref<4x256xf32>
}
}
// The cost of fusing at depth 2 is greater than the cost of fusing at depth 1
// for two reasons:
// 1) Inserting the unsliceable src loop %i1 to a higher depth removes
// redundant computation and reduces costs.
// 2) Inserting the sliceable src loop %i2 at depth 1, we can still reduce
// its trip count to 16 (from 256) reducing costs.
// CHECK: for %i0 = 0 to 4 {
// CHECK-NEXT: %2 = affine_apply #map0(%i0)
// CHECK-NEXT: for %i1 = 0 to 256 {
// CHECK-NEXT: %3 = load %1[%2, %i1] : memref<4x256xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i2 = 0 to 16 {
// CHECK-NEXT: store %cst, %0[%2, %i2] : memref<4x256xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i3 = 0 to 16 {
// CHECK-NEXT: %4 = load %0[%i0, %i3] : memref<4x256xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
return
}
// -----
// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_20
func @should_fuse_at_depth1_with_trip_count_20() {
%a = alloc() : memref<100xf32>
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
for %i0 = 0 to 100 {
store %cf0, %a[%i0]: memref<100xf32>
}
for %i1 = 0 to 5 {
for %i2 = 0 to 10 {
%v0 = load %a[%i2]: memref<100xf32>
}
for %i3 = 0 to 10 {
for %i4 = 0 to 20 {
%v1 = load %a[%i4]: memref<100xf32>
}
}
}
// CHECK: for %i0 = 0 to 5 {
// CHECK-NEXT: for %i1 = 0 to 20 {
// CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i2 = 0 to 10 {
// CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i3 = 0 to 10 {
// CHECK-NEXT: for %i4 = 0 to 20 {
// CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
return
}
// -----
// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_19
func @should_fuse_at_depth1_with_trip_count_19() {
%a = alloc() : memref<100xf32>
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
for %i0 = 0 to 100 {
store %cf0, %a[%i0]: memref<100xf32>
}
for %i1 = 0 to 5 {
for %i2 = 0 to 19 {
%v0 = load %a[%i2]: memref<100xf32>
}
for %i3 = 0 to 10 {
for %i4 = 0 to 10 {
%v1 = load %a[%i4]: memref<100xf32>
}
}
}
// CHECK: for %i0 = 0 to 5 {
// CHECK-NEXT: for %i1 = 0 to 19 {
// CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i2 = 0 to 19 {
// CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i3 = 0 to 10 {
// CHECK-NEXT: for %i4 = 0 to 10 {
// CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
return
}