[mlir][interfaces] Fix infinite loop in insideMutuallyExclusiveRegions

This function was missing a termination condition.
This commit is contained in:
Matthias Springer 2022-04-19 16:21:08 +09:00
parent 4e01184ad5
commit a3005a406e
2 changed files with 82 additions and 54 deletions

View File

@ -237,6 +237,40 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
return success();
}
/// Return `true` if region `r` is reachable from region `begin` according to
/// the RegionBranchOpInterface (by taking a branch).
static bool isRegionReachable(Region *begin, Region *r) {
assert(begin->getParentOp() == r->getParentOp() &&
"expected that both regions belong to the same op");
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<unsigned> worklist;
auto enqueueAllSuccessors = [&](unsigned index) {
SmallVector<RegionSuccessor> successors;
op.getSuccessorRegions(index, successors);
for (RegionSuccessor successor : successors)
if (!successor.isParent())
worklist.push_back(successor.getSuccessor()->getRegionNumber());
};
enqueueAllSuccessors(begin->getRegionNumber());
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
unsigned nextRegion = worklist.pop_back_val();
if (nextRegion == r->getRegionNumber())
return true;
if (visited[nextRegion])
continue;
visited[nextRegion] = true;
enqueueAllSuccessors(nextRegion);
}
return false;
}
/// Return `true` if `a` and `b` are in mutually exclusive regions.
///
/// 1. Find the first common of `a` and `b` (ancestor) that implements
@ -274,33 +308,9 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
}
assert(regionA && regionB && "could not find region of op");
// Helper function that checks if region `r` is reachable from region
// `begin`.
std::function<bool(Region *, Region *)> isRegionReachable =
[&](Region *begin, Region *r) {
if (begin == r)
return true;
if (begin == nullptr)
return false;
// Compute index of region.
int64_t beginIndex = -1;
for (const auto &it : llvm::enumerate(branchOp->getRegions()))
if (&it.value() == begin)
beginIndex = it.index();
assert(beginIndex != -1 && "could not find region in op");
// Retrieve all successors of the region.
SmallVector<RegionSuccessor> successors;
branchOp.getSuccessorRegions(beginIndex, successors);
// Call function recursively on all successors.
for (RegionSuccessor successor : successors)
if (isRegionReachable(successor.getSuccessor(), r))
return true;
return false;
};
// `a` and `b` are in mutually exclusive regions if neither region is
// reachable from the other region.
return !isRegionReachable(regionA, regionB) &&
// `a` and `b` are in mutually exclusive regions if both regions are
// distinct and neither region is reachable from the other region.
return regionA != regionB && !isRegionReachable(regionA, regionB) &&
!isRegionReachable(regionB, regionA);
}
@ -310,32 +320,8 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
}
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
SmallVector<bool> visited(getOperation()->getNumRegions(), false);
visited[index] = true;
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<unsigned> worklist;
auto enqueueAllSuccessors = [&](unsigned index) {
SmallVector<RegionSuccessor> successors;
this->getSuccessorRegions(index, successors);
for (RegionSuccessor successor : successors)
if (!successor.isParent())
worklist.push_back(successor.getSuccessor()->getRegionNumber());
};
enqueueAllSuccessors(index);
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
unsigned nextRegion = worklist.pop_back_val();
if (nextRegion == index)
return true;
if (visited[nextRegion])
continue;
visited[nextRegion] = true;
enqueueAllSuccessors(nextRegion);
}
return false;
Region *region = &getOperation()->getRegion(index);
return isRegionReachable(region, region);
}
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {

View File

@ -65,6 +65,27 @@ struct LoopRegionsOp
}
};
/// Each region branches back it itself or the parent.
struct DoubleLoopRegionsOp
: public Op<DoubleLoopRegionsOp, RegionBranchOpInterface::Trait> {
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static StringRef getOperationName() {
return "cftest.double_loop_regions_op";
}
void getSuccessorRegions(Optional<unsigned> index,
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
if (index.hasValue()) {
regions.push_back(RegionSuccessor());
regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
}
}
};
/// Regions are executed sequentially.
struct SequentialRegionsOp
: public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
@ -89,7 +110,7 @@ struct CFTestDialect : Dialect {
explicit CFTestDialect(MLIRContext *ctx)
: Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
SequentialRegionsOp>();
DoubleLoopRegionsOp, SequentialRegionsOp>();
}
static StringRef getDialectNamespace() { return "cftest"; }
};
@ -115,6 +136,27 @@ TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
}
TEST(RegionBranchOpInterface, MutuallyExclusiveOps2) {
const char *ir = R"MLIR(
"cftest.double_loop_regions_op"() (
{"cftest.dummy_op"() : () -> ()}, // op1
{"cftest.dummy_op"() : () -> ()} // op2
) : () -> ()
)MLIR";
DialectRegistry registry;
registry.insert<CFTestDialect>();
MLIRContext ctx(registry);
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation *testOp = &module->getBody()->getOperations().front();
Operation *op1 = &testOp->getRegion(0).front().front();
Operation *op2 = &testOp->getRegion(1).front().front();
EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
}
TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
const char *ir = R"MLIR(
"cftest.sequential_regions_op"() (