[mlir][SCF] Add utility to outline the then and else branches of an scf.IfOp

Differential Revision: https://reviews.llvm.org/D85449
This commit is contained in:
Nicolas Vasilache 2020-08-07 14:35:33 -04:00
parent aedaa077f5
commit 2a01d7f7b6
6 changed files with 166 additions and 5 deletions

View File

@ -13,11 +13,15 @@
#ifndef MLIR_DIALECT_SCF_UTILS_H_
#define MLIR_DIALECT_SCF_UTILS_H_
#include "mlir/Support/LLVM.h"
namespace mlir {
class FuncOp;
class OpBuilder;
class ValueRange;
namespace scf {
class IfOp;
class ForOp;
class ParallelOp;
} // end namespace scf
@ -46,5 +50,12 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
ValueRange newYieldedValues,
bool replaceLoopResults = true);
/// Outline the then and/or else regions of `ifOp` as follows:
/// - if `thenFn` is not null, `thenFnName` must be specified and the `then`
/// region is inlined into a new FuncOp that is captured by the pointer.
/// - if `elseFn` is not null, `elseFnName` must be specified and the `else`
/// region is inlined into a new FuncOp that is captured by the pointer.
void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
} // end namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_H_

View File

@ -17,4 +17,5 @@ add_mlir_dialect_library(MLIRSCFTransforms
MLIRSCF
MLIRStandardOps
MLIRSupport
)
MLIRTransformUtils
)

View File

@ -13,7 +13,12 @@
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
@ -71,3 +76,50 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
return newLoop;
}
void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
StringRef thenFnName, FuncOp *elseFn,
StringRef elseFnName) {
Location loc = ifOp.getLoc();
MLIRContext *ctx = ifOp.getContext();
auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
assert(!funcName.empty() && "Expected function name for outlining");
assert(ifOrElseRegion.getBlocks().size() <= 1 &&
"Expected at most one block");
// Outline before current function.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(ifOp.getParentOfType<FuncOp>());
llvm::SetVector<Value> captures;
getUsedValuesDefinedAbove(ifOrElseRegion, captures);
ValueRange values(captures.getArrayRef());
FunctionType type =
FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
BlockAndValueMapping bvm;
for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
bvm.map(std::get<0>(it), std::get<1>(it));
for (Operation &op : ifOrElseRegion.front().without_terminator())
b.clone(op, bvm);
Operation *term = ifOrElseRegion.front().getTerminator();
SmallVector<Value, 4> terminatorOperands;
for (auto op : term->getOperands())
terminatorOperands.push_back(bvm.lookup(op));
b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
ifOrElseRegion.front().clear();
b.setInsertionPointToEnd(&ifOrElseRegion.front());
Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
b.create<scf::YieldOp>(loc, call->getResults());
return outlinedFunc;
};
if (thenFn && !ifOp.thenRegion().empty())
*thenFn = outline(ifOp.thenRegion(), thenFnName);
if (elseFn && !ifOp.elseRegion().empty())
*elseFn = outline(ifOp.elseRegion(), elseFnName);
}

View File

@ -0,0 +1,75 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-if-utils -split-input-file %s | FileCheck %s
// -----
// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) -> i8 {
// CHECK-NEXT: %{{.*}} = "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
// CHECK-NEXT: return %{{.*}} : i8
// CHECK-NEXT: }
// CHECK: func @outlined_else0(%{{.*}}: i8) -> i8 {
// CHECK-NEXT: return %{{.*}}0 : i8
// CHECK-NEXT: }
// CHECK: func @outline_if_else(
// CHECK-NEXT: %{{.*}} = scf.if %{{.*}} -> (i8) {
// CHECK-NEXT: %{{.*}} = call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
// CHECK-NEXT: scf.yield %{{.*}} : i8
// CHECK-NEXT: } else {
// CHECK-NEXT: %{{.*}} = call @outlined_else0(%{{.*}}) : (i8) -> i8
// CHECK-NEXT: scf.yield %{{.*}} : i8
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
func @outline_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
%r = scf.if %cond -> (i8) {
%r = "some_op"(%cond, %b) : (i1, memref<?xf32>) -> (i8)
scf.yield %r : i8
} else {
scf.yield %c : i8
}
return
}
// -----
// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func @outline_if(
// CHECK-NEXT: scf.if %{{.*}} {
// CHECK-NEXT: call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
func @outline_if(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
scf.if %cond {
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
scf.yield
}
return
}
// -----
// CHECK: func @outlined_then0() {
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func @outline_empty_if_else(
// CHECK-NEXT: scf.if %{{.*}} {
// CHECK-NEXT: call @outlined_then0() : () -> ()
// CHECK-NEXT: } else {
// CHECK-NEXT: call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
func @outline_empty_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
scf.if %cond {
} else {
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
}
return
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
// CHECK-LABEL: @hoist
// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,

View File

@ -21,9 +21,10 @@
using namespace mlir;
namespace {
class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
class TestSCFForUtilsPass
: public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
public:
explicit TestSCFUtilsPass() {}
explicit TestSCFForUtilsPass() {}
void runOnFunction() override {
FuncOp func = getFunction();
@ -49,10 +50,31 @@ public:
loop.erase();
}
};
class TestSCFIfUtilsPass
: public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
public:
explicit TestSCFIfUtilsPass() {}
void runOnFunction() override {
int count = 0;
FuncOp func = getFunction();
func.walk([&](scf::IfOp ifOp) {
auto strCount = std::to_string(count++);
FuncOp thenFn, elseFn;
OpBuilder b(ifOp);
outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
&elseFn, std::string("outlined_else") + strCount);
});
}
};
} // end namespace
namespace mlir {
void registerTestSCFUtilsPass() {
PassRegistration<TestSCFUtilsPass>("test-scf-utils", "test scf utils");
PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
"test scf.for utils");
PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
"test scf.if utils");
}
} // namespace mlir