[mlir] Fix ControlFlowInterfaces implementation for Async dialect

* Add `RegionBranchTerminatorOpInterface` to `YieldOp`.
* Implement `getSuccessorEntryOperands` in `ExecuteOp`.
* Fix `getSuccessorRegions` implementation in `ExecuteOp`.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D108373
This commit is contained in:
Vladislav Vinogradov 2021-08-19 16:28:16 +03:00
parent 119146f8ae
commit 9775c0c9f0
2 changed files with 21 additions and 8 deletions

View File

@ -29,7 +29,8 @@ class Async_Op<string mnemonic, list<OpTrait> traits = []> :
def Async_ExecuteOp :
Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getNumRegionInvocations"]>,
["getSuccessorEntryOperands",
"getNumRegionInvocations"]>,
AttrSizedOperandSegments]> {
let summary = "Asynchronous execute operation";
let description = [{
@ -99,7 +100,9 @@ def Async_ExecuteOp :
}
def Async_YieldOp :
Async_Op<"yield", [HasParent<"ExecuteOp">, NoSideEffect, Terminator]> {
Async_Op<"yield", [
HasParent<"ExecuteOp">, NoSideEffect, Terminator,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
let summary = "terminator for Async execute operation";
let description = [{
The `async.yield` is a special terminator operation for the block inside

View File

@ -48,6 +48,12 @@ static LogicalResult verify(YieldOp op) {
return success();
}
MutableOperandRange
YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
assert(!index.hasValue());
return operandsMutable();
}
//===----------------------------------------------------------------------===//
/// ExecuteOp
//===----------------------------------------------------------------------===//
@ -55,24 +61,28 @@ static LogicalResult verify(YieldOp op) {
constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
void ExecuteOp::getNumRegionInvocations(
ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
(void)operands;
ArrayRef<Attribute>, SmallVectorImpl<int64_t> &countPerRegion) {
assert(countPerRegion.empty());
countPerRegion.push_back(1);
}
OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "invalid region index");
return operands();
}
void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
ArrayRef<Attribute> operands,
ArrayRef<Attribute>,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `body` region branch back to the parent operation.
if (index.hasValue()) {
assert(*index == 0);
regions.push_back(RegionSuccessor(getResults()));
assert(*index == 0 && "invalid region index");
regions.push_back(RegionSuccessor(results()));
return;
}
// Otherwise the successor is the body region.
regions.push_back(RegionSuccessor(&body()));
regions.push_back(RegionSuccessor(&body(), body().getArguments()));
}
void ExecuteOp::build(OpBuilder &builder, OperationState &result,