forked from OSchip/llvm-project
Some loop fusion code cleanup/simplification post cl/229575126
- enforce the assumptions better / in a simpler way PiperOrigin-RevId: 229612424
This commit is contained in:
parent
3766332533
commit
c1ca23ef6e
|
@ -465,8 +465,8 @@ static uint64_t getComputeCost(
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
|
static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
|
||||||
assert(lbMap.getNumResults() == 1);
|
assert(lbMap.getNumResults() == 1 && "expected single result bound map");
|
||||||
assert(ubMap.getNumResults() == 1);
|
assert(ubMap.getNumResults() == 1 && "expected single result bound map");
|
||||||
assert(lbMap.getNumDims() == ubMap.getNumDims());
|
assert(lbMap.getNumDims() == ubMap.getNumDims());
|
||||||
assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
|
assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
|
||||||
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
|
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
|
||||||
|
@ -560,32 +560,15 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
|
||||||
return loopDepth;
|
return loopDepth;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if 'map' is a single result constant or single result
|
// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
|
||||||
// dim expr where its corresponding loop IV in 'operands' has zero constant
|
// using a rectangular bounding box.
|
||||||
// lower bound.
|
|
||||||
static bool hasZeroMinValue(AffineMap map, ArrayRef<Value *> operands) {
|
|
||||||
if (map.isSingleConstant() && map.getSingleConstantResult() == 0)
|
|
||||||
return true;
|
|
||||||
if (map.getNumResults() != 1 || !map.getResult(0).isa<AffineDimExpr>())
|
|
||||||
return false;
|
|
||||||
// Get operand position of single dim expr result.
|
|
||||||
unsigned pos = map.getResult(0).cast<AffineDimExpr>().getPosition();
|
|
||||||
// Check if loop IV at 'pos' has zero constant lower bound.
|
|
||||||
auto *operand = operands[pos];
|
|
||||||
assert(isa<ForInst>(operand));
|
|
||||||
auto *forInst = cast<ForInst>(operand);
|
|
||||||
return forInst->hasConstantLowerBound() &&
|
|
||||||
forInst->getConstantLowerBound() == 0;
|
|
||||||
}
|
|
||||||
// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in
|
|
||||||
// 'sliceStateB'.
|
|
||||||
// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
|
// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
|
||||||
// and 'sliceStateB' are aligned.
|
// and 'sliceStateB' are aligned.
|
||||||
// Specifically, when taking the union of overlapping intervals, it assumes
|
// Specifically, when taking the union of overlapping intervals, it assumes
|
||||||
// that both intervals start at zero. Support needs to be added to take into
|
// that both intervals start at zero. Support needs to be added to take into
|
||||||
// account interval start offset when computing the union.
|
// account interval start offset when computing the union.
|
||||||
// TODO(andydavis) Move this function to an analysis library.
|
// TODO(andydavis) Move this function to an analysis library.
|
||||||
static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
static bool getSliceUnion(const ComputationSliceState &sliceStateA,
|
||||||
ComputationSliceState *sliceStateB) {
|
ComputationSliceState *sliceStateB) {
|
||||||
assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
|
assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
|
||||||
assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
|
assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
|
||||||
|
@ -597,10 +580,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
||||||
assert(ubMapA == AffineMap::Null());
|
assert(ubMapA == AffineMap::Null());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
assert(ubMapA != AffineMap::Null());
|
assert(ubMapA && "expected non-null ub map");
|
||||||
// Validate that constant lower bounds are aligned at zero.
|
|
||||||
if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i]))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
AffineMap lbMapB = sliceStateB->lbs[i];
|
AffineMap lbMapB = sliceStateB->lbs[i];
|
||||||
AffineMap ubMapB = sliceStateB->ubs[i];
|
AffineMap ubMapB = sliceStateB->ubs[i];
|
||||||
|
@ -611,8 +591,13 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
||||||
sliceStateB->ubs[i] = ubMapA;
|
sliceStateB->ubs[i] = ubMapA;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Validate that constant lower bounds are aligned at zero.
|
|
||||||
if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i]))
|
// TODO(andydavis) Change this code to take the min across all lower bounds
|
||||||
|
// and max across all upper bounds for each dimension. This code can for
|
||||||
|
// cases where a unique min or max could not be statically determined.
|
||||||
|
|
||||||
|
// Assumption: both lower bounds are the same.
|
||||||
|
if (lbMapA != lbMapB)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Add bound with the largest trip count to union.
|
// Add bound with the largest trip count to union.
|
||||||
|
@ -620,9 +605,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
|
||||||
Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
|
Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
|
||||||
if (!tripCountA.hasValue() || !tripCountB.hasValue())
|
if (!tripCountA.hasValue() || !tripCountB.hasValue())
|
||||||
return false;
|
return false;
|
||||||
// TODO(andydavis) Change this code to take the min across all lower bounds
|
|
||||||
// and max across all upper bounds for each dimension. This code can for
|
|
||||||
// cases where a unique min or max could not be statically determined.
|
|
||||||
if (tripCountA.getValue() > tripCountB.getValue()) {
|
if (tripCountA.getValue() > tripCountB.getValue()) {
|
||||||
sliceStateB->lbs[i] = lbMapA;
|
sliceStateB->lbs[i] = lbMapA;
|
||||||
sliceStateB->ubs[i] = ubMapA;
|
sliceStateB->ubs[i] = ubMapA;
|
||||||
|
@ -720,7 +703,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
|
||||||
&tmpSliceState))
|
&tmpSliceState))
|
||||||
return false;
|
return false;
|
||||||
// Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
|
// Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
|
||||||
getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]);
|
getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
|
||||||
}
|
}
|
||||||
// Build trip count map for computation slice.
|
// Build trip count map for computation slice.
|
||||||
sliceTripCountMap.clear();
|
sliceTripCountMap.clear();
|
||||||
|
|
Loading…
Reference in New Issue