forked from OSchip/llvm-project
[mlir][interfaces] Add helpers for detecting recursive regions
Add helper functions to check if an op may be executed multiple times based on RegionBranchOpInterface. Differential Revision: https://reviews.llvm.org/D123789
This commit is contained in:
parent
c5cac48549
commit
0f4ba02db3
|
@ -216,6 +216,16 @@ private:
|
|||
/// RegionBranchOpInterface.
|
||||
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);
|
||||
|
||||
/// Return the first enclosing region of the given op that may be executed
|
||||
/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
|
||||
/// exists.
|
||||
Region *getEnclosingRepetitiveRegion(Operation *op);
|
||||
|
||||
/// Return the first enclosing region of the given Value that may be executed
|
||||
/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
|
||||
/// exists.
|
||||
Region *getEnclosingRepetitiveRegion(Value value);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionBranchTerminatorOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -211,6 +211,11 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
|
|||
SmallVector<Attribute, 2> nullAttrs(getOperation()->getNumOperands());
|
||||
getSuccessorRegions(index, nullAttrs, regions);
|
||||
}
|
||||
|
||||
/// Return `true` if control flow originating from the given region may
|
||||
/// eventually branch back to the same region. (Maybe after passing through
|
||||
/// other regions.)
|
||||
bool isRepetitiveRegion(unsigned index);
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -309,6 +309,57 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
|
||||
SmallVector<bool> visited(getOperation()->getNumRegions(), false);
|
||||
visited[index] = true;
|
||||
|
||||
// Retrieve all successors of the region and enqueue them in the worklist.
|
||||
SmallVector<unsigned> worklist;
|
||||
auto enqueueAllSuccessors = [&](unsigned index) {
|
||||
SmallVector<RegionSuccessor> successors;
|
||||
this->getSuccessorRegions(index, successors);
|
||||
for (RegionSuccessor successor : successors)
|
||||
if (!successor.isParent())
|
||||
worklist.push_back(successor.getSuccessor()->getRegionNumber());
|
||||
};
|
||||
enqueueAllSuccessors(index);
|
||||
|
||||
// Process all regions in the worklist via DFS.
|
||||
while (!worklist.empty()) {
|
||||
unsigned nextRegion = worklist.pop_back_val();
|
||||
if (nextRegion == index)
|
||||
return true;
|
||||
if (visited[nextRegion])
|
||||
continue;
|
||||
visited[nextRegion] = true;
|
||||
enqueueAllSuccessors(nextRegion);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
|
||||
while (Region *region = op->getParentRegion()) {
|
||||
op = region->getParentOp();
|
||||
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
|
||||
if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
|
||||
return region;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
|
||||
Region *region = value.getParentRegion();
|
||||
while (region) {
|
||||
Operation *op = region->getParentOp();
|
||||
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
|
||||
if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
|
||||
return region;
|
||||
region = op->getParentRegion();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionBranchTerminatorOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -42,6 +42,29 @@ struct MutuallyExclusiveRegionsOp
|
|||
SmallVectorImpl<RegionSuccessor> ®ions) {}
|
||||
};
|
||||
|
||||
/// All regions of this op call each other in a large circle.
|
||||
struct LoopRegionsOp
|
||||
: public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
|
||||
using Op::Op;
|
||||
static const unsigned kNumRegions = 3;
|
||||
|
||||
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
||||
|
||||
static StringRef getOperationName() { return "cftest.loop_regions_op"; }
|
||||
|
||||
void getSuccessorRegions(Optional<unsigned> index,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (index) {
|
||||
if (*index == 1)
|
||||
// This region also branches back to the parent.
|
||||
regions.push_back(RegionSuccessor());
|
||||
regions.push_back(
|
||||
RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Regions are executed sequentially.
|
||||
struct SequentialRegionsOp
|
||||
: public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
|
||||
|
@ -65,7 +88,8 @@ struct SequentialRegionsOp
|
|||
struct CFTestDialect : Dialect {
|
||||
explicit CFTestDialect(MLIRContext *ctx)
|
||||
: Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
|
||||
addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
|
||||
addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
|
||||
SequentialRegionsOp>();
|
||||
}
|
||||
static StringRef getDialectNamespace() { return "cftest"; }
|
||||
};
|
||||
|
@ -142,3 +166,52 @@ TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
|
|||
EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
|
||||
EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
|
||||
}
|
||||
|
||||
TEST(RegionBranchOpInterface, RecursiveRegions) {
|
||||
const char *ir = R"MLIR(
|
||||
"cftest.loop_regions_op"() (
|
||||
{"cftest.dummy_op"() : () -> ()}, // op1
|
||||
{"cftest.dummy_op"() : () -> ()}, // op2
|
||||
{"cftest.dummy_op"() : () -> ()} // op3
|
||||
) : () -> ()
|
||||
)MLIR";
|
||||
|
||||
DialectRegistry registry;
|
||||
registry.insert<CFTestDialect>();
|
||||
MLIRContext ctx(registry);
|
||||
|
||||
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
|
||||
Operation *testOp = &module->getBody()->getOperations().front();
|
||||
auto regionOp = cast<RegionBranchOpInterface>(testOp);
|
||||
Operation *op1 = &testOp->getRegion(0).front().front();
|
||||
Operation *op2 = &testOp->getRegion(1).front().front();
|
||||
Operation *op3 = &testOp->getRegion(2).front().front();
|
||||
|
||||
EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
|
||||
EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
|
||||
EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
|
||||
EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
|
||||
EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
|
||||
EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
|
||||
}
|
||||
|
||||
TEST(RegionBranchOpInterface, NotRecursiveRegions) {
|
||||
const char *ir = R"MLIR(
|
||||
"cftest.sequential_regions_op"() (
|
||||
{"cftest.dummy_op"() : () -> ()}, // op1
|
||||
{"cftest.dummy_op"() : () -> ()} // op2
|
||||
) : () -> ()
|
||||
)MLIR";
|
||||
|
||||
DialectRegistry registry;
|
||||
registry.insert<CFTestDialect>();
|
||||
MLIRContext ctx(registry);
|
||||
|
||||
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
|
||||
Operation *testOp = &module->getBody()->getOperations().front();
|
||||
Operation *op1 = &testOp->getRegion(0).front().front();
|
||||
Operation *op2 = &testOp->getRegion(1).front().front();
|
||||
|
||||
EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
|
||||
EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue