[MLIR] Introduce generic visitors.

- Generic visitors invoke operation callbacks before/in-between/after visiting the regions
  attached to an operation and use a `WalkStage` to indicate which regions have been
  visited.
- This can be useful for cases where we need to visit the operation in between visiting
  regions attached to the operation.

Differential Revision: https://reviews.llvm.org/D116230
This commit is contained in:
Rahul Joshi 2022-01-13 13:32:14 -08:00
parent 67076ebb60
commit 8067ced144
8 changed files with 548 additions and 6 deletions

View File

@ -510,10 +510,40 @@ public:
/// });
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
typename std::enable_if<
llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type
walk(FnT &&callback) {
return detail::walk<Order>(this, std::forward<FnT>(callback));
}
/// Generic walker with a stage aware callback. Walk the operation by calling
/// the callback for each nested operation (including this one) N+1 times,
/// where N is the number of regions attached to that operation.
///
/// The callback method can take any of the following forms:
/// void(Operation *, const WalkStage &) : Walk all operation opaquely
/// * op->walk([](Operation *nestedOp, const WalkStage &stage) { ...});
/// void(OpT, const WalkStage &) : Walk all operations of the given derived
/// type.
/// * op->walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
/// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
/// but allow for interruption/skipping.
/// * op->walk([](... op, const WalkStage &stage) {
/// // Skip the walk of this op based on some invariant.
/// if (some_invariant)
/// return WalkResult::skip();
/// // Interrupt, i.e cancel, the walk based on some invariant.
/// if (another_invariant)
/// return WalkResult::interrupt();
/// return WalkResult::advance();
/// });
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
typename std::enable_if<
llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type
walk(FnT &&callback) {
return detail::walk(this, std::forward<FnT>(callback));
}
//===--------------------------------------------------------------------===//
// Uses
//===--------------------------------------------------------------------===//

View File

@ -61,13 +61,49 @@ public:
/// Traversal order for region, block and operation walk utilities.
enum class WalkOrder { PreOrder, PostOrder };
/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
/// regions have been visited, or choose a "generic" walker where the callback
/// is invoked on the operation N+1 times where N is the number of regions
/// attached to that operation. The `WalkStage` class below encodes the current
/// stage of the walk, i.e., which regions have already been visited, and the
/// callback accepts an additional argument for the current stage. Such
/// generic walkers that accept stage-aware callbacks are only applicable when
/// the callback operates on an operation (i.e., not applicable for callbacks
/// on Blocks or Regions).
class WalkStage {
public:
explicit WalkStage(Operation *op);
/// Return true if parent operation is being visited before all regions.
bool isBeforeAllRegions() const { return nextRegion == 0; }
/// Returns true if parent operation is being visited just before visiting
/// region number `region`.
bool isBeforeRegion(int region) const { return nextRegion == region; }
/// Returns true if parent operation is being visited just after visiting
/// region number `region`.
bool isAfterRegion(int region) const { return nextRegion == region + 1; }
/// Return true if parent operation is being visited after all regions.
bool isAfterAllRegions() const { return nextRegion == numRegions; }
/// Advance the walk stage.
void advance() { nextRegion++; }
/// Returns the next region that will be visited.
int getNextRegion() const { return nextRegion; }
private:
const int numRegions;
int nextRegion;
};
namespace detail {
/// Helper templates to deduce the first argument of a callback parameter.
template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
template <typename Ret, typename F, typename Arg>
Arg first_argument_type(Ret (F::*)(Arg));
template <typename Ret, typename F, typename Arg>
Arg first_argument_type(Ret (F::*)(Arg) const);
template <typename Ret, typename Arg, typename... Rest>
Arg first_argument_type(Ret (*)(Arg, Rest...));
template <typename Ret, typename F, typename Arg, typename... Rest>
Arg first_argument_type(Ret (F::*)(Arg, Rest...));
template <typename Ret, typename F, typename Arg, typename... Rest>
Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
template <typename F>
decltype(first_argument_type(&F::operator())) first_argument_type(F);
@ -197,6 +233,87 @@ walk(Operation *op, FuncTy &&callback) {
return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
}
/// Generic walkers with stage aware callbacks.
/// Walk all the operations nested under (and including) the given operation,
/// with the callback being invoked on each operation N+1 times, where N is the
/// number of regions attached to the operation. The `stage` input to the
/// callback indicates the current walk stage. This method is invoked for void
/// returning callbacks.
void walk(Operation *op,
function_ref<void(Operation *, const WalkStage &stage)> callback);
/// Walk all the operations nested under (and including) the given operation,
/// with the callback being invoked on each operation N+1 times, where N is the
/// number of regions attached to the operation. The `stage` input to the
/// callback indicates the current walk stage. This method is invoked for
/// skippable or interruptible callbacks.
WalkResult
walk(Operation *op,
function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
/// Walk all of the operations nested under and including the given operation.
/// This method is selected for stage-aware callbacks that operate on
/// Operation*.
///
/// Example:
/// op->walk([](Operation *op, const WalkStage &stage) { ... });
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
walk(Operation *op, FuncTy &&callback) {
return detail::walk(op,
function_ref<RetT(ArgT, const WalkStage &)>(callback));
}
/// Walk all of the operations of type 'ArgT' nested under and including the
/// given operation. This method is selected for void returning callbacks that
/// operate on a specific derived operation type.
///
/// Example:
/// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
std::is_same<RetT, void>::value,
RetT>::type
walk(Operation *op, FuncTy &&callback) {
auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
if (auto derivedOp = dyn_cast<ArgT>(op))
callback(derivedOp, stage);
};
return detail::walk(
op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
}
/// Walk all of the operations of type 'ArgT' nested under and including the
/// given operation. This method is selected for WalkReturn returning
/// interruptible callbacks that operate on a specific derived operation type.
///
/// Example:
/// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
/// if (some_invariant)
/// return WalkResult::interrupt();
/// return WalkResult::advance();
/// });
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
std::is_same<RetT, WalkResult>::value,
RetT>::type
walk(Operation *op, FuncTy &&callback) {
auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
if (auto derivedOp = dyn_cast<ArgT>(op))
return callback(derivedOp, stage);
return WalkResult::advance();
};
return detail::walk(
op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
}
/// Utility to provide the return type of a templated walk method.
template <typename FnT>
using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));

