[mlir] Add a generic while/do-while loop to the SCF dialect

The new construct represents a generic loop with two regions: one executed
before the loop condition is verifier and another after that. This construct
can be used to express both a "while" loop and a "do-while" loop, depending on
where the main payload is located. It is intended as an intermediate
abstraction for lowering, which will be added later. This form is relatively
easy to target from higher-level abstractions and supports transformations such
as loop rotation and LICM.

Differential Revision: https://reviews.llvm.org/D90255
This commit is contained in:
Alex Zinenko 2020-11-04 09:41:55 +01:00
parent 3bec07f91f
commit 79716559b5
7 changed files with 484 additions and 38 deletions

View File

@ -36,6 +36,25 @@ class SCF_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def ConditionOp : SCF_Op<"condition",
[HasParent<"WhileOp">, NoSideEffect, Terminator]> {
let summary = "loop continuation condition";
let description = [{
This operation accepts the continuation (i.e., inverse of exit) condition
of the `scf.while` construct. If its first argument is true, the "after"
region of `scf.while` is executed, with the remaining arguments forwarded
to the entry block of the region. Otherwise, the loop terminates.
}];
let arguments = (ins I1:$condition, Variadic<AnyType>:$args);
let assemblyFormat =
[{ `(` $condition `)` attr-dict ($args^ `:` type($args))? }];
// Override the default verifier, everything is checked by traits.
let verifier = ?;
}
def ForOp : SCF_Op<"for",
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@ -413,8 +432,135 @@ def ReduceReturnOp :
let assemblyFormat = "$result attr-dict `:` type($result)";
}
def WhileOp : SCF_Op<"while",
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
RecursiveSideEffects]> {
let summary = "a generic 'while' loop";
let description = [{
This operation represents a generic "while"/"do-while" loop that keeps
iterating as long as a condition is satisfied. There is no restriction on
the complexity of the condition. It consists of two regions (with single
block each): "before" region and "after" region. The names of regions
indicates whether they execute before or after the condition check.
Therefore, if the main loop payload is located in the "before" region, the
operation is a "do-while" loop. Otherwise, it is a "while" loop.
The "before" region terminates with a special operation, `scf.condition`,
that accepts as its first operand an `i1` value indicating whether to
proceed to the "after" region (value is `true`) or not. The two regions
communicate by means of region arguments. Initially, the "before" region
accepts as arguments the operands of the `scf.while` operation and uses them
to evaluate the condition. It forwards the trailing, non-condition operands
of the `scf.condition` terminator either to the "after" region if the
control flow is transferred there or to results of the `scf.while` operation
otherwise. The "after" region takes as arguments the values produced by the
"before" region and uses `scf.yield` to supply new arguments for the "after"
region, into which it transfers the control flow unconditionally.
A simple "while" loop can be represented as follows.
```mlir
%res = scf.while (%arg1 = %init1) : (f32) -> f32 {
/* "Before" region.
* In a "while" loop, this region computes the condition. */
%condition = call @evaluate_condition(%arg1) : (f32) -> i1
/* Forward the argument (as result or "after" region argument). */
scf.condition(%condition) %arg1 : f32
} do {
^bb0(%arg2: f32):
/* "After region.
* In a "while" loop, this region is the loop body. */
%next = call @payload(%arg2) : (f32) -> f32
/* Forward the new value to the "before" region.
* The operand types must match the types of the `scf.while` operands. */
scf.yield %next : f32
}
```
A simple "do-while" loop can be represented by reducing the "after" block
to a simple forwarder.
```mlir
%res = scf.while (%arg1 = %init1) : (f32) -> f32 {
/* "Before" region.
* In a "do-while" loop, this region contains the loop body. */
%next = call @payload(%arg1) : (f32) -> f32
/* And also evalutes the condition. */
%condition = call @evaluate_condition(%arg1) : (f32) -> i1
/* Loop through the "after" region. */
scf.condition(%condition) %next : f32
} do {
^bb0(%arg2: f32):
/* "After" region.
* Forwards the values back to "before" region unmodified. */
scf.yield %arg2 : f32
}
```
Note that the types of region arguments need not to match with each other.
The op expects the operand types to match with argument types of the
"before" region"; the result types to match with the trailing operand types
of the terminator of the "before" region, and with the argument types of the
"after" region. The following scheme can be used to share the results of
some operations executed in the "before" region with the "after" region,
avoiding the need to recompute them.
```mlir
%res = scf.while (%arg1 = %init1) : (f32) -> i64 {
/* One can perform some computations, e.g., necessary to evaluate the
* condition, in the "before" region and forward their results to the
* "after" region. */
%shared = call @shared_compute(%arg1) : (f32) -> i64
/* Evalute the condition. */
%condition = call @evaluate_condition(%arg1, %shared) : (f32, i64) -> i1
/* Forward the result of the shared computation to the "after" region.
* The types must match the arguments of the "after" region as well as
* those of the `scf.while` results. */
scf.condition(%condition) %shared : i64
} do {
^bb0(%arg2: i64) {
/* Use the partial result to compute the rest of the payload in the
* "after" region. */
%res = call @payload(%arg2) : (i64) -> f32
/* Forward the new value to the "before" region.
* The operand types must match the types of the `scf.while` operands. */
scf.yield %res : f32
}
```
The custom syntax for this operation is as follows.
```
op ::= `scf.while` assignments `:` function-type region `do` region
`attributes` attribute-dict
initializer ::= /* empty */ | `(` assignment-list `)`
assignment-list ::= assignment | assignment `,` assignment-list
assignment ::= ssa-value `=` ssa-value
```
}];
let arguments = (ins Variadic<AnyType>:$inits);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
let extraClassDeclaration = [{
OperandRange getSuccessorEntryOperands(unsigned index);
}];
}
def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> {
ParentOneOf<["IfOp, ForOp", "ParallelOp",
"WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"scf.yield" yields an SSA value from the SCF dialect op region and
@ -434,4 +580,5 @@ def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
// needed.
let verifier = ?;
}
#endif // MLIR_DIALECT_SCF_SCFOPS

View File

@ -755,11 +755,18 @@ public:
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a list of assignments of the form
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
/// The list must contain at least one entry
virtual ParseResult
parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
SmallVectorImpl<OperandType> &rhs) = 0;
/// (%x1 = %y1, %x2 = %y2, ...)
ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
SmallVectorImpl<OperandType> &rhs) {
OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
if (!result.hasValue())
return emitError(getCurrentLocation(), "expected '('");
return result.getValue();
}
virtual OptionalParseResult
parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
SmallVectorImpl<OperandType> &rhs) = 0;
/// Parse a keyword followed by a type.
ParseResult parseKeywordType(const char *keyword, Type &result) {

View File

@ -140,26 +140,37 @@ static LogicalResult verify(ForOp op) {
return RegionBranchOpInterface::verifyTypes(op);
}
/// Prints the initialization list in the form of
/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
/// where 'inner' values are assumed to be region arguments and 'outer' values
/// are regular SSA values.
static void printInitializationList(OpAsmPrinter &p,
Block::BlockArgListType blocksArgs,
ValueRange initializers,
StringRef prefix = "") {
assert(blocksArgs.size() == initializers.size() &&
"expected same length of arguments and initializers");
if (initializers.empty())
return;
p << prefix << '(';
llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it);
});
p << ")";
}
static void print(OpAsmPrinter &p, ForOp op) {
bool printBlockTerminators = false;
p << op.getOperationName() << " " << op.getInductionVar() << " = "
<< op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
if (op.hasIterOperands()) {
p << " iter_args(";
auto regionArgs = op.getRegionIterArgs();
auto operands = op.getIterOperands();
llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it);
});
p << ")";
p << " -> (" << op.getResultTypes() << ")";
printBlockTerminators = true;
}
printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(),
" iter_args");
if (!op.getIterOperands().empty())
p << " -> (" << op.getIterOperands().getTypes() << ')';
p.printRegion(op.region(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
/*printBlockTerminators=*/op.hasIterOperands());
p.printOptionalAttrDict(op.getAttrs());
}
@ -933,6 +944,158 @@ static LogicalResult verify(ReduceReturnOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 &&
"WhileOp is expected to branch only to the first region");
return inits();
}
void WhileOp::getSuccessorRegions(Optional<unsigned> index,
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
(void)operands;
if (!index.hasValue()) {
regions.emplace_back(&before(), before().getArguments());
return;
}
assert(*index < 2 && "there are only two regions in a WhileOp");
if (*index == 0) {
regions.emplace_back(&after(), after().getArguments());
regions.emplace_back(getResults());
return;
}
regions.emplace_back(&before(), before().getArguments());
}
/// Parses a `while` op.
///
/// op ::= `scf.while` assignments `:` function-type region `do` region
/// `attributes` attribute-dict
/// initializer ::= /* empty */ | `(` assignment-list `)`
/// assignment-list ::= assignment | assignment `,` assignment-list
/// assignment ::= ssa-value `=` ssa-value
static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
Region *before = result.addRegion();
Region *after = result.addRegion();
OptionalParseResult listResult =
parser.parseOptionalAssignmentList(regionArgs, operands);
if (listResult.hasValue() && failed(listResult.getValue()))
return failure();
FunctionType functionType;
llvm::SMLoc typeLoc = parser.getCurrentLocation();
if (failed(parser.parseColonType(functionType)))
return failure();
result.addTypes(functionType.getResults());
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
<< "expected as many input types as operands "
<< "(expected " << operands.size() << " got "
<< functionType.getNumInputs() << ")";
}
// Resolve input operands.
if (failed(parser.resolveOperands(operands, functionType.getInputs(),
parser.getCurrentLocation(),
result.operands)))
return failure();
return failure(
parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
parser.parseKeyword("do") || parser.parseRegion(*after) ||
parser.parseOptionalAttrDictWithKeyword(result.attributes));
}
/// Prints a `while` op.
static void print(OpAsmPrinter &p, scf::WhileOp op) {
p << op.getOperationName();
printInitializationList(p, op.before().front().getArguments(), op.inits(),
" ");
p << " : ";
p.printFunctionalType(op.inits().getTypes(), op.results().getTypes());
p.printRegion(op.before(), /*printEntryBlockArgs=*/false);
p << " do";
p.printRegion(op.after());
p.printOptionalAttrDictWithKeyword(op.getAttrs());
}
/// Verifies that two ranges of types match, i.e. have the same number of
/// entries and that types are pairwise equals. Reports errors on the given
/// operation in case of mismatch.
template <typename OpTy>
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
TypeRange right, StringRef message) {
if (left.size() != right.size())
return op.emitOpError("expects the same number of ") << message;
for (unsigned i = 0, e = left.size(); i < e; ++i) {
if (left[i] != right[i]) {
InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
<< message;
diag.attachNote() << "for argument " << i << ", found " << left[i]
<< " and " << right[i];
return diag;
}
}
return success();
}
/// Verifies that the first block of the given `region` is terminated by a
/// YieldOp. Reports errors on the given operation if it is not the case.
template <typename TerminatorTy>
static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
StringRef errorMessage) {
Operation *terminatorOperation = region.front().getTerminator();
if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
return yield;
auto diag = op.emitOpError(errorMessage);
if (terminatorOperation)
diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
return nullptr;
}
static LogicalResult verify(scf::WhileOp op) {
if (failed(RegionBranchOpInterface::verifyTypes(op)))
return failure();
auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
op, op.before(),
"expects the 'before' region to terminate with 'scf.condition'");
if (!beforeTerminator)
return failure();
TypeRange trailingTerminatorOperands = beforeTerminator.args().getTypes();
if (failed(verifyTypeRangesMatch(op, trailingTerminatorOperands,
op.after().getArgumentTypes(),
"trailing operands of the 'before' block "
"terminator and 'after' region arguments")))
return failure();
if (failed(verifyTypeRangesMatch(
op, trailingTerminatorOperands, op.getResultTypes(),
"trailing operands of the 'before' block terminator and op results")))
return failure();
auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
op, op.after(),
"expects the 'after' region to terminate with 'scf.yield'");
return success(afterTerminator != nullptr);
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//

View File

@ -76,10 +76,13 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
/// Verify that types match along all region control flow edges originating from
/// `sourceNo` (region # if source is a region, llvm::None if source is parent
/// op). `getInputsTypesForRegion` is a function that returns the types of the
/// inputs that flow from `sourceIndex' to the given region.
static LogicalResult verifyTypesAlongAllEdges(
Operation *op, Optional<unsigned> sourceNo,
function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) {
/// inputs that flow from `sourceIndex' to the given region, or llvm::None if
/// the exact type match verification is not necessary (e.g., if the Op verifies
/// the match itself).
static LogicalResult
verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
function_ref<Optional<TypeRange>(Optional<unsigned>)>
getInputsTypesForRegion) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
SmallVector<RegionSuccessor, 2> successors;
@ -113,17 +116,20 @@ static LogicalResult verifyTypesAlongAllEdges(
return diag;
};
TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo);
Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
if (!sourceTypes.hasValue())
continue;
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes.size() != succInputsTypes.size()) {
if (sourceTypes->size() != succInputsTypes.size()) {
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
return printEdgeName(diag) << ": source has " << sourceTypes.size()
return printEdgeName(diag) << ": source has " << sourceTypes->size()
<< " operands, but target successor needs "
<< succInputsTypes.size();
}
for (auto typesIdx :
llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) {
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
Type sourceType = std::get<0>(typesIdx.value());
Type inputType = std::get<1>(typesIdx.value());
if (sourceType != inputType) {
@ -191,10 +197,15 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
<< " operands mismatch between return-like terminators";
}
auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange {
auto inputTypesFromRegion =
[&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
// If there is no return-like terminator, the op itself should verify
// type consistency.
if (!regionReturn)
return llvm::None;
// All successors get the same set of operands.
return regionReturn ? TypeRange(regionReturn->getOperands().getTypes())
: TypeRange();
return TypeRange(regionReturn->getOperands().getTypes());
};
if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))

View File

@ -1480,10 +1480,13 @@ public:
}
/// Parse a list of assignments of the form
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
/// The list must contain at least one entry
ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
SmallVectorImpl<OperandType> &rhs) override {
/// (%x1 = %y1, %x2 = %y2, ...).
OptionalParseResult
parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
SmallVectorImpl<OperandType> &rhs) override {
if (failed(parseOptionalLParen()))
return llvm::None;
auto parseElt = [&]() -> ParseResult {
OperandType regionArg, operand;
if (parseRegionArgument(regionArg) || parseEqual() ||
@ -1493,8 +1496,6 @@ public:
rhs.push_back(operand);
return success();
};
if (parseLParen())
return failure();
return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
}

View File

@ -425,10 +425,88 @@ func @parallel_invalid_yield(
}
// -----
func @yield_invalid_parent_op() {
"my.op"() ({
// expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel'}}
// expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel, scf.while'}}
scf.yield
}) : () -> ()
return
}
// -----
func @while_parser_type_mismatch() {
%true = constant true
// expected-error@+1 {{expected as many input types as operands (expected 0 got 1)}}
scf.while : (i32) -> () {
scf.condition(%true)
} do {
scf.yield
}
}
// -----
func @while_bad_terminator() {
// expected-error@+1 {{expects the 'before' region to terminate with 'scf.condition'}}
scf.while : () -> () {
// expected-note@+1 {{terminator here}}
"some.other_terminator"() : () -> ()
} do {
scf.yield
}
}
// -----
func @while_cross_region_type_mismatch() {
%true = constant true
// expected-error@+1 {{expects the same number of trailing operands of the 'before' block terminator and 'after' region arguments}}
scf.while : () -> () {
scf.condition(%true)
} do {
^bb0(%arg0: i32):
scf.yield
}
}
// -----
func @while_cross_region_type_mismatch() {
%true = constant true
// expected-error@+2 {{expects the same types for trailing operands of the 'before' block terminator and 'after' region arguments}}
// expected-note@+1 {{for argument 0, found 'i1' and 'i32}}
scf.while : () -> () {
scf.condition(%true) %true : i1
} do {
^bb0(%arg0: i32):
scf.yield
}
}
// -----
func @while_result_type_mismatch() {
%true = constant true
// expected-error@+1 {{expects the same number of trailing operands of the 'before' block terminator and op results}}
scf.while : () -> () {
scf.condition(%true) %true : i1
} do {
^bb0(%arg0: i1):
scf.yield
}
}
// -----
func @while_bad_terminator() {
%true = constant true
// expected-error@+1 {{expects the 'after' region to terminate with 'scf.yield'}}
scf.while : () -> () {
scf.condition(%true)
} do {
// expected-note@+1 {{terminator here}}
"some.other_terminator"() : () -> ()
}
}

View File

@ -240,3 +240,42 @@ func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %ste
// CHECK-NEXT: scf.yield %[[IFRES]] : f32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[RESULT]]
// CHECK-LABEL: @while
func @while() {
%0 = "test.get_some_value"() : () -> i32
%1 = "test.get_some_value"() : () -> f32
// CHECK: = scf.while (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, f32) -> (i64, f64) {
%2:2 = scf.while (%arg0 = %0, %arg1 = %1) : (i32, f32) -> (i64, f64) {
%3:2 = "test.some_operation"(%arg0, %arg1) : (i32, f32) -> (i64, f64)
%4 = "test.some_condition"(%arg0, %arg1) : (i32, f32) -> i1
// CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}} : i64, f64
scf.condition(%4) %3#0, %3#1 : i64, f64
// CHECK: } do {
} do {
// CHECK: ^{{.*}}(%{{.*}}: i64, %{{.*}}: f64):
^bb0(%arg2: i64, %arg3: f64):
%5:2 = "test.some_operation"(%arg2, %arg3): (i64, f64) -> (i32, f32)
// CHECK: scf.yield %{{.*}}, %{{.*}} : i32, f32
scf.yield %5#0, %5#1 : i32, f32
// CHECK: attributes {foo = "bar"}
} attributes {foo="bar"}
return
}
// CHECK-LABEL: @infinite_while
func @infinite_while() {
%true = constant true
// CHECK: scf.while : () -> () {
scf.while : () -> () {
// CHECK: scf.condition(%{{.*}})
scf.condition(%true)
// CHECK: } do {
} do {
// CHECK: scf.yield
scf.yield
}
return
}