forked from OSchip/llvm-project
[mlir][SCF] Add utility to outline the then and else branches of an scf.IfOp
Differential Revision: https://reviews.llvm.org/D85449
This commit is contained in:
parent
aedaa077f5
commit
2a01d7f7b6
|
@ -13,11 +13,15 @@
|
|||
#ifndef MLIR_DIALECT_SCF_UTILS_H_
|
||||
#define MLIR_DIALECT_SCF_UTILS_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
class FuncOp;
|
||||
class OpBuilder;
|
||||
class ValueRange;
|
||||
|
||||
namespace scf {
|
||||
class IfOp;
|
||||
class ForOp;
|
||||
class ParallelOp;
|
||||
} // end namespace scf
|
||||
|
@ -46,5 +50,12 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
|
|||
ValueRange newYieldedValues,
|
||||
bool replaceLoopResults = true);
|
||||
|
||||
/// Outline the then and/or else regions of `ifOp` as follows:
|
||||
/// - if `thenFn` is not null, `thenFnName` must be specified and the `then`
|
||||
/// region is inlined into a new FuncOp that is captured by the pointer.
|
||||
/// - if `elseFn` is not null, `elseFnName` must be specified and the `else`
|
||||
/// region is inlined into a new FuncOp that is captured by the pointer.
|
||||
void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
|
||||
StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
|
||||
} // end namespace mlir
|
||||
#endif // MLIR_DIALECT_SCF_UTILS_H_
|
||||
|
|
|
@ -17,4 +17,5 @@ add_mlir_dialect_library(MLIRSCFTransforms
|
|||
MLIRSCF
|
||||
MLIRStandardOps
|
||||
MLIRSupport
|
||||
)
|
||||
MLIRTransformUtils
|
||||
)
|
||||
|
|
|
@ -13,7 +13,12 @@
|
|||
#include "mlir/Dialect/SCF/Utils.h"
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -71,3 +76,50 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
|
|||
|
||||
return newLoop;
|
||||
}
|
||||
|
||||
void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
|
||||
StringRef thenFnName, FuncOp *elseFn,
|
||||
StringRef elseFnName) {
|
||||
Location loc = ifOp.getLoc();
|
||||
MLIRContext *ctx = ifOp.getContext();
|
||||
auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
|
||||
assert(!funcName.empty() && "Expected function name for outlining");
|
||||
assert(ifOrElseRegion.getBlocks().size() <= 1 &&
|
||||
"Expected at most one block");
|
||||
|
||||
// Outline before current function.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(ifOp.getParentOfType<FuncOp>());
|
||||
|
||||
llvm::SetVector<Value> captures;
|
||||
getUsedValuesDefinedAbove(ifOrElseRegion, captures);
|
||||
|
||||
ValueRange values(captures.getArrayRef());
|
||||
FunctionType type =
|
||||
FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
|
||||
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
|
||||
b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
|
||||
BlockAndValueMapping bvm;
|
||||
for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
|
||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||
for (Operation &op : ifOrElseRegion.front().without_terminator())
|
||||
b.clone(op, bvm);
|
||||
|
||||
Operation *term = ifOrElseRegion.front().getTerminator();
|
||||
SmallVector<Value, 4> terminatorOperands;
|
||||
for (auto op : term->getOperands())
|
||||
terminatorOperands.push_back(bvm.lookup(op));
|
||||
b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
|
||||
|
||||
ifOrElseRegion.front().clear();
|
||||
b.setInsertionPointToEnd(&ifOrElseRegion.front());
|
||||
Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
|
||||
b.create<scf::YieldOp>(loc, call->getResults());
|
||||
return outlinedFunc;
|
||||
};
|
||||
|
||||
if (thenFn && !ifOp.thenRegion().empty())
|
||||
*thenFn = outline(ifOp.thenRegion(), thenFnName);
|
||||
if (elseFn && !ifOp.elseRegion().empty())
|
||||
*elseFn = outline(ifOp.elseRegion(), elseFnName);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-if-utils -split-input-file %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) -> i8 {
|
||||
// CHECK-NEXT: %{{.*}} = "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
|
||||
// CHECK-NEXT: return %{{.*}} : i8
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: func @outlined_else0(%{{.*}}: i8) -> i8 {
|
||||
// CHECK-NEXT: return %{{.*}}0 : i8
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: func @outline_if_else(
|
||||
// CHECK-NEXT: %{{.*}} = scf.if %{{.*}} -> (i8) {
|
||||
// CHECK-NEXT: %{{.*}} = call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
|
||||
// CHECK-NEXT: scf.yield %{{.*}} : i8
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %{{.*}} = call @outlined_else0(%{{.*}}) : (i8) -> i8
|
||||
// CHECK-NEXT: scf.yield %{{.*}} : i8
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
func @outline_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
|
||||
%r = scf.if %cond -> (i8) {
|
||||
%r = "some_op"(%cond, %b) : (i1, memref<?xf32>) -> (i8)
|
||||
scf.yield %r : i8
|
||||
} else {
|
||||
scf.yield %c : i8
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
|
||||
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: func @outline_if(
|
||||
// CHECK-NEXT: scf.if %{{.*}} {
|
||||
// CHECK-NEXT: call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
func @outline_if(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
|
||||
scf.if %cond {
|
||||
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @outlined_then0() {
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
|
||||
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: func @outline_empty_if_else(
|
||||
// CHECK-NEXT: scf.if %{{.*}} {
|
||||
// CHECK-NEXT: call @outlined_then0() : () -> ()
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
func @outline_empty_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
|
||||
scf.if %cond {
|
||||
} else {
|
||||
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s
|
||||
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @hoist
|
||||
// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
|
|
@ -21,9 +21,10 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
|
||||
class TestSCFForUtilsPass
|
||||
: public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
|
||||
public:
|
||||
explicit TestSCFUtilsPass() {}
|
||||
explicit TestSCFForUtilsPass() {}
|
||||
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
|
@ -49,10 +50,31 @@ public:
|
|||
loop.erase();
|
||||
}
|
||||
};
|
||||
|
||||
class TestSCFIfUtilsPass
|
||||
: public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
|
||||
public:
|
||||
explicit TestSCFIfUtilsPass() {}
|
||||
|
||||
void runOnFunction() override {
|
||||
int count = 0;
|
||||
FuncOp func = getFunction();
|
||||
func.walk([&](scf::IfOp ifOp) {
|
||||
auto strCount = std::to_string(count++);
|
||||
FuncOp thenFn, elseFn;
|
||||
OpBuilder b(ifOp);
|
||||
outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
|
||||
&elseFn, std::string("outlined_else") + strCount);
|
||||
});
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
namespace mlir {
|
||||
void registerTestSCFUtilsPass() {
|
||||
PassRegistration<TestSCFUtilsPass>("test-scf-utils", "test scf utils");
|
||||
PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
|
||||
"test scf.for utils");
|
||||
PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
|
||||
"test scf.if utils");
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue