From 27d067e16451da80f6b53dc90740a2238e3f4ee7 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Wed, 16 Jan 2019 09:55:02 -0800 Subject: [PATCH] LoopFusion improvements: *) Adds support for fusing into consumer loop nests with multiple loads from the same memref. *) Adds support for reducing slice loop trip count by projecting out destination loop IVs greater than destination loop depth. *) Removes dependence on src loop depth and simplifies cost model computation. PiperOrigin-RevId: 229575126 --- mlir/include/mlir/Analysis/Utils.h | 18 +- mlir/lib/Analysis/AffineStructures.cpp | 17 +- mlir/lib/Analysis/Utils.cpp | 56 ++-- mlir/lib/Transforms/LoopFusion.cpp | 414 +++++++++++++++---------- mlir/test/Transforms/loop-fusion.mlir | 150 ++++++++- 5 files changed, 442 insertions(+), 213 deletions(-) diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 7cd30ba86abf..4e304067411b 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -173,23 +173,23 @@ struct ComputationSliceState { /// Returns true on success, false otherwise. bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned dstLoopDepth, ComputationSliceState *sliceState); /// Creates a clone of the computation contained in the loop nest surrounding -/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop -/// IVs, and inserts the computation slice at the beginning of the instruction -/// block of the loop at 'dstLoopDepth' in the loop nest surrounding -/// 'dstAccess'. Returns the top-level loop of the computation slice on +/// 'srcOpInst', slices the iteration space of src loop based on slice bounds +/// in 'sliceState', and inserts the computation slice at the beginning of the +/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding +/// 'dstOpInst'. Returns the top-level loop of the computation slice on /// success, returns nullptr otherwise. // Loop depth is a crucial optimization choice that determines where to // materialize the results of the backward slice - presenting a trade-off b/w // storage and redundant computation in several cases. // TODO(andydavis) Support computation slices with common surrounding loops. -ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess, - MemRefAccess *dstAccess, - ComputationSliceState *sliceState, - unsigned srcLoopDepth, - unsigned dstLoopDepth); +ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst, + OperationInst *dstOpInst, + unsigned dstLoopDepth, + ComputationSliceState *sliceState); } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index bf915dbbf5bb..af9252c279ce 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1101,8 +1101,21 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context, (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {}); (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {}); } else { - (*lbMaps)[pos] = AffineMap::Null(); - (*ubMaps)[pos] = AffineMap::Null(); + // TODO(andydavis, bondhugula) Add support for computing slice bounds + // symbolic in the identifies [num, numIds). + auto lbConst = getConstantLowerBound(pos); + auto ubConst = getConstantUpperBound(pos); + if (lbConst.hasValue() && ubConst.hasValue()) { + (*lbMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(lbConst.getValue(), context), {}); + (*ubMaps)[pos] = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(ubConst.getValue() + 1, context), {}); + } else { + (*lbMaps)[pos] = AffineMap::Null(); + (*ubMaps)[pos] = AffineMap::Null(); + } } LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: "); LLVM_DEBUG(expr.dump();); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 49e1e31f55d5..c003a6413113 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -346,12 +346,13 @@ static Instruction *getInstAtPosition(ArrayRef positions, return nullptr; } -// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the -// dependence constraint system to create AffineMaps with which to adjust the -// loop bounds of the inserted compution slice so that they are functions of the -// loop IVs and symbols of the loops surrounding 'dstAccess'. +// Computes memref dependence between 'srcAccess' and 'dstAccess', projects +// out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice +// bounds in 'sliceState' which represent the src IVs in terms of the dst IVs, +// symbols and constants. bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned dstLoopDepth, ComputationSliceState *sliceState) { FlatAffineConstraints dependenceConstraints; if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1, @@ -364,6 +365,19 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, getLoopIVs(*srcAccess.opInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); + // Get loop nest surrounding dst operation. + SmallVector dstLoopIVs; + getLoopIVs(*dstAccess.opInst, &dstLoopIVs); + unsigned numDstLoopIVs = dstLoopIVs.size(); + if (dstLoopDepth > numDstLoopIVs) { + dstAccess.opInst->emitError("invalid destination loop depth"); + return false; + } + + // Project out dimensions other than those up to 'dstLoopDepth'. + dependenceConstraints.projectOut(numSrcLoopIVs + dstLoopDepth, + numDstLoopIVs - dstLoopDepth); + // Set up lower/upper bound affine maps for the slice. sliceState->lbs.resize(numSrcLoopIVs, AffineMap::Null()); sliceState->ubs.resize(numSrcLoopIVs, AffineMap::Null()); @@ -385,12 +399,10 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, return true; } -/// Creates a computation slice of the loop nest surrounding 'srcAccess' -/// utilizing slice loop bounds in 'sliceState' (for src loops up to -/// 'srcLoopDepth'), and inserts this slice into loop nest surrounding -/// 'dstAccess' at loop depth 'dstLoopDepth'. For all loops at loop depth -/// greater than 'srcLoopDepth' their full loop bounds will be used in the -/// slice. +/// Creates a computation slice of the loop nest surrounding 'srcOpInst', +/// updates the slice loop bounds with any non-null bound maps specified in +/// 'sliceState', and inserts this slice into the loop nest surrounding +/// 'dstOpInst' at loop depth 'dstLoopDepth'. // TODO(andydavis,bondhugula): extend the slicing utility to compute slices that // aren't necessarily a one-to-one relation b/w the source and destination. The // relation between the source and destination could be many-to-many in general. @@ -401,33 +413,27 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // solution. // TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project // out loop IVs we don't care about and produce smaller slice. -ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, - MemRefAccess *dstAccess, - ComputationSliceState *sliceState, - unsigned srcLoopDepth, - unsigned dstLoopDepth) { +ForInst *mlir::insertBackwardComputationSlice( + OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth, + ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. SmallVector srcLoopIVs; - getLoopIVs(*srcAccess->opInst, &srcLoopIVs); + getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); - if (srcLoopDepth > numSrcLoopIVs) { - srcAccess->opInst->emitError("invalid source loop depth"); - return nullptr; - } // Get loop nest surrounding dst operation. SmallVector dstLoopIVs; - getLoopIVs(*dstAccess->opInst, &dstLoopIVs); + getLoopIVs(*dstOpInst, &dstLoopIVs); unsigned dstLoopIVsSize = dstLoopIVs.size(); if (dstLoopDepth > dstLoopIVsSize) { - dstAccess->opInst->emitError("invalid destination loop depth"); + dstOpInst->emitError("invalid destination loop depth"); return nullptr; } - // Find the inst block positions of 'srcAccess->opInst' within 'srcLoopIVs'. + // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. - findInstPosition(srcAccess->opInst, srcLoopIVs[0]->getBlock(), &positions); + findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions); // Clone src loop nest and insert it a the beginning of the instruction block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. @@ -451,7 +457,7 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, assert(sliceLoopLimit >= sliceSurroundingLoopsSize); // Update loop bounds for loops in 'sliceLoopNest'. - for (unsigned i = 0; i < srcLoopDepth; ++i) { + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { auto *forInst = sliceSurroundingLoops[dstLoopDepth + i]; if (AffineMap lbMap = sliceState->lbs[i]) forInst->setLowerBound(sliceState->lbOperands[i], lbMap); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 91e8d2946a63..cdd1c77f302a 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -69,17 +69,6 @@ char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } -// FusionCandidate encapsulates source and destination memref access within -// loop nests which are candidates for loop fusion. -struct FusionCandidate { - // Load or store access within src loop nest to be fused into dst loop nest. - MemRefAccess srcAccess; - // Load or store access within dst loop nest. - MemRefAccess dstAccess; - explicit FusionCandidate(OperationInst *src, OperationInst *dst) - : srcAccess(MemRefAccess(src)), dstAccess(MemRefAccess(dst)) {} -}; - namespace { // LoopNestStateCollector walks loop nests and collects load and store @@ -172,10 +161,27 @@ public: return &it->second; } + // Returns true iff there is an edge from node 'srcId' to node 'dstId' for + // 'memref'. Returns false otherwise. + bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) { + if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { + return false; + } + bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { + return edge.id == dstId && edge.memref == memref; + }); + bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { + return edge.id == srcId && edge.memref == memref; + }); + return hasOutEdge && hasInEdge; + } + // Adds an edge from node 'srcId' to node 'dstId' for 'memref'. void addEdge(unsigned srcId, unsigned dstId, Value *memref) { - outEdges[srcId].push_back({dstId, memref}); - inEdges[dstId].push_back({srcId, memref}); + if (!hasEdge(srcId, dstId, memref)) { + outEdges[srcId].push_back({dstId, memref}); + inEdges[dstId].push_back({srcId, memref}); + } } // Removes an edge from node 'srcId' to node 'dstId' for 'memref'. @@ -425,10 +431,10 @@ public: // inserting a sliced loop nest of known cost into the loop's body. // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. -static uint64_t -getComputeCost(ForInst *forInst, LoopNestStats *stats, - DenseMap *tripCountOverrideMap, - DenseMap *computeCostMap) { +static uint64_t getComputeCost( + ForInst *forInst, LoopNestStats *stats, + llvm::SmallDenseMap *tripCountOverrideMap, + DenseMap *computeCostMap) { // 'opCount' is the total number operations in one iteration of 'forInst' body uint64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { @@ -458,17 +464,33 @@ getComputeCost(ForInst *forInst, LoopNestStats *stats, } // end anonymous namespace +static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { + assert(lbMap.getNumResults() == 1); + assert(ubMap.getNumResults() == 1); + assert(lbMap.getNumDims() == ubMap.getNumDims()); + assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); + // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'. + // ub_expr - lb_expr + AffineExpr lbExpr(lbMap.getResult(0)); + AffineExpr ubExpr(ubMap.getResult(0)); + auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), + lbMap.getNumSymbols()); + auto cExpr = loopSpanExpr.dyn_cast(); + if (!cExpr) + return None; + return cExpr.getValue(); +} + // Builds a map 'tripCountMap' from ForInst to constant trip count for loop // nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'. // Returns true on success, false otherwise (if a non-constant trip count // was encountered). // TODO(andydavis) Make this work with non-unit step loops. -static bool -buildSliceTripCountMap(MemRefAccess *srcAccess, - ComputationSliceState *sliceState, - DenseMap *tripCountMap) { +static bool buildSliceTripCountMap( + OperationInst *srcOpInst, ComputationSliceState *sliceState, + llvm::SmallDenseMap *tripCountMap) { SmallVector srcLoopIVs; - getLoopIVs(*srcAccess->opInst, &srcLoopIVs); + getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Populate map from ForInst -> trip count for (unsigned i = 0; i < numSrcLoopIVs; ++i) { @@ -485,109 +507,166 @@ buildSliceTripCountMap(MemRefAccess *srcAccess, } return false; } - // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'. - // ub_expr - lb_expr - AffineExpr lbExpr(lbMap.getResult(0)); - AffineExpr ubExpr(ubMap.getResult(0)); - auto loopSpanExpr = simplifyAffineExpr( - ubExpr - lbExpr, std::max(lbMap.getNumDims(), ubMap.getNumDims()), - std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols())); - auto cExpr = loopSpanExpr.dyn_cast(); - if (!cExpr) + Optional tripCount = getConstDifference(lbMap, ubMap); + if (!tripCount.hasValue()) return false; - (*tripCountMap)[srcLoopIVs[i]] = cExpr.getValue(); + (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue(); } return true; } -// Returns the maximum loop depth within the source loop nest at which a -// sliced loop bound is detected in 'sliceState'. -static unsigned getMaxSrcLoopDepth(unsigned srcLoopDepthLimit, - ComputationSliceState *sliceState) { - unsigned maxSrcPos = 0; - for (unsigned i = 0; i < srcLoopDepthLimit; ++i) { - if (sliceState->lbs[i] != AffineMap::Null() && - sliceState->ubs[i] != AffineMap::Null()) { - maxSrcPos = std::max(maxSrcPos, i); - } +// Removes load operations from 'srcLoads' which operate on 'memref', and +// adds them to 'dstLoads'. +static void +moveLoadsAccessingMemrefTo(Value *memref, + SmallVectorImpl *srcLoads, + SmallVectorImpl *dstLoads) { + dstLoads->clear(); + SmallVector srcLoadsToKeep; + for (auto *load : *srcLoads) { + if (load->cast()->getMemRef() == memref) + dstLoads->push_back(load); + else + srcLoadsToKeep.push_back(load); } - return maxSrcPos + 1; + srcLoads->swap(srcLoadsToKeep); } -// Returns the minimum loop depth within the destination loop nest at which the -// computation slice can be inserted (based on the destination loop IVs that -// the source slice actually depends on / is a function of). -static unsigned getMinDstLoopDepth(unsigned srcLoopDepth, - ComputationSliceState *sliceState) { - // Record in 'maxDstLoopDepth' the largest position (+1) of a dst loop nest - // IV, which is used in a sliced loop bound in the src loop nest. - unsigned maxDstLoopDepth = 0; - for (unsigned i = 0; i < srcLoopDepth; ++i) { - if (AffineMap lbMap = sliceState->lbs[i]) { - lbMap.walkExprs([&](AffineExpr expr) { - if (auto dimExpr = expr.dyn_cast()) { - maxDstLoopDepth = - std::max(maxDstLoopDepth, dimExpr.getPosition() + 1); - } - }); - } - if (AffineMap ubMap = sliceState->ubs[i]) { - ubMap.walkExprs([&](AffineExpr expr) { - if (auto dimExpr = expr.dyn_cast()) { - maxDstLoopDepth = - std::max(maxDstLoopDepth, dimExpr.getPosition() + 1); - } - }); - } +// Returns the innermost common loop depth for the set of operations in 'ops'. +static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { + unsigned numOps = ops.size(); + assert(numOps > 0); + + std::vector> loops(numOps); + unsigned loopDepthLimit = std::numeric_limits::max(); + for (unsigned i = 0; i < numOps; ++i) { + getLoopIVs(*ops[i], &loops[i]); + loopDepthLimit = + std::min(loopDepthLimit, static_cast(loops[i].size())); } - return maxDstLoopDepth; + + unsigned loopDepth = 0; + for (unsigned d = 0; d < loopDepthLimit; ++d) { + unsigned i; + for (i = 1; i < numOps; ++i) { + if (loops[i - 1][d] != loops[i][d]) { + break; + } + } + if (i != numOps) + break; + ++loopDepth; + } + return loopDepth; } -// Checks the profitability of fusion candidate 'candidate'. Returns true if it -// profitable to fuse the candidate loop nests. Returns false otherwise. +// Returns true if 'map' is a single result constant or single result +// dim expr where its corresponding loop IV in 'operands' has zero constant +// lower bound. +static bool hasZeroMinValue(AffineMap map, ArrayRef operands) { + if (map.isSingleConstant() && map.getSingleConstantResult() == 0) + return true; + if (map.getNumResults() != 1 || !map.getResult(0).isa()) + return false; + // Get operand position of single dim expr result. + unsigned pos = map.getResult(0).cast().getPosition(); + // Check if loop IV at 'pos' has zero constant lower bound. + auto *operand = operands[pos]; + assert(isa(operand)); + auto *forInst = cast(operand); + return forInst->hasConstantLowerBound() && + forInst->getConstantLowerBound() == 0; +} +// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in +// 'sliceStateB'. +// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA' +// and 'sliceStateB' are aligned. +// Specifically, when taking the union of overlapping intervals, it assumes +// that both intervals start at zero. Support needs to be added to take into +// account interval start offset when computing the union. +// TODO(andydavis) Move this function to an analysis library. +static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, + ComputationSliceState *sliceStateB) { + assert(sliceStateA.lbs.size() == sliceStateB->lbs.size()); + assert(sliceStateA.ubs.size() == sliceStateB->ubs.size()); + + for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) { + AffineMap lbMapA = sliceStateA.lbs[i]; + AffineMap ubMapA = sliceStateA.ubs[i]; + if (lbMapA == AffineMap::Null()) { + assert(ubMapA == AffineMap::Null()); + continue; + } + assert(ubMapA != AffineMap::Null()); + // Validate that constant lower bounds are aligned at zero. + if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i])) + return false; + + AffineMap lbMapB = sliceStateB->lbs[i]; + AffineMap ubMapB = sliceStateB->ubs[i]; + if (lbMapB == AffineMap::Null()) { + assert(ubMapB == AffineMap::Null()); + // Union 'sliceStateB' does not have a bound for 'i' so copy from A. + sliceStateB->lbs[i] = lbMapA; + sliceStateB->ubs[i] = ubMapA; + continue; + } + // Validate that constant lower bounds are aligned at zero. + if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i])) + return false; + + // Add bound with the largest trip count to union. + Optional tripCountA = getConstDifference(lbMapA, ubMapA); + Optional tripCountB = getConstDifference(lbMapB, ubMapB); + if (!tripCountA.hasValue() || !tripCountB.hasValue()) + return false; + // TODO(andydavis) Change this code to take the min across all lower bounds + // and max across all upper bounds for each dimension. This code can for + // cases where a unique min or max could not be statically determined. + if (tripCountA.getValue() > tripCountB.getValue()) { + sliceStateB->lbs[i] = lbMapA; + sliceStateB->ubs[i] = ubMapA; + } + } + return true; +} + +// Checks the profitability of fusing a backwards slice of the loop nest +// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'. +// Returns true if it profitable to fuse the candidate loop nests. Returns +// false otherwise. // The profitability model executes the following steps: -// *) Computes the backward computation slice at 'candidate.srcAccess'. This -// computation slice of the loop nest surrounding 'candidate.srcAccess' is +// *) Computes the backward computation slice at 'srcOpInst'. This +// computation slice of the loop nest surrounding 'srcOpInst' is // represented by modified src loop bounds in 'sliceState', which are -// functions of loop IVs in the loop nest surrounding 'candidate.dstAccess'. +// functions of loop IVs in the loop nest surrounding 'srcOpInst'. // *) Computes the cost of unfused src/dst loop nests (currently the cost of a // loop nest is the total number of dynamic operation instances in the loop // nest). // *) Computes the cost of fusing a slice of the src loop nest into the dst -// loop nest at various values of src/dst loop depth, attempting to fuse -// the biggest compution slice (max src loop depth) at the maximal dst loop -// depth (closest to the load) to minimize reuse distance and opportunity for -// subsequent load/store forwarding. -// NOTE: 'srcLoopDepth' refers to the loop depth within the source loop nest -// at which we slice the loops bounds (all src loops below this depth will -// utilize full loop bounds). +// loop nest at various values of dst loop depth, attempting to fuse +// the largest compution slice at the maximal dst loop depth (closest to the +// load) to minimize reuse distance and potentially enable subsequent +// load/store forwarding. +// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for +// the same memref as is written by 'srcOpInst', then the union of slice +// loop bounds is used to compute the slice and associated slice cost. // NOTE: 'dstLoopDepth' refers the loop depth within the destination loop // nest, at which the src computation slice is inserted/fused. -// NOTE: We attempt to maximize the source loop depth, but there are cases -// where a particular setting for 'dstLoopNest' might fused an unsliced +// NOTE: We attempt to maximize the dst loop depth, but there are cases +// where a particular setting for 'dstLoopNest' might fuse an unsliced // loop (within the src computation slice) at a depth which results in // execessive recomputation (see unit tests for examples). // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. -static bool isFusionProfitable(FusionCandidate *candidate, +static bool isFusionProfitable(OperationInst *srcOpInst, + ArrayRef dstOpInsts, ComputationSliceState *sliceState, - unsigned *srcLoopDepth, unsigned *dstLoopDepth) { - // Compute backward computation slice state: src IV bounds w.r.t dst IVs, etc. - if (!mlir::getBackwardComputationSliceState( - candidate->srcAccess, candidate->dstAccess, sliceState)) { - return false; - } - - // Build trip count map for src loops with sliced loop bounds in 'sliceState'. - DenseMap sliceTripCountMap; - if (!buildSliceTripCountMap(&candidate->srcAccess, sliceState, - &sliceTripCountMap)) - return false; - + unsigned *dstLoopDepth) { // Compute cost of sliced and unsliced src loop nest. SmallVector srcLoopIVs; - getLoopIVs(*candidate->srcAccess.opInst, &srcLoopIVs); + getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Walk src loop nest and collect stats. @@ -600,8 +679,7 @@ static bool isFusionProfitable(FusionCandidate *candidate, // Compute cost of dst loop nest. SmallVector dstLoopIVs; - getLoopIVs(*candidate->dstAccess.opInst, &dstLoopIVs); - unsigned numDstLoopIVs = dstLoopIVs.size(); + getLoopIVs(*dstOpInsts[0], &dstLoopIVs); LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); @@ -610,51 +688,60 @@ static bool isFusionProfitable(FusionCandidate *candidate, if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; - // Search for min cost values for 'srcLoopDepth' and 'dstLoopDepth'. - // This search is O(n^2) where 'n' is very small (eg. six). - // TODO(andydavis) Consider a solution where we just iteration through - // dstLoopDepth possibilities and project out IVs we do not need (remove - // dependence on 'srcLoopDepth'. - DenseMap tripCountMap; - DenseMap computeCostMap; - unsigned maxSrcLoopDepth = getMaxSrcLoopDepth(numSrcLoopIVs, sliceState); + // Compute the innermost common loop for ops in 'dstOpInst'. + unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts); + if (maxDstLoopDepth == 0) + return false; + + // Search for min cost value for 'dstLoopDepth'. At each value of + // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice + // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union + // of these bounds). Next the union slice bounds are used to calculate + // the cost of the slice and the cost of the slice inserted into the dst + // loop nest at 'dstLoopDepth'. unsigned minFusedLoopNestComputeCost = std::numeric_limits::max(); - unsigned bestSrcLoopDepth; unsigned bestDstLoopDepth; - for (unsigned i = maxSrcLoopDepth; i >= 1; --i) { - // Compute minDstLoopDepth based on dst loop IVs used in slice loop bounds. - unsigned minDstLoopDepth = getMinDstLoopDepth(i, sliceState); - assert(minDstLoopDepth <= numDstLoopIVs); - if (minDstLoopDepth == 0) { - // TODO(andydavis) Support inserting computation slices at top-level. - continue; - } - // Copy elements from slice trip count map up to src loop depth 'i'. - tripCountMap.clear(); - for (unsigned k = 0; k < i; ++k) { - auto *forInst = srcLoopIVs[k]; - auto it = sliceTripCountMap.find(forInst); - if (it != sliceTripCountMap.end()) { - tripCountMap[forInst] = it->second; - } + SmallVector sliceStates; + sliceStates.resize(maxDstLoopDepth); + + llvm::SmallDenseMap sliceTripCountMap; + DenseMap computeCostMap; + for (unsigned i = maxDstLoopDepth; i >= 1; --i) { + MemRefAccess srcAccess(srcOpInst); + // Handle the common case of one dst load without a copy. + if (!mlir::getBackwardComputationSliceState( + srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1])) + return false; + // Compute the union of slice bound of all ops in 'dstOpInsts'. + for (int j = 1, e = dstOpInsts.size(); j < e; ++j) { + MemRefAccess dstAccess(dstOpInsts[j]); + ComputationSliceState tmpSliceState; + if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i, + &tmpSliceState)) + return false; + // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'. + getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]); } + // Build trip count map for computation slice. + sliceTripCountMap.clear(); + if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1], + &sliceTripCountMap)) + return false; + // Compute op instance count for the src loop nest with iteration slicing. uint64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &tripCountMap, + getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap, /*computeCostMap=*/nullptr); - for (unsigned j = numDstLoopIVs; j >= minDstLoopDepth; --j) { - // Compute cost of fusion for these values of 'i' and 'j'. - computeCostMap.clear(); - computeCostMap[dstLoopIVs[j - 1]] = sliceComputeCost; - uint64_t fusedLoopNestComputeCost = - getComputeCost(dstLoopIVs[0], &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, &computeCostMap); - if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) { - minFusedLoopNestComputeCost = fusedLoopNestComputeCost; - bestSrcLoopDepth = i; - bestDstLoopDepth = j; - } + // Compute cost of fusion for these values of 'i' and 'j'. + computeCostMap.clear(); + computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost; + uint64_t fusedLoopNestComputeCost = + getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, &computeCostMap); + if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) { + minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + bestDstLoopDepth = i; } } @@ -668,7 +755,6 @@ static bool isFusionProfitable(FusionCandidate *candidate, /*computeCostMap=*/nullptr); LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics " - << " bestSrcLoopDepth: " << bestSrcLoopDepth << " bestDstLoopDepth: " << bestDstLoopDepth << " srcLoopNestCost: " << srcLoopNestCost << " dstLoopNestCost: " << dstLoopNestCost @@ -680,25 +766,23 @@ static bool isFusionProfitable(FusionCandidate *candidate, // for load/store forwarding in cost model. if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost) return false; - // Set src/dstLoopDepth based on best values from search. - *srcLoopDepth = bestSrcLoopDepth; + // Update return parameter 'sliceState' with 'bestSliceState'. + ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1]; + sliceState->lbs = bestSliceState->lbs; + sliceState->ubs = bestSliceState->ubs; + sliceState->lbOperands = bestSliceState->lbOperands; + sliceState->ubOperands = bestSliceState->ubOperands; + // Set dstLoopDepth based on best values from search. *dstLoopDepth = bestDstLoopDepth; - // Update 'sliceState' bounds based on computed 'srcLoopDepth': - // *) Canonicalize affine map now that 'srcLoopDepth' has been chosen. - // *) Replace slice bound maps at depth > 'srcLoopDepth' withAffineMap::Null() + // Canonicalize slice bound affine maps. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - if (i < bestSrcLoopDepth) { - if (sliceState->lbs[i] != AffineMap::Null()) { - canonicalizeMapAndOperands(&sliceState->lbs[i], - &sliceState->lbOperands[i]); - } - if (sliceState->ubs[i] != AffineMap::Null()) { - canonicalizeMapAndOperands(&sliceState->ubs[i], - &sliceState->ubOperands[i]); - } - } else { - sliceState->lbs[i] = AffineMap::Null(); - sliceState->ubs[i] = AffineMap::Null(); + if (sliceState->lbs[i] != AffineMap::Null()) { + canonicalizeMapAndOperands(&sliceState->lbs[i], + &sliceState->lbOperands[i]); + } + if (sliceState->ubs[i] != AffineMap::Null()) { + canonicalizeMapAndOperands(&sliceState->ubs[i], + &sliceState->ubOperands[i]); } } return true; @@ -767,12 +851,12 @@ public: continue; SmallVector loads = dstNode->loads; + SmallVector dstLoadOpInsts; while (!loads.empty()) { - auto *dstLoadOpInst = loads.pop_back_val(); - auto *memref = dstLoadOpInst->cast()->getMemRef(); - // Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'. - if (dstNode->getLoadOpCount(memref) != 1) - continue; + // Get memref of load on top of the stack. + auto *memref = loads.back()->cast()->getMemRef(); + // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'. + moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts); // Skip if no input edges along which to fuse. if (mdg->inEdges.count(dstId) == 0) continue; @@ -801,19 +885,15 @@ public: continue; // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); - // Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'. - FusionCandidate candidate(srcStoreOpInst, dstLoadOpInst); // Check if fusion would be profitable. - unsigned srcLoopDepth; unsigned dstLoopDepth; mlir::ComputationSliceState sliceState; - if (!isFusionProfitable(&candidate, &sliceState, &srcLoopDepth, + if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState, &dstLoopDepth)) continue; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto *sliceLoopNest = mlir::insertBackwardComputationSlice( - &candidate.srcAccess, &candidate.dstAccess, &sliceState, - srcLoopDepth, dstLoopDepth); + srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { // Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode' mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 525c9d63ad08..61335be227f1 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -95,13 +95,18 @@ func @should_fuse_loop_nests_with_shifts() { } } + // The cost of fusing the src loop nest at dst loop depth 1 is less expensive + // than fusing at dst loop depth 2, because at dst loop depth 1, we are + // able to reduce the trip count around the %i1 loop by one (because the + // dst loop never reads the last element written by the src loop). // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0) - // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1) - // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2) - // CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32> - // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32> + // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0) + // CHECK-NEXT: for %i1 = 0 to 9 { + // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %i1) + // CHECK-NEXT: store %cst, %0[%2#0, %2#1] : memref<10x10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: %3 = load %0[%i0, %i2] : memref<10x10xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return @@ -849,6 +854,7 @@ func @should_fuse_src_depth1_at_dst_depth2() { } // ----- +// CHECK: #map0 = ()[s0] -> (s0) // CHECK-LABEL: func @fusion_at_depth0_not_currently_supported func @fusion_at_depth0_not_currently_supported() { @@ -862,10 +868,9 @@ func @fusion_at_depth0_not_currently_supported() { %1 = load %0[%c0] : memref<10xf32> } // CHECK:for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 10 { - // CHECK-NEXT: %1 = load %0[%c0] : memref<10xf32> + // CHECK-NEXT: %1 = affine_apply #map0()[%c0] + // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32> + // CHECK-NEXT: %2 = load %0[%c0] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -977,3 +982,128 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: return return } + +// ----- +// CHECK: #map0 = (d0) -> (d0) + +// CHECK-LABEL: func @should_fuse_at_depth1_and_reduce_slice_trip_count +func @should_fuse_at_depth1_and_reduce_slice_trip_count() { + %a = alloc() : memref<4x256xf32> + %b = alloc() : memref<4x256xf32> + + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + + for %i0 = 0 to 4 { + for %i1 = 0 to 256 { + %v0 = load %b[%i0, %i1] : memref<4x256xf32> + } + for %i2 = 0 to 256 { + store %cf0, %a[%i0, %i2] : memref<4x256xf32> + } + } + + for %d0 = 0 to 4 { + for %d1 = 0 to 16 { + %v1 = load %a[%d0, %d1] : memref<4x256xf32> + } + } + // The cost of fusing at depth 2 is greater than the cost of fusing at depth 1 + // for two reasons: + // 1) Inserting the unsliceable src loop %i1 to a higher depth removes + // redundant computation and reduces costs. + // 2) Inserting the sliceable src loop %i2 at depth 1, we can still reduce + // its trip count to 16 (from 256) reducing costs. + // CHECK: for %i0 = 0 to 4 { + // CHECK-NEXT: %2 = affine_apply #map0(%i0) + // CHECK-NEXT: for %i1 = 0 to 256 { + // CHECK-NEXT: %3 = load %1[%2, %i1] : memref<4x256xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 16 { + // CHECK-NEXT: store %cst, %0[%2, %i2] : memref<4x256xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i3 = 0 to 16 { + // CHECK-NEXT: %4 = load %0[%i0, %i3] : memref<4x256xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_20 +func @should_fuse_at_depth1_with_trip_count_20() { + %a = alloc() : memref<100xf32> + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + + for %i0 = 0 to 100 { + store %cf0, %a[%i0]: memref<100xf32> + } + + for %i1 = 0 to 5 { + for %i2 = 0 to 10 { + %v0 = load %a[%i2]: memref<100xf32> + } + for %i3 = 0 to 10 { + for %i4 = 0 to 20 { + %v1 = load %a[%i4]: memref<100xf32> + } + } + } + // CHECK: for %i0 = 0 to 5 { + // CHECK-NEXT: for %i1 = 0 to 20 { + // CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 10 { + // CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i3 = 0 to 10 { + // CHECK-NEXT: for %i4 = 0 to 20 { + // CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_19 +func @should_fuse_at_depth1_with_trip_count_19() { + %a = alloc() : memref<100xf32> + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + + for %i0 = 0 to 100 { + store %cf0, %a[%i0]: memref<100xf32> + } + + for %i1 = 0 to 5 { + for %i2 = 0 to 19 { + %v0 = load %a[%i2]: memref<100xf32> + } + for %i3 = 0 to 10 { + for %i4 = 0 to 10 { + %v1 = load %a[%i4]: memref<100xf32> + } + } + } + // CHECK: for %i0 = 0 to 5 { + // CHECK-NEXT: for %i1 = 0 to 19 { + // CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i2 = 0 to 19 { + // CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i3 = 0 to 10 { + // CHECK-NEXT: for %i4 = 0 to 10 { + // CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + return +}