forked from OSchip/llvm-project
LoopFusion improvements:
*) Adds support for fusing into consumer loop nests with multiple loads from the same memref. *) Adds support for reducing slice loop trip count by projecting out destination loop IVs greater than destination loop depth. *) Removes dependence on src loop depth and simplifies cost model computation. PiperOrigin-RevId: 229575126
This commit is contained in:
parent
9d4bb57189
commit
27d067e164
|
@ -173,23 +173,23 @@ struct ComputationSliceState {
|
||||||
/// Returns true on success, false otherwise.
|
/// Returns true on success, false otherwise.
|
||||||
bool getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
bool getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||||
const MemRefAccess &dstAccess,
|
const MemRefAccess &dstAccess,
|
||||||
|
unsigned dstLoopDepth,
|
||||||
ComputationSliceState *sliceState);
|
ComputationSliceState *sliceState);
|
||||||
|
|
||||||
/// Creates a clone of the computation contained in the loop nest surrounding
|
/// Creates a clone of the computation contained in the loop nest surrounding
|
||||||
/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop
|
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
|
||||||
/// IVs, and inserts the computation slice at the beginning of the instruction
|
/// in 'sliceState', and inserts the computation slice at the beginning of the
|
||||||
/// block of the loop at 'dstLoopDepth' in the loop nest surrounding
|
/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding
|
||||||
/// 'dstAccess'. Returns the top-level loop of the computation slice on
|
/// 'dstOpInst'. Returns the top-level loop of the computation slice on
|
||||||
/// success, returns nullptr otherwise.
|
/// success, returns nullptr otherwise.
|
||||||
// Loop depth is a crucial optimization choice that determines where to
|
// Loop depth is a crucial optimization choice that determines where to
|
||||||
// materialize the results of the backward slice - presenting a trade-off b/w
|
// materialize the results of the backward slice - presenting a trade-off b/w
|
||||||
// storage and redundant computation in several cases.
|
// storage and redundant computation in several cases.
|
||||||
// TODO(andydavis) Support computation slices with common surrounding loops.
|
// TODO(andydavis) Support computation slices with common surrounding loops.
|
||||||
ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst,
|
||||||
MemRefAccess *dstAccess,
|
OperationInst *dstOpInst,
|
||||||
ComputationSliceState *sliceState,
|
unsigned dstLoopDepth,
|
||||||
unsigned srcLoopDepth,
|
ComputationSliceState *sliceState);
|
||||||
unsigned dstLoopDepth);
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_ANALYSIS_UTILS_H
|
#endif // MLIR_ANALYSIS_UTILS_H
|
||||||
|
|
|
@ -1101,8 +1101,21 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
|
||||||
(*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
|
(*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
|
||||||
(*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
|
(*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
|
||||||
} else {
|
} else {
|
||||||
(*lbMaps)[pos] = AffineMap::Null();
|
// TODO(andydavis, bondhugula) Add support for computing slice bounds
|
||||||
(*ubMaps)[pos] = AffineMap::Null();
|
// symbolic in the identifies [num, numIds).
|
||||||
|
auto lbConst = getConstantLowerBound(pos);
|
||||||
|
auto ubConst = getConstantUpperBound(pos);
|
||||||
|
if (lbConst.hasValue() && ubConst.hasValue()) {
|
||||||
|
(*lbMaps)[pos] = AffineMap::get(
|
||||||
|
numMapDims, numMapSymbols,
|
||||||
|
getAffineConstantExpr(lbConst.getValue(), context), {});
|
||||||
|
(*ubMaps)[pos] = AffineMap::get(
|
||||||
|
numMapDims, numMapSymbols,
|
||||||
|
getAffineConstantExpr(ubConst.getValue() + 1, context), {});
|
||||||
|
} else {
|
||||||
|
(*lbMaps)[pos] = AffineMap::Null();
|
||||||
|
(*ubMaps)[pos] = AffineMap::Null();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
|
LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
|
||||||
LLVM_DEBUG(expr.dump(););
|
LLVM_DEBUG(expr.dump(););
|
||||||
|
|
|
@ -346,12 +346,13 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the
|
// Computes memref dependence between 'srcAccess' and 'dstAccess', projects
|
||||||
// dependence constraint system to create AffineMaps with which to adjust the
|
// out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice
|
||||||
// loop bounds of the inserted compution slice so that they are functions of the
|
// bounds in 'sliceState' which represent the src IVs in terms of the dst IVs,
|
||||||
// loop IVs and symbols of the loops surrounding 'dstAccess'.
|
// symbols and constants.
|
||||||
bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||||
const MemRefAccess &dstAccess,
|
const MemRefAccess &dstAccess,
|
||||||
|
unsigned dstLoopDepth,
|
||||||
ComputationSliceState *sliceState) {
|
ComputationSliceState *sliceState) {
|
||||||
FlatAffineConstraints dependenceConstraints;
|
FlatAffineConstraints dependenceConstraints;
|
||||||
if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1,
|
if (!checkMemrefAccessDependence(srcAccess, dstAccess, /*loopDepth=*/1,
|
||||||
|
@ -364,6 +365,19 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||||
getLoopIVs(*srcAccess.opInst, &srcLoopIVs);
|
getLoopIVs(*srcAccess.opInst, &srcLoopIVs);
|
||||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||||||
|
|
||||||
|
// Get loop nest surrounding dst operation.
|
||||||
|
SmallVector<ForInst *, 4> dstLoopIVs;
|
||||||
|
getLoopIVs(*dstAccess.opInst, &dstLoopIVs);
|
||||||
|
unsigned numDstLoopIVs = dstLoopIVs.size();
|
||||||
|
if (dstLoopDepth > numDstLoopIVs) {
|
||||||
|
dstAccess.opInst->emitError("invalid destination loop depth");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Project out dimensions other than those up to 'dstLoopDepth'.
|
||||||
|
dependenceConstraints.projectOut(numSrcLoopIVs + dstLoopDepth,
|
||||||
|
numDstLoopIVs - dstLoopDepth);
|
||||||
|
|
||||||
// Set up lower/upper bound affine maps for the slice.
|
// Set up lower/upper bound affine maps for the slice.
|
||||||
sliceState->lbs.resize(numSrcLoopIVs, AffineMap::Null());
|
sliceState->lbs.resize(numSrcLoopIVs, AffineMap::Null());
|
||||||
sliceState->ubs.resize(numSrcLoopIVs, AffineMap::Null());
|
sliceState->ubs.resize(numSrcLoopIVs, AffineMap::Null());
|
||||||
|
@ -385,12 +399,10 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a computation slice of the loop nest surrounding 'srcAccess'
|
/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
|
||||||
/// utilizing slice loop bounds in 'sliceState' (for src loops up to
|
/// updates the slice loop bounds with any non-null bound maps specified in
|
||||||
/// 'srcLoopDepth'), and inserts this slice into loop nest surrounding
|
/// 'sliceState', and inserts this slice into the loop nest surrounding
|
||||||
/// 'dstAccess' at loop depth 'dstLoopDepth'. For all loops at loop depth
|
/// 'dstOpInst' at loop depth 'dstLoopDepth'.
|
||||||
/// greater than 'srcLoopDepth' their full loop bounds will be used in the
|
|
||||||
/// slice.
|
|
||||||
// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that
|
// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that
|
||||||
// aren't necessarily a one-to-one relation b/w the source and destination. The
|
// aren't necessarily a one-to-one relation b/w the source and destination. The
|
||||||
// relation between the source and destination could be many-to-many in general.
|
// relation between the source and destination could be many-to-many in general.
|
||||||
|
@ -401,33 +413,27 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||||
// solution.
|
// solution.
|
||||||
// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project
|
// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project
|
||||||
// out loop IVs we don't care about and produce smaller slice.
|
// out loop IVs we don't care about and produce smaller slice.
|
||||||
ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
ForInst *mlir::insertBackwardComputationSlice(
|
||||||
MemRefAccess *dstAccess,
|
OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth,
|
||||||
ComputationSliceState *sliceState,
|
ComputationSliceState *sliceState) {
|
||||||
unsigned srcLoopDepth,
|
|
||||||
unsigned dstLoopDepth) {
|
|
||||||
// Get loop nest surrounding src operation.
|
// Get loop nest surrounding src operation.
|
||||||
SmallVector<ForInst *, 4> srcLoopIVs;
|
SmallVector<ForInst *, 4> srcLoopIVs;
|
||||||
getLoopIVs(*srcAccess->opInst, &srcLoopIVs);
|
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||||||
if (srcLoopDepth > numSrcLoopIVs) {
|
|
||||||
srcAccess->opInst->emitError("invalid source loop depth");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get loop nest surrounding dst operation.
|
// Get loop nest surrounding dst operation.
|
||||||
SmallVector<ForInst *, 4> dstLoopIVs;
|
SmallVector<ForInst *, 4> dstLoopIVs;
|
||||||
getLoopIVs(*dstAccess->opInst, &dstLoopIVs);
|
getLoopIVs(*dstOpInst, &dstLoopIVs);
|
||||||
unsigned dstLoopIVsSize = dstLoopIVs.size();
|
unsigned dstLoopIVsSize = dstLoopIVs.size();
|
||||||
if (dstLoopDepth > dstLoopIVsSize) {
|
if (dstLoopDepth > dstLoopIVsSize) {
|
||||||
dstAccess->opInst->emitError("invalid destination loop depth");
|
dstOpInst->emitError("invalid destination loop depth");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the inst block positions of 'srcAccess->opInst' within 'srcLoopIVs'.
|
// Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'.
|
||||||
SmallVector<unsigned, 4> positions;
|
SmallVector<unsigned, 4> positions;
|
||||||
// TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
|
// TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
|
||||||
findInstPosition(srcAccess->opInst, srcLoopIVs[0]->getBlock(), &positions);
|
findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
|
||||||
|
|
||||||
// Clone src loop nest and insert it a the beginning of the instruction block
|
// Clone src loop nest and insert it a the beginning of the instruction block
|
||||||
// of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
|
// of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
|
||||||
|
@ -451,7 +457,7 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
||||||
assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
|
assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
|
||||||
|
|
||||||
// Update loop bounds for loops in 'sliceLoopNest'.
|
// Update loop bounds for loops in 'sliceLoopNest'.
|
||||||
for (unsigned i = 0; i < srcLoopDepth; ++i) {
|
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||||||
auto *forInst = sliceSurroundingLoops[dstLoopDepth + i];
|
auto *forInst = sliceSurroundingLoops[dstLoopDepth + i];
|
||||||
if (AffineMap lbMap = sliceState->lbs[i])
|
if (AffineMap lbMap = sliceState->lbs[i])
|
||||||
forInst->setLowerBound(sliceState->lbOperands[i], lbMap);
|
forInst->setLowerBound(sliceState->lbOperands[i], lbMap);
|
||||||
|
|
|
@ -69,17 +69,6 @@ char LoopFusion::passID = 0;
|
||||||
|
|
||||||
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
|
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
|
||||||
|
|
||||||
// FusionCandidate encapsulates source and destination memref access within
|
|
||||||
// loop nests which are candidates for loop fusion.
|
|
||||||
struct FusionCandidate {
|
|
||||||
// Load or store access within src loop nest to be fused into dst loop nest.
|
|
||||||
MemRefAccess srcAccess;
|
|
||||||
// Load or store access within dst loop nest.
|
|
||||||
MemRefAccess dstAccess;
|
|
||||||
explicit FusionCandidate(OperationInst *src, OperationInst *dst)
|
|
||||||
: srcAccess(MemRefAccess(src)), dstAccess(MemRefAccess(dst)) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// LoopNestStateCollector walks loop nests and collects load and store
|
// LoopNestStateCollector walks loop nests and collects load and store
|
||||||
|
@ -172,10 +161,27 @@ public:
|
||||||
return &it->second;
|
return &it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns true iff there is an edge from node 'srcId' to node 'dstId' for
|
||||||
|
// 'memref'. Returns false otherwise.
|
||||||
|
bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) {
|
||||||
|
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
|
||||||
|
return edge.id == dstId && edge.memref == memref;
|
||||||
|
});
|
||||||
|
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
|
||||||
|
return edge.id == srcId && edge.memref == memref;
|
||||||
|
});
|
||||||
|
return hasOutEdge && hasInEdge;
|
||||||
|
}
|
||||||
|
|
||||||
// Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
|
// Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
|
||||||
void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
|
void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
|
||||||
outEdges[srcId].push_back({dstId, memref});
|
if (!hasEdge(srcId, dstId, memref)) {
|
||||||
inEdges[dstId].push_back({srcId, memref});
|
outEdges[srcId].push_back({dstId, memref});
|
||||||
|
inEdges[dstId].push_back({srcId, memref});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
|
// Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
|
||||||
|
@ -425,10 +431,10 @@ public:
|
||||||
// inserting a sliced loop nest of known cost into the loop's body.
|
// inserting a sliced loop nest of known cost into the loop's body.
|
||||||
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
|
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
|
||||||
// within another loop.
|
// within another loop.
|
||||||
static uint64_t
|
static uint64_t getComputeCost(
|
||||||
getComputeCost(ForInst *forInst, LoopNestStats *stats,
|
ForInst *forInst, LoopNestStats *stats,
|
||||||
DenseMap<ForInst *, uint64_t> *tripCountOverrideMap,
|
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
|
||||||
DenseMap<ForInst *, uint64_t> *computeCostMap) {
|
DenseMap<ForInst *, uint64_t> *computeCostMap) {
|
||||||
// 'opCount' is the total number operations in one iteration of 'forInst' body
|
// 'opCount' is the total number operations in one iteration of 'forInst' body
|
||||||
uint64_t opCount = stats->opCountMap[forInst];
|
uint64_t opCount = stats->opCountMap[forInst];
|
||||||
if (stats->loopMap.count(forInst) > 0) {
|
if (stats->loopMap.count(forInst) > 0) {
|
||||||
|
@ -458,17 +464,33 @@ getComputeCost(ForInst *forInst, LoopNestStats *stats,
|
||||||
|
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
|
||||||
|
assert(lbMap.getNumResults() == 1);
|
||||||
|
assert(ubMap.getNumResults() == 1);
|
||||||
|
assert(lbMap.getNumDims() == ubMap.getNumDims());
|
||||||
|
assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
|
||||||
|
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
|
||||||
|
// ub_expr - lb_expr
|
||||||
|
AffineExpr lbExpr(lbMap.getResult(0));
|
||||||
|
AffineExpr ubExpr(ubMap.getResult(0));
|
||||||
|
auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
|
||||||
|
lbMap.getNumSymbols());
|
||||||
|
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
|
||||||
|
if (!cExpr)
|
||||||
|
return None;
|
||||||
|
return cExpr.getValue();
|
||||||
|
}
|
||||||
|
|
||||||
// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
|
// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
|
||||||
// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
|
// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
|
||||||
// Returns true on success, false otherwise (if a non-constant trip count
|
// Returns true on success, false otherwise (if a non-constant trip count
|
||||||
// was encountered).
|
// was encountered).
|
||||||
// TODO(andydavis) Make this work with non-unit step loops.
|
// TODO(andydavis) Make this work with non-unit step loops.
|
||||||
static bool
|
static bool buildSliceTripCountMap(
|
||||||
buildSliceTripCountMap(MemRefAccess *srcAccess,
|
OperationInst *srcOpInst, ComputationSliceState *sliceState,
|
||||||
ComputationSliceState *sliceState,
|
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
|
||||||
DenseMap<ForInst *, uint64_t> *tripCountMap) {
|
|
||||||
SmallVector<ForInst *, 4> srcLoopIVs;
|
SmallVector<ForInst *, 4> srcLoopIVs;
|
||||||
getLoopIVs(*srcAccess->opInst, &srcLoopIVs);
|
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||||||
// Populate map from ForInst -> trip count
|
// Populate map from ForInst -> trip count
|
||||||
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||||||
|
@ -485,109 +507,166 @@ buildSliceTripCountMap(MemRefAccess *srcAccess,
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
|
Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
|
||||||
// ub_expr - lb_expr
|
if (!tripCount.hasValue())
|
||||||
AffineExpr lbExpr(lbMap.getResult(0));
|
|
||||||
AffineExpr ubExpr(ubMap.getResult(0));
|
|
||||||
auto loopSpanExpr = simplifyAffineExpr(
|
|
||||||
ubExpr - lbExpr, std::max(lbMap.getNumDims(), ubMap.getNumDims()),
|
|
||||||
std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols()));
|
|
||||||
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
|
|
||||||
if (!cExpr)
|
|
||||||
return false;
|
return false;
|
||||||
(*tripCountMap)[srcLoopIVs[i]] = cExpr.getValue();
|
(*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the maximum loop depth within the source loop nest at which a
|
// Removes load operations from 'srcLoads' which operate on 'memref', and
|
||||||
// sliced loop bound is detected in 'sliceState'.
|
// adds them to 'dstLoads'.
|
||||||
static unsigned getMaxSrcLoopDepth(unsigned srcLoopDepthLimit,
|
static void
|
||||||
ComputationSliceState *sliceState) {
|
moveLoadsAccessingMemrefTo(Value *memref,
|
||||||
unsigned maxSrcPos = 0;
|
SmallVectorImpl<OperationInst *> *srcLoads,
|
||||||
for (unsigned i = 0; i < srcLoopDepthLimit; ++i) {
|
SmallVectorImpl<OperationInst *> *dstLoads) {
|
||||||
if (sliceState->lbs[i] != AffineMap::Null() &&
|
dstLoads->clear();
|
||||||
sliceState->ubs[i] != AffineMap::Null()) {
|
SmallVector<OperationInst *, 4> srcLoadsToKeep;
|
||||||
maxSrcPos = std::max(maxSrcPos, i);
|
for (auto *load : *srcLoads) {
|
||||||
}
|
if (load->cast<LoadOp>()->getMemRef() == memref)
|
||||||
|
dstLoads->push_back(load);
|
||||||
|
else
|
||||||
|
srcLoadsToKeep.push_back(load);
|
||||||
}
|
}
|
||||||
return maxSrcPos + 1;
|
srcLoads->swap(srcLoadsToKeep);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the minimum loop depth within the destination loop nest at which the
|
// Returns the innermost common loop depth for the set of operations in 'ops'.
|
||||||
// computation slice can be inserted (based on the destination loop IVs that
|
static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
|
||||||
// the source slice actually depends on / is a function of).
|
unsigned numOps = ops.size();
|
||||||
static unsigned getMinDstLoopDepth(unsigned srcLoopDepth,
|
assert(numOps > 0);
|
||||||
ComputationSliceState *sliceState) {
|
|
||||||
// Record in 'maxDstLoopDepth' the largest position (+1) of a dst loop nest
|
std::vector<SmallVector<ForInst *, 4>> loops(numOps);
|
||||||
// IV, which is used in a sliced loop bound in the src loop nest.
|
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
|
||||||
unsigned maxDstLoopDepth = 0;
|
for (unsigned i = 0; i < numOps; ++i) {
|
||||||
for (unsigned i = 0; i < srcLoopDepth; ++i) {
|
getLoopIVs(*ops[i], &loops[i]);
|
||||||
if (AffineMap lbMap = sliceState->lbs[i]) {
|
loopDepthLimit =
|
||||||
lbMap.walkExprs([&](AffineExpr expr) {
|
std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
|
||||||
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
|
|
||||||
maxDstLoopDepth =
|
|
||||||
std::max(maxDstLoopDepth, dimExpr.getPosition() + 1);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (AffineMap ubMap = sliceState->ubs[i]) {
|
|
||||||
ubMap.walkExprs([&](AffineExpr expr) {
|
|
||||||
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
|
|
||||||
maxDstLoopDepth =
|
|
||||||
std::max(maxDstLoopDepth, dimExpr.getPosition() + 1);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return maxDstLoopDepth;
|
|
||||||
|
unsigned loopDepth = 0;
|
||||||
|
for (unsigned d = 0; d < loopDepthLimit; ++d) {
|
||||||
|
unsigned i;
|
||||||
|
for (i = 1; i < numOps; ++i) {
|
||||||
|
if (loops[i - 1][d] != loops[i][d]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (i != numOps)
|
||||||
|
break;
|
||||||
|
++loopDepth;
|
||||||
|
}
|
||||||
|
return loopDepth;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks the profitability of fusion candidate 'candidate'. Returns true if it
|
// Returns true if 'map' is a single result constant or single result
|
||||||
// profitable to fuse the candidate loop nests. Returns false otherwise.
|
// dim expr where its corresponding loop IV in 'operands' has zero constant
|
||||||
|
// lower bound.
|
||||||
|
static bool hasZeroMinValue(AffineMap map, ArrayRef<Value *> operands) {
|
||||||
|
if (map.isSingleConstant() && map.getSingleConstantResult() == 0)
|
||||||
|
return true;
|
||||||
|
if (map.getNumResults() != 1 || !map.getResult(0).isa<AffineDimExpr>())
|
||||||
|
return false;
|
||||||
|
// Get operand position of single dim expr result.
|
||||||
|
unsigned pos = map.getResult(0).cast<AffineDimExpr>().getPosition();
|
||||||
|
// Check if loop IV at 'pos' has zero constant lower bound.
|
||||||
|
auto *operand = operands[pos];
|
||||||
|
assert(isa<ForInst>(operand));
|
||||||
|
auto *forInst = cast<ForInst>(operand);
|
||||||
|
return forInst->hasConstantLowerBound() &&
|
||||||
|
forInst->getConstantLowerBound() == 0;
|
||||||
|
}
|
||||||
|
// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in
|
||||||
|
// 'sliceStateB'.
|
||||||
|
// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
|
||||||
|
// and 'sliceStateB' are aligned.
|
||||||
|
// Specifically, when taking the union of overlapping intervals, it assumes
|
||||||
|
// that both intervals start at zero. Support needs to be added to take into
|
||||||
|
// account interval start offset when computing the union.
|
||||||
|
// TODO(andydavis) Move this function to an analysis library.
|
||||||
|
static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
||||||
|
ComputationSliceState *sliceStateB) {
|
||||||
|
assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
|
||||||
|
assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
|
||||||
|
|
||||||
|
for (unsigned i = 0, e = sliceStateA.lbs.size(); i < e; ++i) {
|
||||||
|
AffineMap lbMapA = sliceStateA.lbs[i];
|
||||||
|
AffineMap ubMapA = sliceStateA.ubs[i];
|
||||||
|
if (lbMapA == AffineMap::Null()) {
|
||||||
|
assert(ubMapA == AffineMap::Null());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
assert(ubMapA != AffineMap::Null());
|
||||||
|
// Validate that constant lower bounds are aligned at zero.
|
||||||
|
if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i]))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
AffineMap lbMapB = sliceStateB->lbs[i];
|
||||||
|
AffineMap ubMapB = sliceStateB->ubs[i];
|
||||||
|
if (lbMapB == AffineMap::Null()) {
|
||||||
|
assert(ubMapB == AffineMap::Null());
|
||||||
|
// Union 'sliceStateB' does not have a bound for 'i' so copy from A.
|
||||||
|
sliceStateB->lbs[i] = lbMapA;
|
||||||
|
sliceStateB->ubs[i] = ubMapA;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Validate that constant lower bounds are aligned at zero.
|
||||||
|
if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i]))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Add bound with the largest trip count to union.
|
||||||
|
Optional<uint64_t> tripCountA = getConstDifference(lbMapA, ubMapA);
|
||||||
|
Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
|
||||||
|
if (!tripCountA.hasValue() || !tripCountB.hasValue())
|
||||||
|
return false;
|
||||||
|
// TODO(andydavis) Change this code to take the min across all lower bounds
|
||||||
|
// and max across all upper bounds for each dimension. This code can for
|
||||||
|
// cases where a unique min or max could not be statically determined.
|
||||||
|
if (tripCountA.getValue() > tripCountB.getValue()) {
|
||||||
|
sliceStateB->lbs[i] = lbMapA;
|
||||||
|
sliceStateB->ubs[i] = ubMapA;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks the profitability of fusing a backwards slice of the loop nest
|
||||||
|
// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
|
||||||
|
// Returns true if it profitable to fuse the candidate loop nests. Returns
|
||||||
|
// false otherwise.
|
||||||
// The profitability model executes the following steps:
|
// The profitability model executes the following steps:
|
||||||
// *) Computes the backward computation slice at 'candidate.srcAccess'. This
|
// *) Computes the backward computation slice at 'srcOpInst'. This
|
||||||
// computation slice of the loop nest surrounding 'candidate.srcAccess' is
|
// computation slice of the loop nest surrounding 'srcOpInst' is
|
||||||
// represented by modified src loop bounds in 'sliceState', which are
|
// represented by modified src loop bounds in 'sliceState', which are
|
||||||
// functions of loop IVs in the loop nest surrounding 'candidate.dstAccess'.
|
// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
|
||||||
// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
|
// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
|
||||||
// loop nest is the total number of dynamic operation instances in the loop
|
// loop nest is the total number of dynamic operation instances in the loop
|
||||||
// nest).
|
// nest).
|
||||||
// *) Computes the cost of fusing a slice of the src loop nest into the dst
|
// *) Computes the cost of fusing a slice of the src loop nest into the dst
|
||||||
// loop nest at various values of src/dst loop depth, attempting to fuse
|
// loop nest at various values of dst loop depth, attempting to fuse
|
||||||
// the biggest compution slice (max src loop depth) at the maximal dst loop
|
// the largest compution slice at the maximal dst loop depth (closest to the
|
||||||
// depth (closest to the load) to minimize reuse distance and opportunity for
|
// load) to minimize reuse distance and potentially enable subsequent
|
||||||
// subsequent load/store forwarding.
|
// load/store forwarding.
|
||||||
// NOTE: 'srcLoopDepth' refers to the loop depth within the source loop nest
|
// NOTE: If the dst loop nest includes multiple loads in 'dstOpInsts' for
|
||||||
// at which we slice the loops bounds (all src loops below this depth will
|
// the same memref as is written by 'srcOpInst', then the union of slice
|
||||||
// utilize full loop bounds).
|
// loop bounds is used to compute the slice and associated slice cost.
|
||||||
// NOTE: 'dstLoopDepth' refers the loop depth within the destination loop
|
// NOTE: 'dstLoopDepth' refers the loop depth within the destination loop
|
||||||
// nest, at which the src computation slice is inserted/fused.
|
// nest, at which the src computation slice is inserted/fused.
|
||||||
// NOTE: We attempt to maximize the source loop depth, but there are cases
|
// NOTE: We attempt to maximize the dst loop depth, but there are cases
|
||||||
// where a particular setting for 'dstLoopNest' might fused an unsliced
|
// where a particular setting for 'dstLoopNest' might fuse an unsliced
|
||||||
// loop (within the src computation slice) at a depth which results in
|
// loop (within the src computation slice) at a depth which results in
|
||||||
// execessive recomputation (see unit tests for examples).
|
// execessive recomputation (see unit tests for examples).
|
||||||
// *) Compares the total cost of the unfused loop nests to the min cost fused
|
// *) Compares the total cost of the unfused loop nests to the min cost fused
|
||||||
// loop nest computed in the previous step, and returns true if the latter
|
// loop nest computed in the previous step, and returns true if the latter
|
||||||
// is lower.
|
// is lower.
|
||||||
static bool isFusionProfitable(FusionCandidate *candidate,
|
static bool isFusionProfitable(OperationInst *srcOpInst,
|
||||||
|
ArrayRef<OperationInst *> dstOpInsts,
|
||||||
ComputationSliceState *sliceState,
|
ComputationSliceState *sliceState,
|
||||||
unsigned *srcLoopDepth, unsigned *dstLoopDepth) {
|
unsigned *dstLoopDepth) {
|
||||||
// Compute backward computation slice state: src IV bounds w.r.t dst IVs, etc.
|
|
||||||
if (!mlir::getBackwardComputationSliceState(
|
|
||||||
candidate->srcAccess, candidate->dstAccess, sliceState)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build trip count map for src loops with sliced loop bounds in 'sliceState'.
|
|
||||||
DenseMap<ForInst *, uint64_t> sliceTripCountMap;
|
|
||||||
if (!buildSliceTripCountMap(&candidate->srcAccess, sliceState,
|
|
||||||
&sliceTripCountMap))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
// Compute cost of sliced and unsliced src loop nest.
|
// Compute cost of sliced and unsliced src loop nest.
|
||||||
SmallVector<ForInst *, 4> srcLoopIVs;
|
SmallVector<ForInst *, 4> srcLoopIVs;
|
||||||
getLoopIVs(*candidate->srcAccess.opInst, &srcLoopIVs);
|
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||||||
|
|
||||||
// Walk src loop nest and collect stats.
|
// Walk src loop nest and collect stats.
|
||||||
|
@ -600,8 +679,7 @@ static bool isFusionProfitable(FusionCandidate *candidate,
|
||||||
|
|
||||||
// Compute cost of dst loop nest.
|
// Compute cost of dst loop nest.
|
||||||
SmallVector<ForInst *, 4> dstLoopIVs;
|
SmallVector<ForInst *, 4> dstLoopIVs;
|
||||||
getLoopIVs(*candidate->dstAccess.opInst, &dstLoopIVs);
|
getLoopIVs(*dstOpInsts[0], &dstLoopIVs);
|
||||||
unsigned numDstLoopIVs = dstLoopIVs.size();
|
|
||||||
|
|
||||||
LoopNestStats dstLoopNestStats;
|
LoopNestStats dstLoopNestStats;
|
||||||
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
|
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
|
||||||
|
@ -610,51 +688,60 @@ static bool isFusionProfitable(FusionCandidate *candidate,
|
||||||
if (dstStatsCollector.hasLoopWithNonConstTripCount)
|
if (dstStatsCollector.hasLoopWithNonConstTripCount)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Search for min cost values for 'srcLoopDepth' and 'dstLoopDepth'.
|
// Compute the innermost common loop for ops in 'dstOpInst'.
|
||||||
// This search is O(n^2) where 'n' is very small (eg. six).
|
unsigned maxDstLoopDepth = getInnermostCommonLoopDepth(dstOpInsts);
|
||||||
// TODO(andydavis) Consider a solution where we just iteration through
|
if (maxDstLoopDepth == 0)
|
||||||
// dstLoopDepth possibilities and project out IVs we do not need (remove
|
return false;
|
||||||
// dependence on 'srcLoopDepth'.
|
|
||||||
DenseMap<ForInst *, uint64_t> tripCountMap;
|
// Search for min cost value for 'dstLoopDepth'. At each value of
|
||||||
DenseMap<ForInst *, uint64_t> computeCostMap;
|
// 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
|
||||||
unsigned maxSrcLoopDepth = getMaxSrcLoopDepth(numSrcLoopIVs, sliceState);
|
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
|
||||||
|
// of these bounds). Next the union slice bounds are used to calculate
|
||||||
|
// the cost of the slice and the cost of the slice inserted into the dst
|
||||||
|
// loop nest at 'dstLoopDepth'.
|
||||||
unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max();
|
unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max();
|
||||||
unsigned bestSrcLoopDepth;
|
|
||||||
unsigned bestDstLoopDepth;
|
unsigned bestDstLoopDepth;
|
||||||
for (unsigned i = maxSrcLoopDepth; i >= 1; --i) {
|
SmallVector<ComputationSliceState, 4> sliceStates;
|
||||||
// Compute minDstLoopDepth based on dst loop IVs used in slice loop bounds.
|
sliceStates.resize(maxDstLoopDepth);
|
||||||
unsigned minDstLoopDepth = getMinDstLoopDepth(i, sliceState);
|
|
||||||
assert(minDstLoopDepth <= numDstLoopIVs);
|
llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
|
||||||
if (minDstLoopDepth == 0) {
|
DenseMap<ForInst *, uint64_t> computeCostMap;
|
||||||
// TODO(andydavis) Support inserting computation slices at top-level.
|
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
|
||||||
continue;
|
MemRefAccess srcAccess(srcOpInst);
|
||||||
}
|
// Handle the common case of one dst load without a copy.
|
||||||
// Copy elements from slice trip count map up to src loop depth 'i'.
|
if (!mlir::getBackwardComputationSliceState(
|
||||||
tripCountMap.clear();
|
srcAccess, MemRefAccess(dstOpInsts[0]), i, &sliceStates[i - 1]))
|
||||||
for (unsigned k = 0; k < i; ++k) {
|
return false;
|
||||||
auto *forInst = srcLoopIVs[k];
|
// Compute the union of slice bound of all ops in 'dstOpInsts'.
|
||||||
auto it = sliceTripCountMap.find(forInst);
|
for (int j = 1, e = dstOpInsts.size(); j < e; ++j) {
|
||||||
if (it != sliceTripCountMap.end()) {
|
MemRefAccess dstAccess(dstOpInsts[j]);
|
||||||
tripCountMap[forInst] = it->second;
|
ComputationSliceState tmpSliceState;
|
||||||
}
|
if (!mlir::getBackwardComputationSliceState(srcAccess, dstAccess, i,
|
||||||
|
&tmpSliceState))
|
||||||
|
return false;
|
||||||
|
// Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
|
||||||
|
getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]);
|
||||||
}
|
}
|
||||||
|
// Build trip count map for computation slice.
|
||||||
|
sliceTripCountMap.clear();
|
||||||
|
if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
|
||||||
|
&sliceTripCountMap))
|
||||||
|
return false;
|
||||||
|
|
||||||
// Compute op instance count for the src loop nest with iteration slicing.
|
// Compute op instance count for the src loop nest with iteration slicing.
|
||||||
uint64_t sliceComputeCost =
|
uint64_t sliceComputeCost =
|
||||||
getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &tripCountMap,
|
getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap,
|
||||||
/*computeCostMap=*/nullptr);
|
/*computeCostMap=*/nullptr);
|
||||||
|
|
||||||
for (unsigned j = numDstLoopIVs; j >= minDstLoopDepth; --j) {
|
// Compute cost of fusion for these values of 'i' and 'j'.
|
||||||
// Compute cost of fusion for these values of 'i' and 'j'.
|
computeCostMap.clear();
|
||||||
computeCostMap.clear();
|
computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
|
||||||
computeCostMap[dstLoopIVs[j - 1]] = sliceComputeCost;
|
uint64_t fusedLoopNestComputeCost =
|
||||||
uint64_t fusedLoopNestComputeCost =
|
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
|
||||||
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
|
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
|
||||||
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
|
if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
|
||||||
if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
|
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
|
||||||
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
|
bestDstLoopDepth = i;
|
||||||
bestSrcLoopDepth = i;
|
|
||||||
bestDstLoopDepth = j;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -668,7 +755,6 @@ static bool isFusionProfitable(FusionCandidate *candidate,
|
||||||
/*computeCostMap=*/nullptr);
|
/*computeCostMap=*/nullptr);
|
||||||
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics "
|
LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics "
|
||||||
<< " bestSrcLoopDepth: " << bestSrcLoopDepth
|
|
||||||
<< " bestDstLoopDepth: " << bestDstLoopDepth
|
<< " bestDstLoopDepth: " << bestDstLoopDepth
|
||||||
<< " srcLoopNestCost: " << srcLoopNestCost
|
<< " srcLoopNestCost: " << srcLoopNestCost
|
||||||
<< " dstLoopNestCost: " << dstLoopNestCost
|
<< " dstLoopNestCost: " << dstLoopNestCost
|
||||||
|
@ -680,25 +766,23 @@ static bool isFusionProfitable(FusionCandidate *candidate,
|
||||||
// for load/store forwarding in cost model.
|
// for load/store forwarding in cost model.
|
||||||
if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost)
|
if (minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost)
|
||||||
return false;
|
return false;
|
||||||
// Set src/dstLoopDepth based on best values from search.
|
// Update return parameter 'sliceState' with 'bestSliceState'.
|
||||||
*srcLoopDepth = bestSrcLoopDepth;
|
ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1];
|
||||||
|
sliceState->lbs = bestSliceState->lbs;
|
||||||
|
sliceState->ubs = bestSliceState->ubs;
|
||||||
|
sliceState->lbOperands = bestSliceState->lbOperands;
|
||||||
|
sliceState->ubOperands = bestSliceState->ubOperands;
|
||||||
|
// Set dstLoopDepth based on best values from search.
|
||||||
*dstLoopDepth = bestDstLoopDepth;
|
*dstLoopDepth = bestDstLoopDepth;
|
||||||
// Update 'sliceState' bounds based on computed 'srcLoopDepth':
|
// Canonicalize slice bound affine maps.
|
||||||
// *) Canonicalize affine map now that 'srcLoopDepth' has been chosen.
|
|
||||||
// *) Replace slice bound maps at depth > 'srcLoopDepth' withAffineMap::Null()
|
|
||||||
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||||||
if (i < bestSrcLoopDepth) {
|
if (sliceState->lbs[i] != AffineMap::Null()) {
|
||||||
if (sliceState->lbs[i] != AffineMap::Null()) {
|
canonicalizeMapAndOperands(&sliceState->lbs[i],
|
||||||
canonicalizeMapAndOperands(&sliceState->lbs[i],
|
&sliceState->lbOperands[i]);
|
||||||
&sliceState->lbOperands[i]);
|
}
|
||||||
}
|
if (sliceState->ubs[i] != AffineMap::Null()) {
|
||||||
if (sliceState->ubs[i] != AffineMap::Null()) {
|
canonicalizeMapAndOperands(&sliceState->ubs[i],
|
||||||
canonicalizeMapAndOperands(&sliceState->ubs[i],
|
&sliceState->ubOperands[i]);
|
||||||
&sliceState->ubOperands[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sliceState->lbs[i] = AffineMap::Null();
|
|
||||||
sliceState->ubs[i] = AffineMap::Null();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -767,12 +851,12 @@ public:
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
SmallVector<OperationInst *, 4> loads = dstNode->loads;
|
SmallVector<OperationInst *, 4> loads = dstNode->loads;
|
||||||
|
SmallVector<OperationInst *, 4> dstLoadOpInsts;
|
||||||
while (!loads.empty()) {
|
while (!loads.empty()) {
|
||||||
auto *dstLoadOpInst = loads.pop_back_val();
|
// Get memref of load on top of the stack.
|
||||||
auto *memref = dstLoadOpInst->cast<LoadOp>()->getMemRef();
|
auto *memref = loads.back()->cast<LoadOp>()->getMemRef();
|
||||||
// Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'.
|
// Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
|
||||||
if (dstNode->getLoadOpCount(memref) != 1)
|
moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
|
||||||
continue;
|
|
||||||
// Skip if no input edges along which to fuse.
|
// Skip if no input edges along which to fuse.
|
||||||
if (mdg->inEdges.count(dstId) == 0)
|
if (mdg->inEdges.count(dstId) == 0)
|
||||||
continue;
|
continue;
|
||||||
|
@ -801,19 +885,15 @@ public:
|
||||||
continue;
|
continue;
|
||||||
// Get unique 'srcNode' store op.
|
// Get unique 'srcNode' store op.
|
||||||
auto *srcStoreOpInst = srcNode->stores.front();
|
auto *srcStoreOpInst = srcNode->stores.front();
|
||||||
// Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'.
|
|
||||||
FusionCandidate candidate(srcStoreOpInst, dstLoadOpInst);
|
|
||||||
// Check if fusion would be profitable.
|
// Check if fusion would be profitable.
|
||||||
unsigned srcLoopDepth;
|
|
||||||
unsigned dstLoopDepth;
|
unsigned dstLoopDepth;
|
||||||
mlir::ComputationSliceState sliceState;
|
mlir::ComputationSliceState sliceState;
|
||||||
if (!isFusionProfitable(&candidate, &sliceState, &srcLoopDepth,
|
if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
|
||||||
&dstLoopDepth))
|
&dstLoopDepth))
|
||||||
continue;
|
continue;
|
||||||
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
|
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
|
||||||
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
|
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
|
||||||
&candidate.srcAccess, &candidate.dstAccess, &sliceState,
|
srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState);
|
||||||
srcLoopDepth, dstLoopDepth);
|
|
||||||
if (sliceLoopNest != nullptr) {
|
if (sliceLoopNest != nullptr) {
|
||||||
// Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode'
|
// Remove edges between 'srcNode' and 'dstNode' and remove 'srcNode'
|
||||||
mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
|
mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
|
||||||
|
|
|
@ -95,13 +95,18 @@ func @should_fuse_loop_nests_with_shifts() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The cost of fusing the src loop nest at dst loop depth 1 is less expensive
|
||||||
|
// than fusing at dst loop depth 2, because at dst loop depth 1, we are
|
||||||
|
// able to reduce the trip count around the %i1 loop by one (because the
|
||||||
|
// dst loop never reads the last element written by the src loop).
|
||||||
// CHECK: for %i0 = 0 to 10 {
|
// CHECK: for %i0 = 0 to 10 {
|
||||||
// CHECK-NEXT: for %i1 = 0 to 10 {
|
// CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0)
|
||||||
// CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0)
|
// CHECK-NEXT: for %i1 = 0 to 9 {
|
||||||
// CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1)
|
// CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %i1)
|
||||||
// CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2)
|
// CHECK-NEXT: store %cst, %0[%2#0, %2#1] : memref<10x10xf32>
|
||||||
// CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32>
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32>
|
// CHECK-NEXT: for %i2 = 0 to 10 {
|
||||||
|
// CHECK-NEXT: %3 = load %0[%i0, %i2] : memref<10x10xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
|
@ -849,6 +854,7 @@ func @should_fuse_src_depth1_at_dst_depth2() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
// CHECK: #map0 = ()[s0] -> (s0)
|
||||||
|
|
||||||
// CHECK-LABEL: func @fusion_at_depth0_not_currently_supported
|
// CHECK-LABEL: func @fusion_at_depth0_not_currently_supported
|
||||||
func @fusion_at_depth0_not_currently_supported() {
|
func @fusion_at_depth0_not_currently_supported() {
|
||||||
|
@ -862,10 +868,9 @@ func @fusion_at_depth0_not_currently_supported() {
|
||||||
%1 = load %0[%c0] : memref<10xf32>
|
%1 = load %0[%c0] : memref<10xf32>
|
||||||
}
|
}
|
||||||
// CHECK:for %i0 = 0 to 10 {
|
// CHECK:for %i0 = 0 to 10 {
|
||||||
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
|
// CHECK-NEXT: %1 = affine_apply #map0()[%c0]
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: store %cst, %0[%1] : memref<10xf32>
|
||||||
// CHECK-NEXT: for %i1 = 0 to 10 {
|
// CHECK-NEXT: %2 = load %0[%c0] : memref<10xf32>
|
||||||
// CHECK-NEXT: %1 = load %0[%c0] : memref<10xf32>
|
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
return
|
return
|
||||||
|
@ -977,3 +982,128 @@ func @should_fuse_deep_loop_nests() {
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK: #map0 = (d0) -> (d0)
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @should_fuse_at_depth1_and_reduce_slice_trip_count
|
||||||
|
func @should_fuse_at_depth1_and_reduce_slice_trip_count() {
|
||||||
|
%a = alloc() : memref<4x256xf32>
|
||||||
|
%b = alloc() : memref<4x256xf32>
|
||||||
|
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%cf0 = constant 0.0 : f32
|
||||||
|
|
||||||
|
for %i0 = 0 to 4 {
|
||||||
|
for %i1 = 0 to 256 {
|
||||||
|
%v0 = load %b[%i0, %i1] : memref<4x256xf32>
|
||||||
|
}
|
||||||
|
for %i2 = 0 to 256 {
|
||||||
|
store %cf0, %a[%i0, %i2] : memref<4x256xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for %d0 = 0 to 4 {
|
||||||
|
for %d1 = 0 to 16 {
|
||||||
|
%v1 = load %a[%d0, %d1] : memref<4x256xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// The cost of fusing at depth 2 is greater than the cost of fusing at depth 1
|
||||||
|
// for two reasons:
|
||||||
|
// 1) Inserting the unsliceable src loop %i1 to a higher depth removes
|
||||||
|
// redundant computation and reduces costs.
|
||||||
|
// 2) Inserting the sliceable src loop %i2 at depth 1, we can still reduce
|
||||||
|
// its trip count to 16 (from 256) reducing costs.
|
||||||
|
// CHECK: for %i0 = 0 to 4 {
|
||||||
|
// CHECK-NEXT: %2 = affine_apply #map0(%i0)
|
||||||
|
// CHECK-NEXT: for %i1 = 0 to 256 {
|
||||||
|
// CHECK-NEXT: %3 = load %1[%2, %i1] : memref<4x256xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: for %i2 = 0 to 16 {
|
||||||
|
// CHECK-NEXT: store %cst, %0[%2, %i2] : memref<4x256xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: for %i3 = 0 to 16 {
|
||||||
|
// CHECK-NEXT: %4 = load %0[%i0, %i3] : memref<4x256xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_20
|
||||||
|
func @should_fuse_at_depth1_with_trip_count_20() {
|
||||||
|
%a = alloc() : memref<100xf32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%cf0 = constant 0.0 : f32
|
||||||
|
|
||||||
|
for %i0 = 0 to 100 {
|
||||||
|
store %cf0, %a[%i0]: memref<100xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
for %i1 = 0 to 5 {
|
||||||
|
for %i2 = 0 to 10 {
|
||||||
|
%v0 = load %a[%i2]: memref<100xf32>
|
||||||
|
}
|
||||||
|
for %i3 = 0 to 10 {
|
||||||
|
for %i4 = 0 to 20 {
|
||||||
|
%v1 = load %a[%i4]: memref<100xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// CHECK: for %i0 = 0 to 5 {
|
||||||
|
// CHECK-NEXT: for %i1 = 0 to 20 {
|
||||||
|
// CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: for %i2 = 0 to 10 {
|
||||||
|
// CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: for %i3 = 0 to 10 {
|
||||||
|
// CHECK-NEXT: for %i4 = 0 to 20 {
|
||||||
|
// CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @should_fuse_at_depth1_with_trip_count_19
|
||||||
|
func @should_fuse_at_depth1_with_trip_count_19() {
|
||||||
|
%a = alloc() : memref<100xf32>
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%cf0 = constant 0.0 : f32
|
||||||
|
|
||||||
|
for %i0 = 0 to 100 {
|
||||||
|
store %cf0, %a[%i0]: memref<100xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
for %i1 = 0 to 5 {
|
||||||
|
for %i2 = 0 to 19 {
|
||||||
|
%v0 = load %a[%i2]: memref<100xf32>
|
||||||
|
}
|
||||||
|
for %i3 = 0 to 10 {
|
||||||
|
for %i4 = 0 to 10 {
|
||||||
|
%v1 = load %a[%i4]: memref<100xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// CHECK: for %i0 = 0 to 5 {
|
||||||
|
// CHECK-NEXT: for %i1 = 0 to 19 {
|
||||||
|
// CHECK-NEXT: store %cst, %0[%i1] : memref<100xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: for %i2 = 0 to 19 {
|
||||||
|
// CHECK-NEXT: %1 = load %0[%i2] : memref<100xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: for %i3 = 0 to 10 {
|
||||||
|
// CHECK-NEXT: for %i4 = 0 to 10 {
|
||||||
|
// CHECK-NEXT: %2 = load %0[%i4] : memref<100xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue