forked from OSchip/llvm-project
[mlir] LoopToStandard conversion: support "if/else" with results
Summary: A recent extension allowed the `loop.if` operation to return results yielded by its regions. However, such operations could not be lowered to a CFG of standard operations because it would have required to modify the argument list of a block, which is not allowed in a conversion pattern. Now that the conversion infrastructure supports block creation, use it to create a block with an argument list that dominates the operations following the `loop.if` and forward the results as arguments of this block. Depends On D77416 Differential Revision: https://reviews.llvm.org/D77418
This commit is contained in:
parent
b7397e81fe
commit
340e1b2077
|
@ -112,13 +112,21 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
// blocks are respectively the first/last block of the enclosing region. The
|
||||
// operations following the loop.if are split into a continuation (subgraph
|
||||
// exit) block. The condition is lowered to a chain of blocks that implement the
|
||||
// short-circuit scheme. Condition blocks are created by splitting out an empty
|
||||
// block from the block that contains the loop.if operation. They
|
||||
// conditionally branch to either the first block of the "then" region, or to
|
||||
// the first block of the "else" region. If the latter is absent, they branch
|
||||
// to the continuation block instead. The last blocks of "then" and "else"
|
||||
// regions (which are known to be exit blocks thanks to the invariant we
|
||||
// maintain).
|
||||
// short-circuit scheme. The "loop.if" operation is replaced with a conditional
|
||||
// branch to either the first block of the "then" region, or to the first block
|
||||
// of the "else" region. In these blocks, "loop.yield" is unconditional branches
|
||||
// to the post-dominating block. When the "loop.if" does not return values, the
|
||||
// post-dominating block is the same as the continuation block. When it returns
|
||||
// values, the post-dominating block is a new block with arguments that
|
||||
// correspond to the values returned by the "loop.if" that unconditionally
|
||||
// branches to the continuation block. This allows block arguments to dominate
|
||||
// any uses of the hitherto "loop.if" results that they replaced. (Inserting a
|
||||
// new block allows us to avoid modifying the argument list of an existing
|
||||
// block, which is illegal in a conversion pattern). When the "else" region is
|
||||
// empty, which is only allowed for "loop.if"s that don't return values, the
|
||||
// condition branches directly to the continuation block.
|
||||
//
|
||||
// CFG for a loop.if with else and without results.
|
||||
//
|
||||
// +--------------------------------+
|
||||
// | <code before the IfOp> |
|
||||
|
@ -148,6 +156,42 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
// | <code after the IfOp> |
|
||||
// +--------------------------------+
|
||||
//
|
||||
// CFG for a loop.if with results.
|
||||
//
|
||||
// +--------------------------------+
|
||||
// | <code before the IfOp> |
|
||||
// | cond_br %cond, %then, %else |
|
||||
// +--------------------------------+
|
||||
// | |
|
||||
// | --------------|
|
||||
// v |
|
||||
// +--------------------------------+ |
|
||||
// | then: | |
|
||||
// | <then contents> | |
|
||||
// | br dom(%args...) | |
|
||||
// +--------------------------------+ |
|
||||
// | |
|
||||
// |---------- |-------------
|
||||
// | V
|
||||
// | +--------------------------------+
|
||||
// | | else: |
|
||||
// | | <else contents> |
|
||||
// | | br dom(%args...) |
|
||||
// | +--------------------------------+
|
||||
// | |
|
||||
// ------| |
|
||||
// v v
|
||||
// +--------------------------------+
|
||||
// | dom(%args...): |
|
||||
// | br continue |
|
||||
// +--------------------------------+
|
||||
// |
|
||||
// v
|
||||
// +--------------------------------+
|
||||
// | continue: |
|
||||
// | <code after the IfOp> |
|
||||
// +--------------------------------+
|
||||
//
|
||||
struct IfLowering : public OpRewritePattern<IfOp> {
|
||||
using OpRewritePattern<IfOp>::OpRewritePattern;
|
||||
|
||||
|
@ -238,15 +282,25 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
|||
// continuation point.
|
||||
auto *condBlock = rewriter.getInsertionBlock();
|
||||
auto opPosition = rewriter.getInsertionPoint();
|
||||
auto *continueBlock = rewriter.splitBlock(condBlock, opPosition);
|
||||
auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
|
||||
Block *continueBlock;
|
||||
if (ifOp.getNumResults() == 0) {
|
||||
continueBlock = remainingOpsBlock;
|
||||
} else {
|
||||
continueBlock =
|
||||
rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
|
||||
rewriter.create<BranchOp>(loc, remainingOpsBlock);
|
||||
}
|
||||
|
||||
// Move blocks from the "then" region to the region containing 'loop.if',
|
||||
// place it before the continuation block, and branch to it.
|
||||
auto &thenRegion = ifOp.thenRegion();
|
||||
auto *thenBlock = &thenRegion.front();
|
||||
rewriter.eraseOp(thenRegion.back().getTerminator());
|
||||
Operation *thenTerminator = thenRegion.back().getTerminator();
|
||||
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
|
||||
rewriter.setInsertionPointToEnd(&thenRegion.back());
|
||||
rewriter.create<BranchOp>(loc, continueBlock);
|
||||
rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
|
||||
rewriter.eraseOp(thenTerminator);
|
||||
rewriter.inlineRegionBefore(thenRegion, continueBlock);
|
||||
|
||||
// Move blocks from the "else" region (if present) to the region containing
|
||||
|
@ -256,9 +310,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
|||
auto &elseRegion = ifOp.elseRegion();
|
||||
if (!elseRegion.empty()) {
|
||||
elseBlock = &elseRegion.front();
|
||||
rewriter.eraseOp(elseRegion.back().getTerminator());
|
||||
Operation *elseTerminator = elseRegion.back().getTerminator();
|
||||
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
|
||||
rewriter.setInsertionPointToEnd(&elseRegion.back());
|
||||
rewriter.create<BranchOp>(loc, continueBlock);
|
||||
rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
|
||||
rewriter.eraseOp(elseTerminator);
|
||||
rewriter.inlineRegionBefore(elseRegion, continueBlock);
|
||||
}
|
||||
|
||||
|
@ -268,7 +324,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
|||
/*falseArgs=*/ArrayRef<Value>());
|
||||
|
||||
// Ok, we're done!
|
||||
rewriter.eraseOp(ifOp);
|
||||
rewriter.replaceOp(ifOp, continueBlock->getArguments());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -148,6 +148,83 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_if_yield
|
||||
func @simple_if_yield(%arg0: i1) -> (i1, i1) {
|
||||
// CHECK: cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]]
|
||||
%0:2 = loop.if %arg0 -> (i1, i1) {
|
||||
// CHECK: ^[[then]]:
|
||||
// CHECK: %[[v0:.*]] = constant 0
|
||||
// CHECK: %[[v1:.*]] = constant 1
|
||||
// CHECK: br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1)
|
||||
%c0 = constant 0 : i1
|
||||
%c1 = constant 1 : i1
|
||||
loop.yield %c0, %c1 : i1, i1
|
||||
} else {
|
||||
// CHECK: ^[[else]]:
|
||||
// CHECK: %[[v2:.*]] = constant 0
|
||||
// CHECK: %[[v3:.*]] = constant 1
|
||||
// CHECK: br ^[[dom]](%[[v3]], %[[v2]] : i1, i1)
|
||||
%c0 = constant 0 : i1
|
||||
%c1 = constant 1 : i1
|
||||
loop.yield %c1, %c0 : i1, i1
|
||||
}
|
||||
// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1):
|
||||
// CHECK: br ^[[cont:.*]]
|
||||
// CHECK: ^[[cont]]:
|
||||
// CHECK: return %[[arg1]], %[[arg2]]
|
||||
return %0#0, %0#1 : i1, i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @nested_if_yield
|
||||
func @nested_if_yield(%arg0: i1) -> (index) {
|
||||
// CHECK: cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]]
|
||||
%0 = loop.if %arg0 -> i1 {
|
||||
// CHECK: ^[[first_then]]:
|
||||
%1 = constant 1 : i1
|
||||
// CHECK: br ^[[first_dom:.*]]({{.*}})
|
||||
loop.yield %1 : i1
|
||||
} else {
|
||||
// CHECK: ^[[first_else]]:
|
||||
%2 = constant 0 : i1
|
||||
// CHECK: br ^[[first_dom]]({{.*}})
|
||||
loop.yield %2 : i1
|
||||
}
|
||||
// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1):
|
||||
// CHECK: br ^[[first_cont:.*]]
|
||||
// CHECK: ^[[first_cont]]:
|
||||
// CHECK: cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]]
|
||||
%1 = loop.if %0 -> index {
|
||||
// CHECK: ^[[second_outer_then]]:
|
||||
// CHECK: cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]]
|
||||
%3 = loop.if %arg0 -> index {
|
||||
// CHECK: ^[[second_inner_then]]:
|
||||
%4 = constant 40 : index
|
||||
// CHECK: br ^[[second_inner_dom:.*]]({{.*}})
|
||||
loop.yield %4 : index
|
||||
} else {
|
||||
// CHECK: ^[[second_inner_else]]:
|
||||
%5 = constant 41 : index
|
||||
// CHECK: br ^[[second_inner_dom]]({{.*}})
|
||||
loop.yield %5 : index
|
||||
}
|
||||
// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index):
|
||||
// CHECK: br ^[[second_inner_cont:.*]]
|
||||
// CHECK: ^[[second_inner_cont]]:
|
||||
// CHECK: br ^[[second_outer_dom:.*]]({{.*}})
|
||||
loop.yield %3 : index
|
||||
} else {
|
||||
// CHECK: ^[[second_outer_else]]:
|
||||
%6 = constant 42 : index
|
||||
// CHECK: br ^[[second_outer_dom]]({{.*}}
|
||||
loop.yield %6 : index
|
||||
}
|
||||
// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index):
|
||||
// CHECK: br ^[[second_outer_cont:.*]]
|
||||
// CHECK: ^[[second_outer_cont]]:
|
||||
// CHECK: return %[[arg3]] : index
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @parallel_loop(
|
||||
// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
|
||||
// CHECK: [[VAL_5:%.*]] = constant 1 : index
|
||||
|
|
Loading…
Reference in New Issue