forked from OSchip/llvm-project
Revert "[MLIR][SCF] Inline ExecuteRegion if parent can contain multiple blocks"
This reverts commit 5d6240b77e
.
The commit was mistakenly landed without a PR approval, this will be
reverted now and resubmitted.
This commit is contained in:
parent
18c3c77849
commit
2ab27758d5
|
@ -108,8 +108,14 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> {
|
|||
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
||||
// TODO: If the parent is a func like op (which would be the case if all other
|
||||
// ops are from the std dialect), the inliner logic could be readily used to
|
||||
// inline.
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// TODO: can fold if it returns a constant.
|
||||
// TODO: Single block execute_region ops can be readily inlined irrespective
|
||||
// of which op is a parent. Add a fold for this.
|
||||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -143,94 +143,23 @@ static LogicalResult verify(ExecuteRegionOp op) {
|
|||
//
|
||||
// "test.foo"() : () -> ()
|
||||
// %x = "test.val"() : () -> i64
|
||||
// "test.bar"(%x) : (i64) -> ()
|
||||
// "test.bar"(%v) : (i64) -> ()
|
||||
//
|
||||
struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
|
||||
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExecuteRegionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!llvm::hasSingleElement(op.region()))
|
||||
if (op.region().getBlocks().size() != 1)
|
||||
return failure();
|
||||
replaceOpWithRegion(rewriter, op, op.region());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
|
||||
// TODO generalize the conditions for operations which can be inlined into.
|
||||
// func @func_execute_region_elim() {
|
||||
// "test.foo"() : () -> ()
|
||||
// %v = scf.execute_region -> i64 {
|
||||
// %c = "test.cmp"() : () -> i1
|
||||
// cond_br %c, ^bb2, ^bb3
|
||||
// ^bb2:
|
||||
// %x = "test.val1"() : () -> i64
|
||||
// br ^bb4(%x : i64)
|
||||
// ^bb3:
|
||||
// %y = "test.val2"() : () -> i64
|
||||
// br ^bb4(%y : i64)
|
||||
// ^bb4(%z : i64):
|
||||
// scf.yield %z : i64
|
||||
// }
|
||||
// "test.bar"(%v) : (i64) -> ()
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// becomes
|
||||
//
|
||||
// func @func_execute_region_elim() {
|
||||
// "test.foo"() : () -> ()
|
||||
// %c = "test.cmp"() : () -> i1
|
||||
// cond_br %c, ^bb1, ^bb2
|
||||
// ^bb1: // pred: ^bb0
|
||||
// %x = "test.val1"() : () -> i64
|
||||
// br ^bb3(%x : i64)
|
||||
// ^bb2: // pred: ^bb0
|
||||
// %y = "test.val2"() : () -> i64
|
||||
// br ^bb3(%y : i64)
|
||||
// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
|
||||
// "test.bar"(%z) : (i64) -> ()
|
||||
// return
|
||||
// }
|
||||
//
|
||||
struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
|
||||
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExecuteRegionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isa<FuncOp, ExecuteRegionOp>(op->getParentOp()))
|
||||
return failure();
|
||||
|
||||
Block *prevBlock = op->getBlock();
|
||||
Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToEnd(prevBlock);
|
||||
|
||||
rewriter.create<BranchOp>(op.getLoc(), &op.region().front());
|
||||
|
||||
for (Block &blk : op.region()) {
|
||||
if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock,
|
||||
yieldOp.results());
|
||||
rewriter.eraseOp(yieldOp);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.inlineRegionBefore(op.region(), postBlock);
|
||||
SmallVector<Value> blockArgs;
|
||||
|
||||
for (auto res : op.getResults())
|
||||
blockArgs.push_back(postBlock->addArgument(res.getType()));
|
||||
|
||||
rewriter.replaceOp(op, blockArgs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
|
||||
results.add<SingleBlockExecuteInliner>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -948,70 +948,3 @@ func @execute_region_elim() {
|
|||
// CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> ()
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_execute_region_elim
|
||||
func @func_execute_region_elim() {
|
||||
"test.foo"() : () -> ()
|
||||
%v = scf.execute_region -> i64 {
|
||||
%c = "test.cmp"() : () -> i1
|
||||
cond_br %c, ^bb2, ^bb3
|
||||
^bb2:
|
||||
%x = "test.val1"() : () -> i64
|
||||
br ^bb4(%x : i64)
|
||||
^bb3:
|
||||
%y = "test.val2"() : () -> i64
|
||||
br ^bb4(%y : i64)
|
||||
^bb4(%z : i64):
|
||||
scf.yield %z : i64
|
||||
}
|
||||
"test.bar"(%v) : (i64) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: "test.foo"
|
||||
// CHECK: %[[cmp:.+]] = "test.cmp"
|
||||
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
|
||||
// CHECK: ^[[bb1]]: // pred: ^bb0
|
||||
// CHECK: %[[x:.+]] = "test.val1"
|
||||
// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
|
||||
// CHECK: ^[[bb2]]: // pred: ^bb0
|
||||
// CHECK: %[[y:.+]] = "test.val2"
|
||||
// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
|
||||
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||
// CHECK: "test.bar"(%[[z]])
|
||||
// CHECK: return
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_execute_region_elim2
|
||||
func @func_execute_region_elim2() {
|
||||
"test.foo"() : () -> ()
|
||||
%v = scf.execute_region -> i64 {
|
||||
%c = "test.cmp"() : () -> i1
|
||||
cond_br %c, ^bb2, ^bb3
|
||||
^bb2:
|
||||
%x = "test.val1"() : () -> i64
|
||||
scf.yield %x : i64
|
||||
^bb3:
|
||||
%y = "test.val2"() : () -> i64
|
||||
scf.yield %y : i64
|
||||
}
|
||||
"test.bar"(%v) : (i64) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: "test.foo"
|
||||
// CHECK: %[[cmp:.+]] = "test.cmp"
|
||||
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
|
||||
// CHECK: ^[[bb1]]: // pred: ^bb0
|
||||
// CHECK: %[[x:.+]] = "test.val1"
|
||||
// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
|
||||
// CHECK: ^[[bb2]]: // pred: ^bb0
|
||||
// CHECK: %[[y:.+]] = "test.val2"
|
||||
// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
|
||||
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||
// CHECK: "test.bar"(%[[z]])
|
||||
// CHECK: return
|
||||
|
|
Loading…
Reference in New Issue