forked from OSchip/llvm-project
[mlir][transforms] Add topological sort analysis
This change add a helper function for computing a topological sorting of a list of ops. E.g. this can be useful in transforms where a subset of ops should be cloned without dominance errors. The analysis reuses the existing implementation in TopologicalSortUtils.cpp. Differential Revision: https://reviews.llvm.org/D131669
This commit is contained in:
parent
e86119b4ff
commit
31fbdab376
|
@ -90,11 +90,23 @@ bool sortTopologically(
|
||||||
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
|
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
|
||||||
|
|
||||||
/// Given a block, sort its operations in topological order, excluding its
|
/// Given a block, sort its operations in topological order, excluding its
|
||||||
/// terminator if it has one.
|
/// terminator if it has one. This sort is stable.
|
||||||
bool sortTopologically(
|
bool sortTopologically(
|
||||||
Block *block,
|
Block *block,
|
||||||
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
|
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
|
||||||
|
|
||||||
|
/// Compute a topological ordering of the given ops. All ops must belong to the
|
||||||
|
/// specified block.
|
||||||
|
///
|
||||||
|
/// This sort is not stable.
|
||||||
|
///
|
||||||
|
/// Note: If the specified ops contain incomplete/interrupted SSA use-def
|
||||||
|
/// chains, the result may not actually be a topological sorting with respect to
|
||||||
|
/// the entire program.
|
||||||
|
bool computeTopologicalSorting(
|
||||||
|
Block *block, MutableArrayRef<Operation *> ops,
|
||||||
|
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
|
#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
|
||||||
|
|
|
@ -8,9 +8,46 @@
|
||||||
|
|
||||||
#include "mlir/Transforms/TopologicalSortUtils.h"
|
#include "mlir/Transforms/TopologicalSortUtils.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
/// Return `true` if the given operation is ready to be scheduled.
|
||||||
|
static bool isOpReady(Block *block, Operation *op,
|
||||||
|
DenseSet<Operation *> &unscheduledOps,
|
||||||
|
function_ref<bool(Value, Operation *)> isOperandReady) {
|
||||||
|
// An operation is ready to be scheduled if all its operands are ready. An
|
||||||
|
// operation is ready if:
|
||||||
|
const auto isReady = [&](Value value, Operation *top) {
|
||||||
|
// - the user-provided callback marks it as ready,
|
||||||
|
if (isOperandReady && isOperandReady(value, op))
|
||||||
|
return true;
|
||||||
|
Operation *parent = value.getDefiningOp();
|
||||||
|
// - it is a block argument,
|
||||||
|
if (!parent)
|
||||||
|
return true;
|
||||||
|
Operation *ancestor = block->findAncestorOpInBlock(*parent);
|
||||||
|
// - it is an implicit capture,
|
||||||
|
if (!ancestor)
|
||||||
|
return true;
|
||||||
|
// - it is defined in a nested region, or
|
||||||
|
if (ancestor == op)
|
||||||
|
return true;
|
||||||
|
// - its ancestor in the block is scheduled.
|
||||||
|
return !unscheduledOps.contains(ancestor);
|
||||||
|
};
|
||||||
|
|
||||||
|
// An operation is recursively ready to be scheduled of it and its nested
|
||||||
|
// operations are ready.
|
||||||
|
WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
|
||||||
|
return llvm::all_of(nestedOp->getOperands(),
|
||||||
|
[&](Value operand) { return isReady(operand, op); })
|
||||||
|
? WalkResult::advance()
|
||||||
|
: WalkResult::interrupt();
|
||||||
|
});
|
||||||
|
return !readyToSchedule.wasInterrupted();
|
||||||
|
}
|
||||||
|
|
||||||
bool mlir::sortTopologically(
|
bool mlir::sortTopologically(
|
||||||
Block *block, llvm::iterator_range<Block::iterator> ops,
|
Block *block, llvm::iterator_range<Block::iterator> ops,
|
||||||
function_ref<bool(Value, Operation *)> isOperandReady) {
|
function_ref<bool(Value, Operation *)> isOperandReady) {
|
||||||
|
@ -26,27 +63,6 @@ bool mlir::sortTopologically(
|
||||||
Block::iterator nextScheduledOp = ops.begin();
|
Block::iterator nextScheduledOp = ops.begin();
|
||||||
Block::iterator end = ops.end();
|
Block::iterator end = ops.end();
|
||||||
|
|
||||||
// An operation is ready to be scheduled if all its operands are ready. An
|
|
||||||
// operation is ready if:
|
|
||||||
const auto isReady = [&](Value value, Operation *top) {
|
|
||||||
// - the user-provided callback marks it as ready,
|
|
||||||
if (isOperandReady && isOperandReady(value, top))
|
|
||||||
return true;
|
|
||||||
Operation *parent = value.getDefiningOp();
|
|
||||||
// - it is a block argument,
|
|
||||||
if (!parent)
|
|
||||||
return true;
|
|
||||||
Operation *ancestor = block->findAncestorOpInBlock(*parent);
|
|
||||||
// - it is an implicit capture,
|
|
||||||
if (!ancestor)
|
|
||||||
return true;
|
|
||||||
// - it is defined in a nested region, or
|
|
||||||
if (ancestor == top)
|
|
||||||
return true;
|
|
||||||
// - its ancestor in the block is scheduled.
|
|
||||||
return !unscheduledOps.contains(ancestor);
|
|
||||||
};
|
|
||||||
|
|
||||||
bool allOpsScheduled = true;
|
bool allOpsScheduled = true;
|
||||||
while (!unscheduledOps.empty()) {
|
while (!unscheduledOps.empty()) {
|
||||||
bool scheduledAtLeastOnce = false;
|
bool scheduledAtLeastOnce = false;
|
||||||
|
@ -56,16 +72,7 @@ bool mlir::sortTopologically(
|
||||||
// set, and "schedule" it (move it before the `nextScheduledOp`).
|
// set, and "schedule" it (move it before the `nextScheduledOp`).
|
||||||
for (Operation &op :
|
for (Operation &op :
|
||||||
llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
|
llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
|
||||||
// An operation is recursively ready to be scheduled of it and its nested
|
if (!isOpReady(block, &op, unscheduledOps, isOperandReady))
|
||||||
// operations are ready.
|
|
||||||
WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
|
|
||||||
return llvm::all_of(
|
|
||||||
nestedOp->getOperands(),
|
|
||||||
[&](Value operand) { return isReady(operand, &op); })
|
|
||||||
? WalkResult::advance()
|
|
||||||
: WalkResult::interrupt();
|
|
||||||
});
|
|
||||||
if (readyToSchedule.wasInterrupted())
|
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Schedule the operation by moving it to the start.
|
// Schedule the operation by moving it to the start.
|
||||||
|
@ -96,3 +103,48 @@ bool mlir::sortTopologically(
|
||||||
isOperandReady);
|
isOperandReady);
|
||||||
return sortTopologically(block, *block, isOperandReady);
|
return sortTopologically(block, *block, isOperandReady);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool mlir::computeTopologicalSorting(
|
||||||
|
Block *block, MutableArrayRef<Operation *> ops,
|
||||||
|
function_ref<bool(Value, Operation *)> isOperandReady) {
|
||||||
|
if (ops.empty())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// The set of operations that have not yet been scheduled.
|
||||||
|
DenseSet<Operation *> unscheduledOps;
|
||||||
|
|
||||||
|
// Mark all operations as unscheduled.
|
||||||
|
for (Operation *op : ops) {
|
||||||
|
assert(op->getBlock() == block && "op must belong to block");
|
||||||
|
unscheduledOps.insert(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned nextScheduledOp = 0;
|
||||||
|
|
||||||
|
bool allOpsScheduled = true;
|
||||||
|
while (!unscheduledOps.empty()) {
|
||||||
|
bool scheduledAtLeastOnce = false;
|
||||||
|
|
||||||
|
// Loop over the ops that are not sorted yet, try to find the ones "ready",
|
||||||
|
// i.e. the ones for which there aren't any operand produced by an op in the
|
||||||
|
// set, and "schedule" it (swap it with the op at `nextScheduledOp`).
|
||||||
|
for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
|
||||||
|
if (!isOpReady(block, ops[i], unscheduledOps, isOperandReady))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Schedule the operation by moving it to the start.
|
||||||
|
unscheduledOps.erase(ops[i]);
|
||||||
|
std::swap(ops[i], ops[nextScheduledOp]);
|
||||||
|
scheduledAtLeastOnce = true;
|
||||||
|
++nextScheduledOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no operations were scheduled, just schedule the first op and continue.
|
||||||
|
if (!scheduledAtLeastOnce) {
|
||||||
|
allOpsScheduled = false;
|
||||||
|
unscheduledOps.erase(ops[nextScheduledOp++]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allOpsScheduled;
|
||||||
|
}
|
||||||
|
|
|
@ -1,27 +1,39 @@
|
||||||
// RUN: mlir-opt -topological-sort %s | FileCheck %s
|
// RUN: mlir-opt -topological-sort %s | FileCheck %s
|
||||||
|
// RUN: mlir-opt -test-topological-sort-analysis %s | FileCheck %s -check-prefix=CHECK-ANALYSIS
|
||||||
|
|
||||||
// Test producer is after user.
|
// Test producer is after user.
|
||||||
// CHECK-LABEL: test.graph_region
|
// CHECK-LABEL: test.graph_region
|
||||||
test.graph_region {
|
// CHECK-ANALYSIS-LABEL: test.graph_region
|
||||||
|
test.graph_region attributes{"root"} {
|
||||||
// CHECK-NEXT: test.foo
|
// CHECK-NEXT: test.foo
|
||||||
// CHECK-NEXT: test.baz
|
// CHECK-NEXT: test.baz
|
||||||
// CHECK-NEXT: test.bar
|
// CHECK-NEXT: test.bar
|
||||||
%0 = "test.foo"() : () -> i32
|
|
||||||
"test.bar"(%1, %0) : (i32, i32) -> ()
|
// CHECK-ANALYSIS-NEXT: test.foo{{.*}} {pos = 0
|
||||||
%1 = "test.baz"() : () -> i32
|
// CHECK-ANALYSIS-NEXT: test.bar{{.*}} {pos = 2
|
||||||
|
// CHECK-ANALYSIS-NEXT: test.baz{{.*}} {pos = 1
|
||||||
|
%0 = "test.foo"() {selected} : () -> i32
|
||||||
|
"test.bar"(%1, %0) {selected} : (i32, i32) -> ()
|
||||||
|
%1 = "test.baz"() {selected} : () -> i32
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test cycles.
|
// Test cycles.
|
||||||
// CHECK-LABEL: test.graph_region
|
// CHECK-LABEL: test.graph_region
|
||||||
test.graph_region {
|
// CHECK-ANALYSIS-LABEL: test.graph_region
|
||||||
|
test.graph_region attributes{"root"} {
|
||||||
// CHECK-NEXT: test.d
|
// CHECK-NEXT: test.d
|
||||||
// CHECK-NEXT: test.a
|
// CHECK-NEXT: test.a
|
||||||
// CHECK-NEXT: test.c
|
// CHECK-NEXT: test.c
|
||||||
// CHECK-NEXT: test.b
|
// CHECK-NEXT: test.b
|
||||||
%2 = "test.c"(%1) : (i32) -> i32
|
|
||||||
|
// CHECK-ANALYSIS-NEXT: test.c{{.*}} {pos = 0
|
||||||
|
// CHECK-ANALYSIS-NEXT: test.b{{.*}} : (
|
||||||
|
// CHECK-ANALYSIS-NEXT: test.a{{.*}} {pos = 2
|
||||||
|
// CHECK-ANALYSIS-NEXT: test.d{{.*}} {pos = 1
|
||||||
|
%2 = "test.c"(%1) {selected} : (i32) -> i32
|
||||||
%1 = "test.b"(%0, %2) : (i32, i32) -> i32
|
%1 = "test.b"(%0, %2) : (i32, i32) -> i32
|
||||||
%0 = "test.a"(%3) : (i32) -> i32
|
%0 = "test.a"(%3) {selected} : (i32) -> i32
|
||||||
%3 = "test.d"() : () -> i32
|
%3 = "test.d"() {selected} : () -> i32
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test block arguments.
|
// Test block arguments.
|
||||||
|
|
|
@ -5,6 +5,7 @@ add_mlir_library(MLIRTestTransforms
|
||||||
TestControlFlowSink.cpp
|
TestControlFlowSink.cpp
|
||||||
TestInlining.cpp
|
TestInlining.cpp
|
||||||
TestIntRangeInference.cpp
|
TestIntRangeInference.cpp
|
||||||
|
TestTopologicalSort.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_LIBMLIR
|
EXCLUDE_FROM_LIBMLIR
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
//===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/TopologicalSortUtils.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct TestTopologicalSortAnalysisPass
|
||||||
|
: public PassWrapper<TestTopologicalSortAnalysisPass,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const final {
|
||||||
|
return "test-topological-sort-analysis";
|
||||||
|
}
|
||||||
|
StringRef getDescription() const final {
|
||||||
|
return "Test topological sorting of ops";
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
Operation *op = getOperation();
|
||||||
|
OpBuilder builder(op->getContext());
|
||||||
|
|
||||||
|
op->walk([&](Operation *root) {
|
||||||
|
if (!root->hasAttr("root"))
|
||||||
|
return WalkResult::advance();
|
||||||
|
|
||||||
|
assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() &&
|
||||||
|
"expected one block");
|
||||||
|
Block *block = &root->getRegion(0).front();
|
||||||
|
SmallVector<Operation *> selectedOps;
|
||||||
|
block->walk([&](Operation *op) {
|
||||||
|
if (op->hasAttr("selected"))
|
||||||
|
selectedOps.push_back(op);
|
||||||
|
});
|
||||||
|
|
||||||
|
computeTopologicalSorting(block, selectedOps);
|
||||||
|
for (const auto &it : llvm::enumerate(selectedOps))
|
||||||
|
it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
|
||||||
|
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace test {
|
||||||
|
void registerTestTopologicalSortAnalysisPass() {
|
||||||
|
PassRegistration<TestTopologicalSortAnalysisPass>();
|
||||||
|
}
|
||||||
|
} // namespace test
|
||||||
|
} // namespace mlir
|
|
@ -111,6 +111,7 @@ void registerTestSCFUtilsPass();
|
||||||
void registerTestSliceAnalysisPass();
|
void registerTestSliceAnalysisPass();
|
||||||
void registerTestTensorTransforms();
|
void registerTestTensorTransforms();
|
||||||
void registerTestTilingInterface();
|
void registerTestTilingInterface();
|
||||||
|
void registerTestTopologicalSortAnalysisPass();
|
||||||
void registerTestTransformDialectInterpreterPass();
|
void registerTestTransformDialectInterpreterPass();
|
||||||
void registerTestVectorLowerings();
|
void registerTestVectorLowerings();
|
||||||
void registerTestNvgpuLowerings();
|
void registerTestNvgpuLowerings();
|
||||||
|
@ -207,6 +208,7 @@ void registerTestPasses() {
|
||||||
mlir::test::registerTestSliceAnalysisPass();
|
mlir::test::registerTestSliceAnalysisPass();
|
||||||
mlir::test::registerTestTensorTransforms();
|
mlir::test::registerTestTensorTransforms();
|
||||||
mlir::test::registerTestTilingInterface();
|
mlir::test::registerTestTilingInterface();
|
||||||
|
mlir::test::registerTestTopologicalSortAnalysisPass();
|
||||||
mlir::test::registerTestTransformDialectInterpreterPass();
|
mlir::test::registerTestTransformDialectInterpreterPass();
|
||||||
mlir::test::registerTestVectorLowerings();
|
mlir::test::registerTestVectorLowerings();
|
||||||
mlir::test::registerTestNvgpuLowerings();
|
mlir::test::registerTestNvgpuLowerings();
|
||||||
|
|
Loading…
Reference in New Issue