forked from OSchip/llvm-project
Some loop fusion code cleanup/simplification post cl/229575126
- enforce the assumptions better / in a simpler way PiperOrigin-RevId: 229612424
This commit is contained in:
parent
3766332533
commit
c1ca23ef6e
|
@ -465,8 +465,8 @@ static uint64_t getComputeCost(
|
|||
} // end anonymous namespace
|
||||
|
||||
static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
|
||||
assert(lbMap.getNumResults() == 1);
|
||||
assert(ubMap.getNumResults() == 1);
|
||||
assert(lbMap.getNumResults() == 1 && "expected single result bound map");
|
||||
assert(ubMap.getNumResults() == 1 && "expected single result bound map");
|
||||
assert(lbMap.getNumDims() == ubMap.getNumDims());
|
||||
assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
|
||||
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
|
||||
|
@ -560,33 +560,16 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
|
|||
return loopDepth;
|
||||
}
|
||||
|
||||
// Returns true if 'map' is a single result constant or single result
|
||||
// dim expr where its corresponding loop IV in 'operands' has zero constant
|
||||
// lower bound.
|
||||
static bool hasZeroMinValue(AffineMap map, ArrayRef<Value *> operands) {
|
||||
if (map.isSingleConstant() && map.getSingleConstantResult() == 0)
|
||||
return true;
|
||||
if (map.getNumResults() != 1 || !map.getResult(0).isa<AffineDimExpr>())
|
||||
return false;
|
||||
// Get operand position of single dim expr result.
|
||||
unsigned pos = map.getResult(0).cast<AffineDimExpr>().getPosition();
|
||||
// Check if loop IV at 'pos' has zero constant lower bound.
|
||||
auto *operand = operands[pos];
|
||||
assert(isa<ForInst>(operand));
|
||||
auto *forInst = cast<ForInst>(operand);
|
||||
return forInst->hasConstantLowerBound() &&
|
||||
forInst->getConstantLowerBound() == 0;
|
||||
}
|
||||
// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in
|
||||
// 'sliceStateB'.
|
||||
// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
|
||||
// using a rectangular bounding box.
|
||||
// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
|
||||
// and 'sliceStateB' are aligned.
|
||||
// Specifically, when taking the union of overlapping intervals, it assumes
|
||||
// that both intervals start at zero. Support needs to be added to take into
|
||||
// account interval start offset when computing the union.
|
||||
// TODO(andydavis) Move this function to an analysis library.
|
||||
static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
||||
ComputationSliceState *sliceStateB) {
|
||||
static bool getSliceUnion(const ComputationSliceState &sliceStateA,
|
||||
ComputationSliceState *sliceStateB) {
|
||||
assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
|
||||
assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
|
||||
|
||||
|
@ -597,10 +580,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
|||
assert(ubMapA == AffineMap::Null());
|
||||
continue;
|
||||
}
|
||||
assert(ubMapA != AffineMap::Null());
|
||||
// Validate that constant lower bounds are aligned at zero.
|
||||
if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i]))
|
||||
return false;
|
||||
assert(ubMapA && "expected non-null ub map");
|
||||
|
||||
AffineMap lbMapB = sliceStateB->lbs[i];
|
||||
AffineMap ubMapB = sliceStateB->ubs[i];
|
||||
|
@ -611,8 +591,13 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
|||
sliceStateB->ubs[i] = ubMapA;
|
||||
continue;
|
||||
}
|
||||
// Validate that constant lower bounds are aligned at zero.
|
||||
if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i]))
|
||||
|
||||
// TODO(andydavis) Change this code to take the min across all lower bounds
|
||||
// and max across all upper bounds for each dimension. This code can for
|
||||
// cases where a unique min or max could not be statically determined.
|
||||
|
||||
// Assumption: both lower bounds are the same.
|
||||
if (lbMapA != lbMapB)
|
||||
return false;
|
||||
|
||||
// Add bound with the largest trip count to union.
|
||||
|
@ -620,9 +605,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
|||
Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
|
||||
if (!tripCountA.hasValue() || !tripCountB.hasValue())
|
||||
return false;
|
||||
// TODO(andydavis) Change this code to take the min across all lower bounds
|
||||
// and max across all upper bounds for each dimension. This code can for
|
||||
// cases where a unique min or max could not be statically determined.
|
||||
|
||||
if (tripCountA.getValue() > tripCountB.getValue()) {
|
||||
sliceStateB->lbs[i] = lbMapA;
|
||||
sliceStateB->ubs[i] = ubMapA;
|
||||
|
@ -720,7 +703,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
|
|||
&tmpSliceState))
|
||||
return false;
|
||||
// Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
|
||||
getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]);
|
||||
getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
|
||||
}
|
||||
// Build trip count map for computation slice.
|
||||
sliceTripCountMap.clear();
|
||||
|
|
Loading…
Reference in New Issue