forked from OSchip/llvm-project
[mlir][Affine] Refactor affine fusion code in pass to utilities
Refactoring/clean-up step needed to add support for producer-consumer fusion with multi-store producer loops and, in general, to implement more general loop fusion strategies in Affine. It introduces the following changes: - AffineLoopFusion pass now uses loop fusion utilities more broadly to compute fusion legality (canFuseLoops utility) and perform the fusion transformation (fuseLoops utility). - Loop fusion utilities have been extended to deal with AffineLoopFusion requirements and assumptions while preserving both loop fusion utilities and AffineLoopFusion current functionality within a unified implementation. 'FusionStrategy' has been introduced for this purpose and, in the future, it will allow us to have a single loop fusion core implementation that will produce different fusion outputs depending on the strategy used. - Improve separation of concerns for legality and profitability analysis: 'isFusionProfitable' no longer filters out illegal scenarios that 'canFuse' didn't detect, or the other way around. 'canFuse' now takes loop dependences into account to determine the fusion loop depth (producer-consumer fusion only). - As a result, maximal fusion now doesn't require any profitability analysis. - Slices are now computed only once and reused across the legality, profitability and fusion transformation steps (producer-consumer). - Refactor some utilities and remove redundant copies of them. This patch is NFCI and should preserve the existing functionality of both the AffineLoopFusion pass and the affine fusion utilities. Reviewed By: andydavis1, bondhugula Differential Revision: https://reviews.llvm.org/D90798
This commit is contained in:
parent
5f2c5541f7
commit
c1ba9c43ad
|
@ -82,6 +82,11 @@ struct ComputationSliceState {
|
||||||
|
|
||||||
// Clears all bounds and operands in slice state.
|
// Clears all bounds and operands in slice state.
|
||||||
void clearBounds();
|
void clearBounds();
|
||||||
|
|
||||||
|
/// Return true if the computation slice is empty.
|
||||||
|
bool isEmpty() const { return ivs.empty(); }
|
||||||
|
|
||||||
|
void dump() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Computes the computation slice loop bounds for one loop nest as affine maps
|
/// Computes the computation slice loop bounds for one loop nest as affine maps
|
||||||
|
@ -212,7 +217,7 @@ struct MemRefRegion {
|
||||||
/// The last field is a 2-d FlatAffineConstraints symbolic in %i.
|
/// The last field is a 2-d FlatAffineConstraints symbolic in %i.
|
||||||
///
|
///
|
||||||
LogicalResult compute(Operation *op, unsigned loopDepth,
|
LogicalResult compute(Operation *op, unsigned loopDepth,
|
||||||
ComputationSliceState *sliceState = nullptr,
|
const ComputationSliceState *sliceState = nullptr,
|
||||||
bool addMemRefDimBounds = true);
|
bool addMemRefDimBounds = true);
|
||||||
|
|
||||||
FlatAffineConstraints *getConstraints() { return &cst; }
|
FlatAffineConstraints *getConstraints() { return &cst; }
|
||||||
|
@ -309,6 +314,11 @@ bool isLoopParallel(AffineForOp forOp);
|
||||||
/// number of constraints.
|
/// number of constraints.
|
||||||
IntegerSet simplifyIntegerSet(IntegerSet set);
|
IntegerSet simplifyIntegerSet(IntegerSet set);
|
||||||
|
|
||||||
|
/// Returns the innermost common loop depth for the set of operations in 'ops'.
|
||||||
|
unsigned getInnermostCommonLoopDepth(
|
||||||
|
ArrayRef<Operation *> ops,
|
||||||
|
SmallVectorImpl<AffineForOp> *surroundingLoops = nullptr);
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_ANALYSIS_UTILS_H
|
#endif // MLIR_ANALYSIS_UTILS_H
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
|
#ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
|
||||||
#define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
|
#define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
|
||||||
|
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
@ -38,6 +39,45 @@ struct FusionResult {
|
||||||
FusionResult(ResultEnum v) : value(v) {}
|
FusionResult(ResultEnum v) : value(v) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Describes the fusion strategy to be used in the Affine loop fusion
|
||||||
|
/// utilities. Currently, it is used to specialized the loop fusion utilities
|
||||||
|
/// with the assumptions made in the AffineLoopFusion pass for producer-consumer
|
||||||
|
/// and sibling fusion, while sharing a single implementation. The latter
|
||||||
|
/// strategies are also limited to scenarios where a single memref is involved
|
||||||
|
/// in the producer-consume or sibling relationship between the candidate
|
||||||
|
/// loops. We use 'memref' to keep track of such a memref.
|
||||||
|
// TODO: Remove 'memref' when we support more generic scenarios.
|
||||||
|
// TODO: Generalize utilities so that producer-consumer and sibling fusion
|
||||||
|
// strategies can be used without the assumptions made in the AffineLoopFusion
|
||||||
|
// pass.
|
||||||
|
struct FusionStrategy {
|
||||||
|
enum StrategyEnum {
|
||||||
|
// Generic loop fusion: Arbitrary loops are considered for fusion. No
|
||||||
|
// assumptions about a specific fusion strategy from AffineLoopFusion pass
|
||||||
|
// are made.
|
||||||
|
// TODO: Generic fusion is not fully implemented by fusion utilities yet.
|
||||||
|
// It should only be used for testing.
|
||||||
|
Generic,
|
||||||
|
// Producer-consumer fusion: Only loops with a producer-consumer
|
||||||
|
// memref dependence are considered for fusion. Currently, assumptions from
|
||||||
|
// the producer-consumer fusion implementation in AffineLoopFusion pass are
|
||||||
|
// made. See pass for specific details.
|
||||||
|
ProducerConsumer,
|
||||||
|
// Sibling fusion: Only sibling loops with no producer-consumer memref
|
||||||
|
// dependences are considered for fusion. Memref reuse is taken into account
|
||||||
|
// for profitability. Currently, assumptions from the sibling fusion
|
||||||
|
// implementation in AffineLoopFusion pass are made. See pass for specific
|
||||||
|
// details.
|
||||||
|
Sibling
|
||||||
|
} strategy;
|
||||||
|
|
||||||
|
// Target memref for this fusion transformation.
|
||||||
|
Value memref;
|
||||||
|
|
||||||
|
FusionStrategy(StrategyEnum strategy, Value memref)
|
||||||
|
: strategy(strategy), memref(memref) {}
|
||||||
|
};
|
||||||
|
|
||||||
/// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the
|
/// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the
|
||||||
/// loop nest rooted at 'dstForOp' at 'dstLoopDepth'. Returns FusionResult
|
/// loop nest rooted at 'dstForOp' at 'dstLoopDepth'. Returns FusionResult
|
||||||
/// 'Success' if fusion of the src/dst loop nests is feasible (i.e. they are
|
/// 'Success' if fusion of the src/dst loop nests is feasible (i.e. they are
|
||||||
|
@ -48,12 +88,14 @@ struct FusionResult {
|
||||||
/// TODO: Update comments when this function is fully implemented.
|
/// TODO: Update comments when this function is fully implemented.
|
||||||
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
||||||
unsigned dstLoopDepth,
|
unsigned dstLoopDepth,
|
||||||
ComputationSliceState *srcSlice);
|
ComputationSliceState *srcSlice,
|
||||||
|
FusionStrategy fusionStrategy = {
|
||||||
|
FusionStrategy::Generic, Value()});
|
||||||
|
|
||||||
/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
|
/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
|
||||||
/// and source slice loop bounds specified in 'srcSlice'.
|
/// and source slice loop bounds specified in 'srcSlice'.
|
||||||
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
||||||
ComputationSliceState *srcSlice);
|
const ComputationSliceState &srcSlice);
|
||||||
|
|
||||||
/// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
|
/// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
|
||||||
/// and operation count) for a loop nest up until (and including) the innermost
|
/// and operation count) for a loop nest up until (and including) the innermost
|
||||||
|
@ -89,7 +131,8 @@ int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats);
|
||||||
// TODO: Improve this cost model.
|
// TODO: Improve this cost model.
|
||||||
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
|
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
|
||||||
AffineForOp dstForOp, LoopNestStats &dstStats,
|
AffineForOp dstForOp, LoopNestStats &dstStats,
|
||||||
ComputationSliceState *slice, int64_t *computeCost);
|
const ComputationSliceState &slice,
|
||||||
|
int64_t *computeCost);
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -105,6 +105,28 @@ void ComputationSliceState::clearBounds() {
|
||||||
ubOperands.clear();
|
ubOperands.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ComputationSliceState::dump() const {
|
||||||
|
llvm::errs() << "\tIVs:\n";
|
||||||
|
for (Value iv : ivs)
|
||||||
|
llvm::errs() << "\t\t" << iv << "\n";
|
||||||
|
|
||||||
|
llvm::errs() << "\tLBs:\n";
|
||||||
|
for (auto &en : llvm::enumerate(lbs)) {
|
||||||
|
llvm::errs() << "\t\t" << en.value() << "\n";
|
||||||
|
llvm::errs() << "\t\tOperands:\n";
|
||||||
|
for (Value lbOp : lbOperands[en.index()])
|
||||||
|
llvm::errs() << "\t\t\t" << lbOp << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::errs() << "\tUBs:\n";
|
||||||
|
for (auto &en : llvm::enumerate(ubs)) {
|
||||||
|
llvm::errs() << "\t\t" << en.value() << "\n";
|
||||||
|
llvm::errs() << "\t\tOperands:\n";
|
||||||
|
for (Value ubOp : ubOperands[en.index()])
|
||||||
|
llvm::errs() << "\t\t\t" << ubOp << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
unsigned MemRefRegion::getRank() const {
|
unsigned MemRefRegion::getRank() const {
|
||||||
return memref.getType().cast<MemRefType>().getRank();
|
return memref.getType().cast<MemRefType>().getRank();
|
||||||
}
|
}
|
||||||
|
@ -211,7 +233,7 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
|
||||||
// TODO: extend this to any other memref dereferencing ops
|
// TODO: extend this to any other memref dereferencing ops
|
||||||
// (dma_start, dma_wait).
|
// (dma_start, dma_wait).
|
||||||
LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
|
LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
|
||||||
ComputationSliceState *sliceState,
|
const ComputationSliceState *sliceState,
|
||||||
bool addMemRefDimBounds) {
|
bool addMemRefDimBounds) {
|
||||||
assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
|
assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
|
||||||
"affine read/write op expected");
|
"affine read/write op expected");
|
||||||
|
@ -541,13 +563,12 @@ static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the innermost common loop depth for the set of operations in 'ops'.
|
/// Returns the innermost common loop depth for the set of operations in 'ops'.
|
||||||
// TODO: Move this to LoopUtils.
|
// TODO: Move this to LoopUtils.
|
||||||
static unsigned
|
unsigned mlir::getInnermostCommonLoopDepth(
|
||||||
getInnermostCommonLoopDepth(ArrayRef<Operation *> ops,
|
ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
|
||||||
SmallVectorImpl<AffineForOp> &surroundingLoops) {
|
|
||||||
unsigned numOps = ops.size();
|
unsigned numOps = ops.size();
|
||||||
assert(numOps > 0);
|
assert(numOps > 0 && "Expected at least one operation");
|
||||||
|
|
||||||
std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
|
std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
|
||||||
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
|
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
|
||||||
|
@ -564,7 +585,8 @@ getInnermostCommonLoopDepth(ArrayRef<Operation *> ops,
|
||||||
if (loops[i - 1][d] != loops[i][d])
|
if (loops[i - 1][d] != loops[i][d])
|
||||||
return loopDepth;
|
return loopDepth;
|
||||||
}
|
}
|
||||||
surroundingLoops.push_back(loops[i - 1][d]);
|
if (surroundingLoops)
|
||||||
|
surroundingLoops->push_back(loops[i - 1][d]);
|
||||||
++loopDepth;
|
++loopDepth;
|
||||||
}
|
}
|
||||||
return loopDepth;
|
return loopDepth;
|
||||||
|
@ -684,7 +706,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
||||||
}
|
}
|
||||||
SmallVector<AffineForOp, 4> surroundingLoops;
|
SmallVector<AffineForOp, 4> surroundingLoops;
|
||||||
unsigned innermostCommonLoopDepth =
|
unsigned innermostCommonLoopDepth =
|
||||||
getInnermostCommonLoopDepth(ops, surroundingLoops);
|
getInnermostCommonLoopDepth(ops, &surroundingLoops);
|
||||||
if (loopDepth > innermostCommonLoopDepth) {
|
if (loopDepth > innermostCommonLoopDepth) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
|
LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -741,77 +741,6 @@ static void moveLoadsAccessingMemrefTo(Value memref,
|
||||||
srcLoads->swap(srcLoadsToKeep);
|
srcLoads->swap(srcLoadsToKeep);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the innermost common loop depth for the set of operations in 'ops'.
|
|
||||||
static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) {
|
|
||||||
unsigned numOps = ops.size();
|
|
||||||
assert(numOps > 0);
|
|
||||||
|
|
||||||
std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
|
|
||||||
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
|
|
||||||
for (unsigned i = 0; i < numOps; ++i) {
|
|
||||||
getLoopIVs(*ops[i], &loops[i]);
|
|
||||||
loopDepthLimit =
|
|
||||||
std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned loopDepth = 0;
|
|
||||||
for (unsigned d = 0; d < loopDepthLimit; ++d) {
|
|
||||||
unsigned i;
|
|
||||||
for (i = 1; i < numOps; ++i) {
|
|
||||||
if (loops[i - 1][d] != loops[i][d])
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (i != numOps)
|
|
||||||
break;
|
|
||||||
++loopDepth;
|
|
||||||
}
|
|
||||||
return loopDepth;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
|
|
||||||
// and 'storeOpInsts' are satisfied.
|
|
||||||
static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
|
|
||||||
ArrayRef<Operation *> storeOpInsts) {
|
|
||||||
// Merge loads and stores into the same array.
|
|
||||||
SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
|
|
||||||
ops.append(storeOpInsts.begin(), storeOpInsts.end());
|
|
||||||
|
|
||||||
// Compute the innermost common loop depth for loads and stores.
|
|
||||||
unsigned loopDepth = getInnermostCommonLoopDepth(ops);
|
|
||||||
|
|
||||||
// Return common loop depth for loads if there are no store ops.
|
|
||||||
if (storeOpInsts.empty())
|
|
||||||
return loopDepth;
|
|
||||||
|
|
||||||
// Check dependences on all pairs of ops in 'ops' and store the minimum
|
|
||||||
// loop depth at which a dependence is satisfied.
|
|
||||||
for (unsigned i = 0, e = ops.size(); i < e; ++i) {
|
|
||||||
auto *srcOpInst = ops[i];
|
|
||||||
MemRefAccess srcAccess(srcOpInst);
|
|
||||||
for (unsigned j = 0; j < e; ++j) {
|
|
||||||
auto *dstOpInst = ops[j];
|
|
||||||
MemRefAccess dstAccess(dstOpInst);
|
|
||||||
|
|
||||||
unsigned numCommonLoops =
|
|
||||||
getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
|
|
||||||
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
|
|
||||||
FlatAffineConstraints dependenceConstraints;
|
|
||||||
// TODO: Cache dependence analysis results, check cache here.
|
|
||||||
DependenceResult result = checkMemrefAccessDependence(
|
|
||||||
srcAccess, dstAccess, d, &dependenceConstraints,
|
|
||||||
/*dependenceComponents=*/nullptr);
|
|
||||||
if (hasDependence(result)) {
|
|
||||||
// Store minimum loop depth and break because we want the min 'd' at
|
|
||||||
// which there is a dependence.
|
|
||||||
loopDepth = std::min(loopDepth, d - 1);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return loopDepth;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sinks all sequential loops to the innermost levels (while preserving
|
// Sinks all sequential loops to the innermost levels (while preserving
|
||||||
// relative order among them) and moves all parallel loops to the
|
// relative order among them) and moves all parallel loops to the
|
||||||
// outermost (while again preserving relative order among them).
|
// outermost (while again preserving relative order among them).
|
||||||
|
@ -1077,14 +1006,16 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
|
||||||
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
|
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
|
||||||
// the memref being produced and consumed, which is an input to the cost model.
|
// the memref being produced and consumed, which is an input to the cost model.
|
||||||
// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
|
// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
|
||||||
// 'srcOpInst', as we are slicing w.r.t to that producer.
|
// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
|
||||||
// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
|
// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
|
||||||
// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
|
// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
|
||||||
// will be the unique store op in the src node, which will be used to check
|
// unique store op in the src node, which will be used to check that the write
|
||||||
// that the write region is the same after input-reuse fusion.
|
// region is the same after input-reuse fusion. Computation slices are provided
|
||||||
// Returns true if it is profitable to fuse the candidate loop nests. Returns
|
// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
|
||||||
// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
|
// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
|
||||||
// to materialize the source loop nest slice.
|
// profitable to fuse the candidate loop nests. Returns false otherwise.
|
||||||
|
// `dstLoopDepth` is set to the most profitable depth at which to materialize
|
||||||
|
// the source loop nest slice.
|
||||||
// The profitability model executes the following steps:
|
// The profitability model executes the following steps:
|
||||||
// *) Computes the backward computation slice at 'srcOpInst'. This
|
// *) Computes the backward computation slice at 'srcOpInst'. This
|
||||||
// computation slice of the loop nest surrounding 'srcOpInst' is
|
// computation slice of the loop nest surrounding 'srcOpInst' is
|
||||||
|
@ -1112,9 +1043,9 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
|
||||||
// is lower.
|
// is lower.
|
||||||
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
ArrayRef<Operation *> dstLoadOpInsts,
|
ArrayRef<Operation *> dstLoadOpInsts,
|
||||||
ArrayRef<Operation *> dstStoreOpInsts,
|
ArrayRef<ComputationSliceState> depthSliceUnions,
|
||||||
ComputationSliceState *sliceState,
|
unsigned maxLegalFusionDepth,
|
||||||
unsigned *dstLoopDepth, bool maximalFusion,
|
unsigned *dstLoopDepth,
|
||||||
double computeToleranceThreshold) {
|
double computeToleranceThreshold) {
|
||||||
LLVM_DEBUG({
|
LLVM_DEBUG({
|
||||||
llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
|
llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
|
||||||
|
@ -1124,10 +1055,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (maxLegalFusionDepth == 0) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Compute cost of sliced and unsliced src loop nest.
|
// Compute cost of sliced and unsliced src loop nest.
|
||||||
SmallVector<AffineForOp, 4> srcLoopIVs;
|
SmallVector<AffineForOp, 4> srcLoopIVs;
|
||||||
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
getLoopIVs(*srcOpInst, &srcLoopIVs);
|
||||||
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
|
||||||
|
|
||||||
// Walk src loop nest and collect stats.
|
// Walk src loop nest and collect stats.
|
||||||
LoopNestStats srcLoopNestStats;
|
LoopNestStats srcLoopNestStats;
|
||||||
|
@ -1142,19 +1077,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
|
if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Compute the maximum loop depth at which we can can insert the src slice
|
|
||||||
// and still satisfy dest loop nest dependences, for producer-consumer fusion.
|
|
||||||
unsigned maxDstLoopDepth =
|
|
||||||
(srcOpInst == srcStoreOpInst)
|
|
||||||
? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
|
|
||||||
: dstLoopIVs.size();
|
|
||||||
if (maxDstLoopDepth == 0) {
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Search for min cost value for 'dstLoopDepth'. At each value of
|
// Search for min cost value for 'dstLoopDepth'. At each value of
|
||||||
// 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
|
// 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
|
||||||
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
|
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
|
||||||
// of these bounds). Next the union slice bounds are used to calculate
|
// of these bounds). Next the union slice bounds are used to calculate
|
||||||
// the cost of the slice and the cost of the slice inserted into the dst
|
// the cost of the slice and the cost of the slice inserted into the dst
|
||||||
|
@ -1163,8 +1087,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
double maxStorageReduction = 0.0;
|
double maxStorageReduction = 0.0;
|
||||||
Optional<uint64_t> sliceMemEstimate = None;
|
Optional<uint64_t> sliceMemEstimate = None;
|
||||||
|
|
||||||
SmallVector<ComputationSliceState, 4> sliceStates;
|
|
||||||
sliceStates.resize(maxDstLoopDepth);
|
|
||||||
// The best loop depth at which to materialize the slice.
|
// The best loop depth at which to materialize the slice.
|
||||||
Optional<unsigned> bestDstLoopDepth = None;
|
Optional<unsigned> bestDstLoopDepth = None;
|
||||||
|
|
||||||
|
@ -1190,21 +1112,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
|
|
||||||
// Evaluate all depth choices for materializing the slice in the destination
|
// Evaluate all depth choices for materializing the slice in the destination
|
||||||
// loop nest.
|
// loop nest.
|
||||||
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
|
for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
|
||||||
// Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
|
// Skip slice union if it wasn't computed for this depth.
|
||||||
if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
|
if (depthSliceUnions[i - 1].isEmpty())
|
||||||
/*loopDepth=*/i,
|
|
||||||
/*numCommonLoops=*/0,
|
|
||||||
/*isBackwardSlice=*/true,
|
|
||||||
&sliceStates[i - 1]))) {
|
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
|
||||||
<< "computeSliceUnion failed for loopDepth: " << i << "\n");
|
|
||||||
continue;
|
continue;
|
||||||
}
|
|
||||||
|
|
||||||
int64_t fusedLoopNestComputeCost;
|
int64_t fusedLoopNestComputeCost;
|
||||||
if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
|
if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
|
||||||
dstLoopNestStats, &sliceStates[i - 1],
|
dstLoopNestStats, depthSliceUnions[i - 1],
|
||||||
&fusedLoopNestComputeCost)) {
|
&fusedLoopNestComputeCost)) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
|
LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
|
||||||
continue;
|
continue;
|
||||||
|
@ -1216,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
1;
|
1;
|
||||||
|
|
||||||
// Determine what the slice write MemRefRegion would be, if the src loop
|
// Determine what the slice write MemRefRegion would be, if the src loop
|
||||||
// nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
|
// nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst
|
||||||
// nest at loop depth 'i'
|
// loop nest at loop depth 'i'.
|
||||||
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
|
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
|
||||||
if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
|
if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
|
||||||
&sliceStates[i - 1]))) {
|
&depthSliceUnions[i - 1]))) {
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
<< "Failed to compute slice write region at loopDepth: " << i
|
<< "Failed to compute slice write region at loopDepth: " << i
|
||||||
<< "\n");
|
<< "\n");
|
||||||
|
@ -1269,8 +1184,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
// (as per computeToleranceThreshold), we will simply pick the one that
|
// (as per computeToleranceThreshold), we will simply pick the one that
|
||||||
// reduces the intermediary size the most.
|
// reduces the intermediary size the most.
|
||||||
if ((storageReduction > maxStorageReduction) &&
|
if ((storageReduction > maxStorageReduction) &&
|
||||||
(maximalFusion ||
|
(additionalComputeFraction < computeToleranceThreshold)) {
|
||||||
(additionalComputeFraction < computeToleranceThreshold))) {
|
|
||||||
maxStorageReduction = storageReduction;
|
maxStorageReduction = storageReduction;
|
||||||
bestDstLoopDepth = i;
|
bestDstLoopDepth = i;
|
||||||
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
|
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
|
||||||
|
@ -1278,10 +1192,9 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// A simple cost model: fuse if it reduces the memory footprint. If
|
// A simple cost model: fuse if it reduces the memory footprint.
|
||||||
// -maximal-fusion is set, fuse nevertheless.
|
|
||||||
|
|
||||||
if (!maximalFusion && !bestDstLoopDepth.hasValue()) {
|
if (!bestDstLoopDepth.hasValue()) {
|
||||||
LLVM_DEBUG(
|
LLVM_DEBUG(
|
||||||
llvm::dbgs()
|
llvm::dbgs()
|
||||||
<< "All fusion choices involve more than the threshold amount of "
|
<< "All fusion choices involve more than the threshold amount of "
|
||||||
|
@ -1310,34 +1223,31 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
|
|
||||||
Optional<double> storageReduction = None;
|
Optional<double> storageReduction = None;
|
||||||
|
|
||||||
if (!maximalFusion) {
|
if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
|
||||||
if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
LLVM_DEBUG(
|
<< " fusion memory benefit cannot be evaluated; NOT fusing.\n");
|
||||||
llvm::dbgs()
|
return false;
|
||||||
<< " fusion memory benefit cannot be evaluated; NOT fusing.\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto srcMemSizeVal = srcMemSize.getValue();
|
|
||||||
auto dstMemSizeVal = dstMemSize.getValue();
|
|
||||||
|
|
||||||
assert(sliceMemEstimate.hasValue() && "expected value");
|
|
||||||
auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
|
|
||||||
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
|
|
||||||
<< " dst mem: " << dstMemSizeVal << "\n"
|
|
||||||
<< " fused mem: " << fusedMem << "\n"
|
|
||||||
<< " slice mem: " << sliceMemEstimate << "\n");
|
|
||||||
|
|
||||||
if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
storageReduction =
|
|
||||||
100.0 *
|
|
||||||
(1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto srcMemSizeVal = srcMemSize.getValue();
|
||||||
|
auto dstMemSizeVal = dstMemSize.getValue();
|
||||||
|
|
||||||
|
assert(sliceMemEstimate.hasValue() && "expected value");
|
||||||
|
auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
|
||||||
|
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
|
||||||
|
<< " dst mem: " << dstMemSizeVal << "\n"
|
||||||
|
<< " fused mem: " << fusedMem << "\n"
|
||||||
|
<< " slice mem: " << sliceMemEstimate << "\n");
|
||||||
|
|
||||||
|
if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
storageReduction =
|
||||||
|
100.0 *
|
||||||
|
(1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
|
||||||
|
|
||||||
double additionalComputeFraction =
|
double additionalComputeFraction =
|
||||||
100.0 * (minFusedLoopNestComputeCost /
|
100.0 * (minFusedLoopNestComputeCost /
|
||||||
(static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
|
(static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
|
||||||
|
@ -1355,24 +1265,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
|
||||||
llvm::dbgs() << msg.str();
|
llvm::dbgs() << msg.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Update return parameter 'sliceState' with 'bestSliceState'.
|
|
||||||
ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
|
|
||||||
sliceState->lbs = bestSliceState->lbs;
|
|
||||||
sliceState->ubs = bestSliceState->ubs;
|
|
||||||
sliceState->lbOperands = bestSliceState->lbOperands;
|
|
||||||
sliceState->ubOperands = bestSliceState->ubOperands;
|
|
||||||
|
|
||||||
// Canonicalize slice bound affine maps.
|
|
||||||
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
|
||||||
if (sliceState->lbs[i] != AffineMap()) {
|
|
||||||
canonicalizeMapAndOperands(&sliceState->lbs[i],
|
|
||||||
&sliceState->lbOperands[i]);
|
|
||||||
}
|
|
||||||
if (sliceState->ubs[i] != AffineMap()) {
|
|
||||||
canonicalizeMapAndOperands(&sliceState->ubs[i],
|
|
||||||
&sliceState->ubOperands[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1592,138 +1484,142 @@ public:
|
||||||
if (insertPointInst == nullptr)
|
if (insertPointInst == nullptr)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
|
||||||
|
auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
|
||||||
|
|
||||||
// Compute the innermost common loop depth for dstNode loads/stores.
|
// Compute the innermost common loop depth for dstNode loads/stores.
|
||||||
SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(),
|
SmallVector<Operation *, 2> dstMemrefOps;
|
||||||
dstNode->loads.end());
|
for (Operation *op : dstNode->loads)
|
||||||
dstOps.append(dstNode->stores.begin(), dstNode->stores.end());
|
if (cast<AffineReadOpInterface>(op).getMemRef() == memref)
|
||||||
unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps);
|
dstMemrefOps.push_back(op);
|
||||||
|
for (Operation *op : dstNode->stores)
|
||||||
|
if (cast<AffineWriteOpInterface>(op).getMemRef() == memref)
|
||||||
|
dstMemrefOps.push_back(op);
|
||||||
|
unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
|
||||||
|
|
||||||
// Check the feasibility of fusing src loop nest into dst loop nest
|
// Check the feasibility of fusing src loop nest into dst loop nest
|
||||||
// at loop depths in range [1, dstLoopDepthTest].
|
// at loop depths in range [1, dstLoopDepthTest].
|
||||||
// TODO: Use slice union computation and union of memref
|
unsigned maxLegalFusionDepth = 0;
|
||||||
// read/write regions to cost model and fusion.
|
SmallVector<ComputationSliceState, 8> depthSliceUnions;
|
||||||
bool canFuse = false;
|
depthSliceUnions.resize(dstLoopDepthTest);
|
||||||
|
FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref);
|
||||||
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
|
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
|
||||||
ComputationSliceState sliceUnion;
|
|
||||||
FusionResult result = mlir::canFuseLoops(
|
FusionResult result = mlir::canFuseLoops(
|
||||||
cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
|
srcAffineForOp, dstAffineForOp,
|
||||||
/*dstLoopDepth=*/i, &sliceUnion);
|
/*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
|
||||||
|
|
||||||
if (result.value == FusionResult::Success)
|
if (result.value == FusionResult::Success)
|
||||||
canFuse = true;
|
maxLegalFusionDepth = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip if fusion is not feasible at all loop depths.
|
// Skip if fusion is not feasible at any loop depths.
|
||||||
if (!canFuse)
|
if (maxLegalFusionDepth == 0)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Gather 'dstNode' store ops to 'memref'.
|
// Check if fusion would be profitable. We skip profitability analysis
|
||||||
SmallVector<Operation *, 2> dstStoreOpInsts;
|
// for maximal fusion since we already know the maximal legal depth to
|
||||||
for (auto *storeOpInst : dstNode->stores)
|
// fuse.
|
||||||
if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref)
|
unsigned bestDstLoopDepth = maxLegalFusionDepth;
|
||||||
dstStoreOpInsts.push_back(storeOpInst);
|
if (!maximalFusion &&
|
||||||
|
!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
|
||||||
unsigned bestDstLoopDepth;
|
depthSliceUnions, maxLegalFusionDepth,
|
||||||
mlir::ComputationSliceState sliceState;
|
&bestDstLoopDepth, computeToleranceThreshold))
|
||||||
// Check if fusion would be profitable.
|
|
||||||
if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
|
|
||||||
dstStoreOpInsts, &sliceState,
|
|
||||||
&bestDstLoopDepth, maximalFusion,
|
|
||||||
computeToleranceThreshold))
|
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
|
||||||
|
assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
|
||||||
|
"Missing slice union for depth");
|
||||||
|
|
||||||
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
|
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
|
||||||
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
|
fuseLoops(srcAffineForOp, dstAffineForOp,
|
||||||
srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
|
depthSliceUnions[bestDstLoopDepth - 1]);
|
||||||
if (sliceLoopNest) {
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
|
|
||||||
<< *sliceLoopNest.getOperation() << "\n");
|
|
||||||
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
|
|
||||||
auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
|
|
||||||
if (insertPointInst != dstAffineForOp.getOperation()) {
|
|
||||||
dstAffineForOp.getOperation()->moveBefore(insertPointInst);
|
|
||||||
}
|
|
||||||
// Update edges between 'srcNode' and 'dstNode'.
|
|
||||||
mdg->updateEdges(srcNode->id, dstNode->id, memref,
|
|
||||||
createPrivateMemref);
|
|
||||||
|
|
||||||
// Collect slice loop stats.
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
LoopNestStateCollector sliceCollector;
|
<< "Fused src loop " << srcId << " into dst loop " << dstId
|
||||||
sliceCollector.collect(sliceLoopNest.getOperation());
|
<< " at depth " << bestDstLoopDepth << ":\n"
|
||||||
// Promote single iteration slice loops to single IV value.
|
<< dstAffineForOp << "\n");
|
||||||
for (auto forOp : sliceCollector.forOps) {
|
|
||||||
promoteIfSingleIteration(forOp);
|
|
||||||
}
|
|
||||||
if (createPrivateMemref) {
|
|
||||||
// Create private memref for 'memref' in 'dstAffineForOp'.
|
|
||||||
SmallVector<Operation *, 4> storesForMemref;
|
|
||||||
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
|
|
||||||
if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
|
|
||||||
memref)
|
|
||||||
storesForMemref.push_back(storeOpInst);
|
|
||||||
}
|
|
||||||
// TODO: Use union of memref write regions to compute
|
|
||||||
// private memref footprint.
|
|
||||||
auto newMemRef = createPrivateMemRef(
|
|
||||||
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
|
|
||||||
fastMemorySpace, localBufSizeThreshold);
|
|
||||||
visitedMemrefs.insert(newMemRef);
|
|
||||||
// Create new node in dependence graph for 'newMemRef' alloc op.
|
|
||||||
unsigned newMemRefNodeId =
|
|
||||||
mdg->addNode(newMemRef.getDefiningOp());
|
|
||||||
// Add edge from 'newMemRef' node to dstNode.
|
|
||||||
mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect dst loop stats after memref privatization transformation.
|
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
|
||||||
LoopNestStateCollector dstLoopCollector;
|
if (insertPointInst != dstAffineForOp.getOperation())
|
||||||
dstLoopCollector.collect(dstAffineForOp.getOperation());
|
dstAffineForOp.getOperation()->moveBefore(insertPointInst);
|
||||||
|
|
||||||
// Add new load ops to current Node load op list 'loads' to
|
// Update edges between 'srcNode' and 'dstNode'.
|
||||||
// continue fusing based on new operands.
|
mdg->updateEdges(srcNode->id, dstNode->id, memref,
|
||||||
for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
|
createPrivateMemref);
|
||||||
// NOTE: Change 'loads' to a hash set in case efficiency is an
|
|
||||||
// issue. We still use a vector since it's expected to be small.
|
// Collect slice loop stats.
|
||||||
if (!llvm::is_contained(loads, loadOpInst))
|
LoopNestStateCollector dstForCollector;
|
||||||
loads.push_back(loadOpInst);
|
dstForCollector.collect(dstAffineForOp);
|
||||||
|
if (createPrivateMemref) {
|
||||||
|
// Create private memref for 'memref' in 'dstAffineForOp'.
|
||||||
|
SmallVector<Operation *, 4> storesForMemref;
|
||||||
|
for (auto *storeOpInst : dstForCollector.storeOpInsts) {
|
||||||
|
if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
|
||||||
|
memref)
|
||||||
|
storesForMemref.push_back(storeOpInst);
|
||||||
}
|
}
|
||||||
// Clear visited memrefs after fusion so that previously visited src
|
// TODO: Use union of memref write regions to compute
|
||||||
// nodes are considered for fusion again in the context of the new
|
// private memref footprint.
|
||||||
// fused node.
|
auto newMemRef = createPrivateMemRef(
|
||||||
// TODO: This shouldn't be necessary if we visited candidates in the
|
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
|
||||||
// dependence graph in post-order or once we fully support
|
fastMemorySpace, localBufSizeThreshold);
|
||||||
// multi-store producers. Currently, in a multi-store producer
|
visitedMemrefs.insert(newMemRef);
|
||||||
// scenario such as A->B, A->C, B->C, we fail to fuse A+B due to the
|
// Create new node in dependence graph for 'newMemRef' alloc op.
|
||||||
// multiple outgoing edges. However, after fusing B+C, A has a
|
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
|
||||||
// single outgoing edge and can be fused if we revisit it in the
|
// Add edge from 'newMemRef' node to dstNode.
|
||||||
// context of the new fused B+C node.
|
mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
|
||||||
visitedMemrefs.clear();
|
}
|
||||||
|
|
||||||
// Clear and add back loads and stores.
|
// Collect dst loop stats after memref privatization transformation.
|
||||||
mdg->clearNodeLoadAndStores(dstNode->id);
|
LoopNestStateCollector dstLoopCollector;
|
||||||
mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
|
dstLoopCollector.collect(dstAffineForOp.getOperation());
|
||||||
dstLoopCollector.storeOpInsts);
|
|
||||||
// Remove old src loop nest if it no longer has outgoing dependence
|
// Add new load ops to current Node load op list 'loads' to continue
|
||||||
// edges, and if it does not write to a memref which escapes the
|
// fusing based on new operands.
|
||||||
// function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
|
for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
|
||||||
// been fused into 'dstNode' and write region of 'dstNode' covers
|
// NOTE: Change 'loads' to a hash set in case efficiency is an
|
||||||
// the write region of 'srcNode', and 'srcNode' has no other users
|
// issue. We still use a vector since it's expected to be small.
|
||||||
// so it is safe to remove.
|
if (!llvm::is_contained(loads, loadOpInst))
|
||||||
if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
|
loads.push_back(loadOpInst);
|
||||||
mdg->removeNode(srcNode->id);
|
}
|
||||||
srcNode->op->erase();
|
// Clear visited memrefs after fusion so that previously visited src
|
||||||
} else {
|
// nodes are considered for fusion again in the context of the new
|
||||||
// Add remaining users of 'oldMemRef' back on the worklist (if not
|
// fused node.
|
||||||
// already there), as its replacement with a local/private memref
|
// TODO: This shouldn't be necessary if we visited candidates in the
|
||||||
// has reduced dependences on 'oldMemRef' which may have created
|
// dependence graph in post-order or once we fully support multi-store
|
||||||
// new fusion opportunities.
|
// producers. Currently, in a multi-store producer scenario such as
|
||||||
if (mdg->outEdges.count(srcNode->id) > 0) {
|
// A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing
|
||||||
SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
|
// edges. However, after fusing B+C, A has a single outgoing edge and
|
||||||
mdg->outEdges[srcNode->id];
|
// can be fused if we revisit it in the context of the new fused B+C
|
||||||
for (auto &outEdge : oldOutEdges) {
|
// node.
|
||||||
if (outEdge.value == memref &&
|
visitedMemrefs.clear();
|
||||||
worklistSet.count(outEdge.id) == 0) {
|
|
||||||
worklist.push_back(outEdge.id);
|
// Clear and add back loads and stores.
|
||||||
worklistSet.insert(outEdge.id);
|
mdg->clearNodeLoadAndStores(dstNode->id);
|
||||||
}
|
mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
|
||||||
|
dstLoopCollector.storeOpInsts);
|
||||||
|
// Remove old src loop nest if it no longer has outgoing dependence
|
||||||
|
// edges, and if it does not write to a memref which escapes the
|
||||||
|
// function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been
|
||||||
|
// fused into 'dstNode' and write region of 'dstNode' covers the write
|
||||||
|
// region of 'srcNode', and 'srcNode' has no other users so it is safe
|
||||||
|
// to remove.
|
||||||
|
if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
|
||||||
|
mdg->removeNode(srcNode->id);
|
||||||
|
srcNode->op->erase();
|
||||||
|
} else {
|
||||||
|
// Add remaining users of 'oldMemRef' back on the worklist (if not
|
||||||
|
// already there), as its replacement with a local/private memref
|
||||||
|
// has reduced dependences on 'oldMemRef' which may have created new
|
||||||
|
// fusion opportunities.
|
||||||
|
if (mdg->outEdges.count(srcNode->id) > 0) {
|
||||||
|
SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
|
||||||
|
mdg->outEdges[srcNode->id];
|
||||||
|
for (auto &outEdge : oldOutEdges) {
|
||||||
|
if (outEdge.value == memref &&
|
||||||
|
worklistSet.count(outEdge.id) == 0) {
|
||||||
|
worklist.push_back(outEdge.id);
|
||||||
|
worklistSet.insert(outEdge.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1759,6 +1655,8 @@ public:
|
||||||
void fuseWithSiblingNodes(Node *dstNode) {
|
void fuseWithSiblingNodes(Node *dstNode) {
|
||||||
DenseSet<unsigned> visitedSibNodeIds;
|
DenseSet<unsigned> visitedSibNodeIds;
|
||||||
std::pair<unsigned, Value> idAndMemref;
|
std::pair<unsigned, Value> idAndMemref;
|
||||||
|
auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
|
||||||
|
|
||||||
while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
|
while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
|
||||||
unsigned sibId = idAndMemref.first;
|
unsigned sibId = idAndMemref.first;
|
||||||
Value memref = idAndMemref.second;
|
Value memref = idAndMemref.second;
|
||||||
|
@ -1791,31 +1689,53 @@ public:
|
||||||
SmallVector<Operation *, 2> dstLoadOpInsts;
|
SmallVector<Operation *, 2> dstLoadOpInsts;
|
||||||
dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
|
dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
|
||||||
|
|
||||||
// Gather 'dstNode' store ops to 'memref'.
|
SmallVector<AffineForOp, 4> dstLoopIVs;
|
||||||
SmallVector<Operation *, 2> dstStoreOpInsts;
|
getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
|
||||||
dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
|
unsigned dstLoopDepthTest = dstLoopIVs.size();
|
||||||
|
auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
|
||||||
|
|
||||||
unsigned bestDstLoopDepth;
|
// Compute loop depth and slice union for fusion.
|
||||||
mlir::ComputationSliceState sliceState;
|
SmallVector<ComputationSliceState, 8> depthSliceUnions;
|
||||||
|
depthSliceUnions.resize(dstLoopDepthTest);
|
||||||
|
unsigned maxLegalFusionDepth = 0;
|
||||||
|
FusionStrategy strategy(FusionStrategy::Sibling, memref);
|
||||||
|
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
|
||||||
|
FusionResult result = mlir::canFuseLoops(
|
||||||
|
sibAffineForOp, dstAffineForOp,
|
||||||
|
/*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
|
||||||
|
|
||||||
// Check if fusion would be profitable.
|
if (result.value == FusionResult::Success)
|
||||||
if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
|
maxLegalFusionDepth = i;
|
||||||
dstStoreOpInsts, &sliceState, &bestDstLoopDepth,
|
}
|
||||||
maximalFusion, computeToleranceThreshold))
|
|
||||||
|
// Skip if fusion is not feasible at any loop depths.
|
||||||
|
if (maxLegalFusionDepth == 0)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
|
unsigned bestDstLoopDepth = dstLoopDepthTest;
|
||||||
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
|
if (!maximalFusion) {
|
||||||
sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
|
// Check if fusion would be profitable.
|
||||||
if (sliceLoopNest != nullptr) {
|
if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
|
||||||
auto dstForInst = cast<AffineForOp>(dstNode->op);
|
depthSliceUnions, maxLegalFusionDepth,
|
||||||
// Update operation position of fused loop nest (if needed).
|
&bestDstLoopDepth, computeToleranceThreshold))
|
||||||
if (insertPointInst != dstForInst.getOperation()) {
|
continue;
|
||||||
dstForInst.getOperation()->moveBefore(insertPointInst);
|
|
||||||
}
|
|
||||||
// Update data dependence graph state post fusion.
|
|
||||||
updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
|
||||||
|
assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
|
||||||
|
"Fusion depth has no computed slice union");
|
||||||
|
|
||||||
|
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
|
||||||
|
mlir::fuseLoops(sibAffineForOp, dstAffineForOp,
|
||||||
|
depthSliceUnions[bestDstLoopDepth - 1]);
|
||||||
|
|
||||||
|
auto dstForInst = cast<AffineForOp>(dstNode->op);
|
||||||
|
// Update operation position of fused loop nest (if needed).
|
||||||
|
if (insertPointInst != dstForInst.getOperation()) {
|
||||||
|
dstForInst.getOperation()->moveBefore(insertPointInst);
|
||||||
|
}
|
||||||
|
// Update data dependence graph state post fusion.
|
||||||
|
updateStateAfterSiblingFusion(sibNode, dstNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1943,19 +1863,12 @@ public:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode,
|
/// Update data dependence graph state to reflect sibling fusion of 'sibNode'
|
||||||
Node *dstNode) {
|
/// into 'dstNode'.
|
||||||
|
void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
|
||||||
// Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
|
// Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
|
||||||
mdg->updateEdges(sibNode->id, dstNode->id);
|
mdg->updateEdges(sibNode->id, dstNode->id);
|
||||||
|
|
||||||
// Collect slice loop stats.
|
|
||||||
LoopNestStateCollector sliceCollector;
|
|
||||||
sliceCollector.collect(sliceLoopNest.getOperation());
|
|
||||||
// Promote single iteration slice loops to single IV value.
|
|
||||||
for (auto forOp : sliceCollector.forOps) {
|
|
||||||
promoteIfSingleIteration(forOp);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect dst loop stats after memref privatization transformation.
|
// Collect dst loop stats after memref privatization transformation.
|
||||||
auto dstForInst = cast<AffineForOp>(dstNode->op);
|
auto dstForInst = cast<AffineForOp>(dstNode->op);
|
||||||
LoopNestStateCollector dstLoopCollector;
|
LoopNestStateCollector dstLoopCollector;
|
||||||
|
|
|
@ -47,9 +47,9 @@ static void getLoadAndStoreMemRefAccesses(Operation *opA,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if 'op' is a load or store operation which access an memref
|
/// Returns true if 'op' is a load or store operation which access a memref
|
||||||
// accessed 'values' and at least one of the access is a store operation.
|
/// accessed 'values' and at least one of the access is a store operation.
|
||||||
// Returns false otherwise.
|
/// Returns false otherwise.
|
||||||
static bool isDependentLoadOrStoreOp(Operation *op,
|
static bool isDependentLoadOrStoreOp(Operation *op,
|
||||||
DenseMap<Value, bool> &values) {
|
DenseMap<Value, bool> &values) {
|
||||||
if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
|
if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
|
||||||
|
@ -187,26 +187,99 @@ gatherLoadsAndStores(AffineForOp forOp,
|
||||||
return !hasIfOp;
|
return !hasIfOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the maximum loop depth at which we could fuse producer loop
|
||||||
|
/// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
|
||||||
|
// TODO: Generalize this check for sibling and more generic fusion scenarios.
|
||||||
|
// TODO: Support forward slice fusion.
|
||||||
|
static unsigned getMaxLoopDepth(ArrayRef<Operation *> dstOps,
|
||||||
|
FusionStrategy fusionStrategy) {
|
||||||
|
assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer &&
|
||||||
|
"Fusion strategy not supported");
|
||||||
|
|
||||||
|
if (dstOps.empty())
|
||||||
|
// Expected at least one memory operation.
|
||||||
|
// TODO: Revisit this case with a specific example.
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
// Filter out ops in 'dstOps' that do not use the producer-consumer memref so
|
||||||
|
// that they are not considered for analysis.
|
||||||
|
// TODO: Currently, we pass the producer-consumer memref through
|
||||||
|
// fusionStrategy. We will retrieve the memrefs from 'srcOps' once we
|
||||||
|
// generalize the algorithm.
|
||||||
|
SmallVector<Operation *, 4> targetDstOps;
|
||||||
|
for (Operation *dstOp : dstOps) {
|
||||||
|
auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
|
||||||
|
Value memref = loadOp ? loadOp.getMemRef()
|
||||||
|
: cast<AffineWriteOpInterface>(dstOp).getMemRef();
|
||||||
|
if (memref == fusionStrategy.memref)
|
||||||
|
targetDstOps.push_back(dstOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(!targetDstOps.empty() &&
|
||||||
|
"No dependences between 'srcForOp' and 'dstForOp'?");
|
||||||
|
|
||||||
|
// Compute the innermost common loop depth for loads and stores.
|
||||||
|
unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
|
||||||
|
|
||||||
|
// Return common loop depth for loads if there are no store ops.
|
||||||
|
if (all_of(targetDstOps,
|
||||||
|
[&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
|
||||||
|
return loopDepth;
|
||||||
|
|
||||||
|
// Check dependences on all pairs of ops in 'targetDstOps' and store the
|
||||||
|
// minimum loop depth at which a dependence is satisfied.
|
||||||
|
for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
|
||||||
|
auto *srcOpInst = targetDstOps[i];
|
||||||
|
MemRefAccess srcAccess(srcOpInst);
|
||||||
|
for (unsigned j = 0; j < e; ++j) {
|
||||||
|
auto *dstOpInst = targetDstOps[j];
|
||||||
|
MemRefAccess dstAccess(dstOpInst);
|
||||||
|
|
||||||
|
unsigned numCommonLoops =
|
||||||
|
getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
|
||||||
|
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
|
||||||
|
FlatAffineConstraints dependenceConstraints;
|
||||||
|
// TODO: Cache dependence analysis results, check cache here.
|
||||||
|
DependenceResult result = checkMemrefAccessDependence(
|
||||||
|
srcAccess, dstAccess, d, &dependenceConstraints,
|
||||||
|
/*dependenceComponents=*/nullptr);
|
||||||
|
if (hasDependence(result)) {
|
||||||
|
// Store minimum loop depth and break because we want the min 'd' at
|
||||||
|
// which there is a dependence.
|
||||||
|
loopDepth = std::min(loopDepth, d - 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return loopDepth;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Prevent fusion of loop nests with side-effecting operations.
|
// TODO: Prevent fusion of loop nests with side-effecting operations.
|
||||||
|
// TODO: This pass performs some computation that is the same for all the depths
|
||||||
|
// (e.g., getMaxLoopDepth). Implement a version of this utility that processes
|
||||||
|
// all the depths at once or only the legal maximal depth for maximal fusion.
|
||||||
FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
||||||
unsigned dstLoopDepth,
|
unsigned dstLoopDepth,
|
||||||
ComputationSliceState *srcSlice) {
|
ComputationSliceState *srcSlice,
|
||||||
|
FusionStrategy fusionStrategy) {
|
||||||
// Return 'failure' if 'dstLoopDepth == 0'.
|
// Return 'failure' if 'dstLoopDepth == 0'.
|
||||||
if (dstLoopDepth == 0) {
|
if (dstLoopDepth == 0) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
|
LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
|
||||||
return FusionResult::FailPrecondition;
|
return FusionResult::FailPrecondition;
|
||||||
}
|
}
|
||||||
// Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
|
// Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
|
||||||
auto *block = srcForOp.getOperation()->getBlock();
|
auto *block = srcForOp.getOperation()->getBlock();
|
||||||
if (block != dstForOp.getOperation()->getBlock()) {
|
if (block != dstForOp.getOperation()->getBlock()) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n.");
|
LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
|
||||||
return FusionResult::FailPrecondition;
|
return FusionResult::FailPrecondition;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return 'failure' if no valid insertion point for fused loop nest in 'block'
|
// Return 'failure' if no valid insertion point for fused loop nest in 'block'
|
||||||
// exists which would preserve dependences.
|
// exists which would preserve dependences.
|
||||||
if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
|
if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
|
LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
|
||||||
return FusionResult::FailBlockDependence;
|
return FusionResult::FailBlockDependence;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,25 +293,68 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
||||||
// Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
|
// Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
|
||||||
SmallVector<Operation *, 4> opsA;
|
SmallVector<Operation *, 4> opsA;
|
||||||
if (!gatherLoadsAndStores(forOpA, opsA)) {
|
if (!gatherLoadsAndStores(forOpA, opsA)) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
|
LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
|
||||||
return FusionResult::FailPrecondition;
|
return FusionResult::FailPrecondition;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
|
// Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
|
||||||
SmallVector<Operation *, 4> opsB;
|
SmallVector<Operation *, 4> opsB;
|
||||||
if (!gatherLoadsAndStores(forOpB, opsB)) {
|
if (!gatherLoadsAndStores(forOpB, opsB)) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
|
LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
|
||||||
return FusionResult::FailPrecondition;
|
return FusionResult::FailPrecondition;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
|
||||||
|
// loop dependences.
|
||||||
|
// TODO: Enable this check for sibling and more generic loop fusion
|
||||||
|
// strategies.
|
||||||
|
if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) {
|
||||||
|
// TODO: 'getMaxLoopDepth' does not support forward slice fusion.
|
||||||
|
assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
|
||||||
|
if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
|
||||||
|
return FusionResult::FailFusionDependence;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
|
// Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
|
||||||
unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
|
unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
|
||||||
*srcForOp.getOperation(), *dstForOp.getOperation());
|
*srcForOp.getOperation(), *dstForOp.getOperation());
|
||||||
|
|
||||||
|
// Filter out ops in 'opsA' to compute the slice union based on the
|
||||||
|
// assumptions made by the fusion strategy.
|
||||||
|
SmallVector<Operation *, 4> strategyOpsA;
|
||||||
|
switch (fusionStrategy.strategy) {
|
||||||
|
case FusionStrategy::Generic:
|
||||||
|
// Generic fusion. Take into account all the memory operations to compute
|
||||||
|
// the slice union.
|
||||||
|
strategyOpsA.append(opsA.begin(), opsA.end());
|
||||||
|
break;
|
||||||
|
case FusionStrategy::ProducerConsumer:
|
||||||
|
// Producer-consumer fusion (AffineLoopFusion pass) only takes into
|
||||||
|
// account stores to 'memref' in 'srcForOp' to compute the slice union.
|
||||||
|
for (Operation *op : opsA) {
|
||||||
|
auto store = dyn_cast<AffineWriteOpInterface>(op);
|
||||||
|
if (store && store.getMemRef() == fusionStrategy.memref)
|
||||||
|
strategyOpsA.push_back(op);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case FusionStrategy::Sibling:
|
||||||
|
// Sibling fusion (AffineLoopFusion pass) only takes into account the loads
|
||||||
|
// to 'memref' in 'srcForOp' to compute the slice union.
|
||||||
|
for (Operation *op : opsA) {
|
||||||
|
auto load = dyn_cast<AffineReadOpInterface>(op);
|
||||||
|
if (load && load.getMemRef() == fusionStrategy.memref)
|
||||||
|
strategyOpsA.push_back(op);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Compute union of computation slices computed between all pairs of ops
|
// Compute union of computation slices computed between all pairs of ops
|
||||||
// from 'forOpA' and 'forOpB'.
|
// from 'forOpA' and 'forOpB'.
|
||||||
if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops,
|
if (failed(mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth,
|
||||||
isSrcForOpBeforeDstForOp, srcSlice))) {
|
numCommonLoops, isSrcForOpBeforeDstForOp,
|
||||||
|
srcSlice))) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
|
LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
|
||||||
return FusionResult::FailPrecondition;
|
return FusionResult::FailPrecondition;
|
||||||
}
|
}
|
||||||
|
@ -249,24 +365,30 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
||||||
/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
|
/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
|
||||||
/// and source slice loop bounds specified in 'srcSlice'.
|
/// and source slice loop bounds specified in 'srcSlice'.
|
||||||
void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
||||||
ComputationSliceState *srcSlice) {
|
const ComputationSliceState &srcSlice) {
|
||||||
// Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
|
// Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
|
||||||
OpBuilder b(srcSlice->insertPoint->getBlock(), srcSlice->insertPoint);
|
OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
|
||||||
BlockAndValueMapping mapper;
|
BlockAndValueMapping mapper;
|
||||||
b.clone(*srcForOp, mapper);
|
b.clone(*srcForOp, mapper);
|
||||||
|
|
||||||
// Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
|
// Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
|
||||||
SmallVector<AffineForOp, 4> sliceLoops;
|
SmallVector<AffineForOp, 4> sliceLoops;
|
||||||
for (unsigned i = 0, e = srcSlice->ivs.size(); i < e; ++i) {
|
for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
|
||||||
auto loopIV = mapper.lookupOrNull(srcSlice->ivs[i]);
|
auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
|
||||||
if (!loopIV)
|
if (!loopIV)
|
||||||
continue;
|
continue;
|
||||||
auto forOp = getForInductionVarOwner(loopIV);
|
auto forOp = getForInductionVarOwner(loopIV);
|
||||||
sliceLoops.push_back(forOp);
|
sliceLoops.push_back(forOp);
|
||||||
if (AffineMap lbMap = srcSlice->lbs[i])
|
if (AffineMap lbMap = srcSlice.lbs[i]) {
|
||||||
forOp.setLowerBound(srcSlice->lbOperands[i], lbMap);
|
auto lbOperands = srcSlice.lbOperands[i];
|
||||||
if (AffineMap ubMap = srcSlice->ubs[i])
|
canonicalizeMapAndOperands(&lbMap, &lbOperands);
|
||||||
forOp.setUpperBound(srcSlice->ubOperands[i], ubMap);
|
forOp.setLowerBound(lbOperands, lbMap);
|
||||||
|
}
|
||||||
|
if (AffineMap ubMap = srcSlice.ubs[i]) {
|
||||||
|
auto ubOperands = srcSlice.ubOperands[i];
|
||||||
|
canonicalizeMapAndOperands(&ubMap, &ubOperands);
|
||||||
|
forOp.setUpperBound(ubOperands, ubMap);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Promote any single iteration slice loops.
|
// Promote any single iteration slice loops.
|
||||||
|
@ -393,15 +515,15 @@ static uint64_t getSliceIterationCount(
|
||||||
// was encountered).
|
// was encountered).
|
||||||
// TODO: Make this work with non-unit step loops.
|
// TODO: Make this work with non-unit step loops.
|
||||||
static bool buildSliceTripCountMap(
|
static bool buildSliceTripCountMap(
|
||||||
ComputationSliceState *slice,
|
const ComputationSliceState &slice,
|
||||||
llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
|
llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
|
||||||
unsigned numSrcLoopIVs = slice->ivs.size();
|
unsigned numSrcLoopIVs = slice.ivs.size();
|
||||||
// Populate map from AffineForOp -> trip count
|
// Populate map from AffineForOp -> trip count
|
||||||
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||||||
AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]);
|
AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
|
||||||
auto *op = forOp.getOperation();
|
auto *op = forOp.getOperation();
|
||||||
AffineMap lbMap = slice->lbs[i];
|
AffineMap lbMap = slice.lbs[i];
|
||||||
AffineMap ubMap = slice->ubs[i];
|
AffineMap ubMap = slice.ubs[i];
|
||||||
if (lbMap == AffineMap() || ubMap == AffineMap()) {
|
if (lbMap == AffineMap() || ubMap == AffineMap()) {
|
||||||
// The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
|
// The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
|
||||||
if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
|
if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
|
||||||
|
@ -442,7 +564,7 @@ int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
|
||||||
/// the entire loop nest.
|
/// the entire loop nest.
|
||||||
bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
|
bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
|
||||||
AffineForOp dstForOp, LoopNestStats &dstStats,
|
AffineForOp dstForOp, LoopNestStats &dstStats,
|
||||||
ComputationSliceState *slice,
|
const ComputationSliceState &slice,
|
||||||
int64_t *computeCost) {
|
int64_t *computeCost) {
|
||||||
llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
|
llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
|
||||||
DenseMap<Operation *, int64_t> computeCostMap;
|
DenseMap<Operation *, int64_t> computeCostMap;
|
||||||
|
@ -454,7 +576,7 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
|
||||||
int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
|
int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
|
||||||
assert(sliceIterationCount > 0);
|
assert(sliceIterationCount > 0);
|
||||||
bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
|
bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
|
||||||
auto *insertPointParent = slice->insertPoint->getParentOp();
|
auto *insertPointParent = slice.insertPoint->getParentOp();
|
||||||
|
|
||||||
// The store and loads to this memref will disappear.
|
// The store and loads to this memref will disappear.
|
||||||
// TODO: Add load coalescing to memref data flow opt pass.
|
// TODO: Add load coalescing to memref data flow opt pass.
|
||||||
|
|
|
@ -129,7 +129,7 @@ static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
|
||||||
mlir::ComputationSliceState sliceUnion;
|
mlir::ComputationSliceState sliceUnion;
|
||||||
FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
|
FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
|
||||||
if (result.value == FusionResult::Success) {
|
if (result.value == FusionResult::Success) {
|
||||||
mlir::fuseLoops(forOpA, forOpB, &sliceUnion);
|
mlir::fuseLoops(forOpA, forOpB, sliceUnion);
|
||||||
// Note: 'forOpA' is removed to simplify test output. A proper loop
|
// Note: 'forOpA' is removed to simplify test output. A proper loop
|
||||||
// fusion pass should check the data dependence graph and run memref
|
// fusion pass should check the data dependence graph and run memref
|
||||||
// region analysis to ensure removing 'forOpA' is safe.
|
// region analysis to ensure removing 'forOpA' is safe.
|
||||||
|
|
Loading…
Reference in New Issue