forked from OSchip/llvm-project
Adds the ability to compute the MemRefRegion of a sliced loop nest. Utilizes this feature during loop fusion cost computation, to compute what the write region of a fusion candidate loop nest slice would be (without having to materialize the slice or change the IR).
*) Adds parameter to public API of MemRefRegion::compute for passing in the slice loop bounds to compute the memref region of the loop nest slice. *) Exposes public method MemRefRegion::getRegionSize for computing the size of the memref region in bytes. PiperOrigin-RevId: 232706165
This commit is contained in:
parent
31f2b3ffa1
commit
b9dde91ea6
|
@ -61,6 +61,45 @@ void getLoopIVs(const Instruction &inst,
|
|||
/// surrounding this instruction.
|
||||
unsigned getNestingDepth(const Instruction &stmt);
|
||||
|
||||
/// ComputationSliceState aggregates loop bound AffineMaps and their associated
|
||||
/// operands for a set of loops within a loop nest (typically the set of loops
|
||||
/// surrounding a store operation). Loop bound AffineMaps which are non-null
|
||||
/// represent slices of that loop's iteration space.
|
||||
struct ComputationSliceState {
|
||||
// List of lower bound AffineMaps.
|
||||
SmallVector<AffineMap, 4> lbs;
|
||||
// List of upper bound AffineMaps.
|
||||
SmallVector<AffineMap, 4> ubs;
|
||||
// List of lower bound operands (lbOperands[i] are used by 'lbs[i]').
|
||||
std::vector<SmallVector<Value *, 4>> lbOperands;
|
||||
// List of upper bound operands (ubOperands[i] are used by 'ubs[i]').
|
||||
std::vector<SmallVector<Value *, 4>> ubOperands;
|
||||
};
|
||||
|
||||
/// Computes computation slice loop bounds for the loop nest surrounding
|
||||
/// 'srcAccess', where the returned loop bound AffineMaps are functions of
|
||||
/// loop IVs from the loop nest surrounding 'dstAccess'.
|
||||
/// Returns true on success, false otherwise.
|
||||
bool getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||
const MemRefAccess &dstAccess,
|
||||
unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState);
|
||||
|
||||
/// Creates a clone of the computation contained in the loop nest surrounding
|
||||
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
|
||||
/// in 'sliceState', and inserts the computation slice at the beginning of the
|
||||
/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding
|
||||
/// 'dstOpInst'. Returns the top-level loop of the computation slice on
|
||||
/// success, returns nullptr otherwise.
|
||||
// Loop depth is a crucial optimization choice that determines where to
|
||||
// materialize the results of the backward slice - presenting a trade-off b/w
|
||||
// storage and redundant computation in several cases.
|
||||
// TODO(andydavis) Support computation slices with common surrounding loops.
|
||||
OpPointer<AffineForOp>
|
||||
insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst,
|
||||
unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState);
|
||||
|
||||
/// A region of a memref's data space; this is typically constructed by
|
||||
/// analyzing load/store op's on this memref and the index space of loops
|
||||
/// surrounding such op's.
|
||||
|
@ -86,7 +125,17 @@ struct MemRefRegion {
|
|||
/// symbolic identifiers which could include any of the loop IVs surrounding
|
||||
/// opInst up until 'loopDepth' and another additional Function symbols
|
||||
/// involved with the access (for eg., those appear in affine_apply's, loop
|
||||
/// bounds, etc.).
|
||||
/// bounds, etc.). If 'sliceState' is non-null, operands from 'sliceState'
|
||||
/// are added as symbols, and the following constraints are added to the
|
||||
/// system:
|
||||
/// *) Inequality constraints which represent loop bounds for 'sliceState'
|
||||
/// operands which are loop IVS (these represent the destination loop IVs
|
||||
/// of the slice, and are added as symbols to MemRefRegion's constraint
|
||||
/// system).
|
||||
/// *) Inequality constraints for the slice bounds in 'sliceState', which
|
||||
/// represent the bounds on the loop IVs in this constraint system w.r.t
|
||||
/// to slice operands (which correspond to symbols).
|
||||
///
|
||||
/// For example, the memref region for this operation at loopDepth = 1 will
|
||||
/// be:
|
||||
///
|
||||
|
@ -99,7 +148,8 @@ struct MemRefRegion {
|
|||
/// {memref = %A, write = false, {%i <= m0 <= %i + 7} }
|
||||
/// The last field is a 2-d FlatAffineConstraints symbolic in %i.
|
||||
///
|
||||
bool compute(Instruction *inst, unsigned loopDepth);
|
||||
bool compute(Instruction *inst, unsigned loopDepth,
|
||||
ComputationSliceState *sliceState = nullptr);
|
||||
|
||||
FlatAffineConstraints *getConstraints() { return &cst; }
|
||||
const FlatAffineConstraints *getConstraints() const { return &cst; }
|
||||
|
@ -128,6 +178,9 @@ struct MemRefRegion {
|
|||
return cst.getConstantBoundOnDimSize(pos, lb);
|
||||
}
|
||||
|
||||
/// Returns the size of this MemRefRegion in bytes.
|
||||
Optional<int64_t> getRegionSize();
|
||||
|
||||
bool unionBoundingBox(const MemRefRegion &other);
|
||||
|
||||
/// Returns the rank of the memref that this region corresponds to.
|
||||
|
@ -169,52 +222,12 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
|
|||
unsigned getNumCommonSurroundingLoops(const Instruction &A,
|
||||
const Instruction &B);
|
||||
|
||||
/// ComputationSliceState aggregates loop bound AffineMaps and their associated
|
||||
/// operands for a set of loops within a loop nest (typically the set of loops
|
||||
/// surrounding a store operation). Loop bound AffineMaps which are non-null
|
||||
/// represent slices of that loop's iteration space.
|
||||
struct ComputationSliceState {
|
||||
// List of lower bound AffineMaps.
|
||||
SmallVector<AffineMap, 4> lbs;
|
||||
// List of upper bound AffineMaps.
|
||||
SmallVector<AffineMap, 4> ubs;
|
||||
// List of lower bound operands (lbOperands[i] are used by 'lbs[i]').
|
||||
std::vector<SmallVector<Value *, 4>> lbOperands;
|
||||
// List of upper bound operands (ubOperands[i] are used by 'ubs[i]').
|
||||
std::vector<SmallVector<Value *, 4>> ubOperands;
|
||||
};
|
||||
|
||||
/// Computes computation slice loop bounds for the loop nest surrounding
|
||||
/// 'srcAccess', where the returned loop bound AffineMaps are functions of
|
||||
/// loop IVs from the loop nest surrounding 'dstAccess'.
|
||||
/// Returns true on success, false otherwise.
|
||||
bool getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
||||
const MemRefAccess &dstAccess,
|
||||
unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState);
|
||||
|
||||
/// Creates a clone of the computation contained in the loop nest surrounding
|
||||
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
|
||||
/// in 'sliceState', and inserts the computation slice at the beginning of the
|
||||
/// instruction block of the loop at 'dstLoopDepth' in the loop nest surrounding
|
||||
/// 'dstOpInst'. Returns the top-level loop of the computation slice on
|
||||
/// success, returns nullptr otherwise.
|
||||
// Loop depth is a crucial optimization choice that determines where to
|
||||
// materialize the results of the backward slice - presenting a trade-off b/w
|
||||
// storage and redundant computation in several cases.
|
||||
// TODO(andydavis) Support computation slices with common surrounding loops.
|
||||
OpPointer<AffineForOp>
|
||||
insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst,
|
||||
unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState);
|
||||
|
||||
/// Gets the memory footprint of all data touched in the specified memory space
|
||||
/// in bytes; if the memory space is unspecified, considers all memory spaces.
|
||||
Optional<int64_t> getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp,
|
||||
int memorySpace = -1);
|
||||
Optional<int64_t> getMemoryFootprintBytes(const Block &block,
|
||||
int memorySpace = -1);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_UTILS_H
|
||||
|
|
|
@ -378,6 +378,17 @@ public:
|
|||
SmallVectorImpl<AffineMap> *lbMaps,
|
||||
SmallVectorImpl<AffineMap> *ubMaps);
|
||||
|
||||
/// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
|
||||
/// bounds in 'ubMaps' to the constraint system. Note that both lower/upper
|
||||
/// bounds share the same operand list 'operands'.
|
||||
/// This function assumes that position 'lbMaps.size' == 'ubMaps.size',
|
||||
/// and that positions [0, lbMaps.size) represent dimensional identifiers
|
||||
/// which correspond to the loop IVs whose iteration bounds are being sliced.
|
||||
/// Note that both lower/upper bounds use operands from 'operands'.
|
||||
/// Returns true on success, returns false for unimplemented cases.
|
||||
bool addSliceBounds(ArrayRef<AffineMap> lbMaps, ArrayRef<AffineMap> ubMaps,
|
||||
ArrayRef<Value *> operands);
|
||||
|
||||
// Adds an inequality (>= 0) from the coefficients specified in inEq.
|
||||
void addInequality(ArrayRef<int64_t> inEq);
|
||||
// Adds an equality from the coefficients specified in eq.
|
||||
|
|
|
@ -122,7 +122,8 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
|
|||
//
|
||||
// TODO(bondhugula): extend this to any other memref dereferencing ops
|
||||
// (dma_start, dma_wait).
|
||||
bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
|
||||
bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
|
||||
ComputationSliceState *sliceState) {
|
||||
assert((inst->isa<LoadOp>() || inst->isa<StoreOp>()) &&
|
||||
"load/store op expected");
|
||||
|
||||
|
@ -147,18 +148,33 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
|
|||
access.getAccessMap(&accessValueMap);
|
||||
AffineMap accessMap = accessValueMap.getAffineMap();
|
||||
|
||||
unsigned numDims = accessMap.getNumDims();
|
||||
unsigned numSymbols = accessMap.getNumSymbols();
|
||||
unsigned numOperands = accessValueMap.getNumOperands();
|
||||
// Merge operands with slice operands.
|
||||
SmallVector<Value *, 4> operands;
|
||||
operands.resize(numOperands);
|
||||
for (unsigned i = 0; i < numOperands; ++i)
|
||||
operands[i] = accessValueMap.getOperand(i);
|
||||
|
||||
if (sliceState != nullptr) {
|
||||
// Append slice operands to 'operands' as symbols.
|
||||
operands.append(sliceState->lbOperands[0].begin(),
|
||||
sliceState->lbOperands[0].end());
|
||||
// Update 'numSymbols' by operands from 'sliceState'.
|
||||
numSymbols += sliceState->lbOperands[0].size();
|
||||
}
|
||||
|
||||
// We'll first associate the dims and symbols of the access map to the dims
|
||||
// and symbols resp. of cst. This will change below once cst is
|
||||
// fully constructed out.
|
||||
cst.reset(accessMap.getNumDims(), accessMap.getNumSymbols(), 0,
|
||||
accessValueMap.getOperands());
|
||||
cst.reset(numDims, numSymbols, 0, operands);
|
||||
|
||||
// Add equality constraints.
|
||||
unsigned numDims = accessMap.getNumDims();
|
||||
unsigned numSymbols = accessMap.getNumSymbols();
|
||||
// Add inequalties for loop lower/upper bounds.
|
||||
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
|
||||
if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
|
||||
auto *operand = operands[i];
|
||||
if (auto loop = getForInductionVarOwner(operand)) {
|
||||
// Note that cst can now have more dimensions than accessMap if the
|
||||
// bounds expressions involve outer loops or other symbols.
|
||||
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
|
||||
|
@ -167,7 +183,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
|
|||
return false;
|
||||
} else {
|
||||
// Has to be a valid symbol.
|
||||
auto *symbol = accessValueMap.getOperand(i);
|
||||
auto *symbol = operand;
|
||||
assert(isValidSymbol(symbol));
|
||||
// Check if the symbol is a constant.
|
||||
if (auto *inst = symbol->getDefiningInst()) {
|
||||
|
@ -178,6 +194,33 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
|
|||
}
|
||||
}
|
||||
|
||||
// Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
|
||||
if (sliceState != nullptr) {
|
||||
// Add dim and symbol slice operands.
|
||||
for (const auto &operand : sliceState->lbOperands[0]) {
|
||||
unsigned loc;
|
||||
if (!cst.findId(*operand, &loc)) {
|
||||
if (isValidSymbol(operand)) {
|
||||
cst.addSymbolId(cst.getNumSymbolIds(), const_cast<Value *>(operand));
|
||||
loc = cst.getNumDimIds() + cst.getNumSymbolIds() - 1;
|
||||
// Check if the symbol is a constant.
|
||||
if (auto *opInst = operand->getDefiningInst()) {
|
||||
if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
|
||||
cst.setIdToConstant(*operand, constOp->getValue());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cst.addDimId(cst.getNumDimIds(), const_cast<Value *>(operand));
|
||||
loc = cst.getNumDimIds() - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add upper/lower bounds from 'sliceState' to 'cst'.
|
||||
if (!cst.addSliceBounds(sliceState->lbs, sliceState->ubs,
|
||||
sliceState->lbOperands[0]))
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add access function equalities to connect loop IVs to data dimensions.
|
||||
if (!cst.composeMap(&accessValueMap)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n");
|
||||
|
@ -233,6 +276,32 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
|||
return llvm::divideCeil(sizeInBits, 8);
|
||||
}
|
||||
|
||||
// Returns the size of the region.
|
||||
Optional<int64_t> MemRefRegion::getRegionSize() {
|
||||
auto memRefType = memref->getType().cast<MemRefType>();
|
||||
|
||||
auto layoutMaps = memRefType.getAffineMaps();
|
||||
if (layoutMaps.size() > 1 ||
|
||||
(layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Indices to use for the DmaStart op.
|
||||
// Indices for the original memref being DMAed from/to.
|
||||
SmallVector<Value *, 4> memIndices;
|
||||
// Indices for the faster buffer being DMAed into/from.
|
||||
SmallVector<Value *, 4> bufIndices;
|
||||
|
||||
// Compute the extents of the buffer.
|
||||
Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
|
||||
if (!numElements.hasValue()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
|
||||
return None;
|
||||
}
|
||||
return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
|
||||
}
|
||||
|
||||
/// Returns the size of memref data in bytes if it's statically shaped, None
|
||||
/// otherwise. If the element of the memref has vector type, takes into account
|
||||
/// size of the vector as well.
|
||||
|
@ -420,8 +489,6 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
|
|||
// entire destination index set. Subtract out the dependent destination
|
||||
// iterations from destination index set and check for emptiness --- this is one
|
||||
// solution.
|
||||
// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project
|
||||
// out loop IVs we don't care about and produce smaller slice.
|
||||
OpPointer<AffineForOp> mlir::insertBackwardComputationSlice(
|
||||
Instruction *srcOpInst, Instruction *dstOpInst, unsigned dstLoopDepth,
|
||||
ComputationSliceState *sliceState) {
|
||||
|
@ -537,33 +604,6 @@ unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A,
|
|||
return numCommonLoops;
|
||||
}
|
||||
|
||||
// Returns the size of the region.
|
||||
static Optional<int64_t> getRegionSize(const MemRefRegion ®ion) {
|
||||
auto *memref = region.memref;
|
||||
auto memRefType = memref->getType().cast<MemRefType>();
|
||||
|
||||
auto layoutMaps = memRefType.getAffineMaps();
|
||||
if (layoutMaps.size() > 1 ||
|
||||
(layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Indices to use for the DmaStart op.
|
||||
// Indices for the original memref being DMAed from/to.
|
||||
SmallVector<Value *, 4> memIndices;
|
||||
// Indices for the faster buffer being DMAed into/from.
|
||||
SmallVector<Value *, 4> bufIndices;
|
||||
|
||||
// Compute the extents of the buffer.
|
||||
Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape();
|
||||
if (!numElements.hasValue()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
|
||||
return None;
|
||||
}
|
||||
return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
|
||||
}
|
||||
|
||||
Optional<int64_t>
|
||||
mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp,
|
||||
int memorySpace) {
|
||||
|
@ -601,7 +641,7 @@ Optional<int64_t> mlir::getMemoryFootprintBytes(const Block &block,
|
|||
|
||||
int64_t totalSizeInBytes = 0;
|
||||
for (const auto ®ion : regions) {
|
||||
auto size = getRegionSize(*region);
|
||||
auto size = region->getRegionSize();
|
||||
if (!size.hasValue())
|
||||
return None;
|
||||
totalSizeInBytes += size.getValue();
|
||||
|
|
|
@ -1129,6 +1129,66 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
|
|||
}
|
||||
}
|
||||
|
||||
// Adds slice lower/upper bounds from 'lbMaps'/'upMaps' to the constraint
|
||||
// system. This function assumes that position 'lbMaps.size' == 'ubMaps.size',
|
||||
// and that positions [0, lbMaps.size) represent dimensional identifiers which
|
||||
// correspond to the loop IVs whose iteration bounds are being sliced.
|
||||
// Note that both lower/upper bounds use operands from 'operands'.
|
||||
// Returns true on success. Returns false for unimplemented cases such as
|
||||
// semi-affine expressions or expressions with mod/floordiv.
|
||||
bool FlatAffineConstraints::addSliceBounds(ArrayRef<AffineMap> lbMaps,
|
||||
ArrayRef<AffineMap> ubMaps,
|
||||
ArrayRef<Value *> operands) {
|
||||
assert(lbMaps.size() == ubMaps.size());
|
||||
// Record positions of the operands in the constraint system.
|
||||
SmallVector<unsigned, 8> positions;
|
||||
for (const auto &operand : operands) {
|
||||
unsigned loc;
|
||||
if (!findId(*operand, &loc))
|
||||
assert(0 && "expected to be found");
|
||||
positions.push_back(loc);
|
||||
}
|
||||
|
||||
auto addLowerOrUpperBound = [&](unsigned pos, AffineMap boundMap,
|
||||
bool lower) -> bool {
|
||||
FlatAffineConstraints localVarCst;
|
||||
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
||||
if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
|
||||
return false;
|
||||
}
|
||||
if (localVarCst.getNumLocalIds() > 0) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "loop bounds with mod/floordiv expr's not yet supported\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto &flatExpr : flatExprs) {
|
||||
SmallVector<int64_t, 4> ineq(getNumCols(), 0);
|
||||
ineq[pos] = lower ? 1 : -1;
|
||||
for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
|
||||
ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
|
||||
}
|
||||
// Constant term.
|
||||
ineq[getNumCols() - 1] =
|
||||
lower ? -flatExpr[flatExpr.size() - 1]
|
||||
// Upper bound in flattenedExpr is an exclusive one.
|
||||
: flatExpr[flatExpr.size() - 1] - 1;
|
||||
addInequality(ineq);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
|
||||
if (!addLowerOrUpperBound(i, lbMaps[i], /*lower=*/true))
|
||||
return false;
|
||||
if (!addLowerOrUpperBound(i, ubMaps[i], /*lower=*/false))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
|
||||
assert(eq.size() == getNumCols());
|
||||
unsigned offset = equalities.size();
|
||||
|
|
|
@ -1118,12 +1118,23 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
/*tripCountOverrideMap=*/nullptr,
|
||||
/*computeCostMap=*/nullptr);
|
||||
|
||||
// Compute src loop nest write region size.
|
||||
MemRefRegion srcWriteRegion(srcOpInst->getLoc());
|
||||
srcWriteRegion.compute(srcOpInst, /*loopDepth=*/0);
|
||||
Optional<int64_t> maybeSrcWriteRegionSizeBytes =
|
||||
srcWriteRegion.getRegionSize();
|
||||
if (!maybeSrcWriteRegionSizeBytes.hasValue())
|
||||
return false;
|
||||
int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
|
||||
|
||||
// Compute op instance count for the src loop nest.
|
||||
uint64_t dstLoopNestCost =
|
||||
getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
|
||||
/*tripCountOverrideMap=*/nullptr,
|
||||
/*computeCostMap=*/nullptr);
|
||||
|
||||
// Evaluate all depth choices for materializing the slice in the destination
|
||||
// loop nest.
|
||||
llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
|
||||
DenseMap<Instruction *, int64_t> computeCostMap;
|
||||
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
|
||||
|
@ -1187,11 +1198,21 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
(static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
|
||||
1;
|
||||
|
||||
// TODO(bondhugula): This is an ugly approximation. Fix this by finding a
|
||||
// good way to calculate the footprint of the memref in the slice and
|
||||
// divide it by the total memory footprint of the fused computation.
|
||||
double storageReduction =
|
||||
static_cast<double>(srcLoopNestCost) / sliceIterationCount;
|
||||
// Compute what the slice write MemRefRegion would be, if the src loop
|
||||
// nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
|
||||
// nest at loop depth 'i'
|
||||
MemRefRegion sliceWriteRegion(srcOpInst->getLoc());
|
||||
sliceWriteRegion.compute(srcOpInst, /*loopDepth=*/0, &sliceStates[i - 1]);
|
||||
Optional<int64_t> maybeSliceWriteRegionSizeBytes =
|
||||
sliceWriteRegion.getRegionSize();
|
||||
if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
|
||||
maybeSliceWriteRegionSizeBytes.getValue() == 0)
|
||||
continue;
|
||||
int64_t sliceWriteRegionSizeBytes =
|
||||
maybeSliceWriteRegionSizeBytes.getValue();
|
||||
|
||||
double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
|
||||
static_cast<double>(sliceWriteRegionSizeBytes);
|
||||
|
||||
LLVM_DEBUG({
|
||||
std::stringstream msg;
|
||||
|
@ -1219,12 +1240,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
maxStorageReduction = storageReduction;
|
||||
bestDstLoopDepth = i;
|
||||
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
|
||||
// TODO(bondhugula,andydavis): find a good way to compute the memory
|
||||
// footprint of the materialized slice.
|
||||
// Approximating this to the compute cost of the slice. This could be an
|
||||
// under-approximation or an overapproximation, but in many cases
|
||||
// accurate.
|
||||
sliceMemEstimate = sliceIterationCount;
|
||||
sliceMemEstimate = sliceWriteRegionSizeBytes;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue