forked from OSchip/llvm-project
[mlir][transforms] Add topological sort analysis
This change add a helper function for computing a topological sorting of a list of ops. E.g. this can be useful in transforms where a subset of ops should be cloned without dominance errors. The analysis reuses the existing implementation in TopologicalSortUtils.cpp. Differential Revision: https://reviews.llvm.org/D131669
This commit is contained in:
parent
e86119b4ff
commit
31fbdab376
|
@ -90,11 +90,23 @@ bool sortTopologically(
|
|||
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
|
||||
|
||||
/// Given a block, sort its operations in topological order, excluding its
|
||||
/// terminator if it has one.
|
||||
/// terminator if it has one. This sort is stable.
|
||||
bool sortTopologically(
|
||||
Block *block,
|
||||
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
|
||||
|
||||
/// Compute a topological ordering of the given ops. All ops must belong to the
|
||||
/// specified block.
|
||||
///
|
||||
/// This sort is not stable.
|
||||
///
|
||||
/// Note: If the specified ops contain incomplete/interrupted SSA use-def
|
||||
/// chains, the result may not actually be a topological sorting with respect to
|
||||
/// the entire program.
|
||||
bool computeTopologicalSorting(
|
||||
Block *block, MutableArrayRef<Operation *> ops,
|
||||
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
|
||||
|
|
|
@ -8,9 +8,46 @@
|
|||
|
||||
#include "mlir/Transforms/TopologicalSortUtils.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Return `true` if the given operation is ready to be scheduled.
|
||||
static bool isOpReady(Block *block, Operation *op,
|
||||
DenseSet<Operation *> &unscheduledOps,
|
||||
function_ref<bool(Value, Operation *)> isOperandReady) {
|
||||
// An operation is ready to be scheduled if all its operands are ready. An
|
||||
// operation is ready if:
|
||||
const auto isReady = [&](Value value, Operation *top) {
|
||||
// - the user-provided callback marks it as ready,
|
||||
if (isOperandReady && isOperandReady(value, op))
|
||||
return true;
|
||||
Operation *parent = value.getDefiningOp();
|
||||
// - it is a block argument,
|
||||
if (!parent)
|
||||
return true;
|
||||
Operation *ancestor = block->findAncestorOpInBlock(*parent);
|
||||
// - it is an implicit capture,
|
||||
if (!ancestor)
|
||||
return true;
|
||||
// - it is defined in a nested region, or
|
||||
if (ancestor == op)
|
||||
return true;
|
||||
// - its ancestor in the block is scheduled.
|
||||
return !unscheduledOps.contains(ancestor);
|
||||
};
|
||||
|
||||
// An operation is recursively ready to be scheduled of it and its nested
|
||||
// operations are ready.
|
||||
WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
|
||||
return llvm::all_of(nestedOp->getOperands(),
|
||||
[&](Value operand) { return isReady(operand, op); })
|
||||
? WalkResult::advance()
|
||||
: WalkResult::interrupt();
|
||||
});
|
||||
return !readyToSchedule.wasInterrupted();
|
||||
}
|
||||
|
||||
bool mlir::sortTopologically(
|
||||
Block *block, llvm::iterator_range<Block::iterator> ops,
|
||||
function_ref<bool(Value, Operation *)> isOperandReady) {
|
||||
|
@ -26,27 +63,6 @@ bool mlir::sortTopologically(
|
|||
Block::iterator nextScheduledOp = ops.begin();
|
||||
Block::iterator end = ops.end();
|
||||
|
||||
// An operation is ready to be scheduled if all its operands are ready. An
|
||||
// operation is ready if:
|
||||
const auto isReady = [&](Value value, Operation *top) {
|
||||
// - the user-provided callback marks it as ready,
|
||||
if (isOperandReady && isOperandReady(value, top))
|
||||
return true;
|
||||
Operation *parent = value.getDefiningOp();
|
||||
// - it is a block argument,
|
||||
if (!parent)
|
||||
return true;
|
||||
Operation *ancestor = block->findAncestorOpInBlock(*parent);
|
||||
// - it is an implicit capture,
|
||||
if (!ancestor)
|
||||
return true;
|
||||
// - it is defined in a nested region, or
|
||||
if (ancestor == top)
|
||||
return true;
|
||||
// - its ancestor in the block is scheduled.
|
||||
return !unscheduledOps.contains(ancestor);
|
||||
};
|
||||
|
||||
bool allOpsScheduled = true;
|
||||
while (!unscheduledOps.empty()) {
|
||||
bool scheduledAtLeastOnce = false;
|
||||
|
@ -56,16 +72,7 @@ bool mlir::sortTopologically(
|
|||
// set, and "schedule" it (move it before the `nextScheduledOp`).
|
||||
for (Operation &op :
|
||||
llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
|
||||
// An operation is recursively ready to be scheduled of it and its nested
|
||||
// operations are ready.
|
||||
WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
|
||||
return llvm::all_of(
|
||||
nestedOp->getOperands(),
|
||||
[&](Value operand) { return isReady(operand, &op); })
|
||||
? WalkResult::advance()
|
||||
: WalkResult::interrupt();
|
||||
});
|
||||
if (readyToSchedule.wasInterrupted())
|
||||
if (!isOpReady(block, &op, unscheduledOps, isOperandReady))
|
||||
continue;
|
||||
|
||||
// Schedule the operation by moving it to the start.
|
||||
|
@ -96,3 +103,48 @@ bool mlir::sortTopologically(
|
|||
isOperandReady);
|
||||
return sortTopologically(block, *block, isOperandReady);
|
||||
}
|
||||
|
||||
bool mlir::computeTopologicalSorting(
|
||||
Block *block, MutableArrayRef<Operation *> ops,
|
||||
function_ref<bool(Value, Operation *)> isOperandReady) {
|
||||
if (ops.empty())
|
||||
return true;
|
||||
|
||||
// The set of operations that have not yet been scheduled.
|
||||
DenseSet<Operation *> unscheduledOps;
|
||||
|
||||
// Mark all operations as unscheduled.
|
||||
for (Operation *op : ops) {
|
||||
assert(op->getBlock() == block && "op must belong to block");
|
||||
unscheduledOps.insert(op);
|
||||
}
|
||||
|
||||
unsigned nextScheduledOp = 0;
|
||||
|
||||
bool allOpsScheduled = true;
|
||||
while (!unscheduledOps.empty()) {
|
||||
bool scheduledAtLeastOnce = false;
|
||||
|
||||
// Loop over the ops that are not sorted yet, try to find the ones "ready",
|
||||
// i.e. the ones for which there aren't any operand produced by an op in the
|
||||
// set, and "schedule" it (swap it with the op at `nextScheduledOp`).
|
||||
for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
|
||||
if (!isOpReady(block, ops[i], unscheduledOps, isOperandReady))
|
||||
continue;
|
||||
|
||||
// Schedule the operation by moving it to the start.
|
||||
unscheduledOps.erase(ops[i]);
|
||||
std::swap(ops[i], ops[nextScheduledOp]);
|
||||
scheduledAtLeastOnce = true;
|
||||
++nextScheduledOp;
|
||||
}
|
||||
|
||||
// If no operations were scheduled, just schedule the first op and continue.
|
||||
if (!scheduledAtLeastOnce) {
|
||||
allOpsScheduled = false;
|
||||
unscheduledOps.erase(ops[nextScheduledOp++]);
|
||||
}
|
||||
}
|
||||
|
||||
return allOpsScheduled;
|
||||
}
|
||||
|
|
|
@ -1,27 +1,39 @@
|
|||
// RUN: mlir-opt -topological-sort %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-topological-sort-analysis %s | FileCheck %s -check-prefix=CHECK-ANALYSIS
|
||||
|
||||
// Test producer is after user.
|
||||
// CHECK-LABEL: test.graph_region
|
||||
test.graph_region {
|
||||
// CHECK-ANALYSIS-LABEL: test.graph_region
|
||||
test.graph_region attributes{"root"} {
|
||||
// CHECK-NEXT: test.foo
|
||||
// CHECK-NEXT: test.baz
|
||||
// CHECK-NEXT: test.bar
|
||||
%0 = "test.foo"() : () -> i32
|
||||
"test.bar"(%1, %0) : (i32, i32) -> ()
|
||||
%1 = "test.baz"() : () -> i32
|
||||
|
||||
// CHECK-ANALYSIS-NEXT: test.foo{{.*}} {pos = 0
|
||||
// CHECK-ANALYSIS-NEXT: test.bar{{.*}} {pos = 2
|
||||
// CHECK-ANALYSIS-NEXT: test.baz{{.*}} {pos = 1
|
||||
%0 = "test.foo"() {selected} : () -> i32
|
||||
"test.bar"(%1, %0) {selected} : (i32, i32) -> ()
|
||||
%1 = "test.baz"() {selected} : () -> i32
|
||||
}
|
||||
|
||||
// Test cycles.
|
||||
// CHECK-LABEL: test.graph_region
|
||||
test.graph_region {
|
||||
// CHECK-ANALYSIS-LABEL: test.graph_region
|
||||
test.graph_region attributes{"root"} {
|
||||
// CHECK-NEXT: test.d
|
||||
// CHECK-NEXT: test.a
|
||||
// CHECK-NEXT: test.c
|
||||
// CHECK-NEXT: test.b
|
||||
%2 = "test.c"(%1) : (i32) -> i32
|
||||
|
||||
// CHECK-ANALYSIS-NEXT: test.c{{.*}} {pos = 0
|
||||
// CHECK-ANALYSIS-NEXT: test.b{{.*}} : (
|
||||
// CHECK-ANALYSIS-NEXT: test.a{{.*}} {pos = 2
|
||||
// CHECK-ANALYSIS-NEXT: test.d{{.*}} {pos = 1
|
||||
%2 = "test.c"(%1) {selected} : (i32) -> i32
|
||||
%1 = "test.b"(%0, %2) : (i32, i32) -> i32
|
||||
%0 = "test.a"(%3) : (i32) -> i32
|
||||
%3 = "test.d"() : () -> i32
|
||||
%0 = "test.a"(%3) {selected} : (i32) -> i32
|
||||
%3 = "test.d"() {selected} : () -> i32
|
||||
}
|
||||
|
||||
// Test block arguments.
|
||||
|
|
|
@ -5,6 +5,7 @@ add_mlir_library(MLIRTestTransforms
|
|||
TestControlFlowSink.cpp
|
||||
TestInlining.cpp
|
||||
TestIntRangeInference.cpp
|
||||
TestTopologicalSort.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
//===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/TopologicalSortUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestTopologicalSortAnalysisPass
|
||||
: public PassWrapper<TestTopologicalSortAnalysisPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass)
|
||||
|
||||
StringRef getArgument() const final {
|
||||
return "test-topological-sort-analysis";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test topological sorting of ops";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
OpBuilder builder(op->getContext());
|
||||
|
||||
op->walk([&](Operation *root) {
|
||||
if (!root->hasAttr("root"))
|
||||
return WalkResult::advance();
|
||||
|
||||
assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() &&
|
||||
"expected one block");
|
||||
Block *block = &root->getRegion(0).front();
|
||||
SmallVector<Operation *> selectedOps;
|
||||
block->walk([&](Operation *op) {
|
||||
if (op->hasAttr("selected"))
|
||||
selectedOps.push_back(op);
|
||||
});
|
||||
|
||||
computeTopologicalSorting(block, selectedOps);
|
||||
for (const auto &it : llvm::enumerate(selectedOps))
|
||||
it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestTopologicalSortAnalysisPass() {
|
||||
PassRegistration<TestTopologicalSortAnalysisPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
|
@ -111,6 +111,7 @@ void registerTestSCFUtilsPass();
|
|||
void registerTestSliceAnalysisPass();
|
||||
void registerTestTensorTransforms();
|
||||
void registerTestTilingInterface();
|
||||
void registerTestTopologicalSortAnalysisPass();
|
||||
void registerTestTransformDialectInterpreterPass();
|
||||
void registerTestVectorLowerings();
|
||||
void registerTestNvgpuLowerings();
|
||||
|
@ -207,6 +208,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestSliceAnalysisPass();
|
||||
mlir::test::registerTestTensorTransforms();
|
||||
mlir::test::registerTestTilingInterface();
|
||||
mlir::test::registerTestTopologicalSortAnalysisPass();
|
||||
mlir::test::registerTestTransformDialectInterpreterPass();
|
||||
mlir::test::registerTestVectorLowerings();
|
||||
mlir::test::registerTestNvgpuLowerings();
|
||||
|
|
Loading…
Reference in New Issue