View File

@ -11,6 +11,9 @@
using namespace mlir;
WalkStage::WalkStage(Operation *op)
: numRegions(op->getNumRegions()), nextRegion(0) {}
/// Walk all of the regions/blocks/operations nested under and including the
/// given operation. Regions, blocks and operations at the same nesting level
/// are visited in lexicographical order. The walk order for enclosing regions,
@ -67,6 +70,25 @@ void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
callback(op);
}
void detail::walk(Operation *op,
function_ref<void(Operation *, const WalkStage &)> callback) {
WalkStage stage(op);
for (Region &region : op->getRegions()) {
// Invoke callback on the parent op before visiting each child region.
callback(op, stage);
stage.advance();
for (Block &block : region) {
for (Operation &nestedOp : block)
walk(&nestedOp, callback);
}
}
// Invoke callback after all regions have been visited.
callback(op, stage);
}
/// Walk all of the regions/blocks/operations nested under and including the
/// given operation. These functions walk operations until an interrupt result
/// is returned by the callback. Walks on regions, blocks and operations may
@ -157,3 +179,29 @@ WalkResult detail::walk(Operation *op,
return callback(op);
return WalkResult::advance();
}
WalkResult detail::walk(
Operation *op,
function_ref<WalkResult(Operation *, const WalkStage &)> callback) {
WalkStage stage(op);
for (Region &region : op->getRegions()) {
// Invoke callback on the parent op before visiting each child region.
WalkResult result = callback(op, stage);
if (result.wasSkipped())
return WalkResult::advance();
if (result.wasInterrupted())
return WalkResult::interrupt();
stage.advance();
for (Block &block : region) {
// Early increment here in the case where the operation is erased.
for (Operation &nestedOp : llvm::make_early_inc_range(block))
if (walk(&nestedOp, callback).wasInterrupted())
return WalkResult::interrupt();
}
}
return callback(op, stage);
}

View File

@ -0,0 +1,157 @@
// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// Walk is interrupted before visiting "foo"
func @main(%arg0: f32) -> f32 {
%v1 = "foo"() {interrupt_before_all = true} : () -> f32
%v2 = arith.addf %v1, %arg0 : f32
return %v2 : f32
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 walk was interrupted
// -----
// Walk is interrupted after visiting "foo" (which has a single empty region)
func @main(%arg0: f32) -> f32 {
%v1 = "foo"() ({ "bar"() : ()-> () }) {interrupt_after_all = true} : () -> f32
%v2 = arith.addf %v1, %arg0 : f32
return %v2 : f32
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 op 'foo' before all regions
// CHECK: step 3 op 'bar' before all regions
// CHECK: step 4 walk was interrupted
// -----
// Walk is interrupted after visiting "foo"'s 1st region.
func @main(%arg0: f32) -> f32 {
%v1 = "foo"() ({
"bar0"() : () -> ()
}, {
"bar1"() : () -> ()
}) {interrupt_after_region = 0} : () -> f32
%v2 = arith.addf %v1, %arg0 : f32
return %v2 : f32
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 op 'foo' before all regions
// CHECK: step 3 op 'bar0' before all regions
// CHECK: step 4 walk was interrupted
// -----
// Test static filtering.
func @main() {
"foo"() : () -> ()
"test.two_region_op"()(
{"work"() : () -> ()},
{"work"() : () -> ()}
) {interrupt_after_all = true} : () -> ()
return
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 op 'foo' before all regions
// CHECK: step 3 op 'test.two_region_op' before all regions
// CHECK: step 4 op 'work' before all regions
// CHECK: step 5 op 'test.two_region_op' before region #1
// CHECK: step 6 op 'work' before all regions
// CHECK: step 7 walk was interrupted
// CHECK: step 8 op 'test.two_region_op' before all regions
// CHECK: step 9 op 'test.two_region_op' before region #1
// CHECK: step 10 walk was interrupted
// -----
// Test static filtering.
func @main() {
"foo"() : () -> ()
"test.two_region_op"()(
{"work"() : () -> ()},
{"work"() : () -> ()}
) {interrupt_after_region = 0} : () -> ()
return
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 op 'foo' before all regions
// CHECK: step 3 op 'test.two_region_op' before all regions
// CHECK: step 4 op 'work' before all regions
// CHECK: step 5 walk was interrupted
// CHECK: step 6 op 'test.two_region_op' before all regions
// CHECK: step 7 walk was interrupted
// -----
// Test skipping.
// Walk is skipped before visiting "foo".
func @main(%arg0: f32) -> f32 {
%v1 = "foo"() ({
"bar0"() : () -> ()
}, {
"bar1"() : () -> ()
}) {skip_before_all = true} : () -> f32
%v2 = arith.addf %v1, %arg0 : f32
return %v2 : f32
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 op 'arith.addf' before all regions
// CHECK: step 3 op 'std.return' before all regions
// CHECK: step 4 op 'builtin.func' after all regions
// CHECK: step 5 op 'builtin.module' after all regions
// -----
// Walk is skipped after visiting all regions of "foo".
func @main(%arg0: f32) -> f32 {
%v1 = "foo"() ({
"bar0"() : () -> ()
}, {
"bar1"() : () -> ()
}) {skip_after_all = true} : () -> f32
%v2 = arith.addf %v1, %arg0 : f32
return %v2 : f32
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 op 'foo' before all regions
// CHECK: step 3 op 'bar0' before all regions
// CHECK: step 4 op 'foo' before region #1
// CHECK: step 5 op 'bar1' before all regions
// CHECK: step 6 op 'arith.addf' before all regions
// CHECK: step 7 op 'std.return' before all regions
// CHECK: step 8 op 'builtin.func' after all regions
// CHECK: step 9 op 'builtin.module' after all regions
// -----
// Walk is skipped after visiting first region of "foo".
func @main(%arg0: f32) -> f32 {
%v1 = "foo"() ({
"bar0"() : () -> ()
}, {
"bar1"() : () -> ()
}) {skip_after_region = 0} : () -> f32
%v2 = arith.addf %v1, %arg0 : f32
return %v2 : f32
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 op 'foo' before all regions
// CHECK: step 3 op 'bar0' before all regions
// CHECK: step 4 op 'arith.addf' before all regions
// CHECK: step 5 op 'std.return' before all regions
// CHECK: step 6 op 'builtin.func' after all regions
// CHECK: step 7 op 'builtin.module' after all regions

View File

@ -0,0 +1,63 @@
// RUN: mlir-opt -test-generic-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// Verify the different configurations of generic IR visitors.
func @structured_cfg() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
scf.for %i = %c1 to %c10 step %c1 {
%cond = "use0"(%i) : (index) -> (i1)
scf.if %cond {
"use1"(%i) : (index) -> ()
} else {
"use2"(%i) : (index) -> ()
}
"use3"(%i) : (index) -> ()
}
return
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 1 op 'builtin.func' before all regions
// CHECK: step 2 op 'arith.constant' before all regions
// CHECK: step 3 op 'arith.constant' before all regions
// CHECK: step 4 op 'arith.constant' before all regions
// CHECK: step 5 op 'scf.for' before all regions
// CHECK: step 6 op 'use0' before all regions
// CHECK: step 7 op 'scf.if' before all regions
// CHECK: step 8 op 'use1' before all regions
// CHECK: step 9 op 'scf.yield' before all regions
// CHECK: step 10 op 'scf.if' before region #1
// CHECK: step 11 op 'use2' before all regions
// CHECK: step 12 op 'scf.yield' before all regions
// CHECK: step 13 op 'scf.if' after all regions
// CHECK: step 14 op 'use3' before all regions
// CHECK: step 15 op 'scf.yield' before all regions
// CHECK: step 16 op 'scf.for' after all regions
// CHECK: step 17 op 'std.return' before all regions
// CHECK: step 18 op 'builtin.func' after all regions
// CHECK: step 19 op 'builtin.module' after all regions
// -----
// Test the specific operation type visitor.
func @correct_number_of_regions() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
scf.for %i = %c1 to %c10 step %c1 {
"test.two_region_op"()(
{"work"() : () -> ()},
{"work"() : () -> ()}
) : () -> ()
}
return
}
// CHECK: step 0 op 'builtin.module' before all regions
// CHECK: step 15 op 'builtin.module' after all regions
// CHECK: step 16 op 'test.two_region_op' before all regions
// CHECK: step 17 op 'test.two_region_op' before region #1
// CHECK: step 18 op 'test.two_region_op' after all regions

View File

@ -15,6 +15,7 @@ add_mlir_library(MLIRTestIR
TestSymbolUses.cpp
TestTypes.cpp
TestVisitors.cpp
TestVisitorsGeneric.cpp
EXCLUDE_FROM_LIBMLIR

View File

@ -0,0 +1,123 @@
//===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
static std::string getStageDescription(const WalkStage &stage) {
if (stage.isBeforeAllRegions())
return "before all regions";
if (stage.isAfterAllRegions())
return "after all regions";
return "before region #" + std::to_string(stage.getNextRegion());
}
namespace {
/// This pass exercises generic visitor with void callbacks and prints the order
/// and stage in which operations are visited.
class TestGenericIRVisitorPass
: public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
public:
StringRef getArgument() const final { return "test-generic-ir-visitors"; }
StringRef getDescription() const final { return "Test generic IR visitors."; }
void runOnOperation() override {
Operation *outerOp = getOperation();
int stepNo = 0;
outerOp->walk([&](Operation *op, const WalkStage &stage) {
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
<< getStageDescription(stage) << "\n";
});
// Exercise static inference of operation type.
outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
<< getStageDescription(stage) << "\n";
});
}
};
/// This pass exercises the generic visitor with non-void callbacks and prints
/// the order and stage in which operations are visited. It will interrupt the
/// walk based on attributes peesent in the IR.
class TestGenericIRVisitorInterruptPass
: public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
public:
StringRef getArgument() const final {
return "test-generic-ir-visitors-interrupt";
}
StringRef getDescription() const final {
return "Test generic IR visitors with interrupts.";
}
void runOnOperation() override {
Operation *outerOp = getOperation();
int stepNo = 0;
auto walker = [&](Operation *op, const WalkStage &stage) {
if (auto interruptBeforeAall =
op->getAttrOfType<BoolAttr>("interrupt_before_all"))
if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
return WalkResult::interrupt();
if (auto interruptAfterAll =
op->getAttrOfType<BoolAttr>("interrupt_after_all"))
if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
return WalkResult::interrupt();
if (auto interruptAfterRegion =
op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
if (stage.isAfterRegion(
static_cast<int>(interruptAfterRegion.getInt())))
return WalkResult::interrupt();
if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
return WalkResult::skip();
if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
if (skipAfterAll.getValue() && stage.isAfterAllRegions())
return WalkResult::skip();
if (auto skipAfterRegion =
op->getAttrOfType<IntegerAttr>("skip_after_region"))
if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
return WalkResult::skip();
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
<< getStageDescription(stage) << "\n";
return WalkResult::advance();
};
// Interrupt the walk based on attributes on the operation.
auto result = outerOp->walk(walker);
if (result.wasInterrupted())
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
// Exercise static inference of operation type.
result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
return walker(op, stage);
});
if (result.wasInterrupted())
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestGenericIRVisitorsPass() {
PassRegistration<TestGenericIRVisitorPass>();
PassRegistration<TestGenericIRVisitorInterruptPass>();
}
} // namespace test
} // namespace mlir

View File

@ -78,6 +78,8 @@ void registerTestExpandTanhPass();
void registerTestComposeSubView();
void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
void registerTestGenericIRVisitorsInterruptPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgControlFuseByExpansion();
@ -171,6 +173,7 @@ void registerTestPasses() {
mlir::test::registerTestComposeSubView();
mlir::test::registerTestGpuParallelLoopMappingPass();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestLinalgCodegenStrategy();
mlir::test::registerTestLinalgControlFuseByExpansion();