forked from OSchip/llvm-project
[MLIR] Introduce generic visitors.
- Generic visitors invoke operation callbacks before/in-between/after visiting the regions attached to an operation and use a `WalkStage` to indicate which regions have been visited. - This can be useful for cases where we need to visit the operation in between visiting regions attached to the operation. Differential Revision: https://reviews.llvm.org/D116230
This commit is contained in:
parent
67076ebb60
commit
8067ced144
|
@ -510,10 +510,40 @@ public:
|
|||
/// });
|
||||
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
|
||||
typename RetT = detail::walkResultType<FnT>>
|
||||
RetT walk(FnT &&callback) {
|
||||
typename std::enable_if<
|
||||
llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type
|
||||
walk(FnT &&callback) {
|
||||
return detail::walk<Order>(this, std::forward<FnT>(callback));
|
||||
}
|
||||
|
||||
/// Generic walker with a stage aware callback. Walk the operation by calling
|
||||
/// the callback for each nested operation (including this one) N+1 times,
|
||||
/// where N is the number of regions attached to that operation.
|
||||
///
|
||||
/// The callback method can take any of the following forms:
|
||||
/// void(Operation *, const WalkStage &) : Walk all operation opaquely
|
||||
/// * op->walk([](Operation *nestedOp, const WalkStage &stage) { ...});
|
||||
/// void(OpT, const WalkStage &) : Walk all operations of the given derived
|
||||
/// type.
|
||||
/// * op->walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
|
||||
/// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
|
||||
/// but allow for interruption/skipping.
|
||||
/// * op->walk([](... op, const WalkStage &stage) {
|
||||
/// // Skip the walk of this op based on some invariant.
|
||||
/// if (some_invariant)
|
||||
/// return WalkResult::skip();
|
||||
/// // Interrupt, i.e cancel, the walk based on some invariant.
|
||||
/// if (another_invariant)
|
||||
/// return WalkResult::interrupt();
|
||||
/// return WalkResult::advance();
|
||||
/// });
|
||||
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
|
||||
typename std::enable_if<
|
||||
llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type
|
||||
walk(FnT &&callback) {
|
||||
return detail::walk(this, std::forward<FnT>(callback));
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Uses
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -61,13 +61,49 @@ public:
|
|||
/// Traversal order for region, block and operation walk utilities.
|
||||
enum class WalkOrder { PreOrder, PostOrder };
|
||||
|
||||
/// A utility class to encode the current walk stage for "generic" walkers.
|
||||
/// When walking an operation, we can either choose a Pre/Post order walker
|
||||
/// which invokes the callback on an operation before/after all its attached
|
||||
/// regions have been visited, or choose a "generic" walker where the callback
|
||||
/// is invoked on the operation N+1 times where N is the number of regions
|
||||
/// attached to that operation. The `WalkStage` class below encodes the current
|
||||
/// stage of the walk, i.e., which regions have already been visited, and the
|
||||
/// callback accepts an additional argument for the current stage. Such
|
||||
/// generic walkers that accept stage-aware callbacks are only applicable when
|
||||
/// the callback operates on an operation (i.e., not applicable for callbacks
|
||||
/// on Blocks or Regions).
|
||||
class WalkStage {
|
||||
public:
|
||||
explicit WalkStage(Operation *op);
|
||||
|
||||
/// Return true if parent operation is being visited before all regions.
|
||||
bool isBeforeAllRegions() const { return nextRegion == 0; }
|
||||
/// Returns true if parent operation is being visited just before visiting
|
||||
/// region number `region`.
|
||||
bool isBeforeRegion(int region) const { return nextRegion == region; }
|
||||
/// Returns true if parent operation is being visited just after visiting
|
||||
/// region number `region`.
|
||||
bool isAfterRegion(int region) const { return nextRegion == region + 1; }
|
||||
/// Return true if parent operation is being visited after all regions.
|
||||
bool isAfterAllRegions() const { return nextRegion == numRegions; }
|
||||
/// Advance the walk stage.
|
||||
void advance() { nextRegion++; }
|
||||
/// Returns the next region that will be visited.
|
||||
int getNextRegion() const { return nextRegion; }
|
||||
|
||||
private:
|
||||
const int numRegions;
|
||||
int nextRegion;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
/// Helper templates to deduce the first argument of a callback parameter.
|
||||
template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
|
||||
template <typename Ret, typename F, typename Arg>
|
||||
Arg first_argument_type(Ret (F::*)(Arg));
|
||||
template <typename Ret, typename F, typename Arg>
|
||||
Arg first_argument_type(Ret (F::*)(Arg) const);
|
||||
template <typename Ret, typename Arg, typename... Rest>
|
||||
Arg first_argument_type(Ret (*)(Arg, Rest...));
|
||||
template <typename Ret, typename F, typename Arg, typename... Rest>
|
||||
Arg first_argument_type(Ret (F::*)(Arg, Rest...));
|
||||
template <typename Ret, typename F, typename Arg, typename... Rest>
|
||||
Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
|
||||
template <typename F>
|
||||
decltype(first_argument_type(&F::operator())) first_argument_type(F);
|
||||
|
||||
|
@ -197,6 +233,87 @@ walk(Operation *op, FuncTy &&callback) {
|
|||
return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
|
||||
}
|
||||
|
||||
/// Generic walkers with stage aware callbacks.
|
||||
|
||||
/// Walk all the operations nested under (and including) the given operation,
|
||||
/// with the callback being invoked on each operation N+1 times, where N is the
|
||||
/// number of regions attached to the operation. The `stage` input to the
|
||||
/// callback indicates the current walk stage. This method is invoked for void
|
||||
/// returning callbacks.
|
||||
void walk(Operation *op,
|
||||
function_ref<void(Operation *, const WalkStage &stage)> callback);
|
||||
|
||||
/// Walk all the operations nested under (and including) the given operation,
|
||||
/// with the callback being invoked on each operation N+1 times, where N is the
|
||||
/// number of regions attached to the operation. The `stage` input to the
|
||||
/// callback indicates the current walk stage. This method is invoked for
|
||||
/// skippable or interruptible callbacks.
|
||||
WalkResult
|
||||
walk(Operation *op,
|
||||
function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
|
||||
|
||||
/// Walk all of the operations nested under and including the given operation.
|
||||
/// This method is selected for stage-aware callbacks that operate on
|
||||
/// Operation*.
|
||||
///
|
||||
/// Example:
|
||||
/// op->walk([](Operation *op, const WalkStage &stage) { ... });
|
||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||
typename RetT = decltype(std::declval<FuncTy>()(
|
||||
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||
typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
|
||||
walk(Operation *op, FuncTy &&callback) {
|
||||
return detail::walk(op,
|
||||
function_ref<RetT(ArgT, const WalkStage &)>(callback));
|
||||
}
|
||||
|
||||
/// Walk all of the operations of type 'ArgT' nested under and including the
|
||||
/// given operation. This method is selected for void returning callbacks that
|
||||
/// operate on a specific derived operation type.
|
||||
///
|
||||
/// Example:
|
||||
/// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
|
||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||
typename RetT = decltype(std::declval<FuncTy>()(
|
||||
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
|
||||
std::is_same<RetT, void>::value,
|
||||
RetT>::type
|
||||
walk(Operation *op, FuncTy &&callback) {
|
||||
auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
|
||||
if (auto derivedOp = dyn_cast<ArgT>(op))
|
||||
callback(derivedOp, stage);
|
||||
};
|
||||
return detail::walk(
|
||||
op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
|
||||
}
|
||||
|
||||
/// Walk all of the operations of type 'ArgT' nested under and including the
|
||||
/// given operation. This method is selected for WalkReturn returning
|
||||
/// interruptible callbacks that operate on a specific derived operation type.
|
||||
///
|
||||
/// Example:
|
||||
/// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
|
||||
/// if (some_invariant)
|
||||
/// return WalkResult::interrupt();
|
||||
/// return WalkResult::advance();
|
||||
/// });
|
||||
template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||
typename RetT = decltype(std::declval<FuncTy>()(
|
||||
std::declval<ArgT>(), std::declval<const WalkStage &>()))>
|
||||
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
|
||||
std::is_same<RetT, WalkResult>::value,
|
||||
RetT>::type
|
||||
walk(Operation *op, FuncTy &&callback) {
|
||||
auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
|
||||
if (auto derivedOp = dyn_cast<ArgT>(op))
|
||||
return callback(derivedOp, stage);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
return detail::walk(
|
||||
op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
|
||||
}
|
||||
|
||||
/// Utility to provide the return type of a templated walk method.
|
||||
template <typename FnT>
|
||||
using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
|
||||
|
|
|
@ -11,6 +11,9 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
WalkStage::WalkStage(Operation *op)
|
||||
: numRegions(op->getNumRegions()), nextRegion(0) {}
|
||||
|
||||
/// Walk all of the regions/blocks/operations nested under and including the
|
||||
/// given operation. Regions, blocks and operations at the same nesting level
|
||||
/// are visited in lexicographical order. The walk order for enclosing regions,
|
||||
|
@ -67,6 +70,25 @@ void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
|
|||
callback(op);
|
||||
}
|
||||
|
||||
void detail::walk(Operation *op,
|
||||
function_ref<void(Operation *, const WalkStage &)> callback) {
|
||||
WalkStage stage(op);
|
||||
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
// Invoke callback on the parent op before visiting each child region.
|
||||
callback(op, stage);
|
||||
stage.advance();
|
||||
|
||||
for (Block &block : region) {
|
||||
for (Operation &nestedOp : block)
|
||||
walk(&nestedOp, callback);
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke callback after all regions have been visited.
|
||||
callback(op, stage);
|
||||
}
|
||||
|
||||
/// Walk all of the regions/blocks/operations nested under and including the
|
||||
/// given operation. These functions walk operations until an interrupt result
|
||||
/// is returned by the callback. Walks on regions, blocks and operations may
|
||||
|
@ -157,3 +179,29 @@ WalkResult detail::walk(Operation *op,
|
|||
return callback(op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
WalkResult detail::walk(
|
||||
Operation *op,
|
||||
function_ref<WalkResult(Operation *, const WalkStage &)> callback) {
|
||||
WalkStage stage(op);
|
||||
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
// Invoke callback on the parent op before visiting each child region.
|
||||
WalkResult result = callback(op, stage);
|
||||
|
||||
if (result.wasSkipped())
|
||||
return WalkResult::advance();
|
||||
if (result.wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
stage.advance();
|
||||
|
||||
for (Block &block : region) {
|
||||
// Early increment here in the case where the operation is erased.
|
||||
for (Operation &nestedOp : llvm::make_early_inc_range(block))
|
||||
if (walk(&nestedOp, callback).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return callback(op, stage);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
|
||||
// Walk is interrupted before visiting "foo"
|
||||
func @main(%arg0: f32) -> f32 {
|
||||
%v1 = "foo"() {interrupt_before_all = true} : () -> f32
|
||||
%v2 = arith.addf %v1, %arg0 : f32
|
||||
return %v2 : f32
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 walk was interrupted
|
||||
|
||||
// -----
|
||||
|
||||
// Walk is interrupted after visiting "foo" (which has a single empty region)
|
||||
func @main(%arg0: f32) -> f32 {
|
||||
%v1 = "foo"() ({ "bar"() : ()-> () }) {interrupt_after_all = true} : () -> f32
|
||||
%v2 = arith.addf %v1, %arg0 : f32
|
||||
return %v2 : f32
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 op 'foo' before all regions
|
||||
// CHECK: step 3 op 'bar' before all regions
|
||||
// CHECK: step 4 walk was interrupted
|
||||
|
||||
// -----
|
||||
|
||||
// Walk is interrupted after visiting "foo"'s 1st region.
|
||||
func @main(%arg0: f32) -> f32 {
|
||||
%v1 = "foo"() ({
|
||||
"bar0"() : () -> ()
|
||||
}, {
|
||||
"bar1"() : () -> ()
|
||||
}) {interrupt_after_region = 0} : () -> f32
|
||||
%v2 = arith.addf %v1, %arg0 : f32
|
||||
return %v2 : f32
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 op 'foo' before all regions
|
||||
// CHECK: step 3 op 'bar0' before all regions
|
||||
// CHECK: step 4 walk was interrupted
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Test static filtering.
|
||||
func @main() {
|
||||
"foo"() : () -> ()
|
||||
"test.two_region_op"()(
|
||||
{"work"() : () -> ()},
|
||||
{"work"() : () -> ()}
|
||||
) {interrupt_after_all = true} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 op 'foo' before all regions
|
||||
// CHECK: step 3 op 'test.two_region_op' before all regions
|
||||
// CHECK: step 4 op 'work' before all regions
|
||||
// CHECK: step 5 op 'test.two_region_op' before region #1
|
||||
// CHECK: step 6 op 'work' before all regions
|
||||
// CHECK: step 7 walk was interrupted
|
||||
// CHECK: step 8 op 'test.two_region_op' before all regions
|
||||
// CHECK: step 9 op 'test.two_region_op' before region #1
|
||||
// CHECK: step 10 walk was interrupted
|
||||
|
||||
// -----
|
||||
|
||||
// Test static filtering.
|
||||
func @main() {
|
||||
"foo"() : () -> ()
|
||||
"test.two_region_op"()(
|
||||
{"work"() : () -> ()},
|
||||
{"work"() : () -> ()}
|
||||
) {interrupt_after_region = 0} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 op 'foo' before all regions
|
||||
// CHECK: step 3 op 'test.two_region_op' before all regions
|
||||
// CHECK: step 4 op 'work' before all regions
|
||||
// CHECK: step 5 walk was interrupted
|
||||
// CHECK: step 6 op 'test.two_region_op' before all regions
|
||||
// CHECK: step 7 walk was interrupted
|
||||
|
||||
// -----
|
||||
// Test skipping.
|
||||
|
||||
// Walk is skipped before visiting "foo".
|
||||
func @main(%arg0: f32) -> f32 {
|
||||
%v1 = "foo"() ({
|
||||
"bar0"() : () -> ()
|
||||
}, {
|
||||
"bar1"() : () -> ()
|
||||
}) {skip_before_all = true} : () -> f32
|
||||
%v2 = arith.addf %v1, %arg0 : f32
|
||||
return %v2 : f32
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 op 'arith.addf' before all regions
|
||||
// CHECK: step 3 op 'std.return' before all regions
|
||||
// CHECK: step 4 op 'builtin.func' after all regions
|
||||
// CHECK: step 5 op 'builtin.module' after all regions
|
||||
|
||||
// -----
|
||||
// Walk is skipped after visiting all regions of "foo".
|
||||
func @main(%arg0: f32) -> f32 {
|
||||
%v1 = "foo"() ({
|
||||
"bar0"() : () -> ()
|
||||
}, {
|
||||
"bar1"() : () -> ()
|
||||
}) {skip_after_all = true} : () -> f32
|
||||
%v2 = arith.addf %v1, %arg0 : f32
|
||||
return %v2 : f32
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 op 'foo' before all regions
|
||||
// CHECK: step 3 op 'bar0' before all regions
|
||||
// CHECK: step 4 op 'foo' before region #1
|
||||
// CHECK: step 5 op 'bar1' before all regions
|
||||
// CHECK: step 6 op 'arith.addf' before all regions
|
||||
// CHECK: step 7 op 'std.return' before all regions
|
||||
// CHECK: step 8 op 'builtin.func' after all regions
|
||||
// CHECK: step 9 op 'builtin.module' after all regions
|
||||
|
||||
// -----
|
||||
// Walk is skipped after visiting first region of "foo".
|
||||
func @main(%arg0: f32) -> f32 {
|
||||
%v1 = "foo"() ({
|
||||
"bar0"() : () -> ()
|
||||
}, {
|
||||
"bar1"() : () -> ()
|
||||
}) {skip_after_region = 0} : () -> f32
|
||||
%v2 = arith.addf %v1, %arg0 : f32
|
||||
return %v2 : f32
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 op 'foo' before all regions
|
||||
// CHECK: step 3 op 'bar0' before all regions
|
||||
// CHECK: step 4 op 'arith.addf' before all regions
|
||||
// CHECK: step 5 op 'std.return' before all regions
|
||||
// CHECK: step 6 op 'builtin.func' after all regions
|
||||
// CHECK: step 7 op 'builtin.module' after all regions
|
|
@ -0,0 +1,63 @@
|
|||
// RUN: mlir-opt -test-generic-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
|
||||
// Verify the different configurations of generic IR visitors.
|
||||
|
||||
func @structured_cfg() {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c10 = arith.constant 10 : index
|
||||
scf.for %i = %c1 to %c10 step %c1 {
|
||||
%cond = "use0"(%i) : (index) -> (i1)
|
||||
scf.if %cond {
|
||||
"use1"(%i) : (index) -> ()
|
||||
} else {
|
||||
"use2"(%i) : (index) -> ()
|
||||
}
|
||||
"use3"(%i) : (index) -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 1 op 'builtin.func' before all regions
|
||||
// CHECK: step 2 op 'arith.constant' before all regions
|
||||
// CHECK: step 3 op 'arith.constant' before all regions
|
||||
// CHECK: step 4 op 'arith.constant' before all regions
|
||||
// CHECK: step 5 op 'scf.for' before all regions
|
||||
// CHECK: step 6 op 'use0' before all regions
|
||||
// CHECK: step 7 op 'scf.if' before all regions
|
||||
// CHECK: step 8 op 'use1' before all regions
|
||||
// CHECK: step 9 op 'scf.yield' before all regions
|
||||
// CHECK: step 10 op 'scf.if' before region #1
|
||||
// CHECK: step 11 op 'use2' before all regions
|
||||
// CHECK: step 12 op 'scf.yield' before all regions
|
||||
// CHECK: step 13 op 'scf.if' after all regions
|
||||
// CHECK: step 14 op 'use3' before all regions
|
||||
// CHECK: step 15 op 'scf.yield' before all regions
|
||||
// CHECK: step 16 op 'scf.for' after all regions
|
||||
// CHECK: step 17 op 'std.return' before all regions
|
||||
// CHECK: step 18 op 'builtin.func' after all regions
|
||||
// CHECK: step 19 op 'builtin.module' after all regions
|
||||
|
||||
// -----
|
||||
// Test the specific operation type visitor.
|
||||
|
||||
func @correct_number_of_regions() {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c10 = arith.constant 10 : index
|
||||
scf.for %i = %c1 to %c10 step %c1 {
|
||||
"test.two_region_op"()(
|
||||
{"work"() : () -> ()},
|
||||
{"work"() : () -> ()}
|
||||
) : () -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: step 0 op 'builtin.module' before all regions
|
||||
// CHECK: step 15 op 'builtin.module' after all regions
|
||||
// CHECK: step 16 op 'test.two_region_op' before all regions
|
||||
// CHECK: step 17 op 'test.two_region_op' before region #1
|
||||
// CHECK: step 18 op 'test.two_region_op' after all regions
|
|
@ -15,6 +15,7 @@ add_mlir_library(MLIRTestIR
|
|||
TestSymbolUses.cpp
|
||||
TestTypes.cpp
|
||||
TestVisitors.cpp
|
||||
TestVisitorsGeneric.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
//===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "TestDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static std::string getStageDescription(const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions())
|
||||
return "before all regions";
|
||||
if (stage.isAfterAllRegions())
|
||||
return "after all regions";
|
||||
return "before region #" + std::to_string(stage.getNextRegion());
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// This pass exercises generic visitor with void callbacks and prints the order
|
||||
/// and stage in which operations are visited.
|
||||
class TestGenericIRVisitorPass
|
||||
: public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
|
||||
public:
|
||||
StringRef getArgument() const final { return "test-generic-ir-visitors"; }
|
||||
StringRef getDescription() const final { return "Test generic IR visitors."; }
|
||||
void runOnOperation() override {
|
||||
Operation *outerOp = getOperation();
|
||||
int stepNo = 0;
|
||||
outerOp->walk([&](Operation *op, const WalkStage &stage) {
|
||||
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
|
||||
<< getStageDescription(stage) << "\n";
|
||||
});
|
||||
|
||||
// Exercise static inference of operation type.
|
||||
outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
|
||||
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
|
||||
<< getStageDescription(stage) << "\n";
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// This pass exercises the generic visitor with non-void callbacks and prints
|
||||
/// the order and stage in which operations are visited. It will interrupt the
|
||||
/// walk based on attributes peesent in the IR.
|
||||
class TestGenericIRVisitorInterruptPass
|
||||
: public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
|
||||
public:
|
||||
StringRef getArgument() const final {
|
||||
return "test-generic-ir-visitors-interrupt";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test generic IR visitors with interrupts.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
Operation *outerOp = getOperation();
|
||||
int stepNo = 0;
|
||||
|
||||
auto walker = [&](Operation *op, const WalkStage &stage) {
|
||||
if (auto interruptBeforeAall =
|
||||
op->getAttrOfType<BoolAttr>("interrupt_before_all"))
|
||||
if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
if (auto interruptAfterAll =
|
||||
op->getAttrOfType<BoolAttr>("interrupt_after_all"))
|
||||
if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
if (auto interruptAfterRegion =
|
||||
op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
|
||||
if (stage.isAfterRegion(
|
||||
static_cast<int>(interruptAfterRegion.getInt())))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
|
||||
if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
|
||||
return WalkResult::skip();
|
||||
|
||||
if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
|
||||
if (skipAfterAll.getValue() && stage.isAfterAllRegions())
|
||||
return WalkResult::skip();
|
||||
|
||||
if (auto skipAfterRegion =
|
||||
op->getAttrOfType<IntegerAttr>("skip_after_region"))
|
||||
if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
|
||||
return WalkResult::skip();
|
||||
|
||||
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
|
||||
<< getStageDescription(stage) << "\n";
|
||||
return WalkResult::advance();
|
||||
};
|
||||
|
||||
// Interrupt the walk based on attributes on the operation.
|
||||
auto result = outerOp->walk(walker);
|
||||
|
||||
if (result.wasInterrupted())
|
||||
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
|
||||
|
||||
// Exercise static inference of operation type.
|
||||
result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
|
||||
return walker(op, stage);
|
||||
});
|
||||
|
||||
if (result.wasInterrupted())
|
||||
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestGenericIRVisitorsPass() {
|
||||
PassRegistration<TestGenericIRVisitorPass>();
|
||||
PassRegistration<TestGenericIRVisitorInterruptPass>();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace mlir
|
|
@ -78,6 +78,8 @@ void registerTestExpandTanhPass();
|
|||
void registerTestComposeSubView();
|
||||
void registerTestGpuParallelLoopMappingPass();
|
||||
void registerTestIRVisitorsPass();
|
||||
void registerTestGenericIRVisitorsPass();
|
||||
void registerTestGenericIRVisitorsInterruptPass();
|
||||
void registerTestInterfaces();
|
||||
void registerTestLinalgCodegenStrategy();
|
||||
void registerTestLinalgControlFuseByExpansion();
|
||||
|
@ -171,6 +173,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestComposeSubView();
|
||||
mlir::test::registerTestGpuParallelLoopMappingPass();
|
||||
mlir::test::registerTestIRVisitorsPass();
|
||||
mlir::test::registerTestGenericIRVisitorsPass();
|
||||
mlir::test::registerTestInterfaces();
|
||||
mlir::test::registerTestLinalgCodegenStrategy();
|
||||
mlir::test::registerTestLinalgControlFuseByExpansion();
|
||||
|
|
Loading…
Reference in New Issue