Loop fusion for input reuse.

*) Breaks fusion pass into multiple sub passes over nodes in data dependence graph:
- first pass fuses single-use producers into their unique consumer.
- second pass enables fusing for input-reuse by fusing sibling nodes which read from the same memref, but which do not share dependence edges.
- third pass fuses remaining producers into their consumers (Note that the sibling fusion pass may have transformed a producer with multiple uses into a single-use producer).
*) Fusion for input reuse is enabled by computing a sibling node slice using the load/load accesses to the same memref, and fusion safety is guaranteed by checking that the sibling node memref write region (to a different memref) is preserved.
*) Enables output vector and output matrix computations from KFAC patches-second-moment operation to fuse into a single loop nest and reuse input from the image patches operation.
*) Adds a generic loop utilitiy for finding all sequential loops in a loop nest.
*) Adds and updates unit tests.

PiperOrigin-RevId: 236350987
This commit is contained in:
MLIR Team 2019-03-01 11:50:25 -08:00 committed by jpienaar
parent 269c872ee8
commit d038e34735
7 changed files with 651 additions and 59 deletions

View File

@ -94,13 +94,16 @@ struct DependenceComponent {
/// the operation instruction, indices and memref associated with the access.
/// Returns 'false' if it can be determined conclusively that the accesses do
/// not access the same memref element. Returns 'true' otherwise.
/// If 'allowRAR' is true, will consider read-after-read dependences (typically
/// used by applications trying to optimize input reuse).
// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into
// a single struct.
// TODO(andydavis) Make 'dependenceConstraints' optional arg.
bool checkMemrefAccessDependence(
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents);
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
bool allowRAR = false);
} // end namespace mlir
#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H

View File

