forked from OSchip/llvm-project
[MLIR][Affine] Add utility to check if the slice is valid
Fixes a bug in affine fusion pipeline where an incorrect slice is computed. After the slice computation is done, original domain of the the source is compared with the new domain that will result if the fusion succeeds. If the new domain must be a subset of the original domain for the slice to be valid. If the slice computed is incorrect, fusion based on such a slice is avoided. Relevant test cases are added/edited. Fixes https://bugs.llvm.org/show_bug.cgi?id=49203 Differential Revision: https://reviews.llvm.org/D98239
This commit is contained in:
parent
b468f0e165
commit
dc537158d5
|
@ -54,6 +54,18 @@ unsigned getNestingDepth(Operation *op);
|
|||
void getSequentialLoops(AffineForOp forOp,
|
||||
llvm::SmallDenseSet<Value, 8> *sequentialLoops);
|
||||
|
||||
/// Enumerates different result statuses of slice computation by
|
||||
/// `computeSliceUnion`
|
||||
// TODO: Identify and add different kinds of failures during slice computation.
|
||||
struct SliceComputationResult {
|
||||
enum ResultEnum {
|
||||
Success,
|
||||
IncorrectSliceFailure, // Slice is computed, but it is incorrect.
|
||||
GenericFailure, // Unable to compute src loop computation slice.
|
||||
} value;
|
||||
SliceComputationResult(ResultEnum v) : value(v) {}
|
||||
};
|
||||
|
||||
/// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their
|
||||
/// associated operands for a set of loops within a loop nest (typically the
|
||||
/// set of loops surrounding a store operation). Loop bound AffineMaps which
|
||||
|
@ -80,6 +92,12 @@ struct ComputationSliceState {
|
|||
// Returns failure if we cannot add loop bounds because of unsupported cases.
|
||||
LogicalResult getAsConstraints(FlatAffineConstraints *cst);
|
||||
|
||||
/// Adds to 'cst' constraints which represent the original loop bounds on
|
||||
/// 'ivs' in 'this'. This corresponds to the original domain of the loop nest
|
||||
/// from which the slice is being computed. Returns failure if we cannot add
|
||||
/// loop bounds because of unsupported cases.
|
||||
LogicalResult getSourceAsConstraints(FlatAffineConstraints &cst);
|
||||
|
||||
// Clears all bounds and operands in slice state.
|
||||
void clearBounds();
|
||||
|
||||
|
@ -93,6 +111,22 @@ struct ComputationSliceState {
|
|||
// information hasn't changed.
|
||||
Optional<bool> isMaximal() const;
|
||||
|
||||
/// Checks the validity of the slice computed. This is done using the
|
||||
/// following steps:
|
||||
/// 1. Get the new domain of the slice that would be created if fusion
|
||||
/// succeeds. This domain gets constructed with source loop IVS and
|
||||
/// destination loop IVS as dimensions.
|
||||
/// 2. Project out the dimensions of the destination loop from the domain
|
||||
/// above calculated in step(1) to express it purely in terms of the source
|
||||
/// loop IVs.
|
||||
/// 3. Calculate a set difference between the iterations of the new domain and
|
||||
/// the original domain of the source loop.
|
||||
/// If this difference is empty, the slice is declared to be valid. Otherwise,
|
||||
/// return false as it implies that the effective fusion results in at least
|
||||
/// one iteration of the slice that was not originally in the source's domain.
|
||||
/// If the validity cannot be determined, returns llvm:None.
|
||||
Optional<bool> isSliceValid();
|
||||
|
||||
void dump() const;
|
||||
|
||||
private:
|
||||
|
@ -151,21 +185,21 @@ void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp,
|
|||
ComputationSliceState *sliceState);
|
||||
|
||||
/// Computes in 'sliceUnion' the union of all slice bounds computed at
|
||||
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
|
||||
/// The parameter 'numCommonLoops' is the number of loops common to the
|
||||
/// operations in 'opsA' and 'opsB'.
|
||||
/// If 'isBackwardSlice' is true, computes slice bounds for loop nest
|
||||
/// surrounding ops in 'opsA', as a function of IVs and symbols of loop nest
|
||||
/// surrounding ops in 'opsB' at 'loopDepth'.
|
||||
/// If 'isBackwardSlice' is false, computes slice bounds for loop nest
|
||||
/// surrounding ops in 'opsB', as a function of IVs and symbols of loop nest
|
||||
/// surrounding ops in 'opsA' at 'loopDepth'.
|
||||
/// Returns 'success' if union was computed, 'failure' otherwise.
|
||||
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
|
||||
/// then verifies if it is valid. The parameter 'numCommonLoops' is the number
|
||||
/// of loops common to the operations in 'opsA' and 'opsB'. If 'isBackwardSlice'
|
||||
/// is true, computes slice bounds for loop nest surrounding ops in 'opsA', as a
|
||||
/// function of IVs and symbols of loop nest surrounding ops in 'opsB' at
|
||||
/// 'loopDepth'. If 'isBackwardSlice' is false, computes slice bounds for loop
|
||||
/// nest surrounding ops in 'opsB', as a function of IVs and symbols of loop
|
||||
/// nest surrounding ops in 'opsA' at 'loopDepth'. Returns
|
||||
/// 'SliceComputationResult::Success' if union was computed correctly, an
|
||||
/// appropriate 'failure' otherwise.
|
||||
// TODO: Change this API to take 'forOpA'/'forOpB'.
|
||||
LogicalResult computeSliceUnion(ArrayRef<Operation *> opsA,
|
||||
ArrayRef<Operation *> opsB, unsigned loopDepth,
|
||||
unsigned numCommonLoops, bool isBackwardSlice,
|
||||
ComputationSliceState *sliceUnion);
|
||||
SliceComputationResult
|
||||
computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
|
||||
unsigned loopDepth, unsigned numCommonLoops,
|
||||
bool isBackwardSlice, ComputationSliceState *sliceUnion);
|
||||
|
||||
/// Creates a clone of the computation contained in the loop nest surrounding
|
||||
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
|
||||
|
|
|
@ -35,6 +35,7 @@ struct FusionResult {
|
|||
FailBlockDependence, // Fusion would violate another dependence in block.
|
||||
FailFusionDependence, // Fusion would reverse dependences between loops.
|
||||
FailComputationSlice, // Unable to compute src loop computation slice.
|
||||
FailIncorrectSlice, // Slice is computed, but it is incorrect.
|
||||
} value;
|
||||
FusionResult(ResultEnum v) : value(v) {}
|
||||
};
|
||||
|
|
|
@ -2128,13 +2128,22 @@ LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
|
|||
continue;
|
||||
}
|
||||
|
||||
if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
|
||||
/*lower=*/true)))
|
||||
return failure();
|
||||
|
||||
if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
|
||||
/*lower=*/false)))
|
||||
return failure();
|
||||
// If lower or upper bound maps are null or provide no results, it implies
|
||||
// that the source loop was not at all sliced, and the entire loop will be a
|
||||
// part of the slice.
|
||||
if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
|
||||
ubMap.getNumResults() != 0) {
|
||||
if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
|
||||
/*lower=*/true)))
|
||||
return failure();
|
||||
if (failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
|
||||
/*lower=*/false)))
|
||||
return failure();
|
||||
} else {
|
||||
auto loop = getForInductionVarOwner(values[i]);
|
||||
if (failed(this->addAffineForOpDomain(loop)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -61,6 +61,21 @@ void mlir::getEnclosingAffineForAndIfOps(Operation &op,
|
|||
std::reverse(ops->begin(), ops->end());
|
||||
}
|
||||
|
||||
// Populates 'cst' with FlatAffineConstraints which represent original domain of
|
||||
// the loop bounds that define 'ivs'.
|
||||
LogicalResult
|
||||
ComputationSliceState::getSourceAsConstraints(FlatAffineConstraints &cst) {
|
||||
assert(!ivs.empty() && "Cannot have a slice without its IVs");
|
||||
cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs);
|
||||
for (Value iv : ivs) {
|
||||
AffineForOp loop = getForInductionVarOwner(iv);
|
||||
assert(loop && "Expected affine for");
|
||||
if (failed(cst.addAffineForOpDomain(loop)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Populates 'cst' with FlatAffineConstraints which represent slice bounds.
|
||||
LogicalResult
|
||||
ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
|
||||
|
@ -75,9 +90,10 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
|
|||
values.append(lbOperands[0].begin(), lbOperands[0].end());
|
||||
cst->reset(numDims, numSymbols, 0, values);
|
||||
|
||||
// Add loop bound constraints for values which are loop IVs and equality
|
||||
// constraints for symbols which are constants.
|
||||
for (const auto &value : values) {
|
||||
// Add loop bound constraints for values which are loop IVs of the destination
|
||||
// of fusion and equality constraints for symbols which are constants.
|
||||
for (unsigned i = numDims, end = values.size(); i < end; ++i) {
|
||||
Value value = values[i];
|
||||
assert(cst->containsId(value) && "value expected to be present");
|
||||
if (isValidSymbol(value)) {
|
||||
// Check if the symbol is a constant.
|
||||
|
@ -196,6 +212,76 @@ Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if it is deterministically verified that the original iteration
|
||||
/// space of the slice is contained within the new iteration space that is
|
||||
/// created after fusing 'this' slice into its destination.
|
||||
Optional<bool> ComputationSliceState::isSliceValid() {
|
||||
// Fast check to determine if the slice is valid. If the following conditions
|
||||
// are verified to be true, slice is declared valid by the fast check:
|
||||
// 1. Each slice loop is a single iteration loop bound in terms of a single
|
||||
// destination loop IV.
|
||||
// 2. Loop bounds of the destination loop IV (from above) and those of the
|
||||
// source loop IV are exactly the same.
|
||||
// If the fast check is inconclusive or false, we proceed with a more
|
||||
// expensive analysis.
|
||||
// TODO: Store the result of the fast check, as it might be used again in
|
||||
// `canRemoveSrcNodeAfterFusion`.
|
||||
Optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
|
||||
if (isValidFastCheck.hasValue() && isValidFastCheck.getValue())
|
||||
return true;
|
||||
|
||||
// Create constraints for the source loop nest using which slice is computed.
|
||||
FlatAffineConstraints srcConstraints;
|
||||
// TODO: Store the source's domain to avoid computation at each depth.
|
||||
if (failed(getSourceAsConstraints(srcConstraints))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
|
||||
return llvm::None;
|
||||
}
|
||||
// As the set difference utility currently cannot handle symbols in its
|
||||
// operands, validity of the slice cannot be determined.
|
||||
if (srcConstraints.getNumSymbolIds() > 0) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
|
||||
return llvm::None;
|
||||
}
|
||||
// TODO: Handle local ids in the source domains while using the 'projectOut'
|
||||
// utility below. Currently, aligning is not done assuming that there will be
|
||||
// no local ids in the source domain.
|
||||
if (srcConstraints.getNumLocalIds() != 0) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Create constraints for the slice loop nest that would be created if the
|
||||
// fusion succeeds.
|
||||
FlatAffineConstraints sliceConstraints;
|
||||
if (failed(getAsConstraints(&sliceConstraints))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Projecting out every dimension other than the 'ivs' to express slice's
|
||||
// domain completely in terms of source's IVs.
|
||||
sliceConstraints.projectOut(ivs.size(),
|
||||
sliceConstraints.getNumIds() - ivs.size());
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
|
||||
LLVM_DEBUG(srcConstraints.dump());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
|
||||
"(expressed in terms of its source's IVs):\n");
|
||||
LLVM_DEBUG(sliceConstraints.dump());
|
||||
|
||||
// TODO: Store 'srcSet' to avoid recalculating for each depth.
|
||||
PresburgerSet srcSet(srcConstraints);
|
||||
PresburgerSet sliceSet(sliceConstraints);
|
||||
PresburgerSet diffSet = sliceSet.subtract(srcSet);
|
||||
|
||||
if (!diffSet.isIntegerEmpty()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if the computation slice encloses all the iterations of the
|
||||
/// sliced loop nest. Returns false if it does not. Returns llvm::None if it
|
||||
/// cannot determine if the slice is maximal or not.
|
||||
|
@ -715,14 +801,14 @@ unsigned mlir::getInnermostCommonLoopDepth(
|
|||
}
|
||||
|
||||
/// Computes in 'sliceUnion' the union of all slice bounds computed at
|
||||
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
|
||||
/// Returns 'Success' if union was computed, 'failure' otherwise.
|
||||
LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
||||
ArrayRef<Operation *> opsB,
|
||||
unsigned loopDepth,
|
||||
unsigned numCommonLoops,
|
||||
bool isBackwardSlice,
|
||||
ComputationSliceState *sliceUnion) {
|
||||
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
|
||||
/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
|
||||
/// union was computed correctly, an appropriate failure otherwise.
|
||||
SliceComputationResult
|
||||
mlir::computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
|
||||
unsigned loopDepth, unsigned numCommonLoops,
|
||||
bool isBackwardSlice,
|
||||
ComputationSliceState *sliceUnion) {
|
||||
// Compute the union of slice bounds between all pairs in 'opsA' and
|
||||
// 'opsB' in 'sliceUnionCst'.
|
||||
FlatAffineConstraints sliceUnionCst;
|
||||
|
@ -738,7 +824,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
|||
if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
|
||||
(isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
}
|
||||
|
||||
bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
|
||||
|
@ -751,7 +837,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
|||
/*allowRAR=*/readReadAccesses);
|
||||
if (result.value == DependenceResult::Failure) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
}
|
||||
if (result.value == DependenceResult::NoDependence)
|
||||
continue;
|
||||
|
@ -768,7 +854,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
|||
if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Unable to compute slice bound constraints\n");
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
}
|
||||
assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
|
||||
continue;
|
||||
|
@ -779,7 +865,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
|||
if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Unable to compute slice bound constraints\n");
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
}
|
||||
|
||||
// Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
|
||||
|
@ -802,9 +888,9 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
|||
// to unionBoundingBox below expects constraints for each Loop IV, even
|
||||
// if they are the unsliced full loop bounds added here.
|
||||
if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
}
|
||||
// Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
|
||||
if (sliceUnionCst.getNumLocalIds() > 0 ||
|
||||
|
@ -812,14 +898,14 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
|||
failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Unable to compute union bounding box of slice bounds\n");
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Empty union.
|
||||
if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
|
||||
// Gather loops surrounding ops from loop nest where slice will be inserted.
|
||||
SmallVector<Operation *, 4> ops;
|
||||
|
@ -831,7 +917,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
|||
getInnermostCommonLoopDepth(ops, &surroundingLoops);
|
||||
if (loopDepth > innermostCommonLoopDepth) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
|
||||
return failure();
|
||||
return SliceComputationResult::GenericFailure;
|
||||
}
|
||||
|
||||
// Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
|
||||
|
@ -868,7 +954,18 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
|
|||
// canonicalization.
|
||||
sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
|
||||
sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
|
||||
return success();
|
||||
|
||||
// Check if the slice computed is valid. Return success only if it is verified
|
||||
// that the slice is valid, otherwise return appropriate failure status.
|
||||
Optional<bool> isSliceValid = sliceUnion->isSliceValid();
|
||||
if (!isSliceValid.hasValue()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
|
||||
return SliceComputationResult::GenericFailure;
|
||||
}
|
||||
if (!isSliceValid.getValue())
|
||||
return SliceComputationResult::IncorrectSliceFailure;
|
||||
|
||||
return SliceComputationResult::Success;
|
||||
}
|
||||
|
||||
const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
|
||||
|
|
|
@ -347,12 +347,18 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
|||
|
||||
// Compute union of computation slices computed between all pairs of ops
|
||||
// from 'forOpA' and 'forOpB'.
|
||||
if (failed(mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth,
|
||||
numCommonLoops, isSrcForOpBeforeDstForOp,
|
||||
srcSlice))) {
|
||||
SliceComputationResult sliceComputationResult =
|
||||
mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
|
||||
isSrcForOpBeforeDstForOp, srcSlice);
|
||||
if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
|
||||
return FusionResult::FailPrecondition;
|
||||
}
|
||||
if (sliceComputationResult.value ==
|
||||
SliceComputationResult::IncorrectSliceFailure) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
|
||||
return FusionResult::FailIncorrectSlice;
|
||||
}
|
||||
|
||||
return FusionResult::Success;
|
||||
}
|
||||
|
@ -400,7 +406,7 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
|
|||
auto *parentForOp = forOp->getParentOp();
|
||||
if (!llvm::isa<FuncOp>(parentForOp)) {
|
||||
if (!isa<AffineForOp>(parentForOp)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
|
||||
LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
// Add mapping to 'forOp' from its parent AffineForOp.
|
||||
|
@ -421,7 +427,7 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
|
|||
Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
|
||||
if (!maybeConstTripCount.hasValue()) {
|
||||
// Currently only constant trip count loop nests are supported.
|
||||
LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
|
||||
LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
|
@ -519,7 +525,11 @@ static bool buildSliceTripCountMap(
|
|||
auto *op = forOp.getOperation();
|
||||
AffineMap lbMap = slice.lbs[i];
|
||||
AffineMap ubMap = slice.ubs[i];
|
||||
if (lbMap == AffineMap() || ubMap == AffineMap()) {
|
||||
// If lower or upper bound maps are null or provide no results, it implies
|
||||
// that source loop was not at all sliced, and the entire loop will be a
|
||||
// part of the slice.
|
||||
if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
|
||||
ubMap.getNumResults() == 0) {
|
||||
// The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
|
||||
if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
|
||||
(*tripCountMap)[op] =
|
||||
|
|
|
@ -7,7 +7,7 @@ func @slice_depth1_loop_nest() {
|
|||
%0 = memref.alloc() : memref<100xf32>
|
||||
%cst = constant 7.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 16 {
|
||||
// expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}}
|
||||
// expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}}
|
||||
affine.store %cst, %0[%i0] : memref<100xf32>
|
||||
}
|
||||
affine.for %i1 = 0 to 5 {
|
||||
|
@ -19,6 +19,23 @@ func @slice_depth1_loop_nest() {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @forward_slice_slice_depth1_loop_nest() {
|
||||
func @forward_slice_slice_depth1_loop_nest() {
|
||||
%0 = memref.alloc() : memref<100xf32>
|
||||
%cst = constant 7.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 5 {
|
||||
// expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}}
|
||||
affine.store %cst, %0[%i0] : memref<100xf32>
|
||||
}
|
||||
affine.for %i1 = 0 to 16 {
|
||||
// expected-remark@-1 {{Incorrect slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}}
|
||||
%1 = affine.load %0[%i1] : memref<100xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Loop %i0 writes to locations [2, 17] and loop %i0 reads from locations [3, 6]
|
||||
// Slice loop bounds should be adjusted such that the load/store are for the
|
||||
// same location.
|
||||
|
@ -27,7 +44,7 @@ func @slice_depth1_loop_nest_with_offsets() {
|
|||
%0 = memref.alloc() : memref<100xf32>
|
||||
%cst = constant 7.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 16 {
|
||||
// expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}}
|
||||
// expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}}
|
||||
%a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0)
|
||||
affine.store %cst, %0[%a0] : memref<100xf32>
|
||||
}
|
||||
|
@ -48,8 +65,8 @@ func @slice_depth2_loop_nest() {
|
|||
%0 = memref.alloc() : memref<100x100xf32>
|
||||
%cst = constant 7.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 16 {
|
||||
// expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
|
||||
// expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}}
|
||||
// expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
|
||||
// expected-remark@-2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}}
|
||||
affine.for %i1 = 0 to 16 {
|
||||
affine.store %cst, %0[%i0, %i1] : memref<100x100xf32>
|
||||
}
|
||||
|
@ -75,8 +92,8 @@ func @slice_depth2_loop_nest_two_loads() {
|
|||
%c0 = constant 0 : index
|
||||
%cst = constant 7.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 16 {
|
||||
// expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
|
||||
// expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (0), (d0, d1) -> (8)] )}}
|
||||
// expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
|
||||
// expected-remark@-2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (0), (d0, d1) -> (16)] )}}
|
||||
affine.for %i1 = 0 to 16 {
|
||||
affine.store %cst, %0[%i0, %i1] : memref<100x100xf32>
|
||||
}
|
||||
|
@ -103,7 +120,7 @@ func @slice_depth2_loop_nest_two_stores() {
|
|||
%c0 = constant 0 : index
|
||||
%cst = constant 7.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 16 {
|
||||
// expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
|
||||
// expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
|
||||
affine.for %i1 = 0 to 16 {
|
||||
affine.store %cst, %0[%i0, %i1] : memref<100x100xf32>
|
||||
}
|
||||
|
@ -128,8 +145,8 @@ func @slice_loop_nest_with_smaller_outer_trip_count() {
|
|||
%c0 = constant 0 : index
|
||||
%cst = constant 7.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 16 {
|
||||
// expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}}
|
||||
// expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}}
|
||||
// expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}}
|
||||
// expected-remark@-2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}}
|
||||
affine.for %i1 = 0 to 16 {
|
||||
affine.store %cst, %0[%i0, %i1] : memref<100x100xf32>
|
||||
}
|
||||
|
|
|
@ -3068,3 +3068,50 @@ func @call_op_does_not_prevent_fusion(%arg0: memref<16xf32>){
|
|||
// CHECK-LABEL: func @call_op_does_not_prevent_fusion
|
||||
// CHECK: affine.for
|
||||
// CHECK-NOT: affine.for
|
||||
|
||||
// -----
|
||||
|
||||
// Fusion is avoided when the slice computed is invalid. Comments below describe
|
||||
// incorrect backward slice computation. Similar logic applies for forward slice
|
||||
// as well.
|
||||
func @no_fusion_cannot_compute_valid_slice() {
|
||||
%A = memref.alloc() : memref<5xf32>
|
||||
%B = memref.alloc() : memref<6xf32>
|
||||
%C = memref.alloc() : memref<5xf32>
|
||||
%cst = constant 0. : f32
|
||||
|
||||
affine.for %arg0 = 0 to 5 {
|
||||
%a = affine.load %A[%arg0] : memref<5xf32>
|
||||
affine.store %a, %B[%arg0 + 1] : memref<6xf32>
|
||||
}
|
||||
|
||||
affine.for %arg0 = 0 to 5 {
|
||||
// Backward slice computed will be:
|
||||
// slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0)
|
||||
// loop bounds: [(d0) -> (d0 - 1), (d0) -> (d0)] )
|
||||
|
||||
// Resulting fusion would be as below. It is easy to note the out-of-bounds
|
||||
// access by 'affine.load'.
|
||||
|
||||
// #map0 = affine_map<(d0) -> (d0 - 1)>
|
||||
// #map1 = affine_map<(d0) -> (d0)>
|
||||
// affine.for %arg1 = #map0(%arg0) to #map1(%arg0) {
|
||||
// %5 = affine.load %1[%arg1] : memref<5xf32>
|
||||
// ...
|
||||
// ...
|
||||
// }
|
||||
|
||||
%a = affine.load %B[%arg0] : memref<6xf32>
|
||||
%b = mulf %a, %cst : f32
|
||||
affine.store %b, %C[%arg0] : memref<5xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @no_fusion_cannot_compute_valid_slice
|
||||
// CHECK: affine.for
|
||||
// CHECK-NEXT: affine.load
|
||||
// CHECK-NEXT: affine.store
|
||||
// CHECK: affine.for
|
||||
// CHECK-NEXT: affine.load
|
||||
// CHECK-NEXT: mulf
|
||||
// CHECK-NEXT: affine.store
|
||||
|
|
|
@ -99,10 +99,11 @@ static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
|
||||
// in range ['loopDepth' + 1, 'maxLoopDepth'].
|
||||
// Emits a string representation of the slice union as a remark on 'loops[j]'.
|
||||
// Returns false as IR is not transformed.
|
||||
/// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
|
||||
/// in range ['loopDepth' + 1, 'maxLoopDepth'].
|
||||
/// Emits a string representation of the slice union as a remark on 'loops[j]'
|
||||
/// and marks this as incorrect slice if the slice is invalid. Returns false as
|
||||
/// IR is not transformed.
|
||||
static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
|
||||
unsigned i, unsigned j, unsigned loopDepth,
|
||||
unsigned maxLoopDepth) {
|
||||
|
@ -113,6 +114,10 @@ static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
|
|||
forOpB->emitRemark("slice (")
|
||||
<< " src loop: " << i << ", dst loop: " << j << ", depth: " << d
|
||||
<< " : " << getSliceStr(sliceUnion) << ")";
|
||||
} else if (result.value == FusionResult::FailIncorrectSlice) {
|
||||
forOpB->emitRemark("Incorrect slice (")
|
||||
<< " src loop: " << i << ", dst loop: " << j << ", depth: " << d
|
||||
<< " : " << getSliceStr(sliceUnion) << ")";
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue