forked from OSchip/llvm-project
[mlir][SCF] Add utility to outline the then and else branches of an scf.IfOp
Differential Revision: https://reviews.llvm.org/D85449
This commit is contained in:
parent
aedaa077f5
commit
2a01d7f7b6
mlir
include/mlir/Dialect/SCF
lib/Dialect/SCF/Transforms
test
|
@ -13,11 +13,15 @@
|
||||||
#ifndef MLIR_DIALECT_SCF_UTILS_H_
|
#ifndef MLIR_DIALECT_SCF_UTILS_H_
|
||||||
#define MLIR_DIALECT_SCF_UTILS_H_
|
#define MLIR_DIALECT_SCF_UTILS_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class FuncOp;
|
||||||
class OpBuilder;
|
class OpBuilder;
|
||||||
class ValueRange;
|
class ValueRange;
|
||||||
|
|
||||||
namespace scf {
|
namespace scf {
|
||||||
|
class IfOp;
|
||||||
class ForOp;
|
class ForOp;
|
||||||
class ParallelOp;
|
class ParallelOp;
|
||||||
} // end namespace scf
|
} // end namespace scf
|
||||||
|
@ -46,5 +50,12 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
|
||||||
ValueRange newYieldedValues,
|
ValueRange newYieldedValues,
|
||||||
bool replaceLoopResults = true);
|
bool replaceLoopResults = true);
|
||||||
|
|
||||||
|
/// Outline the then and/or else regions of `ifOp` as follows:
|
||||||
|
/// - if `thenFn` is not null, `thenFnName` must be specified and the `then`
|
||||||
|
/// region is inlined into a new FuncOp that is captured by the pointer.
|
||||||
|
/// - if `elseFn` is not null, `elseFnName` must be specified and the `else`
|
||||||
|
/// region is inlined into a new FuncOp that is captured by the pointer.
|
||||||
|
void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
|
||||||
|
StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
#endif // MLIR_DIALECT_SCF_UTILS_H_
|
#endif // MLIR_DIALECT_SCF_UTILS_H_
|
||||||
|
|
|
@ -17,4 +17,5 @@ add_mlir_dialect_library(MLIRSCFTransforms
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
MLIRStandardOps
|
MLIRStandardOps
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
|
MLIRTransformUtils
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,12 @@
|
||||||
#include "mlir/Dialect/SCF/Utils.h"
|
#include "mlir/Dialect/SCF/Utils.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/Function.h"
|
||||||
|
#include "mlir/Transforms/RegionUtils.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -71,3 +76,50 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
|
||||||
|
|
||||||
return newLoop;
|
return newLoop;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
|
||||||
|
StringRef thenFnName, FuncOp *elseFn,
|
||||||
|
StringRef elseFnName) {
|
||||||
|
Location loc = ifOp.getLoc();
|
||||||
|
MLIRContext *ctx = ifOp.getContext();
|
||||||
|
auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
|
||||||
|
assert(!funcName.empty() && "Expected function name for outlining");
|
||||||
|
assert(ifOrElseRegion.getBlocks().size() <= 1 &&
|
||||||
|
"Expected at most one block");
|
||||||
|
|
||||||
|
// Outline before current function.
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
b.setInsertionPoint(ifOp.getParentOfType<FuncOp>());
|
||||||
|
|
||||||
|
llvm::SetVector<Value> captures;
|
||||||
|
getUsedValuesDefinedAbove(ifOrElseRegion, captures);
|
||||||
|
|
||||||
|
ValueRange values(captures.getArrayRef());
|
||||||
|
FunctionType type =
|
||||||
|
FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
|
||||||
|
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
|
||||||
|
b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
|
||||||
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||||
|
for (Operation &op : ifOrElseRegion.front().without_terminator())
|
||||||
|
b.clone(op, bvm);
|
||||||
|
|
||||||
|
Operation *term = ifOrElseRegion.front().getTerminator();
|
||||||
|
SmallVector<Value, 4> terminatorOperands;
|
||||||
|
for (auto op : term->getOperands())
|
||||||
|
terminatorOperands.push_back(bvm.lookup(op));
|
||||||
|
b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
|
||||||
|
|
||||||
|
ifOrElseRegion.front().clear();
|
||||||
|
b.setInsertionPointToEnd(&ifOrElseRegion.front());
|
||||||
|
Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
|
||||||
|
b.create<scf::YieldOp>(loc, call->getResults());
|
||||||
|
return outlinedFunc;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (thenFn && !ifOp.thenRegion().empty())
|
||||||
|
*thenFn = outline(ifOp.thenRegion(), thenFnName);
|
||||||
|
if (elseFn && !ifOp.elseRegion().empty())
|
||||||
|
*elseFn = outline(ifOp.elseRegion(), elseFnName);
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-if-utils -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) -> i8 {
|
||||||
|
// CHECK-NEXT: %{{.*}} = "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
|
||||||
|
// CHECK-NEXT: return %{{.*}} : i8
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK: func @outlined_else0(%{{.*}}: i8) -> i8 {
|
||||||
|
// CHECK-NEXT: return %{{.*}}0 : i8
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK: func @outline_if_else(
|
||||||
|
// CHECK-NEXT: %{{.*}} = scf.if %{{.*}} -> (i8) {
|
||||||
|
// CHECK-NEXT: %{{.*}} = call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
|
||||||
|
// CHECK-NEXT: scf.yield %{{.*}} : i8
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %{{.*}} = call @outlined_else0(%{{.*}}) : (i8) -> i8
|
||||||
|
// CHECK-NEXT: scf.yield %{{.*}} : i8
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
func @outline_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
|
||||||
|
%r = scf.if %cond -> (i8) {
|
||||||
|
%r = "some_op"(%cond, %b) : (i1, memref<?xf32>) -> (i8)
|
||||||
|
scf.yield %r : i8
|
||||||
|
} else {
|
||||||
|
scf.yield %c : i8
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
|
||||||
|
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK: func @outline_if(
|
||||||
|
// CHECK-NEXT: scf.if %{{.*}} {
|
||||||
|
// CHECK-NEXT: call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
func @outline_if(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
|
||||||
|
scf.if %cond {
|
||||||
|
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: func @outlined_then0() {
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
|
||||||
|
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK: func @outline_empty_if_else(
|
||||||
|
// CHECK-NEXT: scf.if %{{.*}} {
|
||||||
|
// CHECK-NEXT: call @outlined_then0() : () -> ()
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
func @outline_empty_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
|
||||||
|
scf.if %cond {
|
||||||
|
} else {
|
||||||
|
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s
|
// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @hoist
|
// CHECK-LABEL: @hoist
|
||||||
// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
|
// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
|
|
@ -21,9 +21,10 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
|
class TestSCFForUtilsPass
|
||||||
|
: public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
explicit TestSCFUtilsPass() {}
|
explicit TestSCFForUtilsPass() {}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
FuncOp func = getFunction();
|
FuncOp func = getFunction();
|
||||||
|
@ -49,10 +50,31 @@ public:
|
||||||
loop.erase();
|
loop.erase();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TestSCFIfUtilsPass
|
||||||
|
: public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
|
||||||
|
public:
|
||||||
|
explicit TestSCFIfUtilsPass() {}
|
||||||
|
|
||||||
|
void runOnFunction() override {
|
||||||
|
int count = 0;
|
||||||
|
FuncOp func = getFunction();
|
||||||
|
func.walk([&](scf::IfOp ifOp) {
|
||||||
|
auto strCount = std::to_string(count++);
|
||||||
|
FuncOp thenFn, elseFn;
|
||||||
|
OpBuilder b(ifOp);
|
||||||
|
outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
|
||||||
|
&elseFn, std::string("outlined_else") + strCount);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
void registerTestSCFUtilsPass() {
|
void registerTestSCFUtilsPass() {
|
||||||
PassRegistration<TestSCFUtilsPass>("test-scf-utils", "test scf utils");
|
PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
|
||||||
|
"test scf.for utils");
|
||||||
|
PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
|
||||||
|
"test scf.if utils");
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
Loading…
Reference in New Issue