@ -62,6 +62,11 @@ void getLoopIVs(const Instruction &inst,
/// surrounding this instruction.
unsigned getNestingDepth(const Instruction &stmt);
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
/// at 'forOp'.
void getSequentialLoops(OpPointer<AffineForOp> forOp,
llvm::SmallDenseSet<Value *, 8> *sequentialLoops);
/// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their
/// associated operands for a set of loops within a loop nest (typically the
/// set of loops surrounding a store operation). Loop bound AffineMaps which

View File

@ -768,7 +768,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
bool mlir::checkMemrefAccessDependence(
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) {
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
bool allowRAR) {
LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
<< Twine(loopDepth) << " between:\n";);
LLVM_DEBUG(srcAccess.opInst->dump(););
@ -778,7 +779,8 @@ bool mlir::checkMemrefAccessDependence(
if (srcAccess.memref != dstAccess.memref)
return false;
// Return 'false' if one of these accesses is not a StoreOp.
if (!srcAccess.opInst->isa<StoreOp>() && !dstAccess.opInst->isa<StoreOp>())
if (!allowRAR && !srcAccess.opInst->isa<StoreOp>() &&
!dstAccess.opInst->isa<StoreOp>())
return false;
// Get composed access function for 'srcAccess'.
@ -802,9 +804,11 @@ bool mlir::checkMemrefAccessDependence(
// Return 'false' if loopDepth > numCommonLoops and if the ancestor operation
// instruction of 'srcAccess' does not properly dominate the ancestor
// operation instruction of 'dstAccess' in the same common instruction block.
// Note: this check is skipped if 'allowRAR' is true, because because RAR
// deps can exist irrespective of lexicographic ordering b/w src and dst.
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
assert(loopDepth <= numCommonLoops + 1);
if (loopDepth > numCommonLoops &&
if (!allowRAR && loopDepth > numCommonLoops &&
!srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
numCommonLoops)) {
return false;

View File

@ -1669,19 +1669,20 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef<Value *> values,
};
for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
assert(lbMaps[i].getNumInputs() == operands.size());
assert(ubMaps[i].getNumInputs() == operands.size());
unsigned pos;
if (!findId(*values[i], &pos))
continue;
if (AffineMap lbMap = lbMaps[i])
if (AffineMap lbMap = lbMaps[i]) {
assert(lbMaps[i].getNumInputs() == operands.size());
if (!addLowerOrUpperBound(pos, lbMap, /*lower=*/true))
return false;
if (AffineMap ubMap = ubMaps[i])
}
if (AffineMap ubMap = ubMaps[i]) {
assert(ubMaps[i].getNumInputs() == operands.size());
if (!addLowerOrUpperBound(pos, ubMap, /*lower=*/false))
return false;
}
}
return true;
}

View File

@ -173,7 +173,6 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
}
}
}
// We'll first associate the dims and symbols of the access map to the dims
// and symbols resp. of cst. This will change below once cst is
// fully constructed out.
@ -236,7 +235,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
}
// Set all identifiers appearing after the first 'rank' identifiers as
// symbolic identifiers - so that the ones correspoding to the memref
// symbolic identifiers - so that the ones corresponding to the memref
// dimensions are the dimensional identifiers for the memref region.
cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
@ -442,10 +441,12 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
unsigned dstLoopDepth,
ComputationSliceState *sliceState) {
bool readReadAccesses =
srcAccess.opInst->isa<LoadOp>() && dstAccess.opInst->isa<LoadOp>();
FlatAffineConstraints dependenceConstraints;
if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1,
&dependenceConstraints,
/*dependenceComponents=*/nullptr)) {
if (!checkMemrefAccessDependence(
srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints,
/*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses)) {
return false;
}
// Get loop nest surrounding src operation.
@ -487,6 +488,25 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
// canonicalization.
sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
// For read-read access pairs, clear any slice bounds on sequential loops.
if (readReadAccesses) {
// Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
llvm::SmallDenseSet<Value *, 8> sequentialLoops;
getSequentialLoops(srcLoopIVs[0], &sequentialLoops);
// Clear all sliced loop bounds beginning at the first sequential loop.
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
Value *iv = srcLoopIVs[i]->getInductionVar();
if (sequentialLoops.count(iv) == 0)
continue;
for (unsigned j = i; j < numSrcLoopIVs; ++j) {
sliceState->lbs[j] = AffineMap();
sliceState->ubs[j] = AffineMap();
}
break;
}
}
return true;
}
@ -675,3 +695,43 @@ mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp,
*forInst->getBlock(), Block::const_iterator(forInst),
std::next(Block::const_iterator(forInst)), memorySpace);
}
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
/// at 'forOp'.
void mlir::getSequentialLoops(
OpPointer<AffineForOp> forOp,
llvm::SmallDenseSet<Value *, 8> *sequentialLoops) {
// Collect all load and store ops in loop nest rooted at 'forOp'.
SmallVector<Instruction *, 4> loadAndStoreOpInsts;
forOp->getInstruction()->walk([&](Instruction *opInst) {
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>())
loadAndStoreOpInsts.push_back(opInst);
});
// Check dependences on all pairs of ops in 'loadAndStoreOpInsts' and record
// loops which carry dependences in 'sequentialLoops'.
for (unsigned i = 0, e = loadAndStoreOpInsts.size(); i < e; ++i) {
auto *srcOpInst = loadAndStoreOpInsts[i];
MemRefAccess srcAccess(srcOpInst);
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
for (auto *dstOpInst : loadAndStoreOpInsts) {
MemRefAccess dstAccess(dstOpInst);
unsigned numCommonLoops =
getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
for (unsigned d = 1; d <= numCommonLoops; ++d) {
auto *iv = srcLoopIVs[d - 1]->getInductionVar();
if (sequentialLoops->count(iv) > 0)
continue;
FlatAffineConstraints dependenceConstraints;
if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
&dependenceConstraints,
/*dependenceComponents=*/nullptr)) {
// Record loop with carried dependence between srcAccess/dstAccess.
sequentialLoops->insert(iv);
}
}
}
}
}

View File

@ -141,6 +141,7 @@ static bool isMemRefDereferencingOp(const Instruction &op) {
return true;
return false;
}
// MemRefDependenceGraph is a graph data structure where graph nodes are
// top-level instructions in a Function which contain load/store ops, and edges
// are memref dependences between the nodes.
@ -182,7 +183,7 @@ public:
return storeOpCount;
}
// Returns all store ups in 'storeOps' which access 'memref'.
// Returns all store ops in 'storeOps' which access 'memref'.
void getStoreOpsForMemref(Value *memref,
SmallVectorImpl<Instruction *> *storeOps) {
for (auto *storeOpInst : stores) {
@ -190,6 +191,29 @@ public:
storeOps->push_back(storeOpInst);
}
}
// Returns all load ops in 'loadOps' which access 'memref'.
void getLoadOpsForMemref(Value *memref,
SmallVectorImpl<Instruction *> *loadOps) {
for (auto *loadOpInst : loads) {
if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
loadOps->push_back(loadOpInst);
}
}
// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
// has at least one load and store operation.
void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
llvm::SmallDenseSet<Value *, 2> loadMemrefs;
for (auto *loadOpInst : loads) {
loadMemrefs.insert(loadOpInst->cast<LoadOp>()->getMemRef());
}
for (auto *storeOpInst : stores) {
auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
if (loadMemrefs.count(memref) > 0)
loadAndStoreMemrefSet->insert(memref);
}
}
};
// Edge represents a data dependece between nodes in the graph.
@ -300,17 +324,18 @@ public:
return true;
}
// Returns true iff there is an edge from node 'srcId' to node 'dstId' for
// 'value'. Returns false otherwise.
bool hasEdge(unsigned srcId, unsigned dstId, Value *value) {
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) {
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
return false;
}
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
return edge.id == dstId && edge.value == value;
return edge.id == dstId && (!value || edge.value == value);
});
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
return edge.id == srcId && edge.value == value;
return edge.id == srcId && (!value || edge.value == value);
});
return hasOutEdge && hasInEdge;
}
@ -349,8 +374,37 @@ public:
}
}
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise.
bool hasDependencePath(unsigned srcId, unsigned dstId) {
// Worklist state is: <node-id, next-output-edge-index-to-visit>
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
worklist.push_back({srcId, 0});
// Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
while (!worklist.empty()) {
auto &idAndIndex = worklist.back();
// Return true if we have reached 'dstId'.
if (idAndIndex.first == dstId)
return true;
// Pop and continue if node has no out edges, or if all out edges have
// already been visited.
if (outEdges.count(idAndIndex.first) == 0 ||
idAndIndex.second == outEdges[idAndIndex.first].size()) {
worklist.pop_back();
continue;
}
// Get graph edge to traverse.
Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
// Increment next output edge index for 'idAndIndex'.
++idAndIndex.second;
// Add node at 'edge.id' to worklist.
worklist.push_back({edge.id, 0});
}
return false;
}
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref'.
// which access 'memref' with a store operation.
unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
unsigned inEdgeCount = 0;
if (inEdges.count(id) > 0)
@ -358,19 +412,19 @@ public:
if (inEdge.value == memref) {
Node *srcNode = getNode(inEdge.id);
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
if (srcNode->getLoadOpCount(memref) > 0 ||
srcNode->getStoreOpCount(memref) > 0)
if (srcNode->getStoreOpCount(memref) > 0)
++inEdgeCount;
}
return inEdgeCount;
}
// Returns the output edge count for node 'id' and 'memref'.
unsigned getOutEdgeCount(unsigned id, Value *memref) {
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) {
unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0)
for (auto &outEdge : outEdges[id])
if (outEdge.value == memref)
if (!memref || outEdge.value == memref)
++outEdgeCount;
return outEdgeCount;
}
@ -469,6 +523,32 @@ public:
}
}
// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
// of sibling node 'sidId' into node 'dstId'.
void updateEdges(unsigned sibId, unsigned dstId) {
// For each edge in 'inEdges[sibId]':
// *) Add new edge from source node 'inEdge.id' to 'dstNode'.
// *) Remove edge from source node 'inEdge.id' to 'sibNode'.
if (inEdges.count(sibId) > 0) {
SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
for (auto &inEdge : oldInEdges) {
addEdge(inEdge.id, dstId, inEdge.value);
removeEdge(inEdge.id, sibId, inEdge.value);
}
}
// For each edge in 'outEdges[sibId]' to node 'id'
// *) Add new edge from 'dstId' to 'outEdge.id'.
// *) Remove edge from 'sibId' to 'outEdge.id'.
if (outEdges.count(sibId) > 0) {
SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
for (auto &outEdge : oldOutEdges) {
addEdge(dstId, outEdge.id, outEdge.value);
removeEdge(sibId, outEdge.id, outEdge.value);
}
}
}
// Adds ops in 'loads' and 'stores' to node at 'id'.
void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
const SmallVectorImpl<Instruction *> &stores) {
@ -485,6 +565,37 @@ public:
node->stores.clear();
}
// Calls 'callback' for each input edge incident to node 'id' which carries a
// memref dependence.
void forEachMemRefInputEdge(unsigned id,
const std::function<void(Edge)> &callback) {
if (inEdges.count(id) > 0)
forEachMemRefEdge(inEdges[id], callback);
}
// Calls 'callback' for each output edge from node 'id' which carries a
// memref dependence.
void forEachMemRefOutputEdge(unsigned id,
const std::function<void(Edge)> &callback) {
if (outEdges.count(id) > 0)
forEachMemRefEdge(outEdges[id], callback);
}
// Calls 'callback' for each edge in 'edges' which carries a memref
// dependence.
void forEachMemRefEdge(ArrayRef<Edge> edges,
const std::function<void(Edge)> &callback) {
for (auto &edge : edges) {
// Skip if 'edge' is not a memref dependence edge.
if (!edge.value->getType().isa<MemRefType>())
continue;
assert(nodes.count(edge.id) > 0);
// Skip if 'edge.id' is not a loop nest.
if (!getNode(edge.id)->inst->isa<AffineForOp>())
continue;
// Visit current input edge 'edge'.
callback(edge);
}
}
void print(raw_ostream &os) const {
os << "\nMemRefDependenceGraph\n";
os << "\nNodes:\n";
@ -1228,6 +1339,14 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
// the memref being produced and consumed, which is an input to the cost model.
// For producer-constumer fusion, 'srcStoreOpInst' will be the same as
// 'srcOpInst', as we are slicing w.r.t to that producer.
// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
// will be the unique store op in the src node, which will be used to check
// that the write region is the same after input-reuse fusion.
// Returns true if it is profitable to fuse the candidate loop nests. Returns
// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
// to materialize the source loop nest slice.
@ -1257,6 +1376,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
// loop nest computed in the previous step, and returns true if the latter
// is lower.
static bool isFusionProfitable(Instruction *srcOpInst,
Instruction *srcStoreOpInst,
ArrayRef<Instruction *> dstLoadOpInsts,
ArrayRef<Instruction *> dstStoreOpInsts,
ComputationSliceState *sliceState,
@ -1294,8 +1414,11 @@ static bool isFusionProfitable(Instruction *srcOpInst,
return false;
// Compute the maximum loop depth at which we can can insert the src slice
// and still satisfy dest loop nest dependences.
unsigned maxDstLoopDepth = getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts);
// and still satisfy dest loop nest dependences, for producer-consumer fusion.
unsigned maxDstLoopDepth =
(srcOpInst == srcStoreOpInst)
? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
: dstLoopIVs.size();
if (maxDstLoopDepth == 0)
return false;
@ -1306,7 +1429,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
// the cost of the slice and the cost of the slice inserted into the dst
// loop nest at 'dstLoopDepth'.
uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
uint64_t maxStorageReduction = 0;
double maxStorageReduction = 0.0;
Optional<uint64_t> sliceMemEstimate = None;
SmallVector<ComputationSliceState, 4> sliceStates;
@ -1321,8 +1444,8 @@ static bool isFusionProfitable(Instruction *srcOpInst,
/*computeCostMap=*/nullptr);
// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcOpInst->getLoc());
srcWriteRegion.compute(srcOpInst, /*loopDepth=*/0);
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0);
Optional<int64_t> maybeSrcWriteRegionSizeBytes =
srcWriteRegion.getRegionSize();
if (!maybeSrcWriteRegionSizeBytes.hasValue())
@ -1345,6 +1468,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
if (!mlir::getBackwardComputationSliceState(
srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1]))
return false;
// Compute the union of slice bound of all ops in 'dstLoadOpInsts'.
for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) {
MemRefAccess dstAccess(dstLoadOpInsts[j]);
@ -1372,6 +1496,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
computeCostMap.clear();
// The store and loads to this memref will disappear.
// TODO(andydavis) Add load coalescing to memref data flow opt pass.
if (storeLoadFwdGuaranteed) {
// A single store disappears: -1 for that.
computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
@ -1403,8 +1528,9 @@ static bool isFusionProfitable(Instruction *srcOpInst,
// Compute what the slice write MemRefRegion would be, if the src loop
// nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
// nest at loop depth 'i'
MemRefRegion sliceWriteRegion(srcOpInst->getLoc());
sliceWriteRegion.compute(srcOpInst, /*loopDepth=*/0, &sliceStates[i - 1]);
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
&sliceStates[i - 1]);
Optional<int64_t> maybeSliceWriteRegionSizeBytes =
sliceWriteRegion.getRegionSize();
if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
@ -1413,6 +1539,14 @@ static bool isFusionProfitable(Instruction *srcOpInst,
int64_t sliceWriteRegionSizeBytes =
maybeSliceWriteRegionSizeBytes.getValue();
// If we are fusing for reuse, check that write regions remain the same.
// TODO(andydavis) Write region check should check sizes and offsets in
// each dimension, so that we are sure they are covering the same memref
// region. Also, move this out to a isMemRefRegionSuperSet helper function.
if (srcOpInst != srcStoreOpInst &&
sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
continue;
double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
static_cast<double>(sliceWriteRegionSizeBytes);
@ -1547,12 +1681,10 @@ static bool isFusionProfitable(Instruction *srcOpInst,
return true;
}
// GreedyFusion greedily fuses loop nests which have a producer/consumer
// relationship on a memref, with the goal of improving locality. Currently,
// this the producer/consumer relationship is required to be unique in the
// Function (there are TODOs to relax this constraint in the future).
// GreedyFusion greedily fuses loop nests which have a producer/consumer or
// input-reuse relationship on a memref, with the goal of improving locality.
//
// The steps of the algorithm are as follows:
// The steps of the producer-consumer fusion algorithm are as follows:
//
// *) A worklist is initialized with node ids from the dependence graph.
// *) For each node id in the worklist:
@ -1560,20 +1692,32 @@ static bool isFusionProfitable(Instruction *srcOpInst,
// candidate destination AffineForOp into which fusion will be attempted.
// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
// *) For each LoadOp in 'dstLoadOps' do:
// *) Lookup dependent loop nests at earlier positions in the Function
// which have a single store op to the same memref.
// *) Check if dependences would be violated by the fusion. For example,
// the src loop nest may load from memrefs which are different than
// the producer-consumer memref between src and dest loop nests.
// *) Lookup dependent loop nests which have a single store op to the same
// memref.
// *) Check if dependences would be violated by the fusion.
// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
// bounds to be functions of 'dstLoopNest' IVs and symbols.
// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
// just before the dst load op user.
// at a loop depth determined by the cost model in 'isFusionProfitable'.
// *) Add the newly fused load/store operation instructions to the state,
// and also add newly fuse load ops to 'dstLoopOps' to be considered
// as fusion dst load ops in another iteration.
// *) Remove old src loop nest and its associated state.
//
// The steps of the input-reuse fusion algorithm are as follows:
//
// *) Initialize 'worklist' with node ids from the dependence graph.
// *) For each 'dstNode' in the worklist:
// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
// loads from the same memref, but which has no dependence paths to/from.
// *) Get a computation slice of 'sibLoopNest', which adjusts its loop
// bounds to be functions of 'dstLoopNest' IVs and symbols.
// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
// at a loop depth determined by the cost model in 'isFusionProfitable'.
// This function also checks that the memref write region of 'sibLoopNest',
// is preserved in the fused loop nest.
// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
//
// Given a graph where top-level instructions are vertices in the set 'V' and
// edges in the set 'E' are dependences between vertices, this algorithm
// takes O(V) time for initialization, and has runtime O(V + E).
@ -1582,25 +1726,54 @@ static bool isFusionProfitable(Instruction *srcOpInst,
// fusing along single producer consumer edges, but there is a TODO to fix this.
//
// TODO(andydavis) Experiment with other fusion policies.
// TODO(andydavis) Add support for fusing for input reuse (perhaps by
// constructing a graph with edges which represent loads from the same memref
// in two different loop nests.
struct GreedyFusion {
public:
// The data dependence graph to traverse during fusion.
MemRefDependenceGraph *mdg;
// Worklist of graph nodes visited during the fusion pass.
SmallVector<unsigned, 8> worklist;
// Set of graph nodes which are present on the worklist.
llvm::SmallDenseSet<unsigned, 16> worklistSet;
// Parameter for local buffer size threshold.
unsigned localBufSizeThreshold;
// Parameter for fast memory space.
Optional<unsigned> fastMemorySpace;
GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
// Initialize worklist with nodes from 'mdg'.
using Node = MemRefDependenceGraph::Node;
GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
Optional<unsigned> fastMemorySpace)
: mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
fastMemorySpace(fastMemorySpace) {}
// Initializes 'worklist' with nodes from 'mdg'
void init() {
// TODO(andydavis) Add a priority queue for prioritizing nodes by different
// metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
worklist.resize(mdg->nodes.size());
std::iota(worklist.begin(), worklist.end(), 0);
worklistSet.insert(worklist.begin(), worklist.end());
worklist.clear();
worklistSet.clear();
for (auto &idAndNode : mdg->nodes) {
const Node &node = idAndNode.second;
worklist.push_back(node.id);
worklistSet.insert(node.id);
}
}
void run(unsigned localBufSizeThreshold, Optional<unsigned> fastMemorySpace) {
// Run the GreedyFusion pass.
// *) First pass through the nodes fuses single-use producer nodes into their
// unique consumer.
// *) Second pass fuses sibling nodes which share no dependence edges.
// *) Third pass fuses any remaining producer nodes into their users.
void run() {
fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
fuseSiblingNodes();
fuseProducerConsumerNodes(
/*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
eraseUnusedMemRefAllocations();
}
void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
worklist.pop_back();
@ -1672,6 +1845,10 @@ public:
!canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
continue;
// Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
continue;
// Compute an instruction list insertion point for the fused loop
// nest which preserves dependences.
Instruction *insertPointInst =
@ -1690,8 +1867,8 @@ public:
unsigned bestDstLoopDepth;
mlir::ComputationSliceState sliceState;
// Check if fusion would be profitable.
if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts,
dstStoreOpInsts, &sliceState,
if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
dstLoadOpInsts, dstStoreOpInsts, &sliceState,
&bestDstLoopDepth))
continue;
@ -1782,7 +1959,202 @@ public:
}
}
}
// Clean up any allocs with no users.
}
// Visits each node in the graph, and for each node, attempts to fuse it with
// its sibling nodes (nodes which share a parent, but no dependence edges).
void fuseSiblingNodes() {
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
worklist.pop_back();
worklistSet.erase(dstId);
// Skip if this node was removed (fused into another node).
if (mdg->nodes.count(dstId) == 0)
continue;
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
if (!dstNode->inst->isa<AffineForOp>())
continue;
// Attempt to fuse 'dstNode' with its sibling nodes in the graph.
fuseWithSiblingNodes(dstNode);
}
}
// Attempt to fuse 'dstNode' with sibling nodes in the graph.
void fuseWithSiblingNodes(Node *dstNode) {
DenseSet<unsigned> visitedSibNodeIds;
std::pair<unsigned, Value *> idAndMemref;
while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
unsigned sibId = idAndMemref.first;
Value *memref = idAndMemref.second;
// TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
// stores to the same memref in 'sibNode' loop nest.
auto *sibNode = mdg->getNode(sibId);
// Compute an instruction list insertion point for the fused loop
// nest which preserves dependences.
assert(sibNode->inst->getBlock() == dstNode->inst->getBlock());
Instruction *insertPointInst =
sibNode->inst->isBeforeInBlock(dstNode->inst)
? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
: mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
if (insertPointInst == nullptr)
continue;
// Check if fusion would be profitable and at what depth.
// Get unique 'sibNode' load op to 'memref'.
SmallVector<Instruction *, 2> sibLoadOpInsts;
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
// Currently findSiblingNodeToFuse searches for siblings with one load.
assert(sibLoadOpInsts.size() == 1);
Instruction *sibLoadOpInst = sibLoadOpInsts[0];
assert(!sibNode->stores.empty());
// TODO(andydavis) Choose the store which postdominates all other stores.
auto *sibStoreOpInst = sibNode->stores.back();
// Gather 'dstNode' load ops to 'memref'.
SmallVector<Instruction *, 2> dstLoadOpInsts;
dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
// Gather 'dstNode' store ops to 'memref'.
SmallVector<Instruction *, 2> dstStoreOpInsts;
dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
unsigned bestDstLoopDepth;
mlir::ComputationSliceState sliceState;
// Check if fusion would be profitable.
if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
dstStoreOpInsts, &sliceState, &bestDstLoopDepth))
continue;
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
auto dstForInst = dstNode->inst->cast<AffineForOp>();
// Update instruction position of fused loop nest (if needed).
if (insertPointInst != dstForInst->getInstruction()) {
dstForInst->getInstruction()->moveBefore(insertPointInst);
}
// Update data dependence graph state post fusion.
updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
}
}
}
// Searches the graph from 'dstNode' looking for a fusion candidate sibling
// node which shares no dependences with 'dstNode' but which loads from the
// same memref. Returns true and sets 'idAndMemrefToFuse' on success. Returns
// false otherwise.
bool findSiblingNodeToFuse(Node *dstNode,
DenseSet<unsigned> *visitedSibNodeIds,
std::pair<unsigned, Value *> *idAndMemrefToFuse) {
// TODO(andydavis) Currently we discover siblings by following edges
// through an intermediate src node. We should also consider siblings
// which load from the same memref, but which do not necessarily share
// a src node parent (e.g. loading from a memref which is a function arg).
// Collect candidate 'dstNode' input edges in 'inEdges'.
SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
mdg->forEachMemRefInputEdge(
dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
// Add 'inEdge' if it is a read-after-write dependence.
if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
inEdges.push_back(inEdge);
});
// Search for sibling nodes to fuse by visiting output edges from each input
// edge in 'inEdges'.
for (auto &inEdge : inEdges) {
// Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
mdg->forEachMemRefOutputEdge(
inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
unsigned sibNodeId = outEdge.id;
if (visitedSibNodeIds->count(sibNodeId) > 0)
return;
// Skip output edge if not a sibling using the same memref.
if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
return;
auto *sibNode = mdg->getNode(sibNodeId);
if (!sibNode->inst->isa<AffineForOp>())
return;
// Skip if 'outEdge' is not a read-after-write dependence.
// TODO(andydavis) Remove restrict to single load op restriction.
if (sibNode->getLoadOpCount(inEdge.value) != 1)
return;
// Skip if there exists a path of dependent edges between
// 'sibNode' and 'dstNode'.
if (mdg->hasDependencePath(sibNodeId, dstNode->id) ||
mdg->hasDependencePath(dstNode->id, sibNodeId))
return;
// Skip sib node if it loads to (and stores from) the same memref on
// which it also has an input dependence edge.
DenseSet<Value *> loadAndStoreMemrefSet;
sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
return mdg->getIncomingMemRefAccesses(sibNode->id, memref) >
0;
}))
return;
// Check that all stores are to the same memref.
DenseSet<Value *> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
storeMemrefs.insert(storeOpInst->cast<StoreOp>()->getMemRef());
}
if (storeMemrefs.size() != 1)
return;
// Add candidate 'outEdge' to sibling node.
outEdges.push_back(outEdge);
});
// Add first candidate if any were returned.
if (!outEdges.empty()) {
visitedSibNodeIds->insert(outEdges[0].id);
idAndMemrefToFuse->first = outEdges[0].id;
idAndMemrefToFuse->second = outEdges[0].value;
return true;
}
}
return false;
}
void updateStateAfterSiblingFusion(OpPointer<AffineForOp> sliceLoopNest,
Node *sibNode, Node *dstNode) {
// Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
mdg->updateEdges(sibNode->id, dstNode->id);
// Collect slice loop stats.
LoopNestStateCollector sliceCollector;
sliceCollector.collect(sliceLoopNest->getInstruction());
// Promote single iteration slice loops to single IV value.
for (auto forOp : sliceCollector.forOps) {
promoteIfSingleIteration(forOp);
}
// Collect dst loop stats after memref privatizaton transformation.
auto dstForInst = dstNode->inst->cast<AffineForOp>();
LoopNestStateCollector dstLoopCollector;
dstLoopCollector.collect(dstForInst->getInstruction());
// Clear and add back loads and stores
mdg->clearNodeLoadAndStores(dstNode->id);
mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
dstLoopCollector.storeOpInsts);
// Remove old sibling loop nest if it no longer has outgoing dependence
// edges, and it does not write to a memref which escapes the
// function.
if (mdg->getOutEdgeCount(sibNode->id) == 0) {
mdg->removeNode(sibNode->id);
sibNode->inst->cast<AffineForOp>()->erase();
}
}
// Clean up any allocs with no users.
void eraseUnusedMemRefAllocations() {
for (auto &pair : mdg->memrefEdgeCount) {
if (pair.second > 0)
continue;
@ -1813,7 +2185,7 @@ void LoopFusion::runOnFunction() {
MemRefDependenceGraph g;
if (g.init(&getFunction()))
GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace);
GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace).run();
}
static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");

View File

@ -1228,13 +1228,13 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() {
// by loops %i1 and %i2.
// CHECK-DAG: %0 = alloc() : memref<1xf32>
// CHECK-DAG: %1 = alloc() : memref<1xf32>
// CHECK: for %i0 = 0 to 82 {
// CHECK: for %i0 = 0 to 17 {
// CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0)
// CHECK-NEXT: store %cst, %1[%2] : memref<1xf32>
// CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0)
// CHECK-NEXT: %4 = load %1[%3] : memref<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i1 = 0 to 17 {
// CHECK-NEXT: for %i1 = 0 to 82 {
// CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1)
// CHECK-NEXT: store %cst, %0[%5] : memref<1xf32>
// CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1)
@ -1915,3 +1915,150 @@ func @test_add_slice_bounds() {
// CHECK-NEXT: }
return
}
// -----
// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2)
// CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3)
func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>) {
%0 = alloc() : memref<10x10xf32>
%cst = constant 0.000000e+00 : f32
%cst_0 = constant 1.000000e+00 : f32
%cst_1 = constant 7.000000e+00 : f32
for %i0 = 0 to 10 {
for %i1 = 0 to 10 {
store %cst_1, %0[%i0, %i1] : memref<10x10xf32>
}
}
for %i2 = 0 to 3 {
for %i3 = 0 to 3 {
store %cst, %arg0[%i2, %i3] : memref<10x10xf32>
}
}
for %i4 = 0 to 3 {
for %i5 = 0 to 3 {
%1 = load %0[%i4, %i5] : memref<10x10xf32>
%2 = load %arg0[%i4, %i5] : memref<10x10xf32>
%3 = mulf %1, %2 : f32
store %3, %arg0[%i4, %i5] : memref<10x10xf32>
}
}
for %i6 = 0 to 3 {
for %i7 = 0 to 3 {
store %cst_0, %arg1[%i6, %i7] : memref<10x10xf32>
}
}
for %i8 = 0 to 3 {
for %i9 = 0 to 3 {
%4 = load %0[%i8, %i9] : memref<10x10xf32>
%5 = load %arg1[%i8, %i9] : memref<10x10xf32>
%6 = addf %4, %5 : f32
store %6, %arg1[%i8, %i9] : memref<10x10xf32>
}
}
// Pass 1: should fuse single-use producer loop nests into their unique user,
// so '%i2' will fuse into '%i4' and '%i6' will fuse into '%i8'.
// Pass 2: should fuse sibling loop nests which share no dependence edges,
// so should fuse '%i4' into '%i8'.
// Pass 3: should fuse single-use producer loop nest '%i0' into '%i8'. Note
// that loop nest '%i0' now has a single user after Pass 2 fused its
// two users together).
// CHECK: for %i0 = 0 to 3 {
// CHECK-NEXT: for %i1 = 0 to 3 {
// CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1)
// CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1)
// CHECK-NEXT: store %cst_1, %0[%1, %2] : memref<1x1xf32>
// CHECK-NEXT: store %cst, %arg0[%i0, %i1] : memref<10x10xf32>
// CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1)
// CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1)
// CHECK-NEXT: %5 = load %0[%3, %4] : memref<1x1xf32>
// CHECK-NEXT: %6 = load %arg0[%i0, %i1] : memref<10x10xf32>
// CHECK-NEXT: %7 = mulf %5, %6 : f32
// CHECK-NEXT: store %7, %arg0[%i0, %i1] : memref<10x10xf32>
// CHECK-NEXT: store %cst_0, %arg1[%i0, %i1] : memref<10x10xf32>
// CHECK-NEXT: %8 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1)
// CHECK-NEXT: %9 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1)
// CHECK-NEXT: %10 = load %0[%8, %9] : memref<1x1xf32>
// CHECK-NEXT: %11 = load %arg1[%i0, %i1] : memref<10x10xf32>
// CHECK-NEXT: %12 = addf %10, %11 : f32
// CHECK-NEXT: store %12, %arg1[%i0, %i1] : memref<10x10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
return
}
// -----
// CHECK-DAG: [[MAP2:#map[0-9]+]] = (d0, d1, d2) -> (d1)
// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d2)
func @two_matrix_vector_products() {
%in_matrix = alloc() : memref<10x10xf32>
%in_vec0 = alloc() : memref<10xf32>
%in_vec1 = alloc() : memref<10xf32>
%out_vec0 = alloc() : memref<10xf32>
%out_vec1 = alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
// Populate input matrix.
for %i0 = 0 to 10 {
for %i1 = 0 to 10 {
store %cf7, %in_matrix[%i0, %i1] : memref<10x10xf32>
}
}
// out_vec0 = in_matrix x in_vec0
for %i2 = 0 to 10 {
for %i3 = 0 to 10 {
%v0 = load %in_matrix[%i2, %i3] : memref<10x10xf32>
%v1 = load %in_vec0[%i3] : memref<10xf32>
%v2 = mulf %v0, %v1 : f32
%v3 = load %out_vec0[%i3] : memref<10xf32>
%v4 = addf %v2, %v3 : f32
store %v4, %out_vec0[%i3] : memref<10xf32>
}
}
// out_vec1 = in_matrix x in_vec1
for %i4 = 0 to 10 {
for %i5 = 0 to 10 {
%v5 = load %in_matrix[%i4, %i5] : memref<10x10xf32>
%v6 = load %in_vec1[%i5] : memref<10xf32>
%v7 = mulf %v5, %v6 : f32
%v8 = load %out_vec1[%i5] : memref<10xf32>
%v9 = addf %v7, %v8 : f32
store %v9, %out_vec1[%i5] : memref<10xf32>
}
}
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: for %i1 = 0 to 10 {
// CHECK-NEXT: %5 = affine.apply [[MAP2]](%i0, %i1, %i0)
// CHECK-NEXT: %6 = affine.apply [[MAP3]](%i0, %i1, %i0)
// CHECK-NEXT: store %cst, %0[%5, %6] : memref<10x1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i2 = 0 to 10 {
// CHECK-NEXT: %7 = affine.apply [[MAP2]](%i0, %i2, %i0)
// CHECK-NEXT: %8 = affine.apply [[MAP3]](%i0, %i2, %i0)
// CHECK-NEXT: %9 = load %0[%7, %8] : memref<10x1xf32>
// CHECK-NEXT: %10 = load %1[%i0] : memref<10xf32>
// CHECK-NEXT: %11 = mulf %9, %10 : f32
// CHECK-NEXT: %12 = load %3[%i0] : memref<10xf32>
// CHECK-NEXT: %13 = addf %11, %12 : f32
// CHECK-NEXT: store %13, %3[%i0] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i3 = 0 to 10 {
// CHECK-NEXT: %14 = affine.apply [[MAP2]](%i0, %i3, %i0)
// CHECK-NEXT: %15 = affine.apply [[MAP3]](%i0, %i3, %i0)
// CHECK-NEXT: %16 = load %0[%14, %15] : memref<10x1xf32>
// CHECK-NEXT: %17 = load %2[%i0] : memref<10xf32>
// CHECK-NEXT: %18 = mulf %16, %17 : f32
// CHECK-NEXT: %19 = load %4[%i0] : memref<10xf32>
// CHECK-NEXT: %20 = addf %18, %19 : f32
// CHECK-NEXT: store %20, %4[%i0] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
return
}