From fd26ca4e7515e7dd32ae02e777bd21693afc68ff Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 19 Oct 2021 09:08:38 +0900 Subject: [PATCH] [mlir][scf] Add insideMutuallyExclusiveBranches helper This helper function checks if two given ops are in mutually exclusive branches of the same scf::IfOp. Differential Revision: https://reviews.llvm.org/D111957 --- mlir/include/mlir/Dialect/SCF/SCF.h | 5 +++++ mlir/lib/Dialect/SCF/SCF.cpp | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h index ba1caa115160..49f5be00c978 100644 --- a/mlir/include/mlir/Dialect/SCF/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/SCF.h @@ -50,6 +50,11 @@ ForOp getForInductionVarOwner(Value val); /// value is not an induction variable, then return nullptr. ParallelOp getParallelForInductionVarOwner(Value val); +/// Return true if ops a and b (or their ancestors) are in mutually exclusive +/// regions/blocks of an IfOp. +// TODO: Consider moving this functionality to RegionBranchOpInterface. +bool insideMutuallyExclusiveBranches(Operation *a, Operation *b); + /// An owning vector of values, handy to return from functions. using ValueVector = std::vector; using LoopVector = std::vector; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index 05da1f56cf10..cc68ff622cdc 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1010,6 +1010,26 @@ void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, // IfOp //===----------------------------------------------------------------------===// +bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) { + assert(a && "expected non-empty operation"); + assert(b && "expected non-empty operation"); + + IfOp ifOp = a->getParentOfType(); + while (ifOp) { + // Check if b is inside ifOp. (We already know that a is.) + if (ifOp->isProperAncestor(b)) + // b is contained in ifOp. a and b are in mutually exclusive branches if + // they are in different blocks of ifOp. + return static_cast(ifOp.thenBlock()->findAncestorOpInBlock(*a)) != + static_cast(ifOp.thenBlock()->findAncestorOpInBlock(*b)); + // Check next enclosing IfOp. + ifOp = ifOp->getParentOfType(); + } + + // Could not find a common IfOp among a's and b's ancestors. + return false; +} + void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, bool withElseRegion) { build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);