forked from OSchip/llvm-project
Loop fusion for input reuse.
*) Breaks fusion pass into multiple sub passes over nodes in data dependence graph: - first pass fuses single-use producers into their unique consumer. - second pass enables fusing for input-reuse by fusing sibling nodes which read from the same memref, but which do not share dependence edges. - third pass fuses remaining producers into their consumers (Note that the sibling fusion pass may have transformed a producer with multiple uses into a single-use producer). *) Fusion for input reuse is enabled by computing a sibling node slice using the load/load accesses to the same memref, and fusion safety is guaranteed by checking that the sibling node memref write region (to a different memref) is preserved. *) Enables output vector and output matrix computations from KFAC patches-second-moment operation to fuse into a single loop nest and reuse input from the image patches operation. *) Adds a generic loop utilitiy for finding all sequential loops in a loop nest. *) Adds and updates unit tests. PiperOrigin-RevId: 236350987
This commit is contained in:
parent
269c872ee8
commit
d038e34735
|
@ -94,13 +94,16 @@ struct DependenceComponent {
|
|||
/// the operation instruction, indices and memref associated with the access.
|
||||
/// Returns 'false' if it can be determined conclusively that the accesses do
|
||||
/// not access the same memref element. Returns 'true' otherwise.
|
||||
/// If 'allowRAR' is true, will consider read-after-read dependences (typically
|
||||
/// used by applications trying to optimize input reuse).
|
||||
// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into
|
||||
// a single struct.
|
||||
// TODO(andydavis) Make 'dependenceConstraints' optional arg.
|
||||
bool checkMemrefAccessDependence(
|
||||
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
|
||||
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
|
||||
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents);
|
||||
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
|
||||
bool allowRAR = false);
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
||||
|
|
|
@ -62,6 +62,11 @@ void getLoopIVs(const Instruction &inst,
|
|||
/// surrounding this instruction.
|
||||
unsigned getNestingDepth(const Instruction &stmt);
|
||||
|
||||
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
|
||||
/// at 'forOp'.
|
||||
void getSequentialLoops(OpPointer<AffineForOp> forOp,
|
||||
llvm::SmallDenseSet<Value *, 8> *sequentialLoops);
|
||||
|
||||
/// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their
|
||||
/// associated operands for a set of loops within a loop nest (typically the
|
||||
/// set of loops surrounding a store operation). Loop bound AffineMaps which
|
||||
|
|
|
@ -768,7 +768,8 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
|
|||
bool mlir::checkMemrefAccessDependence(
|
||||
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
|
||||
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
|
||||
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) {
|
||||
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
|
||||
bool allowRAR) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
|
||||
<< Twine(loopDepth) << " between:\n";);
|
||||
LLVM_DEBUG(srcAccess.opInst->dump(););
|
||||
|
@ -778,7 +779,8 @@ bool mlir::checkMemrefAccessDependence(
|
|||
if (srcAccess.memref != dstAccess.memref)
|
||||
return false;
|
||||
// Return 'false' if one of these accesses is not a StoreOp.
|
||||
if (!srcAccess.opInst->isa<StoreOp>() && !dstAccess.opInst->isa<StoreOp>())
|
||||
if (!allowRAR && !srcAccess.opInst->isa<StoreOp>() &&
|
||||
!dstAccess.opInst->isa<StoreOp>())
|
||||
return false;
|
||||
|
||||
// Get composed access function for 'srcAccess'.
|
||||
|
@ -802,9 +804,11 @@ bool mlir::checkMemrefAccessDependence(
|
|||
// Return 'false' if loopDepth > numCommonLoops and if the ancestor operation
|
||||
// instruction of 'srcAccess' does not properly dominate the ancestor
|
||||
// operation instruction of 'dstAccess' in the same common instruction block.
|
||||
// Note: this check is skipped if 'allowRAR' is true, because because RAR
|
||||
// deps can exist irrespective of lexicographic ordering b/w src and dst.
|
||||
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
|
||||
assert(loopDepth <= numCommonLoops + 1);
|
||||
if (loopDepth > numCommonLoops &&
|
||||
if (!allowRAR && loopDepth > numCommonLoops &&
|
||||
!srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
|
||||
numCommonLoops)) {
|
||||
return false;
|
||||
|
|
|
@ -1669,19 +1669,20 @@ bool FlatAffineConstraints::addSliceBounds(ArrayRef<Value *> values,
|
|||
};
|
||||
|
||||
for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
|
||||
assert(lbMaps[i].getNumInputs() == operands.size());
|
||||
assert(ubMaps[i].getNumInputs() == operands.size());
|
||||
unsigned pos;
|
||||
if (!findId(*values[i], &pos))
|
||||
continue;
|
||||
|
||||
if (AffineMap lbMap = lbMaps[i])
|
||||
if (AffineMap lbMap = lbMaps[i]) {
|
||||
assert(lbMaps[i].getNumInputs() == operands.size());
|
||||
if (!addLowerOrUpperBound(pos, lbMap, /*lower=*/true))
|
||||
return false;
|
||||
|
||||
if (AffineMap ubMap = ubMaps[i])
|
||||
}
|
||||
if (AffineMap ubMap = ubMaps[i]) {
|
||||
assert(ubMaps[i].getNumInputs() == operands.size());
|
||||
if (!addLowerOrUpperBound(pos, ubMap, /*lower=*/false))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -173,7 +173,6 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We'll first associate the dims and symbols of the access map to the dims
|
||||
// and symbols resp. of cst. This will change below once cst is
|
||||
// fully constructed out.
|
||||
|
@ -236,7 +235,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
|
|||
}
|
||||
|
||||
// Set all identifiers appearing after the first 'rank' identifiers as
|
||||
// symbolic identifiers - so that the ones correspoding to the memref
|
||||
// symbolic identifiers - so that the ones corresponding to the memref
|
||||
// dimensions are the dimensional identifiers for the memref region.
|
||||
cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
|
||||
|
||||
|
@ -442,10 +441,12 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
|||
const MemRefAccess &dstAccess,
|
||||
unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState) {
|
||||
bool readReadAccesses =
|
||||
srcAccess.opInst->isa<LoadOp>() && dstAccess.opInst->isa<LoadOp>();
|
||||
FlatAffineConstraints dependenceConstraints;
|
||||
if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1,
|
||||
&dependenceConstraints,
|
||||
/*dependenceComponents=*/nullptr)) {
|
||||
if (!checkMemrefAccessDependence(
|
||||
srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints,
|
||||
/*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses)) {
|
||||
return false;
|
||||
}
|
||||
// Get loop nest surrounding src operation.
|
||||
|
@ -487,6 +488,25 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
|||
// canonicalization.
|
||||
sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
|
||||
sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
|
||||
|
||||
// For read-read access pairs, clear any slice bounds on sequential loops.
|
||||
if (readReadAccesses) {
|
||||
// Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
|
||||
llvm::SmallDenseSet<Value *, 8> sequentialLoops;
|
||||
getSequentialLoops(srcLoopIVs[0], &sequentialLoops);
|
||||
|
||||
// Clear all sliced loop bounds beginning at the first sequential loop.
|
||||
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||||
Value *iv = srcLoopIVs[i]->getInductionVar();
|
||||
if (sequentialLoops.count(iv) == 0)
|
||||
continue;
|
||||
for (unsigned j = i; j < numSrcLoopIVs; ++j) {
|
||||
sliceState->lbs[j] = AffineMap();
|
||||
sliceState->ubs[j] = AffineMap();
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -675,3 +695,43 @@ mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp,
|
|||
*forInst->getBlock(), Block::const_iterator(forInst),
|
||||
std::next(Block::const_iterator(forInst)), memorySpace);
|
||||
}
|
||||
|
||||
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
|
||||
/// at 'forOp'.
|
||||
void mlir::getSequentialLoops(
|
||||
OpPointer<AffineForOp> forOp,
|
||||
llvm::SmallDenseSet<Value *, 8> *sequentialLoops) {
|
||||
// Collect all load and store ops in loop nest rooted at 'forOp'.
|
||||
SmallVector<Instruction *, 4> loadAndStoreOpInsts;
|
||||
forOp->getInstruction()->walk([&](Instruction *opInst) {
|
||||
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>())
|
||||
loadAndStoreOpInsts.push_back(opInst);
|
||||
});
|
||||
|
||||
// Check dependences on all pairs of ops in 'loadAndStoreOpInsts' and record
|
||||
// loops which carry dependences in 'sequentialLoops'.
|
||||
for (unsigned i = 0, e = loadAndStoreOpInsts.size(); i < e; ++i) {
|
||||
auto *srcOpInst = loadAndStoreOpInsts[i];
|
||||
MemRefAccess srcAccess(srcOpInst);
|
||||
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
|
||||
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||
for (auto *dstOpInst : loadAndStoreOpInsts) {
|
||||
MemRefAccess dstAccess(dstOpInst);
|
||||
|
||||
unsigned numCommonLoops =
|
||||
getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
|
||||
for (unsigned d = 1; d <= numCommonLoops; ++d) {
|
||||
auto *iv = srcLoopIVs[d - 1]->getInductionVar();
|
||||
if (sequentialLoops->count(iv) > 0)
|
||||
continue;
|
||||
FlatAffineConstraints dependenceConstraints;
|
||||
if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
|
||||
&dependenceConstraints,
|
||||
/*dependenceComponents=*/nullptr)) {
|
||||
// Record loop with carried dependence between srcAccess/dstAccess.
|
||||
sequentialLoops->insert(iv);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -141,6 +141,7 @@ static bool isMemRefDereferencingOp(const Instruction &op) {
|
|||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// MemRefDependenceGraph is a graph data structure where graph nodes are
|
||||
// top-level instructions in a Function which contain load/store ops, and edges
|
||||
// are memref dependences between the nodes.
|
||||
|
@ -182,7 +183,7 @@ public:
|
|||
return storeOpCount;
|
||||
}
|
||||
|
||||
// Returns all store ups in 'storeOps' which access 'memref'.
|
||||
// Returns all store ops in 'storeOps' which access 'memref'.
|
||||
void getStoreOpsForMemref(Value *memref,
|
||||
SmallVectorImpl<Instruction *> *storeOps) {
|
||||
for (auto *storeOpInst : stores) {
|
||||
|
@ -190,6 +191,29 @@ public:
|
|||
storeOps->push_back(storeOpInst);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns all load ops in 'loadOps' which access 'memref'.
|
||||
void getLoadOpsForMemref(Value *memref,
|
||||
SmallVectorImpl<Instruction *> *loadOps) {
|
||||
for (auto *loadOpInst : loads) {
|
||||
if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
|
||||
loadOps->push_back(loadOpInst);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
|
||||
// has at least one load and store operation.
|
||||
void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
|
||||
llvm::SmallDenseSet<Value *, 2> loadMemrefs;
|
||||
for (auto *loadOpInst : loads) {
|
||||
loadMemrefs.insert(loadOpInst->cast<LoadOp>()->getMemRef());
|
||||
}
|
||||
for (auto *storeOpInst : stores) {
|
||||
auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
|
||||
if (loadMemrefs.count(memref) > 0)
|
||||
loadAndStoreMemrefSet->insert(memref);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Edge represents a data dependece between nodes in the graph.
|
||||
|
@ -300,17 +324,18 @@ public:
|
|||
return true;
|
||||
}
|
||||
|
||||
// Returns true iff there is an edge from node 'srcId' to node 'dstId' for
|
||||
// 'value'. Returns false otherwise.
|
||||
bool hasEdge(unsigned srcId, unsigned dstId, Value *value) {
|
||||
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
|
||||
// is for 'value' if non-null, or for any value otherwise. Returns false
|
||||
// otherwise.
|
||||
bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) {
|
||||
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
|
||||
return false;
|
||||
}
|
||||
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
|
||||
return edge.id == dstId && edge.value == value;
|
||||
return edge.id == dstId && (!value || edge.value == value);
|
||||
});
|
||||
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
|
||||
return edge.id == srcId && edge.value == value;
|
||||
return edge.id == srcId && (!value || edge.value == value);
|
||||
});
|
||||
return hasOutEdge && hasInEdge;
|
||||
}
|
||||
|
@ -349,8 +374,37 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// Returns true if there is a path in the dependence graph from node 'srcId'
|
||||
// to node 'dstId'. Returns false otherwise.
|
||||
bool hasDependencePath(unsigned srcId, unsigned dstId) {
|
||||
// Worklist state is: <node-id, next-output-edge-index-to-visit>
|
||||
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
|
||||
worklist.push_back({srcId, 0});
|
||||
// Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
|
||||
while (!worklist.empty()) {
|
||||
auto &idAndIndex = worklist.back();
|
||||
// Return true if we have reached 'dstId'.
|
||||
if (idAndIndex.first == dstId)
|
||||
return true;
|
||||
// Pop and continue if node has no out edges, or if all out edges have
|
||||
// already been visited.
|
||||
if (outEdges.count(idAndIndex.first) == 0 ||
|
||||
idAndIndex.second == outEdges[idAndIndex.first].size()) {
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
// Get graph edge to traverse.
|
||||
Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
|
||||
// Increment next output edge index for 'idAndIndex'.
|
||||
++idAndIndex.second;
|
||||
// Add node at 'edge.id' to worklist.
|
||||
worklist.push_back({edge.id, 0});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns the input edge count for node 'id' and 'memref' from src nodes
|
||||
// which access 'memref'.
|
||||
// which access 'memref' with a store operation.
|
||||
unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
|
||||
unsigned inEdgeCount = 0;
|
||||
if (inEdges.count(id) > 0)
|
||||
|
@ -358,19 +412,19 @@ public:
|
|||
if (inEdge.value == memref) {
|
||||
Node *srcNode = getNode(inEdge.id);
|
||||
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
|
||||
if (srcNode->getLoadOpCount(memref) > 0 ||
|
||||
srcNode->getStoreOpCount(memref) > 0)
|
||||
if (srcNode->getStoreOpCount(memref) > 0)
|
||||
++inEdgeCount;
|
||||
}
|
||||
return inEdgeCount;
|
||||
}
|
||||
|
||||
// Returns the output edge count for node 'id' and 'memref'.
|
||||
unsigned getOutEdgeCount(unsigned id, Value *memref) {
|
||||
// Returns the output edge count for node 'id' and 'memref' (if non-null),
|
||||
// otherwise returns the total output edge count from node 'id'.
|
||||
unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) {
|
||||
unsigned outEdgeCount = 0;
|
||||
if (outEdges.count(id) > 0)
|
||||
for (auto &outEdge : outEdges[id])
|
||||
if (outEdge.value == memref)
|
||||
if (!memref || outEdge.value == memref)
|
||||
++outEdgeCount;
|
||||
return outEdgeCount;
|
||||
}
|
||||
|
@ -469,6 +523,32 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
|
||||
// of sibling node 'sidId' into node 'dstId'.
|
||||
void updateEdges(unsigned sibId, unsigned dstId) {
|
||||
// For each edge in 'inEdges[sibId]':
|
||||
// *) Add new edge from source node 'inEdge.id' to 'dstNode'.
|
||||
// *) Remove edge from source node 'inEdge.id' to 'sibNode'.
|
||||
if (inEdges.count(sibId) > 0) {
|
||||
SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
|
||||
for (auto &inEdge : oldInEdges) {
|
||||
addEdge(inEdge.id, dstId, inEdge.value);
|
||||
removeEdge(inEdge.id, sibId, inEdge.value);
|
||||
}
|
||||
}
|
||||
|
||||
// For each edge in 'outEdges[sibId]' to node 'id'
|
||||
// *) Add new edge from 'dstId' to 'outEdge.id'.
|
||||
// *) Remove edge from 'sibId' to 'outEdge.id'.
|
||||
if (outEdges.count(sibId) > 0) {
|
||||
SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
|
||||
for (auto &outEdge : oldOutEdges) {
|
||||
addEdge(dstId, outEdge.id, outEdge.value);
|
||||
removeEdge(sibId, outEdge.id, outEdge.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds ops in 'loads' and 'stores' to node at 'id'.
|
||||
void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
|
||||
const SmallVectorImpl<Instruction *> &stores) {
|
||||
|
@ -485,6 +565,37 @@ public:
|
|||
node->stores.clear();
|
||||
}
|
||||
|
||||
// Calls 'callback' for each input edge incident to node 'id' which carries a
|
||||
// memref dependence.
|
||||
void forEachMemRefInputEdge(unsigned id,
|
||||
const std::function<void(Edge)> &callback) {
|
||||
if (inEdges.count(id) > 0)
|
||||
forEachMemRefEdge(inEdges[id], callback);
|
||||
}
|
||||
// Calls 'callback' for each output edge from node 'id' which carries a
|
||||
// memref dependence.
|
||||
void forEachMemRefOutputEdge(unsigned id,
|
||||
const std::function<void(Edge)> &callback) {
|
||||
if (outEdges.count(id) > 0)
|
||||
forEachMemRefEdge(outEdges[id], callback);
|
||||
}
|
||||
// Calls 'callback' for each edge in 'edges' which carries a memref
|
||||
// dependence.
|
||||
void forEachMemRefEdge(ArrayRef<Edge> edges,
|
||||
const std::function<void(Edge)> &callback) {
|
||||
for (auto &edge : edges) {
|
||||
// Skip if 'edge' is not a memref dependence edge.
|
||||
if (!edge.value->getType().isa<MemRefType>())
|
||||
continue;
|
||||
assert(nodes.count(edge.id) > 0);
|
||||
// Skip if 'edge.id' is not a loop nest.
|
||||
if (!getNode(edge.id)->inst->isa<AffineForOp>())
|
||||
continue;
|
||||
// Visit current input edge 'edge'.
|
||||
callback(edge);
|
||||
}
|
||||
}
|
||||
|
||||
void print(raw_ostream &os) const {
|
||||
os << "\nMemRefDependenceGraph\n";
|
||||
os << "\nNodes:\n";
|
||||
|
@ -1228,6 +1339,14 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
|
|||
|
||||
// Checks the profitability of fusing a backwards slice of the loop nest
|
||||
// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
|
||||
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
|
||||
// the memref being produced and consumed, which is an input to the cost model.
|
||||
// For producer-constumer fusion, 'srcStoreOpInst' will be the same as
|
||||
// 'srcOpInst', as we are slicing w.r.t to that producer.
|
||||
// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
|
||||
// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
|
||||
// will be the unique store op in the src node, which will be used to check
|
||||
// that the write region is the same after input-reuse fusion.
|
||||
// Returns true if it is profitable to fuse the candidate loop nests. Returns
|
||||
// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
|
||||
// to materialize the source loop nest slice.
|
||||
|
@ -1257,6 +1376,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
|
|||
// loop nest computed in the previous step, and returns true if the latter
|
||||
// is lower.
|
||||
static bool isFusionProfitable(Instruction *srcOpInst,
|
||||
Instruction *srcStoreOpInst,
|
||||
ArrayRef<Instruction *> dstLoadOpInsts,
|
||||
ArrayRef<Instruction *> dstStoreOpInsts,
|
||||
ComputationSliceState *sliceState,
|
||||
|
@ -1294,8 +1414,11 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
return false;
|
||||
|
||||
// Compute the maximum loop depth at which we can can insert the src slice
|
||||
// and still satisfy dest loop nest dependences.
|
||||
unsigned maxDstLoopDepth = getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts);
|
||||
// and still satisfy dest loop nest dependences, for producer-consumer fusion.
|
||||
unsigned maxDstLoopDepth =
|
||||
(srcOpInst == srcStoreOpInst)
|
||||
? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
|
||||
: dstLoopIVs.size();
|
||||
if (maxDstLoopDepth == 0)
|
||||
return false;
|
||||
|
||||
|
@ -1306,7 +1429,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
// the cost of the slice and the cost of the slice inserted into the dst
|
||||
// loop nest at 'dstLoopDepth'.
|
||||
uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t maxStorageReduction = 0;
|
||||
double maxStorageReduction = 0.0;
|
||||
Optional<uint64_t> sliceMemEstimate = None;
|
||||
|
||||
SmallVector<ComputationSliceState, 4> sliceStates;
|
||||
|
@ -1321,8 +1444,8 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
/*computeCostMap=*/nullptr);
|
||||
|
||||
// Compute src loop nest write region size.
|
||||
MemRefRegion srcWriteRegion(srcOpInst->getLoc());
|
||||
srcWriteRegion.compute(srcOpInst, /*loopDepth=*/0);
|
||||
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
|
||||
srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0);
|
||||
Optional<int64_t> maybeSrcWriteRegionSizeBytes =
|
||||
srcWriteRegion.getRegionSize();
|
||||
if (!maybeSrcWriteRegionSizeBytes.hasValue())
|
||||
|
@ -1345,6 +1468,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
if (!mlir::getBackwardComputationSliceState(
|
||||
srcAccess, MemRefAccess(dstLoadOpInsts[0]), i, &sliceStates[i - 1]))
|
||||
return false;
|
||||
|
||||
// Compute the union of slice bound of all ops in 'dstLoadOpInsts'.
|
||||
for (int j = 1, e = dstLoadOpInsts.size(); j < e; ++j) {
|
||||
MemRefAccess dstAccess(dstLoadOpInsts[j]);
|
||||
|
@ -1372,6 +1496,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
computeCostMap.clear();
|
||||
|
||||
// The store and loads to this memref will disappear.
|
||||
// TODO(andydavis) Add load coalescing to memref data flow opt pass.
|
||||
if (storeLoadFwdGuaranteed) {
|
||||
// A single store disappears: -1 for that.
|
||||
computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
|
||||
|
@ -1403,8 +1528,9 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
// Compute what the slice write MemRefRegion would be, if the src loop
|
||||
// nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
|
||||
// nest at loop depth 'i'
|
||||
MemRefRegion sliceWriteRegion(srcOpInst->getLoc());
|
||||
sliceWriteRegion.compute(srcOpInst, /*loopDepth=*/0, &sliceStates[i - 1]);
|
||||
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
|
||||
sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
|
||||
&sliceStates[i - 1]);
|
||||
Optional<int64_t> maybeSliceWriteRegionSizeBytes =
|
||||
sliceWriteRegion.getRegionSize();
|
||||
if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
|
||||
|
@ -1413,6 +1539,14 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
int64_t sliceWriteRegionSizeBytes =
|
||||
maybeSliceWriteRegionSizeBytes.getValue();
|
||||
|
||||
// If we are fusing for reuse, check that write regions remain the same.
|
||||
// TODO(andydavis) Write region check should check sizes and offsets in
|
||||
// each dimension, so that we are sure they are covering the same memref
|
||||
// region. Also, move this out to a isMemRefRegionSuperSet helper function.
|
||||
if (srcOpInst != srcStoreOpInst &&
|
||||
sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
|
||||
continue;
|
||||
|
||||
double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
|
||||
static_cast<double>(sliceWriteRegionSizeBytes);
|
||||
|
||||
|
@ -1547,12 +1681,10 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
return true;
|
||||
}
|
||||
|
||||
// GreedyFusion greedily fuses loop nests which have a producer/consumer
|
||||
// relationship on a memref, with the goal of improving locality. Currently,
|
||||
// this the producer/consumer relationship is required to be unique in the
|
||||
// Function (there are TODOs to relax this constraint in the future).
|
||||
// GreedyFusion greedily fuses loop nests which have a producer/consumer or
|
||||
// input-reuse relationship on a memref, with the goal of improving locality.
|
||||
//
|
||||
// The steps of the algorithm are as follows:
|
||||
// The steps of the producer-consumer fusion algorithm are as follows:
|
||||
//
|
||||
// *) A worklist is initialized with node ids from the dependence graph.
|
||||
// *) For each node id in the worklist:
|
||||
|
@ -1560,20 +1692,32 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
// candidate destination AffineForOp into which fusion will be attempted.
|
||||
// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
|
||||
// *) For each LoadOp in 'dstLoadOps' do:
|
||||
// *) Lookup dependent loop nests at earlier positions in the Function
|
||||
// which have a single store op to the same memref.
|
||||
// *) Check if dependences would be violated by the fusion. For example,
|
||||
// the src loop nest may load from memrefs which are different than
|
||||
// the producer-consumer memref between src and dest loop nests.
|
||||
// *) Lookup dependent loop nests which have a single store op to the same
|
||||
// memref.
|
||||
// *) Check if dependences would be violated by the fusion.
|
||||
// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
|
||||
// bounds to be functions of 'dstLoopNest' IVs and symbols.
|
||||
// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
|
||||
// just before the dst load op user.
|
||||
// at a loop depth determined by the cost model in 'isFusionProfitable'.
|
||||
// *) Add the newly fused load/store operation instructions to the state,
|
||||
// and also add newly fuse load ops to 'dstLoopOps' to be considered
|
||||
// as fusion dst load ops in another iteration.
|
||||
// *) Remove old src loop nest and its associated state.
|
||||
//
|
||||
// The steps of the input-reuse fusion algorithm are as follows:
|
||||
//
|
||||
// *) Initialize 'worklist' with node ids from the dependence graph.
|
||||
// *) For each 'dstNode' in the worklist:
|
||||
// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
|
||||
// loads from the same memref, but which has no dependence paths to/from.
|
||||
// *) Get a computation slice of 'sibLoopNest', which adjusts its loop
|
||||
// bounds to be functions of 'dstLoopNest' IVs and symbols.
|
||||
// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
|
||||
// at a loop depth determined by the cost model in 'isFusionProfitable'.
|
||||
// This function also checks that the memref write region of 'sibLoopNest',
|
||||
// is preserved in the fused loop nest.
|
||||
// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
|
||||
//
|
||||
// Given a graph where top-level instructions are vertices in the set 'V' and
|
||||
// edges in the set 'E' are dependences between vertices, this algorithm
|
||||
// takes O(V) time for initialization, and has runtime O(V + E).
|
||||
|
@ -1582,25 +1726,54 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
// fusing along single producer consumer edges, but there is a TODO to fix this.
|
||||
//
|
||||
// TODO(andydavis) Experiment with other fusion policies.
|
||||
// TODO(andydavis) Add support for fusing for input reuse (perhaps by
|
||||
// constructing a graph with edges which represent loads from the same memref
|
||||
// in two different loop nests.
|
||||
struct GreedyFusion {
|
||||
public:
|
||||
// The data dependence graph to traverse during fusion.
|
||||
MemRefDependenceGraph *mdg;
|
||||
// Worklist of graph nodes visited during the fusion pass.
|
||||
SmallVector<unsigned, 8> worklist;
|
||||
// Set of graph nodes which are present on the worklist.
|
||||
llvm::SmallDenseSet<unsigned, 16> worklistSet;
|
||||
// Parameter for local buffer size threshold.
|
||||
unsigned localBufSizeThreshold;
|
||||
// Parameter for fast memory space.
|
||||
Optional<unsigned> fastMemorySpace;
|
||||
|
||||
GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
|
||||
// Initialize worklist with nodes from 'mdg'.
|
||||
using Node = MemRefDependenceGraph::Node;
|
||||
|
||||
GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
|
||||
Optional<unsigned> fastMemorySpace)
|
||||
: mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
|
||||
fastMemorySpace(fastMemorySpace) {}
|
||||
|
||||
// Initializes 'worklist' with nodes from 'mdg'
|
||||
void init() {
|
||||
// TODO(andydavis) Add a priority queue for prioritizing nodes by different
|
||||
// metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
|
||||
worklist.resize(mdg->nodes.size());
|
||||
std::iota(worklist.begin(), worklist.end(), 0);
|
||||
worklistSet.insert(worklist.begin(), worklist.end());
|
||||
worklist.clear();
|
||||
worklistSet.clear();
|
||||
for (auto &idAndNode : mdg->nodes) {
|
||||
const Node &node = idAndNode.second;
|
||||
worklist.push_back(node.id);
|
||||
worklistSet.insert(node.id);
|
||||
}
|
||||
}
|
||||
|
||||
void run(unsigned localBufSizeThreshold, Optional<unsigned> fastMemorySpace) {
|
||||
// Run the GreedyFusion pass.
|
||||
// *) First pass through the nodes fuses single-use producer nodes into their
|
||||
// unique consumer.
|
||||
// *) Second pass fuses sibling nodes which share no dependence edges.
|
||||
// *) Third pass fuses any remaining producer nodes into their users.
|
||||
void run() {
|
||||
fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
|
||||
fuseSiblingNodes();
|
||||
fuseProducerConsumerNodes(
|
||||
/*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
|
||||
eraseUnusedMemRefAllocations();
|
||||
}
|
||||
|
||||
void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
|
||||
init();
|
||||
while (!worklist.empty()) {
|
||||
unsigned dstId = worklist.back();
|
||||
worklist.pop_back();
|
||||
|
@ -1672,6 +1845,10 @@ public:
|
|||
!canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
|
||||
continue;
|
||||
|
||||
// Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
|
||||
if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
|
||||
continue;
|
||||
|
||||
// Compute an instruction list insertion point for the fused loop
|
||||
// nest which preserves dependences.
|
||||
Instruction *insertPointInst =
|
||||
|
@ -1690,8 +1867,8 @@ public:
|
|||
unsigned bestDstLoopDepth;
|
||||
mlir::ComputationSliceState sliceState;
|
||||
// Check if fusion would be profitable.
|
||||
if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts,
|
||||
dstStoreOpInsts, &sliceState,
|
||||
if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
|
||||
dstLoadOpInsts, dstStoreOpInsts, &sliceState,
|
||||
&bestDstLoopDepth))
|
||||
continue;
|
||||
|
||||
|
@ -1782,7 +1959,202 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
// Clean up any allocs with no users.
|
||||
}
|
||||
|
||||
// Visits each node in the graph, and for each node, attempts to fuse it with
|
||||
// its sibling nodes (nodes which share a parent, but no dependence edges).
|
||||
void fuseSiblingNodes() {
|
||||
init();
|
||||
while (!worklist.empty()) {
|
||||
unsigned dstId = worklist.back();
|
||||
worklist.pop_back();
|
||||
worklistSet.erase(dstId);
|
||||
|
||||
// Skip if this node was removed (fused into another node).
|
||||
if (mdg->nodes.count(dstId) == 0)
|
||||
continue;
|
||||
// Get 'dstNode' into which to attempt fusion.
|
||||
auto *dstNode = mdg->getNode(dstId);
|
||||
// Skip if 'dstNode' is not a loop nest.
|
||||
if (!dstNode->inst->isa<AffineForOp>())
|
||||
continue;
|
||||
// Attempt to fuse 'dstNode' with its sibling nodes in the graph.
|
||||
fuseWithSiblingNodes(dstNode);
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to fuse 'dstNode' with sibling nodes in the graph.
|
||||
void fuseWithSiblingNodes(Node *dstNode) {
|
||||
DenseSet<unsigned> visitedSibNodeIds;
|
||||
std::pair<unsigned, Value *> idAndMemref;
|
||||
while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
|
||||
unsigned sibId = idAndMemref.first;
|
||||
Value *memref = idAndMemref.second;
|
||||
// TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
|
||||
// stores to the same memref in 'sibNode' loop nest.
|
||||
auto *sibNode = mdg->getNode(sibId);
|
||||
// Compute an instruction list insertion point for the fused loop
|
||||
// nest which preserves dependences.
|
||||
assert(sibNode->inst->getBlock() == dstNode->inst->getBlock());
|
||||
Instruction *insertPointInst =
|
||||
sibNode->inst->isBeforeInBlock(dstNode->inst)
|
||||
? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
|
||||
: mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
|
||||
if (insertPointInst == nullptr)
|
||||
continue;
|
||||
|
||||
// Check if fusion would be profitable and at what depth.
|
||||
|
||||
// Get unique 'sibNode' load op to 'memref'.
|
||||
SmallVector<Instruction *, 2> sibLoadOpInsts;
|
||||
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
|
||||
// Currently findSiblingNodeToFuse searches for siblings with one load.
|
||||
assert(sibLoadOpInsts.size() == 1);
|
||||
Instruction *sibLoadOpInst = sibLoadOpInsts[0];
|
||||
assert(!sibNode->stores.empty());
|
||||
// TODO(andydavis) Choose the store which postdominates all other stores.
|
||||
auto *sibStoreOpInst = sibNode->stores.back();
|
||||
|
||||
// Gather 'dstNode' load ops to 'memref'.
|
||||
SmallVector<Instruction *, 2> dstLoadOpInsts;
|
||||
dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
|
||||
|
||||
// Gather 'dstNode' store ops to 'memref'.
|
||||
SmallVector<Instruction *, 2> dstStoreOpInsts;
|
||||
dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
|
||||
|
||||
unsigned bestDstLoopDepth;
|
||||
mlir::ComputationSliceState sliceState;
|
||||
|
||||
// Check if fusion would be profitable.
|
||||
if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
|
||||
dstStoreOpInsts, &sliceState, &bestDstLoopDepth))
|
||||
continue;
|
||||
|
||||
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
|
||||
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
|
||||
sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
|
||||
if (sliceLoopNest != nullptr) {
|
||||
auto dstForInst = dstNode->inst->cast<AffineForOp>();
|
||||
// Update instruction position of fused loop nest (if needed).
|
||||
if (insertPointInst != dstForInst->getInstruction()) {
|
||||
dstForInst->getInstruction()->moveBefore(insertPointInst);
|
||||
}
|
||||
// Update data dependence graph state post fusion.
|
||||
updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Searches the graph from 'dstNode' looking for a fusion candidate sibling
|
||||
// node which shares no dependences with 'dstNode' but which loads from the
|
||||
// same memref. Returns true and sets 'idAndMemrefToFuse' on success. Returns
|
||||
// false otherwise.
|
||||
bool findSiblingNodeToFuse(Node *dstNode,
|
||||
DenseSet<unsigned> *visitedSibNodeIds,
|
||||
std::pair<unsigned, Value *> *idAndMemrefToFuse) {
|
||||
// TODO(andydavis) Currently we discover siblings by following edges
|
||||
// through an intermediate src node. We should also consider siblings
|
||||
// which load from the same memref, but which do not necessarily share
|
||||
// a src node parent (e.g. loading from a memref which is a function arg).
|
||||
// Collect candidate 'dstNode' input edges in 'inEdges'.
|
||||
SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
|
||||
mdg->forEachMemRefInputEdge(
|
||||
dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
|
||||
// Add 'inEdge' if it is a read-after-write dependence.
|
||||
if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
|
||||
mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
|
||||
inEdges.push_back(inEdge);
|
||||
});
|
||||
|
||||
// Search for sibling nodes to fuse by visiting output edges from each input
|
||||
// edge in 'inEdges'.
|
||||
for (auto &inEdge : inEdges) {
|
||||
// Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
|
||||
SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
|
||||
mdg->forEachMemRefOutputEdge(
|
||||
inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
|
||||
unsigned sibNodeId = outEdge.id;
|
||||
if (visitedSibNodeIds->count(sibNodeId) > 0)
|
||||
return;
|
||||
// Skip output edge if not a sibling using the same memref.
|
||||
if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
|
||||
return;
|
||||
auto *sibNode = mdg->getNode(sibNodeId);
|
||||
if (!sibNode->inst->isa<AffineForOp>())
|
||||
return;
|
||||
// Skip if 'outEdge' is not a read-after-write dependence.
|
||||
// TODO(andydavis) Remove restrict to single load op restriction.
|
||||
if (sibNode->getLoadOpCount(inEdge.value) != 1)
|
||||
return;
|
||||
// Skip if there exists a path of dependent edges between
|
||||
// 'sibNode' and 'dstNode'.
|
||||
if (mdg->hasDependencePath(sibNodeId, dstNode->id) ||
|
||||
mdg->hasDependencePath(dstNode->id, sibNodeId))
|
||||
return;
|
||||
// Skip sib node if it loads to (and stores from) the same memref on
|
||||
// which it also has an input dependence edge.
|
||||
DenseSet<Value *> loadAndStoreMemrefSet;
|
||||
sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
|
||||
if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
|
||||
return mdg->getIncomingMemRefAccesses(sibNode->id, memref) >
|
||||
0;
|
||||
}))
|
||||
return;
|
||||
// Check that all stores are to the same memref.
|
||||
DenseSet<Value *> storeMemrefs;
|
||||
for (auto *storeOpInst : sibNode->stores) {
|
||||
storeMemrefs.insert(storeOpInst->cast<StoreOp>()->getMemRef());
|
||||
}
|
||||
if (storeMemrefs.size() != 1)
|
||||
return;
|
||||
// Add candidate 'outEdge' to sibling node.
|
||||
outEdges.push_back(outEdge);
|
||||
});
|
||||
|
||||
// Add first candidate if any were returned.
|
||||
if (!outEdges.empty()) {
|
||||
visitedSibNodeIds->insert(outEdges[0].id);
|
||||
idAndMemrefToFuse->first = outEdges[0].id;
|
||||
idAndMemrefToFuse->second = outEdges[0].value;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void updateStateAfterSiblingFusion(OpPointer<AffineForOp> sliceLoopNest,
|
||||
Node *sibNode, Node *dstNode) {
|
||||
// Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
|
||||
mdg->updateEdges(sibNode->id, dstNode->id);
|
||||
|
||||
// Collect slice loop stats.
|
||||
LoopNestStateCollector sliceCollector;
|
||||
sliceCollector.collect(sliceLoopNest->getInstruction());
|
||||
// Promote single iteration slice loops to single IV value.
|
||||
for (auto forOp : sliceCollector.forOps) {
|
||||
promoteIfSingleIteration(forOp);
|
||||
}
|
||||
|
||||
// Collect dst loop stats after memref privatizaton transformation.
|
||||
auto dstForInst = dstNode->inst->cast<AffineForOp>();
|
||||
LoopNestStateCollector dstLoopCollector;
|
||||
dstLoopCollector.collect(dstForInst->getInstruction());
|
||||
// Clear and add back loads and stores
|
||||
mdg->clearNodeLoadAndStores(dstNode->id);
|
||||
mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
|
||||
dstLoopCollector.storeOpInsts);
|
||||
// Remove old sibling loop nest if it no longer has outgoing dependence
|
||||
// edges, and it does not write to a memref which escapes the
|
||||
// function.
|
||||
if (mdg->getOutEdgeCount(sibNode->id) == 0) {
|
||||
mdg->removeNode(sibNode->id);
|
||||
sibNode->inst->cast<AffineForOp>()->erase();
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up any allocs with no users.
|
||||
void eraseUnusedMemRefAllocations() {
|
||||
for (auto &pair : mdg->memrefEdgeCount) {
|
||||
if (pair.second > 0)
|
||||
continue;
|
||||
|
@ -1813,7 +2185,7 @@ void LoopFusion::runOnFunction() {
|
|||
|
||||
MemRefDependenceGraph g;
|
||||
if (g.init(&getFunction()))
|
||||
GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace);
|
||||
GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace).run();
|
||||
}
|
||||
|
||||
static PassRegistration<LoopFusion> pass("loop-fusion", "Fuse loop nests");
|
||||
|
|
|
@ -1228,13 +1228,13 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() {
|
|||
// by loops %i1 and %i2.
|
||||
// CHECK-DAG: %0 = alloc() : memref<1xf32>
|
||||
// CHECK-DAG: %1 = alloc() : memref<1xf32>
|
||||
// CHECK: for %i0 = 0 to 82 {
|
||||
// CHECK: for %i0 = 0 to 17 {
|
||||
// CHECK-NEXT: %2 = affine.apply [[MAP0]](%i0, %i0)
|
||||
// CHECK-NEXT: store %cst, %1[%2] : memref<1xf32>
|
||||
// CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i0)
|
||||
// CHECK-NEXT: %4 = load %1[%3] : memref<1xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i1 = 0 to 17 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 82 {
|
||||
// CHECK-NEXT: %5 = affine.apply [[MAP0]](%i1, %i1)
|
||||
// CHECK-NEXT: store %cst, %0[%5] : memref<1xf32>
|
||||
// CHECK-NEXT: %6 = affine.apply [[MAP0]](%i1, %i1)
|
||||
|
@ -1915,3 +1915,150 @@ func @test_add_slice_bounds() {
|
|||
// CHECK-NEXT: }
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d0 + d2)
|
||||
// CHECK-DAG: [[MAP1:#map[0-9]+]] = (d0, d1, d2, d3) -> (-d1 + d3)
|
||||
|
||||
func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>) {
|
||||
%0 = alloc() : memref<10x10xf32>
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%cst_0 = constant 1.000000e+00 : f32
|
||||
%cst_1 = constant 7.000000e+00 : f32
|
||||
for %i0 = 0 to 10 {
|
||||
for %i1 = 0 to 10 {
|
||||
store %cst_1, %0[%i0, %i1] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
for %i2 = 0 to 3 {
|
||||
for %i3 = 0 to 3 {
|
||||
store %cst, %arg0[%i2, %i3] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
for %i4 = 0 to 3 {
|
||||
for %i5 = 0 to 3 {
|
||||
%1 = load %0[%i4, %i5] : memref<10x10xf32>
|
||||
%2 = load %arg0[%i4, %i5] : memref<10x10xf32>
|
||||
%3 = mulf %1, %2 : f32
|
||||
store %3, %arg0[%i4, %i5] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
for %i6 = 0 to 3 {
|
||||
for %i7 = 0 to 3 {
|
||||
store %cst_0, %arg1[%i6, %i7] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
for %i8 = 0 to 3 {
|
||||
for %i9 = 0 to 3 {
|
||||
%4 = load %0[%i8, %i9] : memref<10x10xf32>
|
||||
%5 = load %arg1[%i8, %i9] : memref<10x10xf32>
|
||||
%6 = addf %4, %5 : f32
|
||||
store %6, %arg1[%i8, %i9] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 1: should fuse single-use producer loop nests into their unique user,
|
||||
// so '%i2' will fuse into '%i4' and '%i6' will fuse into '%i8'.
|
||||
// Pass 2: should fuse sibling loop nests which share no dependence edges,
|
||||
// so should fuse '%i4' into '%i8'.
|
||||
// Pass 3: should fuse single-use producer loop nest '%i0' into '%i8'. Note
|
||||
// that loop nest '%i0' now has a single user after Pass 2 fused its
|
||||
// two users together).
|
||||
|
||||
// CHECK: for %i0 = 0 to 3 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 3 {
|
||||
// CHECK-NEXT: %1 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1)
|
||||
// CHECK-NEXT: %2 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1)
|
||||
// CHECK-NEXT: store %cst_1, %0[%1, %2] : memref<1x1xf32>
|
||||
// CHECK-NEXT: store %cst, %arg0[%i0, %i1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %3 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1)
|
||||
// CHECK-NEXT: %4 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1)
|
||||
// CHECK-NEXT: %5 = load %0[%3, %4] : memref<1x1xf32>
|
||||
// CHECK-NEXT: %6 = load %arg0[%i0, %i1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %7 = mulf %5, %6 : f32
|
||||
// CHECK-NEXT: store %7, %arg0[%i0, %i1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: store %cst_0, %arg1[%i0, %i1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %8 = affine.apply [[MAP0]](%i0, %i1, %i0, %i1)
|
||||
// CHECK-NEXT: %9 = affine.apply [[MAP1]](%i0, %i1, %i0, %i1)
|
||||
// CHECK-NEXT: %10 = load %0[%8, %9] : memref<1x1xf32>
|
||||
// CHECK-NEXT: %11 = load %arg1[%i0, %i1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %12 = addf %10, %11 : f32
|
||||
// CHECK-NEXT: store %12, %arg1[%i0, %i1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-DAG: [[MAP2:#map[0-9]+]] = (d0, d1, d2) -> (d1)
|
||||
// CHECK-DAG: [[MAP3:#map[0-9]+]] = (d0, d1, d2) -> (-d0 + d2)
|
||||
|
||||
func @two_matrix_vector_products() {
|
||||
%in_matrix = alloc() : memref<10x10xf32>
|
||||
%in_vec0 = alloc() : memref<10xf32>
|
||||
%in_vec1 = alloc() : memref<10xf32>
|
||||
%out_vec0 = alloc() : memref<10xf32>
|
||||
%out_vec1 = alloc() : memref<10xf32>
|
||||
%cf7 = constant 7.0 : f32
|
||||
|
||||
// Populate input matrix.
|
||||
for %i0 = 0 to 10 {
|
||||
for %i1 = 0 to 10 {
|
||||
store %cf7, %in_matrix[%i0, %i1] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
// out_vec0 = in_matrix x in_vec0
|
||||
for %i2 = 0 to 10 {
|
||||
for %i3 = 0 to 10 {
|
||||
%v0 = load %in_matrix[%i2, %i3] : memref<10x10xf32>
|
||||
%v1 = load %in_vec0[%i3] : memref<10xf32>
|
||||
%v2 = mulf %v0, %v1 : f32
|
||||
%v3 = load %out_vec0[%i3] : memref<10xf32>
|
||||
%v4 = addf %v2, %v3 : f32
|
||||
store %v4, %out_vec0[%i3] : memref<10xf32>
|
||||
}
|
||||
}
|
||||
// out_vec1 = in_matrix x in_vec1
|
||||
for %i4 = 0 to 10 {
|
||||
for %i5 = 0 to 10 {
|
||||
%v5 = load %in_matrix[%i4, %i5] : memref<10x10xf32>
|
||||
%v6 = load %in_vec1[%i5] : memref<10xf32>
|
||||
%v7 = mulf %v5, %v6 : f32
|
||||
%v8 = load %out_vec1[%i5] : memref<10xf32>
|
||||
%v9 = addf %v7, %v8 : f32
|
||||
store %v9, %out_vec1[%i5] : memref<10xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 10 {
|
||||
// CHECK-NEXT: %5 = affine.apply [[MAP2]](%i0, %i1, %i0)
|
||||
// CHECK-NEXT: %6 = affine.apply [[MAP3]](%i0, %i1, %i0)
|
||||
// CHECK-NEXT: store %cst, %0[%5, %6] : memref<10x1xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i2 = 0 to 10 {
|
||||
// CHECK-NEXT: %7 = affine.apply [[MAP2]](%i0, %i2, %i0)
|
||||
// CHECK-NEXT: %8 = affine.apply [[MAP3]](%i0, %i2, %i0)
|
||||
// CHECK-NEXT: %9 = load %0[%7, %8] : memref<10x1xf32>
|
||||
// CHECK-NEXT: %10 = load %1[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: %11 = mulf %9, %10 : f32
|
||||
// CHECK-NEXT: %12 = load %3[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: %13 = addf %11, %12 : f32
|
||||
// CHECK-NEXT: store %13, %3[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i3 = 0 to 10 {
|
||||
// CHECK-NEXT: %14 = affine.apply [[MAP2]](%i0, %i3, %i0)
|
||||
// CHECK-NEXT: %15 = affine.apply [[MAP3]](%i0, %i3, %i0)
|
||||
// CHECK-NEXT: %16 = load %0[%14, %15] : memref<10x1xf32>
|
||||
// CHECK-NEXT: %17 = load %2[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: %18 = mulf %16, %17 : f32
|
||||
// CHECK-NEXT: %19 = load %4[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: %20 = addf %18, %19 : f32
|
||||
// CHECK-NEXT: store %20, %4[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue